diff --git a/src/matrix/registration/Registration.ts b/src/matrix/registration/Registration.ts index 7fb263f3..15006ce2 100644 --- a/src/matrix/registration/Registration.ts +++ b/src/matrix/registration/Registration.ts @@ -17,15 +17,19 @@ limitations under the License. import type {HomeServerApi} from "../net/HomeServerApi"; import {registrationStageFromType} from "./registrationStageFromType"; import type {BaseRegistrationStage} from "./stages/BaseRegistrationStage"; -import type {RegistrationDetails, RegistrationResponse} from "./types/type"; +import type {RegistrationDetails, RegistrationResponse, RegistrationFlow} from "./types/types"; + +type FlowSelector = (flows: RegistrationFlow[]) => RegistrationFlow | void; export class Registration { private _hsApi: HomeServerApi; private _data: RegistrationDetails; + private _flowSelector: FlowSelector; - constructor(hsApi: HomeServerApi, data: RegistrationDetails) { + constructor(hsApi: HomeServerApi, data: RegistrationDetails, flowSelector?: FlowSelector) { this._hsApi = hsApi; this._data = data; + this._flowSelector = flowSelector ?? (flows => flows.pop()); } async start(): Promise<BaseRegistrationStage> { @@ -40,16 +44,16 @@ export class Registration { parseStagesFromResponse(response: RegistrationResponse): BaseRegistrationStage { const { session, params } = response; - const flow = response.flows.pop(); + const flow = this._flowSelector(response.flows); if (!flow) { - throw new Error("No registration flows available!"); + throw new Error("flowSelector did not return any flow!"); } let firstStage: BaseRegistrationStage | undefined; let lastStage: BaseRegistrationStage; for (const stage of flow.stages) { const stageClass = registrationStageFromType(stage); if (!stageClass) { - throw new Error("Unknown stage"); + throw new Error(`Unknown stage: ${stage}`); } const registrationStage = new stageClass(this._hsApi, this._data, session, params?.[stage]); if (!firstStage) { diff --git a/src/matrix/registration/registrationStageFromType.ts b/src/matrix/registration/registrationStageFromType.ts index 28d5c2b0..9552d049 100644 --- a/src/matrix/registration/registrationStageFromType.ts +++ b/src/matrix/registration/registrationStageFromType.ts @@ -16,7 +16,7 @@ limitations under the License. import type {BaseRegistrationStage} from "./stages/BaseRegistrationStage"; import type {HomeServerApi} from "../net/HomeServerApi"; -import type {RegistrationDetails} from "./types/type"; +import type {RegistrationDetails} from "./types/types"; import {DummyAuth} from "./stages/DummyAuth"; import {TermsAuth} from "./stages/TermsAuth"; diff --git a/src/matrix/registration/stages/BaseRegistrationStage.ts b/src/matrix/registration/stages/BaseRegistrationStage.ts index 2dfc0628..62325b4c 100644 --- a/src/matrix/registration/stages/BaseRegistrationStage.ts +++ b/src/matrix/registration/stages/BaseRegistrationStage.ts @@ -15,7 +15,7 @@ limitations under the License. */ import type {HomeServerApi} from "../../net/HomeServerApi"; -import type {RegistrationDetails, RegistrationResponse, AuthenticationData, RegistrationParams} from "../types/type"; +import type {RegistrationDetails, RegistrationResponse, AuthenticationData, RegistrationParams} from "../types/types"; export abstract class BaseRegistrationStage { protected _hsApi: HomeServerApi; diff --git a/src/matrix/registration/types/type.ts b/src/matrix/registration/types/types.ts similarity index 94% rename from src/matrix/registration/types/type.ts rename to src/matrix/registration/types/types.ts index 2863ec28..97dca1be 100644 --- a/src/matrix/registration/types/type.ts +++ b/src/matrix/registration/types/types.ts @@ -25,7 +25,7 @@ export type RegistrationResponse = RegistrationResponse401 & RegistrationRespons type RegistrationResponse401 = { completed: string[]; - flows: Record<string, any>[]; + flows: RegistrationFlow[]; params: Record<string, any>; session: string; } @@ -41,6 +41,10 @@ type RegistrationResponseSuccess = { access_token?: string; } +export type RegistrationFlow = { + stages: string[]; +} + /* Types for Registration Stage */ export type AuthenticationData = {