From 660db4ced3b0df25c22bc56130ee79bf78e4b53c Mon Sep 17 00:00:00 2001 From: RMidhunSuresh Date: Mon, 10 Apr 2023 19:44:44 +0530 Subject: [PATCH] Refactor to avoid passing crosssigning --- .../DeviceVerificationViewModel.ts | 2 +- src/matrix/verification/CrossSigning.ts | 23 ++++++-- .../verification/SAS/SASVerification.ts | 52 ++++++++----------- .../SAS/stages/BaseSASVerificationStage.ts | 2 - .../verification/SAS/stages/VerifyMacStage.ts | 14 ++--- 5 files changed, 47 insertions(+), 46 deletions(-) diff --git a/src/domain/session/verification/DeviceVerificationViewModel.ts b/src/domain/session/verification/DeviceVerificationViewModel.ts index 3257c784..35842c20 100644 --- a/src/domain/session/verification/DeviceVerificationViewModel.ts +++ b/src/domain/session/verification/DeviceVerificationViewModel.ts @@ -55,7 +55,7 @@ export class DeviceVerificationViewModel extends ErrorReportViewModel; + otherDeviceId: string; +} + export class CrossSigning { private readonly storage: Storage; private readonly secretStorage: SecretStorage; @@ -202,7 +207,6 @@ export class CrossSigning { deviceTracker: this.deviceTracker, hsApi: this.hsApi, clock: this.platform.clock, - crossSigning: this, }); return this.sasVerificationInProgress; } @@ -249,13 +253,19 @@ export class CrossSigning { } /** @return the signed device key for the given device id */ - async signDevice(deviceId: string, log: ILogItem): Promise { + async signDevice(verification: IVerificationMethod, log: ILogItem): Promise { return log.wrap("CrossSigning.signDevice", async log => { - log.set("id", deviceId); if (!this._isMasterKeyTrusted) { log.set("mskNotTrusted", true); return; } + const shouldSign = await verification.verify(); + log.set("shouldSign", shouldSign); + if (!shouldSign) { + return; + } + const deviceId = verification.otherDeviceId; + log.set("id", deviceId); const keyToSign = await this.deviceTracker.deviceForId(this.ownUserId, deviceId, this.hsApi, log); if (!keyToSign) { return undefined; @@ -266,7 +276,7 @@ export class CrossSigning { } /** @return the signed MSK for the given user id */ - async signUser(userId: string, log: ILogItem): Promise { + async signUser(userId: string, verification: IVerificationMethod, log: ILogItem): Promise { return log.wrap("CrossSigning.signUser", async log => { log.set("id", userId); if (!this._isMasterKeyTrusted) { @@ -277,6 +287,11 @@ export class CrossSigning { if (userId === this.ownUserId) { return; } + const shouldSign = await verification.verify(); + log.set("shouldSign", shouldSign); + if (!shouldSign) { + return; + } const keyToSign = await this.deviceTracker.getCrossSigningKeyForUser(userId, KeyUsage.Master, this.hsApi, log); if (!keyToSign) { return; diff --git a/src/matrix/verification/SAS/SASVerification.ts b/src/matrix/verification/SAS/SASVerification.ts index 57065218..b408dc25 100644 --- a/src/matrix/verification/SAS/SASVerification.ts +++ b/src/matrix/verification/SAS/SASVerification.ts @@ -23,13 +23,13 @@ import type {IChannel} from "./channel/IChannel"; import type {HomeServerApi} from "../../net/HomeServerApi"; import type {Timeout} from "../../../platform/types/types"; import type {Clock} from "../../../platform/web/dom/Clock.js"; +import type {IVerificationMethod} from "../CrossSigning"; import {CancelReason, VerificationEventType} from "./channel/types"; import {SendReadyStage} from "./stages/SendReadyStage"; import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage"; import {VerificationCancelledError} from "./VerificationCancelledError"; import {EventEmitter} from "../../../utils/EventEmitter"; import {SASProgressEvents} from "./types"; -import type {CrossSigning} from "../CrossSigning"; type Olm = typeof OlmNamespace; @@ -45,10 +45,9 @@ type Options = { deviceTracker: DeviceTracker; hsApi: HomeServerApi; clock: Clock; - crossSigning: CrossSigning } -export class SASVerification extends EventEmitter { +export class SASVerification extends EventEmitter implements IVerificationMethod { private startStage: BaseSASVerificationStage; private olmSas: Olm.SAS; public finished: boolean = false; @@ -74,7 +73,7 @@ export class SASVerification extends EventEmitter { } } - private async setupCancelAfterTimeout(clock: Clock) { + private async setupCancelAfterTimeout(clock: Clock): Promise { try { const tenMinutes = 10 * 60 * 1000; this.timeout = clock.createTimeout(tenMinutes); @@ -86,11 +85,12 @@ export class SASVerification extends EventEmitter { } } - async abort() { + async abort(): Promise { await this.channel.cancelVerification(CancelReason.UserCancelled); } - async start() { + async verify(): Promise { + let success = true; try { let stage = this.startStage; do { @@ -102,6 +102,7 @@ export class SASVerification extends EventEmitter { if (!(e instanceof VerificationCancelledError)) { throw e; } + success = false; } finally { if (this.channel.isCancelled) { @@ -111,6 +112,11 @@ export class SASVerification extends EventEmitter { this.timeout.abort(); this.finished = true; } + return success; + } + + get otherDeviceId(): string { + return this.channel.otherUserDeviceId; } } @@ -189,7 +195,6 @@ export function tests() { olm, startingMessage, ); - const crossSigning = new MockCrossSigning() as unknown as CrossSigning; const clock = new MockClock(); const logger = new NullLogger(); return logger.run("log", (log) => { @@ -207,7 +212,6 @@ export function tests() { ourUserId, ourUserDeviceId: ourDeviceId, log, - crossSigning }); // @ts-ignore channel.setOlmSas(sas.olmSas); @@ -218,16 +222,6 @@ export function tests() { }); } - class MockCrossSigning { - signDevice(deviceId: string, log: ILogItem) { - return Promise.resolve({}); // device keys, means signing succeeded - } - - signUser(userId: string, log: ILogItem) { - return Promise.resolve({}); // cross-signing keys, means signing succeeded - } - } - return { "Order of stages created matches expected order when I sent request, they sent start": async (assert) => { const ourDeviceId = "ILQHOACESQ"; @@ -247,7 +241,7 @@ export function tests() { txnId, receivedMessages ); - await sas.start(); + await sas.verify(); const expectedOrder = [ SendRequestVerificationStage, SelectVerificationMethodStage, @@ -289,7 +283,7 @@ export function tests() { await stage?.selectEmojiMethod(log); }); }); - await sas.start(); + await sas.verify(); const expectedOrder = [ SendRequestVerificationStage, SelectVerificationMethodStage, @@ -326,7 +320,7 @@ export function tests() { receivedMessages, startingMessage, ); - await sas.start(); + await sas.verify(); const expectedOrder = [ SelectVerificationMethodStage, SendAcceptVerificationStage, @@ -364,7 +358,7 @@ export function tests() { txnId, receivedMessages ); - await sas.start(); + await sas.verify(); const expectedOrder = [ SendRequestVerificationStage, SelectVerificationMethodStage, @@ -408,7 +402,7 @@ export function tests() { await stage?.selectEmojiMethod(log); }); }); - await sas.start(); + await sas.verify(); const expectedOrder = [ SendRequestVerificationStage, SelectVerificationMethodStage, @@ -448,7 +442,7 @@ export function tests() { receivedMessages, startingMessage, ); - await sas.start(); + await sas.verify(); const expectedOrder = [ SelectVerificationMethodStage, SendAcceptVerificationStage, @@ -494,7 +488,7 @@ export function tests() { await stage?.selectEmojiMethod(log); }); }); - await sas.start(); + await sas.verify(); const expectedOrder = [ SelectVerificationMethodStage, SendKeyStage, @@ -537,7 +531,7 @@ export function tests() { await stage?.selectEmojiMethod(log); }); }); - await sas.start(); + await sas.verify(); const expectedOrder = [ SendRequestVerificationStage, SelectVerificationMethodStage, @@ -575,7 +569,7 @@ export function tests() { txnId, receivedMessages ); - await sas.start(); + await sas.verify(); const expectedOrder = [ SendRequestVerificationStage, SelectVerificationMethodStage, @@ -613,7 +607,7 @@ export function tests() { txnId, receivedMessages ); - const promise = sas.start(); + const promise = sas.verify(); clock.elapse(10 * 60 * 1000); try { await promise; @@ -643,7 +637,7 @@ export function tests() { receivedMessages ); try { - await sas.start() + await sas.verify() } catch (e) { assert.strictEqual(e instanceof VerificationCancelledError, true); diff --git a/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts b/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts index 534a3ca8..ba6544ba 100644 --- a/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts +++ b/src/matrix/verification/SAS/stages/BaseSASVerificationStage.ts @@ -16,7 +16,6 @@ limitations under the License. import type {ILogItem} from "../../../../logging/types"; import type {Account} from "../../../e2ee/Account.js"; import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js"; -import type {CrossSigning} from "../../CrossSigning"; import {IChannel} from "../channel/IChannel"; import {HomeServerApi} from "../../../net/HomeServerApi"; import {SASProgressEvents} from "../types"; @@ -34,7 +33,6 @@ export type Options = { deviceTracker: DeviceTracker; hsApi: HomeServerApi; eventEmitter: EventEmitter - crossSigning: CrossSigning } export abstract class BaseSASVerificationStage { diff --git a/src/matrix/verification/SAS/stages/VerifyMacStage.ts b/src/matrix/verification/SAS/stages/VerifyMacStage.ts index a1ef5515..c7785418 100644 --- a/src/matrix/verification/SAS/stages/VerifyMacStage.ts +++ b/src/matrix/verification/SAS/stages/VerifyMacStage.ts @@ -68,11 +68,8 @@ export class VerifyMacStage extends BaseSASVerificationStage { const deviceIdOrMSK = keyId.split(":", 2)[1]; const device = await this.deviceTracker.deviceForId(userId, deviceIdOrMSK, this.hsApi, log); if (device) { - if (verifier(keyId, getDeviceEd25519Key(device), keyInfo)) { - await log.wrap("signing device", async log => { - const signedKey = await this.options.crossSigning.signDevice(device.device_id, log); - log.set("success", !!signedKey); - }); + if (!verifier(keyId, getDeviceEd25519Key(device), keyInfo)) { + throw new Error(`MAC verification failed for key ${keyInfo}`); } } else { // If we were not able to find the device, then deviceIdOrMSK is actually the MSK! @@ -82,11 +79,8 @@ export class VerifyMacStage extends BaseSASVerificationStage { throw new Error("Fetching MSK for user failed!"); } const masterKey = getKeyEd25519Key(key); - if(masterKey && verifier(keyId, masterKey, keyInfo)) { - await log.wrap("signing user", async log => { - const signedKey = await this.options.crossSigning.signUser(userId, log); - log.set("success", !!signedKey); - }); + if(!(masterKey && verifier(keyId, masterKey, keyInfo))) { + throw new Error(`MAC verification failed for key ${keyInfo}`); } } }