Refactor to avoid passing crosssigning

This commit is contained in:
RMidhunSuresh 2023-04-10 19:44:44 +05:30
parent 1f2e8332fe
commit 660db4ced3
5 changed files with 47 additions and 46 deletions

View File

@ -55,7 +55,7 @@ export class DeviceVerificationViewModel extends ErrorReportViewModel<SegmentTyp
if (typeof requestOrUserId === "string") {
this.updateCurrentStageViewModel(new WaitingForOtherUserViewModel(this.childOptions({ sas: this.sas })));
}
return this.sas.start();
return crossSigning.signDevice(this.sas, log);
});
}

View File

@ -78,6 +78,11 @@ enum MSKVerification {
Valid
}
export interface IVerificationMethod {
verify(): Promise<boolean>;
otherDeviceId: string;
}
export class CrossSigning {
private readonly storage: Storage;
private readonly secretStorage: SecretStorage;
@ -202,7 +207,6 @@ export class CrossSigning {
deviceTracker: this.deviceTracker,
hsApi: this.hsApi,
clock: this.platform.clock,
crossSigning: this,
});
return this.sasVerificationInProgress;
}
@ -249,13 +253,19 @@ export class CrossSigning {
}
/** @return the signed device key for the given device id */
async signDevice(deviceId: string, log: ILogItem): Promise<DeviceKey | undefined> {
async signDevice(verification: IVerificationMethod, log: ILogItem): Promise<DeviceKey | undefined> {
return log.wrap("CrossSigning.signDevice", async log => {
log.set("id", deviceId);
if (!this._isMasterKeyTrusted) {
log.set("mskNotTrusted", true);
return;
}
const shouldSign = await verification.verify();
log.set("shouldSign", shouldSign);
if (!shouldSign) {
return;
}
const deviceId = verification.otherDeviceId;
log.set("id", deviceId);
const keyToSign = await this.deviceTracker.deviceForId(this.ownUserId, deviceId, this.hsApi, log);
if (!keyToSign) {
return undefined;
@ -266,7 +276,7 @@ export class CrossSigning {
}
/** @return the signed MSK for the given user id */
async signUser(userId: string, log: ILogItem): Promise<CrossSigningKey | undefined> {
async signUser(userId: string, verification: IVerificationMethod, log: ILogItem): Promise<CrossSigningKey | undefined> {
return log.wrap("CrossSigning.signUser", async log => {
log.set("id", userId);
if (!this._isMasterKeyTrusted) {
@ -277,6 +287,11 @@ export class CrossSigning {
if (userId === this.ownUserId) {
return;
}
const shouldSign = await verification.verify();
log.set("shouldSign", shouldSign);
if (!shouldSign) {
return;
}
const keyToSign = await this.deviceTracker.getCrossSigningKeyForUser(userId, KeyUsage.Master, this.hsApi, log);
if (!keyToSign) {
return;

View File

@ -23,13 +23,13 @@ import type {IChannel} from "./channel/IChannel";
import type {HomeServerApi} from "../../net/HomeServerApi";
import type {Timeout} from "../../../platform/types/types";
import type {Clock} from "../../../platform/web/dom/Clock.js";
import type {IVerificationMethod} from "../CrossSigning";
import {CancelReason, VerificationEventType} from "./channel/types";
import {SendReadyStage} from "./stages/SendReadyStage";
import {SelectVerificationMethodStage} from "./stages/SelectVerificationMethodStage";
import {VerificationCancelledError} from "./VerificationCancelledError";
import {EventEmitter} from "../../../utils/EventEmitter";
import {SASProgressEvents} from "./types";
import type {CrossSigning} from "../CrossSigning";
type Olm = typeof OlmNamespace;
@ -45,10 +45,9 @@ type Options = {
deviceTracker: DeviceTracker;
hsApi: HomeServerApi;
clock: Clock;
crossSigning: CrossSigning
}
export class SASVerification extends EventEmitter<SASProgressEvents> {
export class SASVerification extends EventEmitter<SASProgressEvents> implements IVerificationMethod {
private startStage: BaseSASVerificationStage;
private olmSas: Olm.SAS;
public finished: boolean = false;
@ -74,7 +73,7 @@ export class SASVerification extends EventEmitter<SASProgressEvents> {
}
}
private async setupCancelAfterTimeout(clock: Clock) {
private async setupCancelAfterTimeout(clock: Clock): Promise<void> {
try {
const tenMinutes = 10 * 60 * 1000;
this.timeout = clock.createTimeout(tenMinutes);
@ -86,11 +85,12 @@ export class SASVerification extends EventEmitter<SASProgressEvents> {
}
}
async abort() {
async abort(): Promise<void> {
await this.channel.cancelVerification(CancelReason.UserCancelled);
}
async start() {
async verify(): Promise<boolean> {
let success = true;
try {
let stage = this.startStage;
do {
@ -102,6 +102,7 @@ export class SASVerification extends EventEmitter<SASProgressEvents> {
if (!(e instanceof VerificationCancelledError)) {
throw e;
}
success = false;
}
finally {
if (this.channel.isCancelled) {
@ -111,6 +112,11 @@ export class SASVerification extends EventEmitter<SASProgressEvents> {
this.timeout.abort();
this.finished = true;
}
return success;
}
get otherDeviceId(): string {
return this.channel.otherUserDeviceId;
}
}
@ -189,7 +195,6 @@ export function tests() {
olm,
startingMessage,
);
const crossSigning = new MockCrossSigning() as unknown as CrossSigning;
const clock = new MockClock();
const logger = new NullLogger();
return logger.run("log", (log) => {
@ -207,7 +212,6 @@ export function tests() {
ourUserId,
ourUserDeviceId: ourDeviceId,
log,
crossSigning
});
// @ts-ignore
channel.setOlmSas(sas.olmSas);
@ -218,16 +222,6 @@ export function tests() {
});
}
class MockCrossSigning {
signDevice(deviceId: string, log: ILogItem) {
return Promise.resolve({}); // device keys, means signing succeeded
}
signUser(userId: string, log: ILogItem) {
return Promise.resolve({}); // cross-signing keys, means signing succeeded
}
}
return {
"Order of stages created matches expected order when I sent request, they sent start": async (assert) => {
const ourDeviceId = "ILQHOACESQ";
@ -247,7 +241,7 @@ export function tests() {
txnId,
receivedMessages
);
await sas.start();
await sas.verify();
const expectedOrder = [
SendRequestVerificationStage,
SelectVerificationMethodStage,
@ -289,7 +283,7 @@ export function tests() {
await stage?.selectEmojiMethod(log);
});
});
await sas.start();
await sas.verify();
const expectedOrder = [
SendRequestVerificationStage,
SelectVerificationMethodStage,
@ -326,7 +320,7 @@ export function tests() {
receivedMessages,
startingMessage,
);
await sas.start();
await sas.verify();
const expectedOrder = [
SelectVerificationMethodStage,
SendAcceptVerificationStage,
@ -364,7 +358,7 @@ export function tests() {
txnId,
receivedMessages
);
await sas.start();
await sas.verify();
const expectedOrder = [
SendRequestVerificationStage,
SelectVerificationMethodStage,
@ -408,7 +402,7 @@ export function tests() {
await stage?.selectEmojiMethod(log);
});
});
await sas.start();
await sas.verify();
const expectedOrder = [
SendRequestVerificationStage,
SelectVerificationMethodStage,
@ -448,7 +442,7 @@ export function tests() {
receivedMessages,
startingMessage,
);
await sas.start();
await sas.verify();
const expectedOrder = [
SelectVerificationMethodStage,
SendAcceptVerificationStage,
@ -494,7 +488,7 @@ export function tests() {
await stage?.selectEmojiMethod(log);
});
});
await sas.start();
await sas.verify();
const expectedOrder = [
SelectVerificationMethodStage,
SendKeyStage,
@ -537,7 +531,7 @@ export function tests() {
await stage?.selectEmojiMethod(log);
});
});
await sas.start();
await sas.verify();
const expectedOrder = [
SendRequestVerificationStage,
SelectVerificationMethodStage,
@ -575,7 +569,7 @@ export function tests() {
txnId,
receivedMessages
);
await sas.start();
await sas.verify();
const expectedOrder = [
SendRequestVerificationStage,
SelectVerificationMethodStage,
@ -613,7 +607,7 @@ export function tests() {
txnId,
receivedMessages
);
const promise = sas.start();
const promise = sas.verify();
clock.elapse(10 * 60 * 1000);
try {
await promise;
@ -643,7 +637,7 @@ export function tests() {
receivedMessages
);
try {
await sas.start()
await sas.verify()
}
catch (e) {
assert.strictEqual(e instanceof VerificationCancelledError, true);

View File

@ -16,7 +16,6 @@ limitations under the License.
import type {ILogItem} from "../../../../logging/types";
import type {Account} from "../../../e2ee/Account.js";
import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js";
import type {CrossSigning} from "../../CrossSigning";
import {IChannel} from "../channel/IChannel";
import {HomeServerApi} from "../../../net/HomeServerApi";
import {SASProgressEvents} from "../types";
@ -34,7 +33,6 @@ export type Options = {
deviceTracker: DeviceTracker;
hsApi: HomeServerApi;
eventEmitter: EventEmitter<SASProgressEvents>
crossSigning: CrossSigning
}
export abstract class BaseSASVerificationStage {

View File

@ -68,11 +68,8 @@ export class VerifyMacStage extends BaseSASVerificationStage {
const deviceIdOrMSK = keyId.split(":", 2)[1];
const device = await this.deviceTracker.deviceForId(userId, deviceIdOrMSK, this.hsApi, log);
if (device) {
if (verifier(keyId, getDeviceEd25519Key(device), keyInfo)) {
await log.wrap("signing device", async log => {
const signedKey = await this.options.crossSigning.signDevice(device.device_id, log);
log.set("success", !!signedKey);
});
if (!verifier(keyId, getDeviceEd25519Key(device), keyInfo)) {
throw new Error(`MAC verification failed for key ${keyInfo}`);
}
} else {
// If we were not able to find the device, then deviceIdOrMSK is actually the MSK!
@ -82,11 +79,8 @@ export class VerifyMacStage extends BaseSASVerificationStage {
throw new Error("Fetching MSK for user failed!");
}
const masterKey = getKeyEd25519Key(key);
if(masterKey && verifier(keyId, masterKey, keyInfo)) {
await log.wrap("signing user", async log => {
const signedKey = await this.options.crossSigning.signUser(userId, log);
log.set("success", !!signedKey);
});
if(!(masterKey && verifier(keyId, masterKey, keyInfo))) {
throw new Error(`MAC verification failed for key ${keyInfo}`);
}
}
}