Write unit tests

This commit is contained in:
RMidhunSuresh 2023-03-13 21:17:22 +05:30
parent 2e653d5f76
commit 720585b8f2
No known key found for this signature in database
8 changed files with 608 additions and 21 deletions

View File

@ -21,7 +21,6 @@ import type {DeviceTracker} from "../e2ee/DeviceTracker";
import type * as OlmNamespace from "@matrix-org/olm"; import type * as OlmNamespace from "@matrix-org/olm";
import type {HomeServerApi} from "../net/HomeServerApi"; import type {HomeServerApi} from "../net/HomeServerApi";
import type {Account} from "../e2ee/Account"; import type {Account} from "../e2ee/Account";
import type {Room} from "../room/Room.js";
import { ILogItem } from "../../lib"; import { ILogItem } from "../../lib";
import {pkSign} from "./common"; import {pkSign} from "./common";
import type {ISignatures} from "./common"; import type {ISignatures} from "./common";
@ -166,7 +165,7 @@ export class CrossSigning {
e2eeAccount: this.e2eeAccount, e2eeAccount: this.e2eeAccount,
deviceTracker: this.deviceTracker, deviceTracker: this.deviceTracker,
hsApi: this.hsApi, hsApi: this.hsApi,
platform: this.platform, clock: this.platform.clock,
}); });
return this.sasVerificationInProgress; return this.sasVerificationInProgress;
} }

View File

@ -26,7 +26,9 @@ import {SendReadyStage} from "./stages/SendReadyStage";
import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage"; import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage";
import {VerificationCancelledError} from "./VerificationCancelledError"; import {VerificationCancelledError} from "./VerificationCancelledError";
import {Timeout} from "../../../platform/types/types"; import {Timeout} from "../../../platform/types/types";
import {Platform} from "../../../platform/web/Platform.js"; import {Clock} from "../../../platform/web/dom/Clock.js";
import {EventEmitter} from "../../../utils/EventEmitter";
import {SASProgressEvents} from "./types";
type Olm = typeof OlmNamespace; type Olm = typeof OlmNamespace;
@ -40,7 +42,7 @@ type Options = {
e2eeAccount: Account; e2eeAccount: Account;
deviceTracker: DeviceTracker; deviceTracker: DeviceTracker;
hsApi: HomeServerApi; hsApi: HomeServerApi;
platform: Platform; clock: Clock;
} }
export class SASVerification { export class SASVerification {
@ -49,19 +51,20 @@ export class SASVerification {
public finished: boolean = false; public finished: boolean = false;
public readonly channel: IChannel; public readonly channel: IChannel;
private readonly timeout: Timeout; private readonly timeout: Timeout;
public readonly eventEmitter: EventEmitter<SASProgressEvents> = new EventEmitter();
constructor(options: Options) { constructor(options: Options) {
const { olm, channel, platform } = options; const { olm, channel, clock } = options;
const olmSas = new olm.SAS(); const olmSas = new olm.SAS();
this.olmSas = olmSas; this.olmSas = olmSas;
this.channel = channel; this.channel = channel;
this.timeout = platform.clock.createTimeout(10 * 60 * 1000); this.timeout = clock.createTimeout(10 * 60 * 1000);
this.timeout.elapsed().then(() => { this.timeout.elapsed().then(() => {
// Cancel verification after 10 minutes return channel.cancelVerification(CancelTypes.TimedOut);
// todo: catch error here? }).catch(() => {
channel.cancelVerification(CancelTypes.TimedOut); // todo: why do we do nothing here?
}); });
const stageOptions = {...options, olmSas}; const stageOptions = {...options, olmSas, eventEmitter: this.eventEmitter};
if (channel.receivedMessages.get(VerificationEventTypes.Start)) { if (channel.receivedMessages.get(VerificationEventTypes.Start)) {
this.startStage = new SelectVerificationMethodStage(stageOptions); this.startStage = new SelectVerificationMethodStage(stageOptions);
} }
@ -78,7 +81,7 @@ export class SASVerification {
try { try {
let stage = this.startStage; let stage = this.startStage;
do { do {
console.log("Running next stage"); console.log("Running stage", stage.constructor.name);
await stage.completeStage(); await stage.completeStage();
stage = stage.nextStage; stage = stage.nextStage;
} while (stage); } while (stage);
@ -87,12 +90,440 @@ export class SASVerification {
if (!(e instanceof VerificationCancelledError)) { if (!(e instanceof VerificationCancelledError)) {
throw e; throw e;
} }
console.log("Caught error in start()");
} }
finally { finally {
this.olmSas.free(); this.olmSas.free();
this.finished = true;
this.timeout.abort(); this.timeout.abort();
this.finished = true;
} }
} }
} }
import {HomeServer} from "../../../mocks/HomeServer.js";
import Olm from "@matrix-org/olm/olm.js";
import {MockChannel} from "./channel/MockChannel";
import {Clock as MockClock} from "../../../mocks/Clock.js";
import {NullLogger} from "../../../logging/NullLogger";
import {SASFixtures} from "../../../fixtures/matrix/sas/events";
import {SendKeyStage} from "./stages/SendKeyStage";
import {CalculateSASStage} from "./stages/CalculateSASStage";
import {SendMacStage} from "./stages/SendMacStage";
import {VerifyMacStage} from "./stages/VerifyMacStage";
import {SendDoneStage} from "./stages/SendDoneStage";
import {SendAcceptVerificationStage} from "./stages/SendAcceptVerificationStage";
export function tests() {
async function createSASRequest(
ourUserId: string,
ourDeviceId: string,
theirUserId: string,
theirDeviceId: string,
txnId: string,
receivedMessages,
startingMessage?: any
) {
const homeserverMock = new HomeServer();
const hsApi = homeserverMock.api;
const olm = Olm;
await olm.init();
const olmUtil = new Olm.Utility();
const e2eeAccount = {
getDeviceKeysToSignWithCrossSigning: () => {
return {
keys: {
[`ed25519:${ourDeviceId}`]:
"srsWWbrnQFIOmUSdrt3cS/unm03qAIgXcWwQg9BegKs",
},
};
},
};
const deviceTracker = {
getCrossSigningKeysForUser: (userId, _hsApi, _) => {
let masterKey =
userId === ourUserId
? "5HIrEawRiiQioViNfezPDWfPWH2pdaw3pbQNHEVN2jM"
: "Ot8Y58PueQ7hJVpYWAJkg2qaREJAY/UhGZYOrsd52oo";
return { masterKey };
},
deviceForId: (_userId, _deviceId, _hsApi, _log) => {
return {
ed25519Key: "D8w9mrokGdEZPdPgrU0kQkYi4vZyzKEBfvGyZsGK7+Q",
};
},
};
const channel = new MockChannel(
theirDeviceId,
theirUserId,
ourDeviceId,
ourUserId,
receivedMessages,
deviceTracker,
txnId,
olm,
startingMessage,
);
const clock = new MockClock();
const logger = new NullLogger();
return logger.run("log", (log) => {
// @ts-ignore
const sas = new SASVerification({
channel,
clock,
hsApi,
deviceTracker,
e2eeAccount,
olm,
olmUtil,
otherUserId: theirUserId!,
ourUser: { deviceId: ourDeviceId!, userId: ourUserId! },
log,
});
// @ts-ignore
channel.setOlmSas(sas.olmSas);
return { sas, clock, logger };
});
}
return {
"Order of stages created matches expected order when I sent request, they sent start": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount:matrix.org";
const theirUserId = "@foobaraccount3:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.youSentRequest()
.theySentStart()
.fixtures();
const { sas } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages
);
await sas.start();
const expectedOrder = [
RequestVerificationStage,
SelectVerificationMethodStage,
SendAcceptVerificationStage,
SendKeyStage,
CalculateSASStage,
SendMacStage,
VerifyMacStage,
SendDoneStage
]
//@ts-ignore
let stage = sas.startStage;
for (const stageClass of expectedOrder) {
assert.strictEqual(stage instanceof stageClass, true);
stage = stage.nextStage;
}
assert.strictEqual(sas.finished, true);
},
"Order of stages created matches expected order when I sent request, I sent start": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount:matrix.org";
const theirUserId = "@foobaraccount3:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.youSentRequest()
.youSentStart()
.fixtures();
const { sas, logger } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages
);
sas.eventEmitter.on("SelectVerificationStage", (stage) => {
logger.run("send start", async (log) => {
await stage?.selectEmojiMethod(log);
});
});
await sas.start();
const expectedOrder = [
RequestVerificationStage,
SelectVerificationMethodStage,
SendKeyStage,
CalculateSASStage,
SendMacStage,
VerifyMacStage,
SendDoneStage
]
//@ts-ignore
let stage = sas.startStage;
for (const stageClass of expectedOrder) {
assert.strictEqual(stage instanceof stageClass, true);
stage = stage.nextStage;
}
assert.strictEqual(sas.finished, true);
},
"Order of stages created matches expected order when request is received": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount:matrix.org";
const theirUserId = "@foobaraccount3:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.theySentStart()
.fixtures();
const startingMessage = receivedMessages.get(VerificationEventTypes.Start);
const { sas } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages,
startingMessage,
);
await sas.start();
const expectedOrder = [
SelectVerificationMethodStage,
SendAcceptVerificationStage,
SendKeyStage,
CalculateSASStage,
SendMacStage,
VerifyMacStage,
SendDoneStage
]
//@ts-ignore
let stage = sas.startStage;
for (const stageClass of expectedOrder) {
assert.strictEqual(stage instanceof stageClass, true);
stage = stage.nextStage;
}
assert.strictEqual(sas.finished, true);
},
"Order of stages created matches expected order when request is sent with start conflict (they win)": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount:matrix.org";
const theirUserId = "@foobaraccount3:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.youSentRequest()
.theySentStart()
.youSentStart()
.theyWinConflict()
.fixtures();
const { sas } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages
);
await sas.start();
const expectedOrder = [
RequestVerificationStage,
SelectVerificationMethodStage,
SendAcceptVerificationStage,
SendKeyStage,
CalculateSASStage,
SendMacStage,
VerifyMacStage,
SendDoneStage
]
//@ts-ignore
let stage = sas.startStage;
for (const stageClass of expectedOrder) {
assert.strictEqual(stage instanceof stageClass, true);
stage = stage.nextStage;
}
assert.strictEqual(sas.finished, true);
},
"Order of stages created matches expected order when request is sent with start conflict (I win)": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount3:matrix.org";
const theirUserId = "@foobaraccount:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.youSentRequest()
.theySentStart()
.youSentStart()
.youWinConflict()
.fixtures();
const { sas, logger } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages
);
sas.eventEmitter.on("SelectVerificationStage", (stage) => {
logger.run("send start", async (log) => {
await stage?.selectEmojiMethod(log);
});
});
await sas.start();
const expectedOrder = [
RequestVerificationStage,
SelectVerificationMethodStage,
SendKeyStage,
CalculateSASStage,
SendMacStage,
VerifyMacStage,
SendDoneStage
]
//@ts-ignore
let stage = sas.startStage;
for (const stageClass of expectedOrder) {
assert.strictEqual(stage instanceof stageClass, true);
stage = stage.nextStage;
}
assert.strictEqual(sas.finished, true);
},
"Order of stages created matches expected order when request is received with start conflict (they win)": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount:matrix.org";
const theirUserId = "@foobaraccount3:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.theySentStart()
.youSentStart()
.theyWinConflict()
.fixtures();
const startingMessage = receivedMessages.get(VerificationEventTypes.Start);
console.log(receivedMessages);
const { sas } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages,
startingMessage,
);
await sas.start();
const expectedOrder = [
SelectVerificationMethodStage,
SendAcceptVerificationStage,
SendKeyStage,
CalculateSASStage,
SendMacStage,
VerifyMacStage,
SendDoneStage
]
//@ts-ignore
let stage = sas.startStage;
for (const stageClass of expectedOrder) {
console.log("Checking", stageClass.constructor.name, stage.constructor.name);
assert.strictEqual(stage instanceof stageClass, true);
stage = stage.nextStage;
}
assert.strictEqual(sas.finished, true);
},
"Order of stages created matches expected order when request is received with start conflict (I win)": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount3:matrix.org";
const theirUserId = "@foobaraccount:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.theySentStart()
.youSentStart()
.youWinConflict()
.fixtures();
const startingMessage = receivedMessages.get(VerificationEventTypes.Start);
console.log(receivedMessages);
const { sas, logger } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages,
startingMessage,
);
sas.eventEmitter.on("SelectVerificationStage", (stage) => {
logger.run("send start", async (log) => {
await stage?.selectEmojiMethod(log);
});
});
await sas.start();
const expectedOrder = [
SelectVerificationMethodStage,
SendKeyStage,
CalculateSASStage,
SendMacStage,
VerifyMacStage,
SendDoneStage
]
//@ts-ignore
let stage = sas.startStage;
for (const stageClass of expectedOrder) {
console.log("Checking", stageClass.constructor.name, stage.constructor.name);
assert.strictEqual(stage instanceof stageClass, true);
stage = stage.nextStage;
}
assert.strictEqual(sas.finished, true);
},
"Verification is cancelled after 10 minutes": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount:matrix.org";
const theirUserId = "@foobaraccount3:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.youSentRequest()
.theySentStart()
.fixtures();
console.log("receivedMessages", receivedMessages);
const { sas, clock } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages
);
const promise = sas.start();
clock.elapse(10 * 60 * 1000);
try {
await promise;
}
catch (e) {
assert.strictEqual(e instanceof VerificationCancelledError, true);
}
assert.strictEqual(sas.finished, true);
},
"Verification is cancelled when there's no common hash algorithm": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
const ourUserId = "@foobaraccount:matrix.org";
const theirUserId = "@foobaraccount3:matrix.org";
const theirDeviceId = "FWKXUYUHTF";
const txnId = "t150836b91a7bed";
const receivedMessages = new SASFixtures(theirUserId, theirDeviceId, txnId)
.youSentRequest()
.theySentStart()
.fixtures();
receivedMessages.get(VerificationEventTypes.Start).content.key_agreement_protocols = ["foo"];
const { sas } = await createSASRequest(
ourUserId,
ourDeviceId,
theirUserId,
theirDeviceId,
txnId,
receivedMessages
);
try {
await sas.start()
}
catch (e) {
assert.strictEqual(e instanceof VerificationCancelledError, true);
}
assert.strictEqual(sas.finished, true);
},
}
}

View File

@ -0,0 +1,133 @@
import type {ILogItem} from "../../../../lib";
import {createCalculateMAC} from "../mac";
import {VerificationCancelledError} from "../VerificationCancelledError";
import {IChannel} from "./Channel";
import {CancelTypes, VerificationEventTypes} from "./types";
import anotherjson from "another-json";
interface ITestChannel extends IChannel {
setOlmSas(olmSas): void;
}
export class MockChannel implements ITestChannel {
public sentMessages: Map<string, any> = new Map();
public receivedMessages: Map<string, any> = new Map();
public initiatedByUs: boolean;
public startMessage: any;
public isCancelled: boolean = false;
private olmSas: any;
constructor(
public otherUserDeviceId: string,
public otherUserId: string,
public ourUserDeviceId: string,
public ourUserId: string,
private fixtures: Map<string, any>,
private deviceTracker: any,
public id: string,
private olm: any,
startingMessage?: any,
) {
if (startingMessage) {
const eventType = startingMessage.content.method ? VerificationEventTypes.Start : VerificationEventTypes.Request;
this.id = startingMessage.content.transaction_id;
this.receivedMessages.set(eventType, startingMessage);
}
}
async send(eventType: string, content: any, _: ILogItem) {
if (this.isCancelled) {
throw new VerificationCancelledError();
}
Object.assign(content, { transaction_id: this.id });
this.sentMessages.set(eventType, {content});
}
async waitForEvent(eventType: string): Promise<any> {
if (this.isCancelled) {
throw new VerificationCancelledError();
}
const event = this.fixtures.get(eventType);
if (event) {
this.receivedMessages.set(eventType, event);
}
else {
await new Promise(() => {});
}
if (eventType === VerificationEventTypes.Mac) {
await this.recalculateMAC();
}
if(eventType === VerificationEventTypes.Accept && this.startMessage) {
}
return event;
}
private recalculateCommitment() {
const acceptMessage = this.getEvent(VerificationEventTypes.Accept)?.content;
if (!acceptMessage) {
return;
}
const {content} = this.startMessage;
const {content: keyMessage} = this.fixtures.get(VerificationEventTypes.Key);
const key = keyMessage.key;
const commitmentStr = key + anotherjson.stringify(content);
const olmUtil = new this.olm.Utility();
const commitment = olmUtil.sha256(commitmentStr);
olmUtil.free();
acceptMessage.commitment = commitment;
}
private async recalculateMAC() {
// We need to replace the mac with calculated mac
const baseInfo =
"MATRIX_KEY_VERIFICATION_MAC" +
this.otherUserId +
this.otherUserDeviceId +
this.ourUserId +
this.ourUserDeviceId +
this.id;
const { content: macContent } = this.receivedMessages.get(VerificationEventTypes.Mac);
const macMethod = this.getEvent(VerificationEventTypes.Accept).content.message_authentication_code;
const calculateMac = createCalculateMAC(this.olmSas, macMethod);
const input = Object.keys(macContent.mac).sort().join(",");
const properMac = calculateMac(input, baseInfo + "KEY_IDS");
macContent.keys = properMac;
for (const keyId of Object.keys(macContent.mac)) {
const deviceId = keyId.split(":", 2)[1];
const device = await this.deviceTracker.deviceForId(this.otherUserDeviceId, deviceId);
if (device) {
macContent.mac[keyId] = calculateMac(device.ed25519Key, baseInfo + keyId);
}
else {
const {masterKey} = await this.deviceTracker.getCrossSigningKeysForUser(this.otherUserId);
macContent.mac[keyId] = calculateMac(masterKey, baseInfo + keyId);
}
}
}
setStartMessage(event: any): void {
this.startMessage = event;
this.recalculateCommitment();
}
setInitiatedByUs(value: boolean): void {
this.initiatedByUs = value;
}
async cancelVerification(_: CancelTypes): Promise<void> {
console.log("MockChannel.cancelVerification()");
this.isCancelled = true;
}
getEvent(eventType: VerificationEventTypes.Accept): any {
return this.receivedMessages.get(eventType) ?? this.sentMessages.get(eventType);
}
setOlmSas(olmSas: any): void {
this.olmSas = olmSas;
}
get type() {
return 0;
}
}

View File

@ -14,13 +14,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import type {ILogItem} from "../../../../lib.js"; import type {ILogItem} from "../../../../lib.js";
import type {Room} from "../../../room/Room.js";
import type * as OlmNamespace from "@matrix-org/olm"; import type * as OlmNamespace from "@matrix-org/olm";
import type {Account} from "../../../e2ee/Account.js"; import type {Account} from "../../../e2ee/Account.js";
import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js"; import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js";
import {Disposables} from "../../../../utils/Disposables"; import {Disposables} from "../../../../utils/Disposables";
import {IChannel} from "../channel/Channel.js"; import {IChannel} from "../channel/Channel.js";
import {HomeServerApi} from "../../../net/HomeServerApi.js"; import {HomeServerApi} from "../../../net/HomeServerApi.js";
import {SASProgressEvents} from "../types.js";
import {EventEmitter} from "../../../../utils/EventEmitter";
type Olm = typeof OlmNamespace; type Olm = typeof OlmNamespace;
@ -39,6 +40,7 @@ export type Options = {
e2eeAccount: Account; e2eeAccount: Account;
deviceTracker: DeviceTracker; deviceTracker: DeviceTracker;
hsApi: HomeServerApi; hsApi: HomeServerApi;
eventEmitter: EventEmitter<SASProgressEvents>
} }
export abstract class BaseSASVerificationStage extends Disposables { export abstract class BaseSASVerificationStage extends Disposables {
@ -55,6 +57,7 @@ export abstract class BaseSASVerificationStage extends Disposables {
protected e2eeAccount: Account; protected e2eeAccount: Account;
protected deviceTracker: DeviceTracker; protected deviceTracker: DeviceTracker;
protected hsApi: HomeServerApi; protected hsApi: HomeServerApi;
protected eventEmitter: EventEmitter<SASProgressEvents>;
constructor(options: Options) { constructor(options: Options) {
super(); super();
@ -68,6 +71,7 @@ export abstract class BaseSASVerificationStage extends Disposables {
this.e2eeAccount = options.e2eeAccount; this.e2eeAccount = options.e2eeAccount;
this.deviceTracker = options.deviceTracker; this.deviceTracker = options.deviceTracker;
this.hsApi = options.hsApi; this.hsApi = options.hsApi;
this.eventEmitter = options.eventEmitter;
} }
setNextStage(stage: BaseSASVerificationStage) { setNextStage(stage: BaseSASVerificationStage) {

View File

@ -27,7 +27,7 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("SelectVerificationMethodStage.completeStage", async (log) => { await this.log.wrap("SelectVerificationMethodStage.completeStage", async (log) => {
(window as any).select = () => this.selectEmojiMethod(log); this.eventEmitter.emit("SelectVerificationStage", this);
const startMessage = this.channel.waitForEvent(VerificationEventTypes.Start); const startMessage = this.channel.waitForEvent(VerificationEventTypes.Start);
const acceptMessage = this.channel.waitForEvent(VerificationEventTypes.Accept); const acceptMessage = this.channel.waitForEvent(VerificationEventTypes.Accept);
const { content } = await Promise.race([startMessage, acceptMessage]); const { content } = await Promise.race([startMessage, acceptMessage]);
@ -59,7 +59,7 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
}); });
} }
async resolveStartConflict() { private async resolveStartConflict() {
const receivedStartMessage = this.channel.receivedMessages.get(VerificationEventTypes.Start); const receivedStartMessage = this.channel.receivedMessages.get(VerificationEventTypes.Start);
const sentStartMessage = this.channel.sentMessages.get(VerificationEventTypes.Start); const sentStartMessage = this.channel.sentMessages.get(VerificationEventTypes.Start);
if (receivedStartMessage.content.method !== sentStartMessage.content.method) { if (receivedStartMessage.content.method !== sentStartMessage.content.method) {

View File

@ -28,7 +28,6 @@ export class SendAcceptVerificationStage extends BaseSASVerificationStage {
const macMethod = intersection(MAC_LIST, new Set(content.message_authentication_codes))[0]; const macMethod = intersection(MAC_LIST, new Set(content.message_authentication_codes))[0];
const sasMethods = intersection(content.short_authentication_string, SAS_SET); const sasMethods = intersection(content.short_authentication_string, SAS_SET);
if (!(keyAgreement !== undefined && hashMethod !== undefined && macMethod !== undefined && sasMethods.length)) { if (!(keyAgreement !== undefined && hashMethod !== undefined && macMethod !== undefined && sasMethods.length)) {
// todo: ensure this cancels the verification
await this.channel.cancelVerification(CancelTypes.UnknownMethod); await this.channel.cancelVerification(CancelTypes.UnknownMethod);
return; return;
} }

View File

@ -54,14 +54,15 @@ export class VerifyMacStage extends BaseSASVerificationStage {
this.ourUser.deviceId + this.ourUser.deviceId +
this.channel.id; this.channel.id;
if ( content.keys !== this.calculateMAC(Object.keys(content.mac).sort().join(","), baseInfo + "KEY_IDS")) { const calculatedMAC = this.calculateMAC(Object.keys(content.mac).sort().join(","), baseInfo + "KEY_IDS");
// cancel when MAC does not match! if (content.keys !== calculatedMAC) {
// todo: cancel when MAC does not match!
console.log("Keys MAC Verification failed"); console.log("Keys MAC Verification failed");
} }
await this.verifyKeys(content.mac, (keyId, key, keyInfo) => { await this.verifyKeys(content.mac, (keyId, key, keyInfo) => {
if (keyInfo !== this.calculateMAC(key, baseInfo + keyId)) { if (keyInfo !== this.calculateMAC(key, baseInfo + keyId)) {
// cancel when MAC does not match! // todo: cancel when MAC does not match!
console.log("mac obj MAC Verification failed"); console.log("mac obj MAC Verification failed");
} }
}, log); }, log);

View File

@ -0,0 +1,20 @@
/*
Copyright 2023 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage";
export type SASProgressEvents = {
SelectVerificationStage: SelectVerificationMethodStage;
}