diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.ts b/src/matrix/e2ee/megolm/decryption/RoomKey.ts index 8815a4e6..ed799759 100644 --- a/src/matrix/e2ee/megolm/decryption/RoomKey.ts +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.ts @@ -41,47 +41,38 @@ export interface IRoomKey { } export interface IIncomingRoomKey extends IRoomKey { - copyEventIds(value: string[]): void; + get isBetter(): boolean | undefined; + checkIsBetterThanStorage(keyDeserialization: KeyDeserialization, txn: Transaction): Promise; + write(keyDeserialization: KeyDeserialization, txn: Transaction): Promise; } -export async function checkBetterKeyInStorage(key: IIncomingRoomKey, keyDeserialization: KeyDeserialization, txn: Transaction) { - let existingKey = keyDeserialization.cache.get(key.roomId, key.senderKey, key.sessionId); - if (!existingKey) { - const storageKey = await fromStorage(key.roomId, key.senderKey, key.sessionId, txn); - // store the event ids that can be decrypted with this key - // before we overwrite them if called from `write`. - if (storageKey) { - if (storageKey.eventIds) { - key.copyEventIds(storageKey.eventIds); - } - if (storageKey.hasSession) { - existingKey = storageKey; - } +abstract class BaseIncomingRoomKey implements IIncomingRoomKey { + private _eventIds?: string[]; + private _isBetter?: boolean; + + checkBetterKeyInStorage(keyDeserialization: KeyDeserialization, txn: Transaction): Promise { + return this._checkBetterKeyInStorage(keyDeserialization, undefined, txn); + } + + async write(keyDeserialization: KeyDeserialization, pickleKey: string, txn: Transaction): Promise { + // we checked already and we had a better session in storage, so don't write + let pickledSession; + if (this._isBetter === undefined) { + // if this key wasn't used to decrypt any messages in the same sync, + // we haven't checked if this is the best key yet, + // so do that now to not overwrite a better key. + // while we have the key deserialized, also pickle it to store it later on here. + await this._checkBetterKeyInStorage(keyDeserialization, session => { + pickledSession = session.pickle(pickleKey); + }, txn); + } + if (this._isBetter === false) { + return false; } - } - if (existingKey) { - const isBetter = await keyDeserialization.useKey(key, newSession => { - return keyDeserialization.useKey(existingKey, existingSession => { - return newSession.first_known_index() < existingSession.first_known_index(); - }); - }); - return isBetter ? key : existingKey; - } else { - return key; - } -} - -async function write(olm, pickleKey, keyDeserialization, txn) { - // we checked already and we had a better session in storage, so don't write - if (this._isBetter === false) { - return false; - } - if (!this._sessionInfo) { - await this.createSessionInfo(olm, pickleKey, txn); - } - if (this._sessionInfo) { // before calling write in parallel, we need to check keyDeserialization.running is false so we are sure our transaction will not be closed - const pickledSession = await keyDeserialization.useKey(this, session => session.pickle(pickleKey)); + if (!pickledSession) { + pickledSession = await keyDeserialization.useKey(this, session => session.pickle(pickleKey)); + } const sessionEntry = { roomId: this.roomId, senderKey: this.senderKey, @@ -90,19 +81,44 @@ async function write(olm, pickleKey, keyDeserialization, txn) { claimedKeys: this._sessionInfo.claimedKeys, }; txn.inboundGroupSessions.set(sessionEntry); - this.dispose(); return true; } - return false; -} - -class BaseIncomingRoomKey { - private _eventIds?: string[]; get eventIds() { return this._eventIds; } + get isBetter() { return this._isBetter; } - copyEventIds(eventIds: string[]): void { - this._eventIds = eventIds; + private async _checkBetterKeyInStorage(keyDeserialization: KeyDeserialization, callback?: (session: OlmInboundGroupSession) => void, txn: Transaction): Promise { + if (this._isBetter !== undefined) { + return this._isBetter; + } + let existingKey = keyDeserialization.cache.get(this.roomId, this.senderKey, this.sessionId); + if (!existingKey) { + const storageKey = await fromStorage(this.roomId, this.senderKey, this.sessionId, txn); + // store the event ids that can be decrypted with this key + // before we overwrite them if called from `write`. + if (storageKey) { + if (storageKey.hasSession) { + existingKey = storageKey; + } else if (storageKey.eventIds) { + this._eventIds = storageKey.eventIds; + } + } + } + if (existingKey) { + this._isBetter = await keyDeserialization.useKey(key, newSession => { + return keyDeserialization.useKey(existingKey, existingSession => { + const isBetter = newSession.first_known_index() < existingSession.first_known_index(); + if (isBetter && callback) { + callback(newSession); + } + return isBetter; + }); + }); + } else { + // no previous key, so we're the best \o/ + this._isBetter = true; + } + return this._isBetter; } }