Refactor code

1. Remove unused properties from base stage
2. Split UserData into fields
3. Write getter for channel prop
This commit is contained in:
RMidhunSuresh 2023-03-14 15:42:02 +05:30
parent dedf64d011
commit d70dd660c5
No known key found for this signature in database
10 changed files with 41 additions and 43 deletions

View File

@ -158,7 +158,8 @@ export class CrossSigning {
this.sasVerificationInProgress = new SASVerification({
olm: this.olm,
olmUtil: this.olmUtil,
ourUser: { userId: this.ownUserId, deviceId: this.deviceId },
ourUserId: this.ownUserId,
ourUserDeviceId: this.deviceId,
otherUserId: userId,
log,
channel,

View File

@ -15,7 +15,7 @@ limitations under the License.
*/
import {RequestVerificationStage} from "./stages/RequestVerificationStage";
import type {ILogItem} from "../../../logging/types";
import type {BaseSASVerificationStage, UserData} from "./stages/BaseSASVerificationStage";
import type {BaseSASVerificationStage} from "./stages/BaseSASVerificationStage";
import type {Account} from "../../e2ee/Account.js";
import type {DeviceTracker} from "../../e2ee/DeviceTracker.js";
import type * as OlmNamespace from "@matrix-org/olm";
@ -35,7 +35,8 @@ type Olm = typeof OlmNamespace;
type Options = {
olm: Olm;
olmUtil: Olm.Utility;
ourUser: UserData;
ourUserId: string;
ourUserDeviceId: string;
otherUserId: string;
channel: IChannel;
log: ILogItem;
@ -176,7 +177,8 @@ export function tests() {
olm,
olmUtil,
otherUserId: theirUserId!,
ourUser: { deviceId: ourDeviceId!, userId: ourUserId! },
ourUserId,
ourUserDeviceId: ourDeviceId,
log,
});
// @ts-ignore

View File

@ -13,25 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import type {ILogItem} from "../../../../lib.js";
import type * as OlmNamespace from "@matrix-org/olm";
import type {ILogItem} from "../../../../logging/types";
import type {Account} from "../../../e2ee/Account.js";
import type {DeviceTracker} from "../../../e2ee/DeviceTracker.js";
import {Disposables} from "../../../../utils/Disposables";
import {IChannel} from "../channel/Channel.js";
import {HomeServerApi} from "../../../net/HomeServerApi.js";
import {SASProgressEvents} from "../types.js";
import {IChannel} from "../channel/Channel";
import {HomeServerApi} from "../../../net/HomeServerApi";
import {SASProgressEvents} from "../types";
import {EventEmitter} from "../../../../utils/EventEmitter";
type Olm = typeof OlmNamespace;
export type UserData = {
userId: string;
deviceId: string;
}
export type Options = {
ourUser: UserData;
ourUserId: string;
ourUserDeviceId: string;
otherUserId: string;
log: ILogItem;
olmSas: Olm.SAS;
@ -44,13 +36,12 @@ export type Options = {
}
export abstract class BaseSASVerificationStage {
protected ourUser: UserData;
protected ourUserId: string;
protected ourUserDeviceId: string;
protected otherUserId: string;
protected log: ILogItem;
protected olmSAS: Olm.SAS;
protected olmUtil: Olm.Utility;
protected requestEventId: string;
protected previousResult: undefined | any;
protected _nextStage: BaseSASVerificationStage;
protected channel: IChannel;
protected options: Options;
@ -61,7 +52,8 @@ export abstract class BaseSASVerificationStage {
constructor(options: Options) {
this.options = options;
this.ourUser = options.ourUser;
this.ourUserId = options.ourUserId;
this.ourUserDeviceId = options.ourUserDeviceId
this.otherUserId = options.otherUserId;
this.log = options.log;
this.olmSAS = options.olmSas;
@ -81,5 +73,13 @@ export abstract class BaseSASVerificationStage {
return this._nextStage;
}
get otherUserDeviceId(): string {
const id = this.channel.otherUserDeviceId;
if (!id) {
throw new Error("Accessed otherUserDeviceId before it was set in channel!");
}
return id;
}
abstract completeStage(): Promise<any>;
}

View File

@ -121,11 +121,11 @@ export class CalculateSASStage extends BaseSASVerificationStage {
private generateSASBytes(): Uint8Array {
const keyAgreement = this.channel.getEvent(VerificationEventTypes.Accept).content.key_agreement_protocol;
const otherUserDeviceId = this.channel.otherUserDeviceId;
const otherUserDeviceId = this.otherUserDeviceId;
const sasBytes = calculateKeyAgreement[keyAgreement]({
our: {
userId: this.ourUser.userId,
deviceId: this.ourUser.deviceId,
userId: this.ourUserId,
deviceId: this.ourUserDeviceId,
publicKey: this.olmSAS.get_pubkey(),
},
their: {

View File

@ -21,7 +21,7 @@ export class RequestVerificationStage extends BaseSASVerificationStage {
async completeStage() {
await this.log.wrap("StartVerificationStage.completeStage", async (log) => {
const content = {
"from_device": this.ourUser.deviceId,
"from_device": this.ourUserDeviceId,
"methods": ["m.sas.v1"],
};
await this.channel.send(VerificationEventTypes.Request, content, log);

View File

@ -64,8 +64,8 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
return;
}
// In the case of conflict, the lexicographically smaller id wins
const our = this.ourUser.userId === this.otherUserId ? this.ourUser.deviceId : this.ourUser.userId;
const their = this.ourUser.userId === this.otherUserId ? this.channel.otherUserDeviceId : this.otherUserId;
const our = this.ourUserId === this.otherUserId ? this.ourUserDeviceId : this.ourUserId;
const their = this.ourUserId === this.otherUserId ? this.otherUserDeviceId : this.otherUserId;
const startMessageToUse = our < their ? sentStartMessage : receivedStartMessage;
this.channel.setStartMessage(startMessageToUse);
}
@ -74,7 +74,7 @@ export class SelectVerificationMethodStage extends BaseSASVerificationStage {
if (!this.allowSelection) { return; }
const content = {
method: "m.sas.v1",
from_device: this.ourUser.deviceId,
from_device: this.ourUserDeviceId,
key_agreement_protocols: KEY_AGREEMENT_LIST,
hashes: HASHES_LIST,
message_authentication_codes: MAC_LIST,

View File

@ -38,12 +38,7 @@ export class SendAcceptVerificationStage extends BaseSASVerificationStage {
hash: hashMethod,
message_authentication_code: macMethod,
short_authentication_string: sasMethods,
// TODO: use selected hash function (when we support multiple)
commitment: this.olmUtil.sha256(commitmentStr),
"m.relates_to": {
event_id: this.requestEventId,
rel_type: "m.reference",
}
};
await this.channel.send(VerificationEventTypes.Accept, contentToSend, log);
await this.channel.waitForEvent(VerificationEventTypes.Key);

View File

@ -46,18 +46,18 @@ export class SendMacStage extends BaseSASVerificationStage {
const keyList: string[] = [];
const baseInfo =
"MATRIX_KEY_VERIFICATION_MAC" +
this.ourUser.userId +
this.ourUser.deviceId +
this.ourUserId +
this.ourUserDeviceId +
this.otherUserId +
this.channel.otherUserDeviceId +
this.otherUserDeviceId +
this.channel.id;
const deviceKeyId = `ed25519:${this.ourUser.deviceId}`;
const deviceKeyId = `ed25519:${this.ourUserDeviceId}`;
const deviceKeys = this.e2eeAccount.getDeviceKeysToSignWithCrossSigning();
mac[deviceKeyId] = this.calculateMAC(deviceKeys.keys[deviceKeyId], baseInfo + deviceKeyId);
keyList.push(deviceKeyId);
const {masterKey: crossSigningKey} = await this.deviceTracker.getCrossSigningKeysForUser(this.ourUser.userId, this.hsApi, log);
const {masterKey: crossSigningKey} = await this.deviceTracker.getCrossSigningKeysForUser(this.ourUserId, this.hsApi, log);
console.log("masterKey", crossSigningKey);
if (crossSigningKey) {
const crossSigningKeyId = `ed25519:${crossSigningKey}`;

View File

@ -21,7 +21,7 @@ export class SendReadyStage extends BaseSASVerificationStage {
async completeStage() {
await this.log.wrap("StartVerificationStage.completeStage", async (log) => {
const content = {
"from_device": this.ourUser.deviceId,
"from_device": this.ourUserDeviceId,
"methods": ["m.sas.v1"],
};
await this.channel.send(VerificationEventTypes.Ready, content, log);

View File

@ -48,9 +48,9 @@ export class VerifyMacStage extends BaseSASVerificationStage {
const baseInfo =
"MATRIX_KEY_VERIFICATION_MAC" +
this.otherUserId +
this.channel.otherUserDeviceId +
this.ourUser.userId +
this.ourUser.deviceId +
this.otherUserDeviceId +
this.ourUserId +
this.ourUserDeviceId +
this.channel.id;
const calculatedMAC = this.calculateMAC(Object.keys(content.mac).sort().join(","), baseInfo + "KEY_IDS");