diff --git a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts index 3aca957d..c1925e84 100644 --- a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts +++ b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts @@ -17,25 +17,14 @@ limitations under the License. import {isBetterThan, IncomingRoomKey} from "./RoomKey"; import {BaseLRUCache} from "../../../../utils/LRUCache"; import type {RoomKey} from "./RoomKey"; +import type * as OlmNamespace from "@matrix-org/olm"; +type Olm = typeof OlmNamespace; export declare class OlmDecryptionResult { readonly plaintext: string; readonly message_index: number; } -export declare class OlmInboundGroupSession { - constructor(); - free(): void; - pickle(key: string | Uint8Array): string; - unpickle(key: string | Uint8Array, pickle: string); - create(session_key: string): string; - import_session(session_key: string): string; - decrypt(message: string): OlmDecryptionResult; - session_id(): string; - first_known_index(): number; - export_session(message_index: number): string; -} - /* Because Olm only has very limited memory available when compiled to wasm, we limit the amount of sessions held in memory. @@ -43,11 +32,11 @@ we limit the amount of sessions held in memory. export class KeyLoader extends BaseLRUCache { private pickleKey: string; - private olm: any; + private olm: Olm; private resolveUnusedOperation?: () => void; private operationBecomesUnusedPromise?: Promise; - constructor(olm: any, pickleKey: string, limit: number) { + constructor(olm: Olm, pickleKey: string, limit: number) { super(limit); this.pickleKey = pickleKey; this.olm = olm; @@ -60,7 +49,7 @@ export class KeyLoader extends BaseLRUCache { } } - async useKey(key: RoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise | T): Promise { + async useKey(key: RoomKey, callback: (session: Olm.InboundGroupSession, pickleKey: string) => Promise | T): Promise { const keyOp = await this.allocateOperation(key); try { return await callback(keyOp.session, this.pickleKey); @@ -186,11 +175,11 @@ export class KeyLoader extends BaseLRUCache { } class KeyOperation { - session: OlmInboundGroupSession; + session: Olm.InboundGroupSession; key: RoomKey; refCount: number; - constructor(key: RoomKey, session: OlmInboundGroupSession) { + constructor(key: RoomKey, session: Olm.InboundGroupSession) { this.key = key; this.session = session; this.refCount = 1; @@ -248,7 +237,7 @@ export function tests() { get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; } get serializationType(): string { return "type"; } get eventIds(): string[] | undefined { return undefined; } - loadInto(session: OlmInboundGroupSession) { + loadInto(session: Olm.InboundGroupSession) { const mockSession = session as MockInboundSession; mockSession.sessionId = this.sessionId; mockSession.firstKnownIndex = this._firstKnownIndex; @@ -284,7 +273,7 @@ export function tests() { return { "load key gives correct session": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2); let callback1Called = false; let callback2Called = false; const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { @@ -305,7 +294,7 @@ export function tests() { assert(callback2Called); }, "keys with different first index are kept separate": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2); let callback1Called = false; let callback2Called = false; const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { @@ -326,7 +315,7 @@ export function tests() { assert(callback2Called); }, "useKey blocks as long as no free sessions are available": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 1); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 1); let resolve; let callbackCalled = false; loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { @@ -343,7 +332,7 @@ export function tests() { assert.equal(callbackCalled, true); }, "cache hit while key in use, then replace (check refCount works properly)": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 1); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 1); let resolve1, resolve2; const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); const p1 = loader.useKey(key1, async session => { @@ -371,7 +360,7 @@ export function tests() { assert.equal(callbackCalled, true); }, "cache hit while key not in use": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2); let resolve1, resolve2, invocations = 0; const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); await loader.useKey(key1, async session => { invocations += 1; }); @@ -385,7 +374,7 @@ export function tests() { }, "dispose calls free on all sessions": async assert => { instances = 0; - const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2); await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {}); await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {}); assert.equal(instances, 2); @@ -395,7 +384,7 @@ export function tests() { assert.strictEqual(loader.size, 0, "loader.size"); }, "checkBetterThanKeyInStorage false with cache": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); await loader.useKey(key1, async session => {}); // fake we've checked with storage that this is the best key, @@ -409,7 +398,7 @@ export function tests() { assert.strictEqual(key2.isBetter, false); }, "checkBetterThanKeyInStorage true with cache": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key await loader.useKey(key1, async session => {}); @@ -420,7 +409,7 @@ export function tests() { assert.strictEqual(key2.isBetter, true); }, "prefer to remove worst key for a session from cache": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); await loader.useKey(key1, async session => {}); key1.isBetter = true; // set to true just so it gets returned from getCachedKey diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.ts b/src/matrix/e2ee/megolm/decryption/RoomKey.ts index 81f1a9be..2cb65b33 100644 --- a/src/matrix/e2ee/megolm/decryption/RoomKey.ts +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.ts @@ -17,7 +17,9 @@ limitations under the License. import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore"; import type {Transaction} from "../../../storage/idb/Transaction"; import type {DecryptionResult} from "../../DecryptionResult"; -import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; +import type {KeyLoader} from "./KeyLoader"; +import type * as OlmNamespace from "@matrix-org/olm"; +type Olm = typeof OlmNamespace; export abstract class RoomKey { private _isBetter: boolean | undefined; @@ -33,7 +35,7 @@ export abstract class RoomKey { abstract get serializationKey(): string; abstract get serializationType(): string; abstract get eventIds(): string[] | undefined; - abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; + abstract loadInto(session: Olm.InboundGroupSession, pickleKey: string): void; /* Whether the key has been checked against storage (or is from storage) * to be the better key for a given session. Given that all keys are checked to be better * as part of writing, we can trust that when this returns true, it really is the best key @@ -44,7 +46,7 @@ export abstract class RoomKey { set isBetter(value: boolean | undefined) { this._isBetter = value; } } -export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) { +export function isBetterThan(newSession: Olm.InboundGroupSession, existingSession: Olm.InboundGroupSession) { return newSession.first_known_index() < existingSession.first_known_index(); } @@ -87,7 +89,7 @@ export abstract class IncomingRoomKey extends RoomKey { get eventIds() { return this._eventIds; } - private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise { + private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: Olm.InboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise { if (this.isBetter !== undefined) { return this.isBetter; } diff --git a/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts b/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts index 7e466806..f56feb47 100644 --- a/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts +++ b/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts @@ -17,7 +17,7 @@ limitations under the License. import {DecryptionResult} from "../../DecryptionResult.js"; import {DecryptionError} from "../../common.js"; import {ReplayDetectionEntry} from "./ReplayDetectionEntry"; -import type {RoomKey} from "./RoomKey.js"; +import type {RoomKey} from "./RoomKey"; import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader"; import type {OlmWorker} from "../../OlmWorker"; import type {TimelineEvent} from "../../../storage/types"; @@ -61,7 +61,7 @@ export class SessionDecryption { this.decryptionRequests!.push(request); decryptionResult = await request.response(); } else { - decryptionResult = session.decrypt(ciphertext); + decryptionResult = session.decrypt(ciphertext) as OlmDecryptionResult; } const {plaintext} = decryptionResult!; let payload;