From 720585b8f243f2108d7c0622240a5e58cbf3ff3d Mon Sep 17 00:00:00 2001 From: RMidhunSuresh Date: Mon, 13 Mar 2023 21:17:22 +0530 Subject: [PATCH] Write unit tests --- src/matrix/verification/CrossSigning.ts | 3 +- .../verification/SAS/SASVerification.ts | 455 +++++++++++++++++- .../verification/SAS/channel/MockChannel.ts | 133 +++++ .../SAS/stages/BaseSASVerificationStage.ts | 6 +- .../stages/SelectVerificationMethodStage.ts | 4 +- .../SAS/stages/SendAcceptVerificationStage.ts | 1 - .../verification/SAS/stages/VerifyMacStage.ts | 7 +- src/matrix/verification/SAS/types.ts | 20 + 8 files changed, 608 insertions(+), 21 deletions(-) create mode 100644 src/matrix/verification/SAS/channel/MockChannel.ts create mode 100644 src/matrix/verification/SAS/types.ts diff --git a/src/matrix/verification/CrossSigning.ts b/src/matrix/verification/CrossSigning.ts index 2e4ebd26..fcb0e1c7 100644 --- a/src/matrix/verification/CrossSigning.ts +++ b/src/matrix/verification/CrossSigning.ts @@ -21,7 +21,6 @@ import type {DeviceTracker} from "../e2ee/DeviceTracker"; import type * as OlmNamespace from "@matrix-org/olm"; import type {HomeServerApi} from "../net/HomeServerApi"; import type {Account} from "../e2ee/Account"; -import type {Room} from "../room/Room.js"; import { ILogItem } from "../../lib"; import {pkSign} from "./common"; import type {ISignatures} from "./common"; @@ -166,7 +165,7 @@ export class CrossSigning { e2eeAccount: this.e2eeAccount, deviceTracker: this.deviceTracker, hsApi: this.hsApi, - platform: this.platform, + clock: this.platform.clock, }); return this.sasVerificationInProgress; } diff --git a/src/matrix/verification/SAS/SASVerification.ts b/src/matrix/verification/SAS/SASVerification.ts index cab57283..ffec1f90 100644 --- a/src/matrix/verification/SAS/SASVerification.ts +++ b/src/matrix/verification/SAS/SASVerification.ts @@ -26,7 +26,9 @@ import {SendReadyStage} from "./stages/SendReadyStage"; import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage"; import {VerificationCancelledError} from "./VerificationCancelledError"; import {Timeout} from "../../../platform/types/types"; -import {Platform} from "../../../platform/web/Platform.js"; +import {Clock} from "../../../platform/web/dom/Clock.js"; +import {EventEmitter} from "../../../utils/EventEmitter"; +import {SASProgressEvents} from "./types"; type Olm = typeof OlmNamespace; @@ -40,7 +42,7 @@ type Options = { e2eeAccount: Account; deviceTracker: DeviceTracker; hsApi: HomeServerApi; - platform: Platform; + clock: Clock; } export class SASVerification { @@ -49,19 +51,20 @@ export class SASVerification { public finished: boolean = false; public readonly channel: IChannel; private readonly timeout: Timeout; + public readonly eventEmitter: EventEmitter = new EventEmitter(); constructor(options: Options) { - const { olm, channel, platform } = options; + const { olm, channel, clock } = options; const olmSas = new olm.SAS(); this.olmSas = olmSas; this.channel = channel; - this.timeout = platform.clock.createTimeout(10 * 60 * 1000); + this.timeout = clock.createTimeout(10 * 60 * 1000); this.timeout.elapsed().then(() => { - // Cancel verification after 10 minutes - // todo: catch error here? - channel.cancelVerification(CancelTypes.TimedOut); - }); - const stageOptions = {...options, olmSas}; + return channel.cancelVerification(CancelTypes.TimedOut); + }).catch(() => { + // todo: why do we do nothing here? + }); + const stageOptions = {...options, olmSas, eventEmitter: this.eventEmitter}; if (channel.receivedMessages.get(VerificationEventTypes.Start)) { this.startStage = new SelectVerificationMethodStage(stageOptions); } @@ -78,7 +81,7 @@ export class SASVerification { try { let stage = this.startStage; do { - console.log("Running next stage"); + console.log("Running stage", stage.constructor.name); await stage.completeStage(); stage = stage.nextStage; } while (stage); @@ -87,12 +90,440 @@ export class SASVerification { if (!(e instanceof VerificationCancelledError)) { throw e; } - console.log("Caught error in start()"); } finally { this.olmSas.free(); - this.finished = true; this.timeout.abort(); + this.finished = true; } } } + +import {HomeServer} from "../../../mocks/HomeServer.js"; +import Olm from "@matrix-org/olm/olm.js"; +import {MockChannel} from "./channel/MockChannel"; +import {Clock as MockClock} from "../../../mocks/Clock.js"; +import {NullLogger} from "../../../logging/NullLogger"; +import {SASFixtures} from "../../../fixtures/matrix/sas/events"; +import {SendKeyStage} from "./stages/SendKeyStage"; +import {CalculateSASStage} from "./stages/CalculateSASStage"; +import {SendMacStage} from "./stages/SendMacStage"; +import {VerifyMacStage} from "./stages/VerifyMacStage"; +import {SendDoneStage} from "./stages/SendDoneStage"; +import {SendAcceptVerificationStage} from "./stages/SendAcceptVerificationStage"; + +export function tests() { + + async function createSASRequest( + ourUserId: string, + ourDeviceId: string, + theirUserId: string, + theirDeviceId: string, + txnId: string, + receivedMessages, + startingMessage?: any + ) { + const homeserverMock = new HomeServer(); + const hsApi = homeserverMock.api; + const olm = Olm; + await olm.init(); + const olmUtil = new Olm.Utility(); + const e2eeAccount = { + getDeviceKeysToSignWithCrossSigning: () => { + return { + keys: { + [`ed25519:${ourDeviceId}`]: + "srsWWbrnQFIOmUSdrt3cS/unm03qAIgXcWwQg9BegKs", + }, + }; + }, + }; + const deviceTracker = { + getCrossSigningKeysForUser: (userId, _hsApi, _) => { + let masterKey = + userId === ourUserId + ? "5HIrEawRiiQioViNfezPDWfPWH2pdaw3pbQNHEVN2jM" + : "Ot8Y58PueQ7hJVpYWAJkg2qaREJAY/UhGZYOrsd52oo"; + return { masterKey }; + }, + deviceForId: (_userId, _deviceId, _hsApi, _log) => { + return { + ed25519Key: "D8w9mrokGdEZPdPgrU0kQkYi4vZyzKEBfvGyZsGK7+Q", + }; + }, + }; + const channel = new MockChannel( + theirDeviceId, + theirUserId, + ourDeviceId, + ourUserId, + receivedMessages, + deviceTracker, + txnId, + olm, + startingMessage, + ); + const clock = new MockClock(); + const logger = new NullLogger(); + return logger.run("log", (log) => { + // @ts-ignore + const sas = new SASVerification({ + channel, + clock, + hsApi, + deviceTracker, + e2eeAccount, + olm, + olmUtil, + otherUserId: theirUserId!, + ourUser: { deviceId: ourDeviceId!, userId: ourUserId! }, + log, + }); + // @ts-ignore + channel.setOlmSas(sas.olmSas); + return { sas, clock, logger }; + }); + } + + return { + "Order of stages created matches expected order when I sent request, they sent start": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount:matrix.org"; + const theirUserId = "@foobaraccount3:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .youSentRequest() + .theySentStart() + .fixtures(); + const { sas } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages + ); + await sas.start(); + const expectedOrder = [ + RequestVerificationStage, + SelectVerificationMethodStage, + SendAcceptVerificationStage, + SendKeyStage, + CalculateSASStage, + SendMacStage, + VerifyMacStage, + SendDoneStage + ] + //@ts-ignore + let stage = sas.startStage; + for (const stageClass of expectedOrder) { + assert.strictEqual(stage instanceof stageClass, true); + stage = stage.nextStage; + } + assert.strictEqual(sas.finished, true); + }, + "Order of stages created matches expected order when I sent request, I sent start": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount:matrix.org"; + const theirUserId = "@foobaraccount3:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .youSentRequest() + .youSentStart() + .fixtures(); + const { sas, logger } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages + ); + sas.eventEmitter.on("SelectVerificationStage", (stage) => { + logger.run("send start", async (log) => { + await stage?.selectEmojiMethod(log); + }); + }); + await sas.start(); + const expectedOrder = [ + RequestVerificationStage, + SelectVerificationMethodStage, + SendKeyStage, + CalculateSASStage, + SendMacStage, + VerifyMacStage, + SendDoneStage + ] + //@ts-ignore + let stage = sas.startStage; + for (const stageClass of expectedOrder) { + assert.strictEqual(stage instanceof stageClass, true); + stage = stage.nextStage; + } + assert.strictEqual(sas.finished, true); + }, + "Order of stages created matches expected order when request is received": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount:matrix.org"; + const theirUserId = "@foobaraccount3:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .theySentStart() + .fixtures(); + const startingMessage = receivedMessages.get(VerificationEventTypes.Start); + const { sas } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages, + startingMessage, + ); + await sas.start(); + const expectedOrder = [ + SelectVerificationMethodStage, + SendAcceptVerificationStage, + SendKeyStage, + CalculateSASStage, + SendMacStage, + VerifyMacStage, + SendDoneStage + ] + //@ts-ignore + let stage = sas.startStage; + for (const stageClass of expectedOrder) { + assert.strictEqual(stage instanceof stageClass, true); + stage = stage.nextStage; + } + assert.strictEqual(sas.finished, true); + }, + "Order of stages created matches expected order when request is sent with start conflict (they win)": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount:matrix.org"; + const theirUserId = "@foobaraccount3:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .youSentRequest() + .theySentStart() + .youSentStart() + .theyWinConflict() + .fixtures(); + const { sas } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages + ); + await sas.start(); + const expectedOrder = [ + RequestVerificationStage, + SelectVerificationMethodStage, + SendAcceptVerificationStage, + SendKeyStage, + CalculateSASStage, + SendMacStage, + VerifyMacStage, + SendDoneStage + ] + //@ts-ignore + let stage = sas.startStage; + for (const stageClass of expectedOrder) { + assert.strictEqual(stage instanceof stageClass, true); + stage = stage.nextStage; + } + assert.strictEqual(sas.finished, true); + }, + "Order of stages created matches expected order when request is sent with start conflict (I win)": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount3:matrix.org"; + const theirUserId = "@foobaraccount:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .youSentRequest() + .theySentStart() + .youSentStart() + .youWinConflict() + .fixtures(); + const { sas, logger } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages + ); + sas.eventEmitter.on("SelectVerificationStage", (stage) => { + logger.run("send start", async (log) => { + await stage?.selectEmojiMethod(log); + }); + }); + await sas.start(); + const expectedOrder = [ + RequestVerificationStage, + SelectVerificationMethodStage, + SendKeyStage, + CalculateSASStage, + SendMacStage, + VerifyMacStage, + SendDoneStage + ] + //@ts-ignore + let stage = sas.startStage; + for (const stageClass of expectedOrder) { + assert.strictEqual(stage instanceof stageClass, true); + stage = stage.nextStage; + } + assert.strictEqual(sas.finished, true); + }, + "Order of stages created matches expected order when request is received with start conflict (they win)": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount:matrix.org"; + const theirUserId = "@foobaraccount3:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .theySentStart() + .youSentStart() + .theyWinConflict() + .fixtures(); + const startingMessage = receivedMessages.get(VerificationEventTypes.Start); + console.log(receivedMessages); + const { sas } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages, + startingMessage, + ); + await sas.start(); + const expectedOrder = [ + SelectVerificationMethodStage, + SendAcceptVerificationStage, + SendKeyStage, + CalculateSASStage, + SendMacStage, + VerifyMacStage, + SendDoneStage + ] + //@ts-ignore + let stage = sas.startStage; + for (const stageClass of expectedOrder) { + console.log("Checking", stageClass.constructor.name, stage.constructor.name); + assert.strictEqual(stage instanceof stageClass, true); + stage = stage.nextStage; + } + assert.strictEqual(sas.finished, true); + }, + "Order of stages created matches expected order when request is received with start conflict (I win)": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount3:matrix.org"; + const theirUserId = "@foobaraccount:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .theySentStart() + .youSentStart() + .youWinConflict() + .fixtures(); + const startingMessage = receivedMessages.get(VerificationEventTypes.Start); + console.log(receivedMessages); + const { sas, logger } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages, + startingMessage, + ); + sas.eventEmitter.on("SelectVerificationStage", (stage) => { + logger.run("send start", async (log) => { + await stage?.selectEmojiMethod(log); + }); + }); + await sas.start(); + const expectedOrder = [ + SelectVerificationMethodStage, + SendKeyStage, + CalculateSASStage, + SendMacStage, + VerifyMacStage, + SendDoneStage + ] + //@ts-ignore + let stage = sas.startStage; + for (const stageClass of expectedOrder) { + console.log("Checking", stageClass.constructor.name, stage.constructor.name); + assert.strictEqual(stage instanceof stageClass, true); + stage = stage.nextStage; + } + assert.strictEqual(sas.finished, true); + }, + "Verification is cancelled after 10 minutes": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount:matrix.org"; + const theirUserId = "@foobaraccount3:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .youSentRequest() + .theySentStart() + .fixtures(); + console.log("receivedMessages", receivedMessages); + const { sas, clock } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages + ); + const promise = sas.start(); + clock.elapse(10 * 60 * 1000); + try { + await promise; + } + catch (e) { + assert.strictEqual(e instanceof VerificationCancelledError, true); + } + assert.strictEqual(sas.finished, true); + }, + "Verification is cancelled when there's no common hash algorithm": async (assert) => { + const ourDeviceId = "ILQHOACESQ"; + const ourUserId = "@foobaraccount:matrix.org"; + const theirUserId = "@foobaraccount3:matrix.org"; + const theirDeviceId = "FWKXUYUHTF"; + const txnId = "t150836b91a7bed"; + const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId) + .youSentRequest() + .theySentStart() + .fixtures(); + receivedMessages.get(VerificationEventTypes.Start).content.key_agreement_protocols = ["foo"]; + const { sas } = await createSASRequest( + ourUserId, + ourDeviceId, + theirUserId, + theirDeviceId, + txnId, + receivedMessages + ); + try { + await sas.start() + } + catch (e) { + assert.strictEqual(e instanceof VerificationCancelledError, true); + } + assert.strictEqual(sas.finished, true); + }, + } +} diff --git a/src/matrix/verification/SAS/channel/MockChannel.ts b/src/matrix/verification/SAS/channel/MockChannel.ts new file mode 100644 index 00000000..7f4a766a --- /dev/null +++ b/src/matrix/verification/SAS/channel/MockChannel.ts @@ -0,0 +1,133 @@ +import type {ILogItem} from "../../../../lib"; +import {createCalculateMAC} from "../mac"; +import {VerificationCancelledError} from "../VerificationCancelledError"; +import {IChannel} from "./Channel"; +import {CancelTypes, VerificationEventTypes} from "./types"; +import anotherjson from "another-json"; + +interface ITestChannel extends IChannel { + setOlmSas(olmSas): void; +} + +export class MockChannel implements ITestChannel { + public sentMessages: Map = new Map(); + public receivedMessages: Map = new Map(); + public initiatedByUs: boolean; + public startMessage: any; + public isCancelled: boolean = false; + private olmSas: any; + + constructor( + public otherUserDeviceId: string, + public otherUserId: string, + public ourUserDeviceId: string, + public ourUserId: string, + private fixtures: Map, + private deviceTracker: any, + public id: string, + private olm: any, + startingMessage?: any, + ) { + if (startingMessage) { + const eventType = startingMessage.content.method ? VerificationEventTypes.Start : VerificationEventTypes.Request; + this.id = startingMessage.content.transaction_id; + this.receivedMessages.set(eventType, startingMessage); + } + } + + async send(eventType: string, content: any, _: ILogItem) { + if (this.isCancelled) { + throw new VerificationCancelledError(); + } + Object.assign(content, { transaction_id: this.id }); + this.sentMessages.set(eventType, {content}); + } + + async waitForEvent(eventType: string): Promise { + if (this.isCancelled) { + throw new VerificationCancelledError(); + } + const event = this.fixtures.get(eventType); + if (event) { + this.receivedMessages.set(eventType, event); + } + else { + await new Promise(() => {}); + } + if (eventType === VerificationEventTypes.Mac) { + await this.recalculateMAC(); + } + if(eventType === VerificationEventTypes.Accept && this.startMessage) { + } + return event; + } + + private recalculateCommitment() { + const acceptMessage = this.getEvent(VerificationEventTypes.Accept)?.content; + if (!acceptMessage) { + return; + } + const {content} = this.startMessage; + const {content: keyMessage} = this.fixtures.get(VerificationEventTypes.Key); + const key = keyMessage.key; + const commitmentStr = key + anotherjson.stringify(content); + const olmUtil = new this.olm.Utility(); + const commitment = olmUtil.sha256(commitmentStr); + olmUtil.free(); + acceptMessage.commitment = commitment; + } + + private async recalculateMAC() { + // We need to replace the mac with calculated mac + const baseInfo = + "MATRIX_KEY_VERIFICATION_MAC" + + this.otherUserId + + this.otherUserDeviceId + + this.ourUserId + + this.ourUserDeviceId + + this.id; + const { content: macContent } = this.receivedMessages.get(VerificationEventTypes.Mac); + const macMethod = this.getEvent(VerificationEventTypes.Accept).content.message_authentication_code; + const calculateMac = createCalculateMAC(this.olmSas, macMethod); + const input = Object.keys(macContent.mac).sort().join(","); + const properMac = calculateMac(input, baseInfo + "KEY_IDS"); + macContent.keys = properMac; + for (const keyId of Object.keys(macContent.mac)) { + const deviceId = keyId.split(":", 2)[1]; + const device = await this.deviceTracker.deviceForId(this.otherUserDeviceId, deviceId); + if (device) { + macContent.mac[keyId] = calculateMac(device.ed25519Key, baseInfo + keyId); + } + else { + const {masterKey} = await this.deviceTracker.getCrossSigningKeysForUser(this.otherUserId); + macContent.mac[keyId] = calculateMac(masterKey, baseInfo + keyId); + } + } + } + + setStartMessage(event: any): void { + this.startMessage = event; + this.recalculateCommitment(); + } + + setInitiatedByUs(value: boolean): void { + this.initiatedByUs = value; + } + + async cancelVerification(_: CancelTypes): Promise { + console.log("MockChannel.cancelVerification()"); + this.isCancelled = true; + } + + getEvent(eventType: VerificationEventTypes.Accept): any { + return this.receivedMessages.get(eventType) ?? this.sentMessages.get(eventType); + } + + setOlmSas(olmSas: any): void { + this.olmSas = olmSas; + } + + get type() { + return 0; + } +} diff --git a/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts b/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts index a61a2ce9..a923376b 100644 --- a/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts +++ b/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts @@ -14,13 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ import type {ILogItem} from "../../../../lib.js"; -import type {Room} from "../../../room/Room.js"; import type * as OlmNamespace from "@matrix-org/olm"; import type {Account} from "../../../e2ee/Account.js"; import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js"; import {Disposables} from "../../../../utils/Disposables"; import {IChannel} from "../channel/Channel.js"; import {HomeServerApi} from "../../../net/HomeServerApi.js"; +import {SASProgressEvents} from "../types.js"; +import {EventEmitter} from "../../../../utils/EventEmitter"; type Olm = typeof OlmNamespace; @@ -39,6 +40,7 @@ export type Options = { e2eeAccount: Account; deviceTracker: DeviceTracker; hsApi: HomeServerApi; + eventEmitter: EventEmitter } export abstract class BaseSASVerificationStage extends Disposables { @@ -55,6 +57,7 @@ export abstract class BaseSASVerificationStage extends Disposables { protected e2eeAccount: Account; protected deviceTracker: DeviceTracker; protected hsApi: HomeServerApi; + protected eventEmitter: EventEmitter; constructor(options: Options) { super(); @@ -68,6 +71,7 @@ export abstract class BaseSASVerificationStage extends Disposables { this.e2eeAccount = options.e2eeAccount; this.deviceTracker = options.deviceTracker; this.hsApi = options.hsApi; + this.eventEmitter = options.eventEmitter; } setNextStage(stage: BaseSASVerificationStage) { diff --git a/src/matrix/verification/SAS/stages/SelectVerificationMethodStage.ts b/src/matrix/verification/SAS/stages/SelectVerificationMethodStage.ts index af95d70e..a6cef8b9 100644 --- a/src/matrix/verification/SAS/stages/SelectVerificationMethodStage.ts +++ b/src/matrix/verification/SAS/stages/SelectVerificationMethodStage.ts @@ -27,7 +27,7 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage { async completeStage() { await this.log.wrap("SelectVerificationMethodStage.completeStage", async (log) => { - (window as any).select = () => this.selectEmojiMethod(log); + this.eventEmitter.emit("SelectVerificationStage", this); const startMessage = this.channel.waitForEvent(VerificationEventTypes.Start); const acceptMessage = this.channel.waitForEvent(VerificationEventTypes.Accept); const { content } = await Promise.race([startMessage, acceptMessage]); @@ -59,7 +59,7 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage { }); } - async resolveStartConflict() { + private async resolveStartConflict() { const receivedStartMessage = this.channel.receivedMessages.get(VerificationEventTypes.Start); const sentStartMessage = this.channel.sentMessages.get(VerificationEventTypes.Start); if (receivedStartMessage.content.method !== sentStartMessage.content.method) { diff --git a/src/matrix/verification/SAS/stages/SendAcceptVerificationStage.ts b/src/matrix/verification/SAS/stages/SendAcceptVerificationStage.ts index b9112bbc..57ec1fbe 100644 --- a/src/matrix/verification/SAS/stages/SendAcceptVerificationStage.ts +++ b/src/matrix/verification/SAS/stages/SendAcceptVerificationStage.ts @@ -28,7 +28,6 @@ export class SendAcceptVerificationStage extends BaseSASVerificationStage { const macMethod = intersection(MAC_LIST, new Set(content.message_authentication_codes))[0]; const sasMethods = intersection(content.short_authentication_string, SAS_SET); if (!(keyAgreement !== undefined && hashMethod !== undefined && macMethod !== undefined && sasMethods.length)) { - // todo: ensure this cancels the verification await this.channel.cancelVerification(CancelTypes.UnknownMethod); return; } diff --git a/src/matrix/verification/SAS/stages/VerifyMacStage.ts b/src/matrix/verification/SAS/stages/VerifyMacStage.ts index 2eb37018..441d1bc4 100644 --- a/src/matrix/verification/SAS/stages/VerifyMacStage.ts +++ b/src/matrix/verification/SAS/stages/VerifyMacStage.ts @@ -54,14 +54,15 @@ export class VerifyMacStage extends BaseSASVerificationStage { this.ourUser.deviceId + this.channel.id; - if ( content.keys !== this.calculateMAC(Object.keys(content.mac).sort().join(","), baseInfo + "KEY_IDS")) { - // cancel when MAC does not match! + const calculatedMAC = this.calculateMAC(Object.keys(content.mac).sort().join(","), baseInfo + "KEY_IDS"); + if (content.keys !== calculatedMAC) { + // todo: cancel when MAC does not match! console.log("Keys MAC Verification failed"); } await this.verifyKeys(content.mac, (keyId, key, keyInfo) => { if (keyInfo !== this.calculateMAC(key, baseInfo + keyId)) { - // cancel when MAC does not match! + // todo: cancel when MAC does not match! console.log("mac obj MAC Verification failed"); } }, log); diff --git a/src/matrix/verification/SAS/types.ts b/src/matrix/verification/SAS/types.ts new file mode 100644 index 00000000..52e7c97a --- /dev/null +++ b/src/matrix/verification/SAS/types.ts @@ -0,0 +1,20 @@ +/* +Copyright 2023 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage"; + +export type SASProgressEvents = { + SelectVerificationStage: SelectVerificationMethodStage; +}