Refactor code

1. Remove unused properties from base stage
2. Split UserData into fields
3. Write getter for channel prop
This commit is contained in:
RMidhunSuresh 2023-03-14 15:42:02 +05:30
parent dedf64d011
commit d70dd660c5
No known key found for this signature in database
10 changed files with 41 additions and 43 deletions

View File

@ -158,7 +158,8 @@ export class CrossSigning {
this.sasVerificationInProgress = new SASVerification({ this.sasVerificationInProgress = new SASVerification({
olm: this.olm, olm: this.olm,
olmUtil: this.olmUtil, olmUtil: this.olmUtil,
ourUser: { userId: this.ownUserId, deviceId: this.deviceId }, ourUserId: this.ownUserId,
ourUserDeviceId: this.deviceId,
otherUserId: userId, otherUserId: userId,
log, log,
channel, channel,

View File

@ -15,7 +15,7 @@ limitations under the License.
*/ */
import {RequestVerificationStage} from "./stages/RequestVerificationStage"; import {RequestVerificationStage} from "./stages/RequestVerificationStage";
import type {ILogItem} from "../../../logging/types"; import type {ILogItem} from "../../../logging/types";
import type {BaseSASVerificationStage, UserData} from "./stages/BaseSASVerificationStage"; import type {BaseSASVerificationStage} from "./stages/BaseSASVerificationStage";
import type {Account} from "../../e2ee/Account.js"; import type {Account} from "../../e2ee/Account.js";
import type {DeviceTracker} from "../../e2ee/DeviceTracker.js"; import type {DeviceTracker} from "../../e2ee/DeviceTracker.js";
import type * as OlmNamespace from "@matrix-org/olm"; import type * as OlmNamespace from "@matrix-org/olm";
@ -35,7 +35,8 @@ type Olm = typeof OlmNamespace;
type Options = { type Options = {
olm: Olm; olm: Olm;
olmUtil: Olm.Utility; olmUtil: Olm.Utility;
ourUser: UserData; ourUserId: string;
ourUserDeviceId: string;
otherUserId: string; otherUserId: string;
channel: IChannel; channel: IChannel;
log: ILogItem; log: ILogItem;
@ -176,7 +177,8 @@ export function tests() {
olm, olm,
olmUtil, olmUtil,
otherUserId: theirUserId!, otherUserId: theirUserId!,
ourUser: { deviceId: ourDeviceId!, userId: ourUserId! }, ourUserId,
ourUserDeviceId: ourDeviceId,
log, log,
}); });
// @ts-ignore // @ts-ignore

View File

@ -13,25 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import type {ILogItem} from "../../../../lib.js"; import type {ILogItem} from "../../../../logging/types";
import type * as OlmNamespace from "@matrix-org/olm";
import type {Account} from "../../../e2ee/Account.js"; import type {Account} from "../../../e2ee/Account.js";
import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js"; import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js";
import {Disposables} from "../../../../utils/Disposables"; import {IChannel} from "../channel/Channel";
import {IChannel} from "../channel/Channel.js"; import {HomeServerApi} from "../../../net/HomeServerApi";
import {HomeServerApi} from "../../../net/HomeServerApi.js"; import {SASProgressEvents} from "../types";
import {SASProgressEvents} from "../types.js";
import {EventEmitter} from "../../../../utils/EventEmitter"; import {EventEmitter} from "../../../../utils/EventEmitter";
type Olm = typeof OlmNamespace;
export type UserData = {
userId: string;
deviceId: string;
}
export type Options = { export type Options = {
ourUser: UserData; ourUserId: string;
ourUserDeviceId: string;
otherUserId: string; otherUserId: string;
log: ILogItem; log: ILogItem;
olmSas: Olm.SAS; olmSas: Olm.SAS;
@ -44,13 +36,12 @@ export type Options = {
} }
export abstract class BaseSASVerificationStage { export abstract class BaseSASVerificationStage {
protected ourUser: UserData; protected ourUserId: string;
protected ourUserDeviceId: string;
protected otherUserId: string; protected otherUserId: string;
protected log: ILogItem; protected log: ILogItem;
protected olmSAS: Olm.SAS; protected olmSAS: Olm.SAS;
protected olmUtil: Olm.Utility; protected olmUtil: Olm.Utility;
protected requestEventId: string;
protected previousResult: undefined | any;
protected _nextStage: BaseSASVerificationStage; protected _nextStage: BaseSASVerificationStage;
protected channel: IChannel; protected channel: IChannel;
protected options: Options; protected options: Options;
@ -61,7 +52,8 @@ export abstract class BaseSASVerificationStage {
constructor(options: Options) { constructor(options: Options) {
this.options = options; this.options = options;
this.ourUser = options.ourUser; this.ourUserId = options.ourUserId;
this.ourUserDeviceId = options.ourUserDeviceId
this.otherUserId = options.otherUserId; this.otherUserId = options.otherUserId;
this.log = options.log; this.log = options.log;
this.olmSAS = options.olmSas; this.olmSAS = options.olmSas;
@ -81,5 +73,13 @@ export abstract class BaseSASVerificationStage {
return this._nextStage; return this._nextStage;
} }
get otherUserDeviceId(): string {
const id = this.channel.otherUserDeviceId;
if (!id) {
throw new Error("Accessed otherUserDeviceId before it was set in channel!");
}
return id;
}
abstract completeStage(): Promise<any>; abstract completeStage(): Promise<any>;
} }

View File

@ -121,11 +121,11 @@ export class CalculateSASStage extends BaseSASVerificationStage {
private generateSASBytes(): Uint8Array { private generateSASBytes(): Uint8Array {
const keyAgreement = this.channel.getEvent(VerificationEventTypes.Accept).content.key_agreement_protocol; const keyAgreement = this.channel.getEvent(VerificationEventTypes.Accept).content.key_agreement_protocol;
const otherUserDeviceId = this.channel.otherUserDeviceId; const otherUserDeviceId = this.otherUserDeviceId;
const sasBytes = calculateKeyAgreement[keyAgreement]({ const sasBytes = calculateKeyAgreement[keyAgreement]({
our: { our: {
userId: this.ourUser.userId, userId: this.ourUserId,
deviceId: this.ourUser.deviceId, deviceId: this.ourUserDeviceId,
publicKey: this.olmSAS.get_pubkey(), publicKey: this.olmSAS.get_pubkey(),
}, },
their: { their: {

View File

@ -21,7 +21,7 @@ export class RequestVerificationStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("StartVerificationStage.completeStage", async (log) => { await this.log.wrap("StartVerificationStage.completeStage", async (log) => {
const content = { const content = {
"from_device": this.ourUser.deviceId, "from_device": this.ourUserDeviceId,
"methods": ["m.sas.v1"], "methods": ["m.sas.v1"],
}; };
await this.channel.send(VerificationEventTypes.Request, content, log); await this.channel.send(VerificationEventTypes.Request, content, log);

View File

@ -64,8 +64,8 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
return; return;
} }
// In the case of conflict, the lexicographically smaller id wins // In the case of conflict, the lexicographically smaller id wins
const our = this.ourUser.userId === this.otherUserId ? this.ourUser.deviceId : this.ourUser.userId; const our = this.ourUserId === this.otherUserId ? this.ourUserDeviceId : this.ourUserId;
const their = this.ourUser.userId === this.otherUserId ? this.channel.otherUserDeviceId : this.otherUserId; const their = this.ourUserId === this.otherUserId ? this.otherUserDeviceId : this.otherUserId;
const startMessageToUse = our < their ? sentStartMessage : receivedStartMessage; const startMessageToUse = our < their ? sentStartMessage : receivedStartMessage;
this.channel.setStartMessage(startMessageToUse); this.channel.setStartMessage(startMessageToUse);
} }
@ -74,7 +74,7 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
if (!this.allowSelection) { return; } if (!this.allowSelection) { return; }
const content = { const content = {
method: "m.sas.v1", method: "m.sas.v1",
from_device: this.ourUser.deviceId, from_device: this.ourUserDeviceId,
key_agreement_protocols: KEY_AGREEMENT_LIST, key_agreement_protocols: KEY_AGREEMENT_LIST,
hashes: HASHES_LIST, hashes: HASHES_LIST,
message_authentication_codes: MAC_LIST, message_authentication_codes: MAC_LIST,

View File

@ -38,12 +38,7 @@ export class SendAcceptVerificationStage extends BaseSASVerificationStage {
hash: hashMethod, hash: hashMethod,
message_authentication_code: macMethod, message_authentication_code: macMethod,
short_authentication_string: sasMethods, short_authentication_string: sasMethods,
// TODO: use selected hash function (when we support multiple)
commitment: this.olmUtil.sha256(commitmentStr), commitment: this.olmUtil.sha256(commitmentStr),
"m.relates_to": {
event_id: this.requestEventId,
rel_type: "m.reference",
}
}; };
await this.channel.send(VerificationEventTypes.Accept, contentToSend, log); await this.channel.send(VerificationEventTypes.Accept, contentToSend, log);
await this.channel.waitForEvent(VerificationEventTypes.Key); await this.channel.waitForEvent(VerificationEventTypes.Key);

View File

@ -46,18 +46,18 @@ export class SendMacStage extends BaseSASVerificationStage {
const keyList: string[] = []; const keyList: string[] = [];
const baseInfo = const baseInfo =
"MATRIX_KEY_VERIFICATION_MAC" + "MATRIX_KEY_VERIFICATION_MAC" +
this.ourUser.userId + this.ourUserId +
this.ourUser.deviceId + this.ourUserDeviceId +
this.otherUserId + this.otherUserId +
this.channel.otherUserDeviceId + this.otherUserDeviceId +
this.channel.id; this.channel.id;
const deviceKeyId = `ed25519:${this.ourUser.deviceId}`; const deviceKeyId = `ed25519:${this.ourUserDeviceId}`;
const deviceKeys = this.e2eeAccount.getDeviceKeysToSignWithCrossSigning(); const deviceKeys = this.e2eeAccount.getDeviceKeysToSignWithCrossSigning();
mac[deviceKeyId] = this.calculateMAC(deviceKeys.keys[deviceKeyId], baseInfo + deviceKeyId); mac[deviceKeyId] = this.calculateMAC(deviceKeys.keys[deviceKeyId], baseInfo + deviceKeyId);
keyList.push(deviceKeyId); keyList.push(deviceKeyId);
const {masterKey: crossSigningKey} = await this.deviceTracker.getCrossSigningKeysForUser(this.ourUser.userId, this.hsApi, log); const {masterKey: crossSigningKey} = await this.deviceTracker.getCrossSigningKeysForUser(this.ourUserId, this.hsApi, log);
console.log("masterKey", crossSigningKey); console.log("masterKey", crossSigningKey);
if (crossSigningKey) { if (crossSigningKey) {
const crossSigningKeyId = `ed25519:${crossSigningKey}`; const crossSigningKeyId = `ed25519:${crossSigningKey}`;

View File

@ -21,7 +21,7 @@ export class SendReadyStage extends BaseSASVerificationStage {
async completeStage() { async completeStage() {
await this.log.wrap("StartVerificationStage.completeStage", async (log) => { await this.log.wrap("StartVerificationStage.completeStage", async (log) => {
const content = { const content = {
"from_device": this.ourUser.deviceId, "from_device": this.ourUserDeviceId,
"methods": ["m.sas.v1"], "methods": ["m.sas.v1"],
}; };
await this.channel.send(VerificationEventTypes.Ready, content, log); await this.channel.send(VerificationEventTypes.Ready, content, log);

View File

@ -48,9 +48,9 @@ export class VerifyMacStage extends BaseSASVerificationStage {
const baseInfo = const baseInfo =
"MATRIX_KEY_VERIFICATION_MAC" + "MATRIX_KEY_VERIFICATION_MAC" +
this.otherUserId + this.otherUserId +
this.channel.otherUserDeviceId + this.otherUserDeviceId +
this.ourUser.userId + this.ourUserId +
this.ourUser.deviceId + this.ourUserDeviceId +
this.channel.id; this.channel.id;
const calculatedMAC = this.calculateMAC(Object.keys(content.mac).sort().join(","), baseInfo + "KEY_IDS"); const calculatedMAC = this.calculateMAC(Object.keys(content.mac).sort().join(","), baseInfo + "KEY_IDS");