From ab2f15b5a2364ab110aabce1f52481d63c69f34f Mon Sep 17 00:00:00 2001 From: Bruno Windels Date: Mon, 25 Oct 2021 19:17:13 +0200 Subject: [PATCH] prevent cache hiding better keys in storage (+ tests) --- src/matrix/e2ee/megolm/Decryption.ts | 14 +-- .../e2ee/megolm/decryption/KeyLoader.ts | 103 +++++++++++++----- src/matrix/e2ee/megolm/decryption/RoomKey.ts | 80 +++++++------- .../megolm/decryption/SessionDecryption.ts | 6 +- 4 files changed, 125 insertions(+), 78 deletions(-) diff --git a/src/matrix/e2ee/megolm/Decryption.ts b/src/matrix/e2ee/megolm/Decryption.ts index 7ab0e15a..cc577f3d 100644 --- a/src/matrix/e2ee/megolm/Decryption.ts +++ b/src/matrix/e2ee/megolm/Decryption.ts @@ -20,7 +20,7 @@ import {SessionDecryption} from "./decryption/SessionDecryption"; import {MEGOLM_ALGORITHM} from "../common.js"; import {validateEvent, groupEventsBySession} from "./decryption/utils"; import {keyFromStorage, keyFromDeviceMessage, keyFromBackup} from "./decryption/RoomKey"; -import type {IRoomKey, IIncomingRoomKey} from "./decryption/RoomKey"; +import type {RoomKey, IncomingRoomKey} from "./decryption/RoomKey"; import type {KeyLoader} from "./decryption/KeyLoader"; import type {OlmWorker} from "../OlmWorker"; import type {Transaction} from "../../storage/idb/Transaction"; @@ -78,7 +78,7 @@ export class Decryption { * @param {[type]} txn [description] * @return {DecryptionPreparation} */ - async prepareDecryptAll(roomId: string, events: TimelineEvent[], newKeys: IIncomingRoomKey[] | undefined, txn: Transaction) { + async prepareDecryptAll(roomId: string, events: TimelineEvent[], newKeys: IncomingRoomKey[] | undefined, txn: Transaction) { const errors = new Map(); const validEvents: TimelineEvent[] = []; @@ -107,7 +107,7 @@ export class Decryption { return new DecryptionPreparation(roomId, sessionDecryptions, errors); } - private async getRoomKey(roomId: string, senderKey: string, sessionId: string, newKeys: IIncomingRoomKey[] | undefined, txn: Transaction): Promise { + private async getRoomKey(roomId: string, senderKey: string, sessionId: string, newKeys: IncomingRoomKey[] | undefined, txn: Transaction): Promise { if (newKeys) { const key = newKeys.find(k => k.roomId === roomId && k.senderKey === senderKey && k.sessionId === sessionId); if (key && await key.checkBetterThanKeyInStorage(this.keyLoader, txn)) { @@ -128,7 +128,7 @@ export class Decryption { /** * Writes the key as an inbound group session if there is not already a better key in the store */ - writeRoomKey(key: IIncomingRoomKey, txn: Transaction): Promise { + writeRoomKey(key: IncomingRoomKey, txn: Transaction): Promise { return key.write(this.keyLoader, txn); } @@ -136,8 +136,8 @@ export class Decryption { * Extracts room keys from decrypted device messages. * The key won't be persisted yet, you need to call RoomKey.write for that. */ - roomKeysFromDeviceMessages(decryptionResults: DecryptionResult[], log: LogItem): IIncomingRoomKey[] { - const keys: IIncomingRoomKey[] = []; + roomKeysFromDeviceMessages(decryptionResults: DecryptionResult[], log: LogItem): IncomingRoomKey[] { + const keys: IncomingRoomKey[] = []; for (const dr of decryptionResults) { if (dr.event?.type !== "m.room_key" || dr.event.content?.algorithm !== MEGOLM_ALGORITHM) { continue; @@ -157,7 +157,7 @@ export class Decryption { return keys; } - roomKeyFromBackup(roomId: string, sessionId: string, sessionInfo: string): IIncomingRoomKey | undefined { + roomKeyFromBackup(roomId: string, sessionId: string, sessionInfo: string): IncomingRoomKey | undefined { return keyFromBackup(roomId, sessionId, sessionInfo); } diff --git a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts index cf276608..58f968c8 100644 --- a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts +++ b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -import {IRoomKey, isBetterThan} from "./RoomKey"; +import {isBetterThan, IncomingRoomKey} from "./RoomKey"; import {BaseLRUCache} from "../../../../utils/LRUCache"; - +import type {RoomKey} from "./RoomKey"; export declare class OlmDecryptionResult { readonly plaintext: string; @@ -53,14 +53,14 @@ export class KeyLoader extends BaseLRUCache { this.olm = olm; } - getCachedKey(roomId: string, senderKey: string, sessionId: string): IRoomKey | undefined { - const idx = this.findIndexBestForSession(roomId, senderKey, sessionId); + getCachedKey(roomId: string, senderKey: string, sessionId: string): RoomKey | undefined { + const idx = this.findCachedKeyIndex(roomId, senderKey, sessionId); if (idx !== -1) { return this._getByIndexAndMoveUp(idx)!.key; } } - async useKey(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise | T): Promise { + async useKey(key: RoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise | T): Promise { const keyOp = await this.allocateOperation(key); try { return await callback(keyOp.session, this.pickleKey); @@ -81,7 +81,7 @@ export class KeyLoader extends BaseLRUCache { this._entries.splice(0, this._entries.length); } - private async allocateOperation(key: IRoomKey): Promise { + private async allocateOperation(key: RoomKey): Promise { let idx; while((idx = this.findIndexForAllocation(key)) === -1) { await this.operationBecomesUnused(); @@ -127,14 +127,14 @@ export class KeyLoader extends BaseLRUCache { return this.operationBecomesUnusedPromise; } - private findIndexForAllocation(key: IRoomKey) { + private findIndexForAllocation(key: RoomKey) { let idx = this.findIndexSameKey(key); // cache hit if (idx === -1) { - idx = this.findIndexSameSessionUnused(key); - if (idx === -1) { - if (this.size < this.limit) { - idx = this.size; - } else { + if (this.size < this.limit) { + idx = this.size; + } else { + idx = this.findIndexSameSessionUnused(key); + if (idx === -1) { idx = this.findIndexOldestUnused(); } } @@ -142,10 +142,11 @@ export class KeyLoader extends BaseLRUCache { return idx; } - private findIndexBestForSession(roomId: string, senderKey: string, sessionId: string): number { + private findCachedKeyIndex(roomId: string, senderKey: string, sessionId: string): number { return this._entries.reduce((bestIdx, op, i, arr) => { const bestOp = bestIdx === -1 ? undefined : arr[bestIdx]; - if (op.isForSameSession(roomId, senderKey, sessionId)) { + // only operations that are the "best" for their session can be used, see comment on isBest + if (op.isBest === true && op.isForSameSession(roomId, senderKey, sessionId)) { if (!bestOp || op.isBetter(bestOp)) { return i; } @@ -154,20 +155,23 @@ export class KeyLoader extends BaseLRUCache { }, -1); } - private findIndexSameKey(key: IRoomKey): number { + private findIndexSameKey(key: RoomKey): number { return this._entries.findIndex(op => { return op.isForSameSession(key.roomId, key.senderKey, key.sessionId) && op.isForKey(key); }); } - private findIndexSameSessionUnused(key: IRoomKey): number { - for (let i = this._entries.length - 1; i >= 0; i -= 1) { - const op = this._entries[i]; + private findIndexSameSessionUnused(key: RoomKey): number { + return this._entries.reduce((worstIdx, op, i, arr) => { + const worst = worstIdx === -1 ? undefined : arr[worstIdx]; + // we try to pick the worst operation to overwrite, so the best one stays in the cache if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) { - return i; + if (!worst || !op.isBetter(worst)) { + return i; + } } - } - return -1; + return worstIdx; + }, -1); } private findIndexOldestUnused(): number { @@ -183,10 +187,10 @@ export class KeyLoader extends BaseLRUCache { class KeyOperation { session: OlmInboundGroupSession; - key: IRoomKey; + key: RoomKey; refCount: number; - constructor(key: IRoomKey, session: OlmInboundGroupSession) { + constructor(key: RoomKey, session: OlmInboundGroupSession) { this.key = key; this.session = session; this.refCount = 1; @@ -201,7 +205,7 @@ class KeyOperation { return isBetterThan(this.session, other.session); } - isForKey(key: IRoomKey) { + isForKey(key: RoomKey) { return this.key.serializationKey === key.serializationKey && this.key.serializationType === key.serializationType; } @@ -209,18 +213,27 @@ class KeyOperation { dispose() { this.session.free(); } + + /** returns whether the key for this operation has been checked at some point against storage + * and was determined to be the better key, undefined if it hasn't been checked yet. + * Only keys that are the best keys can be returned by getCachedKey as returning a cache hit + * will usually not check for a better session in storage. Also see RoomKey.isBetter. */ + get isBest(): boolean | undefined { + return this.key.isBetter; + } } export function tests() { let instances = 0; - class MockRoomKey implements IRoomKey { + class MockRoomKey extends IncomingRoomKey { private _roomId: string; private _senderKey: string; private _sessionId: string; private _firstKnownIndex: number; constructor(roomId: string, senderKey: string, sessionId: string, firstKnownIndex: number) { + super(); this._roomId = roomId; this._senderKey = senderKey; this._sessionId = sessionId; @@ -267,7 +280,6 @@ export function tests() { const bobSenderKey = "def"; const sessionId1 = "s123"; const sessionId2 = "s456"; - const sessionId3 = "s789"; return { "load key gives correct session": async assert => { @@ -362,6 +374,7 @@ export function tests() { let resolve1, resolve2, invocations = 0; const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); await loader.useKey(key1, async session => { invocations += 1; }); + key1.isBetter = true; assert.equal(loader.size, 1); const cachedKey = loader.getCachedKey(roomId, aliceSenderKey, sessionId1)!; assert.equal(cachedKey, key1); @@ -379,6 +392,42 @@ export function tests() { loader.dispose(); assert.strictEqual(instances, 0, "instances"); assert.strictEqual(loader.size, 0, "loader.size"); - } + }, + "checkBetterThanKeyInStorage false with cache": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); + await loader.useKey(key1, async session => {}); + // fake we've checked with storage that this is the best key, + // and as long is it remains the best key with newly added keys, + // it will be returned from getCachedKey (as called from checkBetterThanKeyInStorage) + key1.isBetter = true; + const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3); + // this will hit cache of key 1 so we pass in null as txn + const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any); + assert.strictEqual(isBetter, false); + assert.strictEqual(key2.isBetter, false); + }, + "checkBetterThanKeyInStorage true with cache": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); + key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key + await loader.useKey(key1, async session => {}); + const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); + // this will hit cache of key 1 so we pass in null as txn + const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any); + assert.strictEqual(isBetter, true); + assert.strictEqual(key2.isBetter, true); + }, + "prefer to remove worst key for a session from cache": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); + await loader.useKey(key1, async session => {}); + key1.isBetter = true; // set to true just so it gets returned from getCachedKey + const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 4); + await loader.useKey(key2, async session => {}); + const key3 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3); + await loader.useKey(key3, async session => {}); + assert.strictEqual(loader.getCachedKey(roomId, aliceSenderKey, sessionId1), key1); + }, } } diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.ts b/src/matrix/e2ee/megolm/decryption/RoomKey.ts index 48af45a7..ad1cd8a0 100644 --- a/src/matrix/e2ee/megolm/decryption/RoomKey.ts +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.ts @@ -19,30 +19,33 @@ import type {Transaction} from "../../../storage/idb/Transaction"; import type {DecryptionResult} from "../../DecryptionResult"; import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; -export interface IRoomKey { - get roomId(): string; - get senderKey(): string; - get sessionId(): string; - get claimedEd25519Key(): string; - get serializationKey(): string; - get serializationType(): string; - get eventIds(): string[] | undefined; - loadInto(session: OlmInboundGroupSession, pickleKey: string): void; +export abstract class RoomKey { + private _isBetter: boolean | undefined; + + abstract get roomId(): string; + abstract get senderKey(): string; + abstract get sessionId(): string; + abstract get claimedEd25519Key(): string; + abstract get serializationKey(): string; + abstract get serializationType(): string; + abstract get eventIds(): string[] | undefined; + abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; + /* Whether the key has been checked against storage (or is from storage) + * to be the better key for a given session. Given that all keys are checked to be better + * as part of writing, we can trust that when this returns true, it really is the best key + * available between storage and cached keys in memory. This is why keys with this field set to + * true are used by the key loader to return cached keys. Also see KeyOperation.isBest there. */ + get isBetter(): boolean | undefined { return this._isBetter; } + // should only be set in key.checkBetterThanKeyInStorage + set isBetter(value: boolean | undefined) { this._isBetter = value; } } export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) { return newSession.first_known_index() < existingSession.first_known_index(); } -export interface IIncomingRoomKey extends IRoomKey { - get isBetter(): boolean | undefined; - checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise; - write(loader: KeyLoader, txn: Transaction): Promise; -} - -abstract class BaseIncomingRoomKey implements IIncomingRoomKey { +export abstract class IncomingRoomKey extends RoomKey { private _eventIds?: string[]; - private _isBetter?: boolean; checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise { return this._checkBetterThanKeyInStorage(loader, undefined, txn); @@ -51,7 +54,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { async write(loader: KeyLoader, txn: Transaction): Promise { // we checked already and we had a better session in storage, so don't write let pickledSession; - if (this._isBetter === undefined) { + if (this.isBetter === undefined) { // if this key wasn't used to decrypt any messages in the same sync, // we haven't checked if this is the best key yet, // so do that now to not overwrite a better key. @@ -60,7 +63,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { pickledSession = session.pickle(pickleKey); }, txn); } - if (this._isBetter === false) { + if (this.isBetter === false) { return false; } // before calling write in parallel, we need to check loader.running is false so we are sure our transaction will not be closed @@ -79,11 +82,10 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { } get eventIds() { return this._eventIds; } - get isBetter() { return this._isBetter; } private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise { - if (this._isBetter !== undefined) { - return this._isBetter; + if (this.isBetter !== undefined) { + return this.isBetter; } let existingKey = loader.getCachedKey(this.roomId, this.senderKey, this.sessionId); if (!existingKey) { @@ -100,32 +102,26 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { } if (existingKey) { const key = existingKey; - this._isBetter = await loader.useKey(this, newSession => { - return loader.useKey(key, (existingSession, pickleKey) => { - const isBetter = isBetterThan(newSession, existingSession); - if (isBetter && callback) { + await loader.useKey(this, async newSession => { + await loader.useKey(key, (existingSession, pickleKey) => { + // set isBetter as soon as possible, on both keys compared, + // as it is is used to determine whether a key can be used for the cache + this.isBetter = isBetterThan(newSession, existingSession); + key.isBetter = !this.isBetter; + if (this.isBetter && callback) { callback(newSession, pickleKey); } - return isBetter; }); }); } else { // no previous key, so we're the best \o/ - this._isBetter = true; + this.isBetter = true; } - return this._isBetter!; + return this.isBetter!; } - - abstract get roomId(): string; - abstract get senderKey(): string; - abstract get sessionId(): string; - abstract get claimedEd25519Key(): string; - abstract get serializationKey(): string; - abstract get serializationType(): string; - abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; } -class DeviceMessageRoomKey extends BaseIncomingRoomKey { +class DeviceMessageRoomKey extends IncomingRoomKey { private _decryptionResult: DecryptionResult; constructor(decryptionResult: DecryptionResult) { @@ -145,7 +141,7 @@ class DeviceMessageRoomKey extends BaseIncomingRoomKey { } } -class BackupRoomKey extends BaseIncomingRoomKey { +class BackupRoomKey extends IncomingRoomKey { private _roomId: string; private _sessionId: string; private _backupInfo: string; @@ -169,10 +165,12 @@ class BackupRoomKey extends BaseIncomingRoomKey { } } -class StoredRoomKey implements IRoomKey { +class StoredRoomKey extends RoomKey { private storageEntry: InboundGroupSessionEntry; constructor(storageEntry: InboundGroupSessionEntry) { + super(); + this.isBetter = true; // usually the key in storage is the best until checks prove otherwise this.storageEntry = storageEntry; } @@ -192,7 +190,7 @@ class StoredRoomKey implements IRoomKey { // sessions are stored before they are received // to keep track of events that need it to be decrypted. // This is used to retry decryption of those events once the session is received. - return !!this.storageEntry.session; + return !!this.serializationKey; } } diff --git a/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts b/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts index 3adf5bdb..7e466806 100644 --- a/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts +++ b/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts @@ -17,7 +17,7 @@ limitations under the License. import {DecryptionResult} from "../../DecryptionResult.js"; import {DecryptionError} from "../../common.js"; import {ReplayDetectionEntry} from "./ReplayDetectionEntry"; -import type {IRoomKey} from "./RoomKey.js"; +import type {RoomKey} from "./RoomKey.js"; import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader"; import type {OlmWorker} from "../../OlmWorker"; import type {TimelineEvent} from "../../../storage/types"; @@ -31,13 +31,13 @@ interface DecryptAllResult { * Does the actual decryption of all events for a given megolm session in a batch */ export class SessionDecryption { - private key: IRoomKey; + private key: RoomKey; private events: TimelineEvent[]; private keyLoader: KeyLoader; private olmWorker?: OlmWorker; private decryptionRequests?: any[]; - constructor(key: IRoomKey, events: TimelineEvent[], olmWorker: OlmWorker | undefined, keyLoader: KeyLoader) { + constructor(key: RoomKey, events: TimelineEvent[], olmWorker: OlmWorker | undefined, keyLoader: KeyLoader) { this.key = key; this.events = events; this.olmWorker = olmWorker;