Throw specific error when cancelled

This commit is contained in:
RMidhunSuresh 2023-03-07 17:27:27 +05:30
parent b3cc07cf1e
commit 0b51fc0168
No known key found for this signature in database
15 changed files with 125 additions and 206 deletions

View File

@ -45,6 +45,7 @@ export class CrossSigning {
private readonly deviceMessageHandler: DeviceMessageHandler; private readonly deviceMessageHandler: DeviceMessageHandler;
private _isMasterKeyTrusted: boolean = false; private _isMasterKeyTrusted: boolean = false;
private readonly deviceId: string; private readonly deviceId: string;
private sasVerificationInProgress?: SASVerification;
constructor(options: { constructor(options: {
storage: Storage, storage: Storage,
@ -72,12 +73,20 @@ export class CrossSigning {
this.deviceMessageHandler = options.deviceMessageHandler; this.deviceMessageHandler = options.deviceMessageHandler;
this.deviceMessageHandler.on("message", async ({ unencrypted: unencryptedEvent }) => { this.deviceMessageHandler.on("message", async ({ unencrypted: unencryptedEvent }) => {
if (this.sasVerificationInProgress &&
(
!this.sasVerificationInProgress.finished ||
// If the start message is for the previous sasverification, ignore it.
this.sasVerificationInProgress.channel.id === unencryptedEvent.content.transaction_id
)) {
return;
}
console.log("unencrypted event", unencryptedEvent); console.log("unencrypted event", unencryptedEvent);
if (unencryptedEvent.type === VerificationEventTypes.Request || if (unencryptedEvent.type === VerificationEventTypes.Request ||
unencryptedEvent.type === VerificationEventTypes.Start) { unencryptedEvent.type === VerificationEventTypes.Start) {
await this.platform.logger.run("Start verification from request", async (log) => { await this.platform.logger.run("Start verification from request", async (log) => {
const sas = this.startVerification(unencryptedEvent.sender, log, unencryptedEvent); const sas = this.startVerification(unencryptedEvent.sender, log, unencryptedEvent);
await sas.start(); await sas?.start();
}); });
} }
}) })
@ -134,7 +143,10 @@ export class CrossSigning {
return this._isMasterKeyTrusted; return this._isMasterKeyTrusted;
} }
startVerification(userId: string, log: ILogItem, event?: any): SASVerification { startVerification(userId: string, log: ILogItem, event?: any): SASVerification | undefined {
if (this.sasVerificationInProgress && !this.sasVerificationInProgress.finished) {
return;
}
const channel = new ToDeviceChannel({ const channel = new ToDeviceChannel({
deviceTracker: this.deviceTracker, deviceTracker: this.deviceTracker,
hsApi: this.hsApi, hsApi: this.hsApi,
@ -143,7 +155,8 @@ export class CrossSigning {
deviceMessageHandler: this.deviceMessageHandler, deviceMessageHandler: this.deviceMessageHandler,
log log
}, event); }, event);
return new SASVerification({
this.sasVerificationInProgress = new SASVerification({
olm: this.olm, olm: this.olm,
olmUtil: this.olmUtil, olmUtil: this.olmUtil,
ourUser: { userId: this.ownUserId, deviceId: this.deviceId }, ourUser: { userId: this.ownUserId, deviceId: this.deviceId },
@ -154,6 +167,7 @@ export class CrossSigning {
deviceTracker: this.deviceTracker, deviceTracker: this.deviceTracker,
hsApi: this.hsApi, hsApi: this.hsApi,
}); });
return this.sasVerificationInProgress;
} }
} }

View File

@ -24,6 +24,7 @@ import {HomeServerApi} from "../../net/HomeServerApi";
import {VerificationEventTypes} from "./channel/types"; import {VerificationEventTypes} from "./channel/types";
import {SendReadyStage} from "./stages/SendReadyStage"; import {SendReadyStage} from "./stages/SendReadyStage";
import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage"; import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage";
import {VerificationCancelledError} from "./VerificationCancelledError";
type Olm = typeof OlmNamespace; type Olm = typeof OlmNamespace;
@ -42,12 +43,14 @@ type Options = {
export class SASVerification { export class SASVerification {
private startStage: BaseSASVerificationStage; private startStage: BaseSASVerificationStage;
private olmSas: Olm.SAS; private olmSas: Olm.SAS;
public finished: boolean = false;
public readonly channel: IChannel;
constructor(options: Options) { constructor(options: Options) {
const { ourUser, otherUserId, log, olmUtil, olm, channel, e2eeAccount, deviceTracker, hsApi } = options; const { ourUser, otherUserId, log, olmUtil, olm, channel, e2eeAccount, deviceTracker, hsApi } = options;
const olmSas = new olm.SAS(); const olmSas = new olm.SAS();
this.olmSas = olmSas; this.olmSas = olmSas;
// channel.send("m.key.verification.request", {}, log); this.channel = channel;
try { try {
const options = { ourUser, otherUserId, log, olmSas, olmUtil, channel, e2eeAccount, deviceTracker, hsApi}; const options = { ourUser, otherUserId, log, olmSas, olmUtil, channel, e2eeAccount, deviceTracker, hsApi};
let stage: BaseSASVerificationStage; let stage: BaseSASVerificationStage;
@ -71,12 +74,20 @@ export class SASVerification {
try { try {
let stage = this.startStage; let stage = this.startStage;
do { do {
console.log("Running next stage");
await stage.completeStage(); await stage.completeStage();
stage = stage.nextStage; stage = stage.nextStage;
} while (stage); } while (stage);
} }
catch (e) {
if (!(e instanceof VerificationCancelledError)) {
throw e;
}
console.log("Caught error in start()");
}
finally { finally {
this.olmSas.free(); this.olmSas.free();
this.finished = true;
} }
} }
} }

View File

@ -0,0 +1,25 @@
/*
Copyright 2023 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
export class VerificationCancelledError extends Error {
get name(): string {
return "VerificationCancelledError";
}
get message(): string {
return "Verification is cancelled!";
}
}

View File

@ -21,9 +21,11 @@ import type {Platform} from "../../../../platform/web/Platform.js";
import type {DeviceMessageHandler} from "../../../DeviceMessageHandler.js"; import type {DeviceMessageHandler} from "../../../DeviceMessageHandler.js";
import {makeTxnId} from "../../../common.js"; import {makeTxnId} from "../../../common.js";
import {CancelTypes, VerificationEventTypes} from "./types"; import {CancelTypes, VerificationEventTypes} from "./types";
import {Disposables} from "../../../../lib";
import {VerificationCancelledError} from "../VerificationCancelledError";
const messageFromErrorType = { const messageFromErrorType = {
[CancelTypes.UserCancelled]: "User cancelled this request.", [CancelTypes.UserCancelled]: "User declined",
[CancelTypes.InvalidMessage]: "Invalid Message.", [CancelTypes.InvalidMessage]: "Invalid Message.",
[CancelTypes.KeyMismatch]: "Key Mismatch.", [CancelTypes.KeyMismatch]: "Key Mismatch.",
[CancelTypes.OtherUserAccepted]: "Another device has accepted this request.", [CancelTypes.OtherUserAccepted]: "Another device has accepted this request.",
@ -49,12 +51,12 @@ export interface IChannel {
otherUserDeviceId: string; otherUserDeviceId: string;
sentMessages: Map<string, any>; sentMessages: Map<string, any>;
receivedMessages: Map<string, any>; receivedMessages: Map<string, any>;
localMessages: Map<string, any>;
setStartMessage(content: any): void; setStartMessage(content: any): void;
setInitiatedByUs(value: boolean): void; setInitiatedByUs(value: boolean): void;
initiatedByUs: boolean; initiatedByUs: boolean;
startMessage: any; startMessage: any;
cancelVerification(cancellationType: CancelTypes): Promise<void>; cancelVerification(cancellationType: CancelTypes): Promise<void>;
getEvent(eventType: VerificationEventTypes.Accept): any;
} }
type Options = { type Options = {
@ -66,7 +68,7 @@ type Options = {
log: ILogItem; log: ILogItem;
} }
export class ToDeviceChannel implements IChannel { export class ToDeviceChannel extends Disposables implements IChannel {
private readonly hsApi: HomeServerApi; private readonly hsApi: HomeServerApi;
private readonly deviceTracker: DeviceTracker; private readonly deviceTracker: DeviceTracker;
private readonly otherUserId: string; private readonly otherUserId: string;
@ -74,19 +76,20 @@ export class ToDeviceChannel implements IChannel {
private readonly deviceMessageHandler: DeviceMessageHandler; private readonly deviceMessageHandler: DeviceMessageHandler;
public readonly sentMessages: Map<string, any> = new Map(); public readonly sentMessages: Map<string, any> = new Map();
public readonly receivedMessages: Map<string, any> = new Map(); public readonly receivedMessages: Map<string, any> = new Map();
public readonly localMessages: Map<string, any> = new Map(); private readonly waitMap: Map<string, {resolve: any, reject: any, promise: Promise<any>}> = new Map();
private readonly waitMap: Map<string, {resolve: any, promise: Promise<any>}> = new Map();
private readonly log: ILogItem; private readonly log: ILogItem;
public otherUserDeviceId: string; public otherUserDeviceId: string;
public startMessage: any; public startMessage: any;
public id: string; public id: string;
private _initiatedByUs: boolean; private _initiatedByUs: boolean;
private _isCancelled = false;
/** /**
* *
* @param startingMessage Create the channel with existing message in the receivedMessage buffer * @param startingMessage Create the channel with existing message in the receivedMessage buffer
*/ */
constructor(options: Options, startingMessage?: any) { constructor(options: Options, startingMessage?: any) {
super();
this.hsApi = options.hsApi; this.hsApi = options.hsApi;
this.deviceTracker = options.deviceTracker; this.deviceTracker = options.deviceTracker;
this.otherUserId = options.otherUserId; this.otherUserId = options.otherUserId;
@ -94,7 +97,10 @@ export class ToDeviceChannel implements IChannel {
this.log = options.log; this.log = options.log;
this.deviceMessageHandler = options.deviceMessageHandler; this.deviceMessageHandler = options.deviceMessageHandler;
// todo: find a way to dispose this subscription // todo: find a way to dispose this subscription
this.deviceMessageHandler.on("message", ({unencrypted}) => this.handleDeviceMessage(unencrypted)) this.track(this.deviceMessageHandler.disposableOn("message", ({ unencrypted }) => this.handleDeviceMessage(unencrypted)));
this.track(() => {
this.waitMap.forEach((value) => { value.reject(new VerificationCancelledError()); });
});
// Copy over request message // Copy over request message
if (startingMessage) { if (startingMessage) {
/** /**
@ -105,14 +111,22 @@ export class ToDeviceChannel implements IChannel {
this.receivedMessages.set(eventType, startingMessage); this.receivedMessages.set(eventType, startingMessage);
this.otherUserDeviceId = startingMessage.content.from_device; this.otherUserDeviceId = startingMessage.content.from_device;
} }
(window as any).foo = () => this.cancelVerification(CancelTypes.OtherUserAccepted);
} }
get type() { get type() {
return ChannelType.ToDeviceMessage; return ChannelType.ToDeviceMessage;
} }
get isCancelled(): boolean {
return this._isCancelled;
}
async send(eventType: string, content: any, log: ILogItem): Promise<void> { async send(eventType: string, content: any, log: ILogItem): Promise<void> {
await log.wrap("ToDeviceChannel.send", async () => { await log.wrap("ToDeviceChannel.send", async () => {
if (this.isCancelled) {
throw new VerificationCancelledError();
}
if (eventType === VerificationEventTypes.Request) { if (eventType === VerificationEventTypes.Request) {
// Handle this case specially // Handle this case specially
await this.handleRequestEventSpecially(eventType, content, log); await this.handleRequestEventSpecially(eventType, content, log);
@ -128,12 +142,12 @@ export class ToDeviceChannel implements IChannel {
} }
} }
} }
await this.hsApi.sendToDevice(eventType, payload, this.id, { log }).response(); await this.hsApi.sendToDevice(eventType, payload, makeTxnId(), { log }).response();
this.sentMessages.set(eventType, {content}); this.sentMessages.set(eventType, {content});
}); });
} }
async handleRequestEventSpecially(eventType: string, content: any, log: ILogItem) { private async handleRequestEventSpecially(eventType: string, content: any, log: ILogItem) {
await log.wrap("ToDeviceChannel.handleRequestEventSpecially", async () => { await log.wrap("ToDeviceChannel.handleRequestEventSpecially", async () => {
const timestamp = this.platform.clock.now(); const timestamp = this.platform.clock.now();
const txnId = makeTxnId(); const txnId = makeTxnId();
@ -146,10 +160,14 @@ export class ToDeviceChannel implements IChannel {
} }
} }
} }
await this.hsApi.sendToDevice(eventType, payload, txnId, { log }).response(); await this.hsApi.sendToDevice(eventType, payload, makeTxnId(), { log }).response();
}); });
} }
getEvent(eventType: VerificationEventTypes.Accept) {
return this.receivedMessages.get(eventType) ?? this.sentMessages.get(eventType);
}
private handleDeviceMessage(event) { private handleDeviceMessage(event) {
this.log.wrap("ToDeviceChannel.handleDeviceMessage", (log) => { this.log.wrap("ToDeviceChannel.handleDeviceMessage", (log) => {
@ -159,6 +177,11 @@ export class ToDeviceChannel implements IChannel {
this.receivedMessages.set(event.type, event); this.receivedMessages.set(event.type, event);
if (event.type === VerificationEventTypes.Ready) { if (event.type === VerificationEventTypes.Ready) {
this.handleReadyMessage(event, log); this.handleReadyMessage(event, log);
return;
}
if (event.type === VerificationEventTypes.Cancel) {
this.dispose();
return;
} }
}); });
} }
@ -181,7 +204,7 @@ export class ToDeviceChannel implements IChannel {
[this.otherUserId]: deviceMessages [this.otherUserId]: deviceMessages
} }
} }
await this.hsApi.sendToDevice(VerificationEventTypes.Cancel, payload, this.id, { log }).response(); await this.hsApi.sendToDevice(VerificationEventTypes.Cancel, payload, makeTxnId(), { log }).response();
} }
catch (e) { catch (e) {
console.log(e); console.log(e);
@ -191,6 +214,9 @@ export class ToDeviceChannel implements IChannel {
async cancelVerification(cancellationType: CancelTypes) { async cancelVerification(cancellationType: CancelTypes) {
await this.log.wrap("Channel.cancelVerification", async log => { await this.log.wrap("Channel.cancelVerification", async log => {
if (this.isCancelled) {
throw new VerificationCancelledError();
}
const payload = { const payload = {
messages: { messages: {
[this.otherUserId]: { [this.otherUserId]: {
@ -202,7 +228,9 @@ export class ToDeviceChannel implements IChannel {
} }
} }
} }
await this.hsApi.sendToDevice(VerificationEventTypes.Cancel, payload, this.id, { log }).response(); await this.hsApi.sendToDevice(VerificationEventTypes.Cancel, payload, makeTxnId(), { log }).response();
this._isCancelled = true;
this.dispose();
}); });
} }
@ -226,12 +254,13 @@ export class ToDeviceChannel implements IChannel {
if (existingWait) { if (existingWait) {
return existingWait.promise; return existingWait.promise;
} }
let resolve; let resolve, reject;
// Add to wait map // Add to wait map
const promise = new Promise(r => { const promise = new Promise((_resolve, _reject) => {
resolve = r; resolve = _resolve;
reject = _reject;
}); });
this.waitMap.set(eventType, { resolve, promise }); this.waitMap.set(eventType, { resolve, reject, promise });
return promise; return promise;
} }

View File

@ -70,16 +70,6 @@ export abstract class BaseSASVerificationStage extends Disposables {
this.hsApi = options.hsApi; this.hsApi = options.hsApi;
} }
setRequestEventId(id: string) {
this.requestEventId = id;
// todo: can this race with incoming message?
this.nextStage?.setRequestEventId(id);
}
setResultFromPreviousStage(result?: any) {
this.previousResult = result;
}
setNextStage(stage: BaseSASVerificationStage) { setNextStage(stage: BaseSASVerificationStage) {
this._nextStage = stage; this._nextStage = stage;
} }
@ -88,6 +78,5 @@ export abstract class BaseSASVerificationStage extends Disposables {
return this._nextStage; return this._nextStage;
} }
abstract get type(): string;
abstract completeStage(): Promise<any>; abstract completeStage(): Promise<any>;
} }

View File

@ -88,8 +88,8 @@ export class CalculateSASStage extends BaseSASVerificationStage {
this.olmSAS.set_their_key(this.theirKey); this.olmSAS.set_their_key(this.theirKey);
const sasBytes = this.generateSASBytes(); const sasBytes = this.generateSASBytes();
const emoji = generateEmojiSas(Array.from(sasBytes)); const emoji = generateEmojiSas(Array.from(sasBytes));
console.log("emoji", emoji); console.log("Emoji calculated:", emoji);
this._nextStage = new SendMacStage(this.options); this.setNextStage(new SendMacStage(this.options));
this.dispose(); this.dispose();
}); });
} }
@ -98,9 +98,10 @@ export class CalculateSASStage extends BaseSASVerificationStage {
return await log.wrap("CalculateSASStage.verifyHashCommitment", async () => { return await log.wrap("CalculateSASStage.verifyHashCommitment", async () => {
const acceptMessage = this.channel.receivedMessages.get(VerificationEventTypes.Accept).content; const acceptMessage = this.channel.receivedMessages.get(VerificationEventTypes.Accept).content;
const keyMessage = this.channel.receivedMessages.get(VerificationEventTypes.Key).content; const keyMessage = this.channel.receivedMessages.get(VerificationEventTypes.Key).content;
const commitmentStr = keyMessage.key + anotherjson.stringify(acceptMessage); const commitmentStr = keyMessage.key + anotherjson.stringify(this.channel.startMessage.content);
const receivedCommitment = acceptMessage.commitment; const receivedCommitment = acceptMessage.commitment;
if (this.olmUtil.sha256(commitmentStr) !== receivedCommitment) { const hash = this.olmUtil.sha256(commitmentStr);
if (hash !== receivedCommitment) {
log.set("Commitment mismatched!", {}); log.set("Commitment mismatched!", {});
// cancel the process! // cancel the process!
await this.channel.cancelVerification(CancelTypes.MismatchedCommitment); await this.channel.cancelVerification(CancelTypes.MismatchedCommitment);
@ -120,8 +121,8 @@ export class CalculateSASStage extends BaseSASVerificationStage {
} }
private generateSASBytes(): Uint8Array { private generateSASBytes(): Uint8Array {
const keyAgreement = this.channel.sentMessages.get(VerificationEventTypes.Accept).content.key_agreement_protocol; const keyAgreement = this.channel.getEvent(VerificationEventTypes.Accept).content.key_agreement_protocol;
const otherUserDeviceId = this.channel.startMessage.content.from_device; const otherUserDeviceId = this.channel.otherUserDeviceId;
const sasBytes = calculateKeyAgreement[keyAgreement]({ const sasBytes = calculateKeyAgreement[keyAgreement]({
our: { our: {
userId: this.ourUser.userId, userId: this.ourUser.userId,
@ -150,8 +151,4 @@ export class CalculateSASStage extends BaseSASVerificationStage {
const { content } = this.channel.receivedMessages.get(VerificationEventTypes.Key); const { content } = this.channel.receivedMessages.get(VerificationEventTypes.Key);
return content.key; return content.key;
} }
get type() {
return "m.key.verification.accept";
}
} }

View File

@ -19,51 +19,16 @@ import {SelectVerificationMethodStage} from "./SelectVerificationMethodStage";
import {VerificationEventTypes} from "../channel/types"; import {VerificationEventTypes} from "../channel/types";
export class RequestVerificationStage extends BaseSASVerificationStage { export class RequestVerificationStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("StartVerificationStage.completeStage", async (log) => { await this.log.wrap("StartVerificationStage.completeStage", async (log) => {
const content = { const content = {
// "body": `${this.ourUser.userId} is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.`,
"from_device": this.ourUser.deviceId, "from_device": this.ourUser.deviceId,
"methods": ["m.sas.v1"], "methods": ["m.sas.v1"],
// "msgtype": "m.key.verification.request",
// "to": this.otherUserId,
}; };
// const promise = this.trackEventId();
// await this.room.sendEvent("m.room.message", content, null, log);
await this.channel.send(VerificationEventTypes.Request, content, log); await this.channel.send(VerificationEventTypes.Request, content, log);
this._nextStage = new SelectVerificationMethodStage(this.options); this.setNextStage(new SelectVerificationMethodStage(this.options));
const readyContent = await this.channel.waitForEvent("m.key.verification.ready"); await this.channel.waitForEvent("m.key.verification.ready");
// const eventId = await promise;
// console.log("eventId", eventId);
// this.setRequestEventId(eventId);
this.dispose(); this.dispose();
}); });
} }
// private trackEventId(): Promise<string> {
// return new Promise(resolve => {
// this.track(
// this.room._timeline.entries.subscribe({
// onAdd: (_, entry) => {
// if (entry instanceof FragmentBoundaryEntry) {
// return;
// }
// if (!entry.isPending &&
// entry.content["msgtype"] === "m.key.verification.request" &&
// entry.content["from_device"] === this.ourUser.deviceId) {
// console.log("found event", entry);
// resolve(entry.id);
// }
// },
// onRemove: () => { /**noop*/ },
// onUpdate: () => { /**noop*/ },
// })
// );
// });
// }
get type() {
return "m.key.verification.request";
}
} }

View File

@ -18,6 +18,7 @@ import {KEY_AGREEMENT_LIST, HASHES_LIST, MAC_LIST, SAS_LIST} from "./constants";
import {CancelTypes, VerificationEventTypes} from "../channel/types"; import {CancelTypes, VerificationEventTypes} from "../channel/types";
import type {ILogItem} from "../../../../logging/types"; import type {ILogItem} from "../../../../logging/types";
import {SendAcceptVerificationStage} from "./SendAcceptVerificationStage"; import {SendAcceptVerificationStage} from "./SendAcceptVerificationStage";
import {SendKeyStage} from "./SendKeyStage";
export class SelectVerificationMethodStage extends BaseSASVerificationStage { export class SelectVerificationMethodStage extends BaseSASVerificationStage {
private hasSentStartMessage = false; private hasSentStartMessage = false;
@ -26,6 +27,7 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("SelectVerificationMethodStage.completeStage", async (log) => { await this.log.wrap("SelectVerificationMethodStage.completeStage", async (log) => {
(window as any).select = () => this.selectEmojiMethod(log);
const startMessage = this.channel.waitForEvent(VerificationEventTypes.Start); const startMessage = this.channel.waitForEvent(VerificationEventTypes.Start);
const acceptMessage = this.channel.waitForEvent(VerificationEventTypes.Accept); const acceptMessage = this.channel.waitForEvent(VerificationEventTypes.Accept);
const { content } = await Promise.race([startMessage, acceptMessage]); const { content } = await Promise.race([startMessage, acceptMessage]);
@ -45,7 +47,11 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
this.channel.setStartMessage(this.channel.sentMessages.get(VerificationEventTypes.Start)); this.channel.setStartMessage(this.channel.sentMessages.get(VerificationEventTypes.Start));
this.channel.setInitiatedByUs(true); this.channel.setInitiatedByUs(true);
} }
if (!this.channel.initiatedByUs) { if (this.channel.initiatedByUs) {
await acceptMessage;
this.setNextStage(new SendKeyStage(this.options));
}
else {
// We need to send the accept message next // We need to send the accept message next
this.setNextStage(new SendAcceptVerificationStage(this.options)); this.setNextStage(new SendAcceptVerificationStage(this.options));
} }
@ -86,8 +92,4 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
await this.channel.send(VerificationEventTypes.Start, content, log); await this.channel.send(VerificationEventTypes.Start, content, log);
this.hasSentStartMessage = true; this.hasSentStartMessage = true;
} }
get type() {
return "SelectVerificationStage";
}
} }

View File

@ -15,7 +15,6 @@ limitations under the License.
*/ */
import {BaseSASVerificationStage} from "./BaseSASVerificationStage"; import {BaseSASVerificationStage} from "./BaseSASVerificationStage";
import anotherjson from "another-json"; import anotherjson from "another-json";
import type { KeyAgreement, MacMethod } from "./constants";
import {HASHES_LIST, MAC_LIST, SAS_SET, KEY_AGREEMENT_LIST} from "./constants"; import {HASHES_LIST, MAC_LIST, SAS_SET, KEY_AGREEMENT_LIST} from "./constants";
import {VerificationEventTypes} from "../channel/types"; import {VerificationEventTypes} from "../channel/types";
import {SendKeyStage} from "./SendKeyStage"; import {SendKeyStage} from "./SendKeyStage";
@ -23,11 +22,7 @@ export class SendAcceptVerificationStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("SendAcceptVerificationStage.completeStage", async (log) => { await this.log.wrap("SendAcceptVerificationStage.completeStage", async (log) => {
const event = this.channel.startMessage; const { content } = this.channel.startMessage;
const content = {
...event.content,
// "m.relates_to": event.relation,
};
const keyAgreement = intersection(KEY_AGREEMENT_LIST, new Set(content.key_agreement_protocols))[0]; const keyAgreement = intersection(KEY_AGREEMENT_LIST, new Set(content.key_agreement_protocols))[0];
const hashMethod = intersection(HASHES_LIST, new Set(content.hashes))[0]; const hashMethod = intersection(HASHES_LIST, new Set(content.hashes))[0];
const macMethod = intersection(MAC_LIST, new Set(content.message_authentication_codes))[0]; const macMethod = intersection(MAC_LIST, new Set(content.message_authentication_codes))[0];
@ -50,24 +45,12 @@ export class SendAcceptVerificationStage extends BaseSASVerificationStage {
rel_type: "m.reference", rel_type: "m.reference",
} }
}; };
// await this.room.sendEvent("m.key.verification.accept", contentToSend, null, log);
await this.channel.send(VerificationEventTypes.Accept, contentToSend, log); await this.channel.send(VerificationEventTypes.Accept, contentToSend, log);
this.channel.localMessages.set("our_pub_key", ourPubKey);
await this.channel.waitForEvent(VerificationEventTypes.Key); await this.channel.waitForEvent(VerificationEventTypes.Key);
this._nextStage = new SendKeyStage(this.options); this.setNextStage(new SendKeyStage(this.options));
// this.nextStage?.setResultFromPreviousStage({
// ...this.previousResult,
// [this.type]: contentToSend,
// "our_pub_key": ourPubKey,
// });
this.dispose(); this.dispose();
}); });
} }
get type() {
return "m.key.verification.accept";
}
} }
function intersection<T>(anArray: T[], aSet: Set<T>): T[] { function intersection<T>(anArray: T[], aSet: Set<T>): T[] {

View File

@ -16,17 +16,11 @@ limitations under the License.
import {BaseSASVerificationStage} from "./BaseSASVerificationStage"; import {BaseSASVerificationStage} from "./BaseSASVerificationStage";
import {VerificationEventTypes} from "../channel/types"; import {VerificationEventTypes} from "../channel/types";
export class SendDoneStage extends BaseSASVerificationStage { export class SendDoneStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("VerifyMacStage.completeStage", async (log) => { await this.log.wrap("VerifyMacStage.completeStage", async (log) => {
await this.channel.send(VerificationEventTypes.Done, {}, log); await this.channel.send(VerificationEventTypes.Done, {}, log);
this.dispose(); this.dispose();
}); });
} }
get type() {
return "m.key.verification.accept";
}
} }

View File

@ -18,7 +18,6 @@ import {VerificationEventTypes} from "../channel/types";
import {CalculateSASStage} from "./CalculateSASStage"; import {CalculateSASStage} from "./CalculateSASStage";
export class SendKeyStage extends BaseSASVerificationStage { export class SendKeyStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("SendKeyStage.completeStage", async (log) => { await this.log.wrap("SendKeyStage.completeStage", async (log) => {
const ourSasKey = this.olmSAS.get_pubkey(); const ourSasKey = this.olmSAS.get_pubkey();
@ -30,12 +29,8 @@ export class SendKeyStage extends BaseSASVerificationStage {
* key. * key.
*/ */
await this.channel.waitForEvent(VerificationEventTypes.Key); await this.channel.waitForEvent(VerificationEventTypes.Key);
this._nextStage = new CalculateSASStage(this.options) this.setNextStage(new CalculateSASStage(this.options));
this.dispose(); this.dispose();
}); });
} }
get type() {
return "m.key.verification.accept";
}
} }

View File

@ -37,7 +37,7 @@ export class SendMacStage extends BaseSASVerificationStage {
this.calculateMAC = createCalculateMAC(this.olmSAS, macMethod); this.calculateMAC = createCalculateMAC(this.olmSAS, macMethod);
await this.sendMAC(log); await this.sendMAC(log);
await this.channel.waitForEvent(VerificationEventTypes.Mac); await this.channel.waitForEvent(VerificationEventTypes.Mac);
this._nextStage = new VerifyMacStage(this.options); this.setNextStage(new VerifyMacStage(this.options));
this.dispose(); this.dispose();
}); });
} }
@ -70,9 +70,5 @@ export class SendMacStage extends BaseSASVerificationStage {
console.log("result", mac, keys); console.log("result", mac, keys);
await this.channel.send(VerificationEventTypes.Mac, { mac, keys }, log); await this.channel.send(VerificationEventTypes.Mac, { mac, keys }, log);
} }
get type() {
return "m.key.verification.accept";
}
} }

View File

@ -18,23 +18,15 @@ import {VerificationEventTypes} from "../channel/types";
import {SelectVerificationMethodStage} from "./SelectVerificationMethodStage"; import {SelectVerificationMethodStage} from "./SelectVerificationMethodStage";
export class SendReadyStage extends BaseSASVerificationStage { export class SendReadyStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("StartVerificationStage.completeStage", async (log) => { await this.log.wrap("StartVerificationStage.completeStage", async (log) => {
const content = { const content = {
// "body": `${this.ourUser.userId} is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.`,
"from_device": this.ourUser.deviceId, "from_device": this.ourUser.deviceId,
"methods": ["m.sas.v1"], "methods": ["m.sas.v1"],
// "msgtype": "m.key.verification.request",
// "to": this.otherUserId,
}; };
await this.channel.send(VerificationEventTypes.Ready, content, log); await this.channel.send(VerificationEventTypes.Ready, content, log);
this._nextStage = new SelectVerificationMethodStage(this.options); this.setNextStage(new SelectVerificationMethodStage(this.options));
this.dispose(); this.dispose();
}); });
} }
get type() {
return "m.key.verification.request";
}
} }

View File

@ -39,7 +39,7 @@ export class VerifyMacStage extends BaseSASVerificationStage {
this.calculateMAC = createCalculateMAC(this.olmSAS, macMethod); this.calculateMAC = createCalculateMAC(this.olmSAS, macMethod);
await this.checkMAC(log); await this.checkMAC(log);
await this.channel.waitForEvent(VerificationEventTypes.Done); await this.channel.waitForEvent(VerificationEventTypes.Done);
this._nextStage = new SendDoneStage(this.options); this.setNextStage(new SendDoneStage(this.options));
this.dispose(); this.dispose();
}); });
} }
@ -88,8 +88,4 @@ export class VerifyMacStage extends BaseSASVerificationStage {
} }
} }
} }
get type() {
return "m.key.verification.accept";
}
} }

View File

@ -1,69 +0,0 @@
/*
Copyright 2023 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {BaseSASVerificationStage, Options} from "./BaseSASVerificationStage";
import {FragmentBoundaryEntry} from "../../../room/timeline/entries/FragmentBoundaryEntry.js";
export class WaitForIncomingMessageStage extends BaseSASVerificationStage {
constructor(private messageType: string, options: Options) {
super(options);
}
async completeStage() {
await this.log.wrap("WaitForIncomingMessageStage.completeStage", async (log) => {
const entry = await this.fetchMessageEventsFromTimeline();
console.log("content", entry);
this.nextStage?.setResultFromPreviousStage({
...this.previousResult,
[this.messageType]: entry
});
this.dispose();
});
}
private fetchMessageEventsFromTimeline() {
// todo: add timeout after 10 mins
return new Promise(resolve => {
this.track(
this.room._timeline.entries.subscribe({
onAdd: (_, entry) => {
if (entry.eventType === this.messageType &&
entry.relatedEventId === this.requestEventId) {
resolve(entry);
}
},
onRemove: () => { },
onUpdate: () => { },
})
);
const remoteEntries = this.room._timeline.remoteEntries;
// In case we were slow and the event is already added to the timeline,
for (const entry of remoteEntries) {
if (entry instanceof FragmentBoundaryEntry) {
return;
}
if (entry.eventType === this.messageType &&
entry.relatedEventId === this.requestEventId) {
resolve(entry);
}
}
});
}
get type() {
return this.messageType;
}
}