diff --git a/src/domain/RootViewModel.js b/src/domain/RootViewModel.js index 4094d864..6cc17f35 100644 --- a/src/domain/RootViewModel.js +++ b/src/domain/RootViewModel.js @@ -38,6 +38,8 @@ export class RootViewModel extends ViewModel { this.track(this.navigation.observe("login").subscribe(() => this._applyNavigation())); this.track(this.navigation.observe("session").subscribe(() => this._applyNavigation())); this.track(this.navigation.observe("sso").subscribe(() => this._applyNavigation())); + this.track(this.navigation.observe("oidc-callback").subscribe(() => this._applyNavigation())); + this.track(this.navigation.observe("oidc-error").subscribe(() => this._applyNavigation())); this._applyNavigation(true); } @@ -46,6 +48,8 @@ export class RootViewModel extends ViewModel { const logoutSessionId = this.navigation.path.get("logout")?.value; const sessionId = this.navigation.path.get("session")?.value; const loginToken = this.navigation.path.get("sso")?.value; + const oidcCallback = this.navigation.path.get("oidc-callback")?.value; + const oidcError = this.navigation.path.get("oidc-error")?.value; if (isLogin) { if (this.activeSection !== "login") { this._showLogin(); @@ -77,7 +81,20 @@ export class RootViewModel extends ViewModel { } else if (loginToken) { this.urlCreator.normalizeUrl(); if (this.activeSection !== "login") { - this._showLogin(loginToken); + this._showLogin({loginToken}); + } + } else if (oidcError) { + this._setSection(() => this._error = new Error(`OIDC error: ${oidcError[1]}`)); + } else if (oidcCallback) { + this._setSection(() => this._error = new Error(`OIDC callback: state=${oidcCallback[0]}, code=${oidcCallback[1]}`)); + this.urlCreator.normalizeUrl(); + if (this.activeSection !== "login") { + this._showLogin({ + oidc: { + state: oidcCallback[0], + code: oidcCallback[1], + } + }); } } else { @@ -109,7 +126,7 @@ export class RootViewModel extends ViewModel { } } - _showLogin(loginToken) { + _showLogin({loginToken, oidc} = {}) { this._setSection(() => { this._loginViewModel = new LoginViewModel(this.childOptions({ defaultHomeserver: this.platform.config["defaultHomeServer"], @@ -125,7 +142,8 @@ export class RootViewModel extends ViewModel { this._pendingClient = client; this.navigation.push("session", client.sessionId); }, - loginToken + loginToken, + oidc, })); }); } diff --git a/src/domain/login/CompleteOIDCLoginViewModel.js b/src/domain/login/CompleteOIDCLoginViewModel.js new file mode 100644 index 00000000..fa0b665e --- /dev/null +++ b/src/domain/login/CompleteOIDCLoginViewModel.js @@ -0,0 +1,84 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {OidcApi} from "../../matrix/net/OidcApi"; +import {ViewModel} from "../ViewModel"; +import {OIDCLoginMethod} from "../../matrix/login/OIDCLoginMethod"; +import {LoginFailure} from "../../matrix/Client"; + +export class CompleteOIDCLoginViewModel extends ViewModel { + constructor(options) { + super(options); + const { + state, + code, + attemptLogin, + } = options; + this._request = options.platform.request; + this._encoding = options.platform.encoding; + this._state = state; + this._code = code; + this._attemptLogin = attemptLogin; + this._errorMessage = ""; + this.performOIDCLoginCompletion(); + } + + get errorMessage() { return this._errorMessage; } + + _showError(message) { + this._errorMessage = message; + this.emitChange("errorMessage"); + } + + async performOIDCLoginCompletion() { + if (!this._state || !this._code) { + return; + } + const code = this._code; + // TODO: cleanup settings storage + const [startedAt, nonce, codeVerifier, homeserver, issuer] = await Promise.all([ + this.platform.settingsStorage.getInt(`oidc_${this._state}_started_at`), + this.platform.settingsStorage.getString(`oidc_${this._state}_nonce`), + this.platform.settingsStorage.getString(`oidc_${this._state}_code_verifier`), + this.platform.settingsStorage.getString(`oidc_${this._state}_homeserver`), + this.platform.settingsStorage.getString(`oidc_${this._state}_issuer`), + ]); + + const oidcApi = new OidcApi({ + issuer, + clientId: "hydrogen-web", + request: this._request, + encoding: this._encoding, + }); + const method = new OIDCLoginMethod({oidcApi, nonce, codeVerifier, code, homeserver, startedAt}); + const status = await this._attemptLogin(method); + let error = ""; + switch (status) { + case LoginFailure.Credentials: + error = this.i18n`Your login token is invalid.`; + break; + case LoginFailure.Connection: + error = this.i18n`Can't connect to ${homeserver}.`; + break; + case LoginFailure.Unknown: + error = this.i18n`Something went wrong while checking your login token.`; + break; + } + if (error) { + this._showError(error); + } + } +} diff --git a/src/domain/login/LoginViewModel.ts b/src/domain/login/LoginViewModel.ts index 8eb11a9e..55e503a9 100644 --- a/src/domain/login/LoginViewModel.ts +++ b/src/domain/login/LoginViewModel.ts @@ -15,19 +15,24 @@ limitations under the License. */ import {Client} from "../../matrix/Client.js"; +import {OidcApi} from "../../matrix/net/OidcApi.js"; import {Options as BaseOptions, ViewModel} from "../ViewModel"; import {PasswordLoginViewModel} from "./PasswordLoginViewModel.js"; import {StartSSOLoginViewModel} from "./StartSSOLoginViewModel.js"; import {CompleteSSOLoginViewModel} from "./CompleteSSOLoginViewModel.js"; +import {StartOIDCLoginViewModel} from "./StartOIDCLoginViewModel.js"; +import {CompleteOIDCLoginViewModel} from "./CompleteOIDCLoginViewModel.js"; import {LoadStatus} from "../../matrix/Client.js"; import {SessionLoadViewModel} from "../SessionLoadViewModel.js"; import {SegmentType} from "../navigation/index"; import type {PasswordLoginMethod, SSOLoginHelper, TokenLoginMethod, ILoginMethod} from "../../matrix/login"; +import { OIDCLoginMethod } from "../../matrix/login/OIDCLoginMethod.js"; type Options = { defaultHomeserver: string; ready: ReadyFn; + oidc?: { state: string, code: string }; loginToken?: string; } & BaseOptions; @@ -35,10 +40,13 @@ export class LoginViewModel extends ViewModel { private _ready: ReadyFn; private _loginToken?: string; private _client: Client; + private _oidc?: { state: string, code: string }; private _loginOptions?: LoginOptions; private _passwordLoginViewModel?: PasswordLoginViewModel; private _startSSOLoginViewModel?: StartSSOLoginViewModel; private _completeSSOLoginViewModel?: CompleteSSOLoginViewModel; + private _startOIDCLoginViewModel?: StartOIDCLoginViewModel; + private _completeOIDCLoginViewModel?: CompleteOIDCLoginViewModel; private _loadViewModel?: SessionLoadViewModel; private _loadViewModelSubscription?: () => void; private _homeserver: string; @@ -52,9 +60,10 @@ export class LoginViewModel extends ViewModel { constructor(options: Readonly) { super(options); - const {ready, defaultHomeserver, loginToken} = options; + const {ready, defaultHomeserver, loginToken, oidc} = options; this._ready = ready; this._loginToken = loginToken; + this._oidc = oidc; this._client = new Client(this.platform); this._homeserver = defaultHomeserver; this._initViewModels(); @@ -72,6 +81,15 @@ export class LoginViewModel extends ViewModel { return this._completeSSOLoginViewModel; } + get startOIDCLoginViewModel(): StartOIDCLoginViewModel { + return this._startOIDCLoginViewModel; + } + + get completeOIDCLoginViewModel(): CompleteOIDCLoginViewModel { + return this._completeOIDCLoginViewModel; + } + + get homeserver(): string { return this._homeserver; } @@ -116,6 +134,18 @@ export class LoginViewModel extends ViewModel { }))); this.emitChange("completeSSOLoginViewModel"); } + else if (this._oidc) { + this._hideHomeserver = true; + this._completeOIDCLoginViewModel = this.track(new CompleteOIDCLoginViewModel( + this.childOptions( + { + client: this._client, + attemptLogin: (loginMethod: OIDCLoginMethod) => this.attemptLogin(loginMethod), + state: this._oidc.state, + code: this._oidc.code, + }))); + this.emitChange("completeOIDCLoginViewModel"); + } else { void this.queryHomeserver(); } @@ -137,6 +167,14 @@ export class LoginViewModel extends ViewModel { this.emitChange("startSSOLoginViewModel"); } + private async _showOIDCLogin(): Promise { + this._startOIDCLoginViewModel = this.track( + new StartOIDCLoginViewModel(this.childOptions({loginOptions: this._loginOptions})) + ); + await this._startOIDCLoginViewModel.start(); + this.emitChange("startOIDCLoginViewModel"); + } + private _showError(message: string): void { this._errorMessage = message; this.emitChange("errorMessage"); @@ -263,7 +301,8 @@ export class LoginViewModel extends ViewModel { if (this._loginOptions) { if (this._loginOptions.sso) { this._showSSOLogin(); } if (this._loginOptions.password) { this._showPasswordLogin(); } - if (!this._loginOptions.sso && !this._loginOptions.password) { + if (this._loginOptions.oidc) { await this._showOIDCLogin(); } + if (!this._loginOptions.sso && !this._loginOptions.password && !this._loginOptions.oidc) { this._showError("This homeserver supports neither SSO nor password based login flows"); } } @@ -289,5 +328,6 @@ type LoginOptions = { homeserver: string; password?: (username: string, password: string) => PasswordLoginMethod; sso?: SSOLoginHelper; + oidc?: { issuer: string }; token?: (loginToken: string) => TokenLoginMethod; }; diff --git a/src/domain/login/StartOIDCLoginViewModel.js b/src/domain/login/StartOIDCLoginViewModel.js new file mode 100644 index 00000000..e742fe1c --- /dev/null +++ b/src/domain/login/StartOIDCLoginViewModel.js @@ -0,0 +1,55 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {OidcApi} from "../../matrix/net/OidcApi"; +import {ViewModel} from "../ViewModel"; + +export class StartOIDCLoginViewModel extends ViewModel { + constructor(options) { + super(options); + this._isBusy = true; + this._authorizationEndpoint = null; + this._api = new OidcApi({ + clientId: "hydrogen-web", + issuer: options.loginOptions.oidc.issuer, + request: this.platform.request, + encoding: this.platform.encoding, + }); + this._homeserver = options.loginOptions.homeserver; + } + + get isBusy() { return this._isBusy; } + get authorizationEndpoint() { return this._authorizationEndpoint; } + + async start() { + const p = this._api.generateParams("openid"); + await Promise.all([ + this.platform.settingsStorage.setInt(`oidc_${p.state}_started_at`, Date.now()), + this.platform.settingsStorage.setString(`oidc_${p.state}_nonce`, p.nonce), + this.platform.settingsStorage.setString(`oidc_${p.state}_code_verifier`, p.codeVerifier), + this.platform.settingsStorage.setString(`oidc_${p.state}_homeserver`, this._homeserver), + this.platform.settingsStorage.setString(`oidc_${p.state}_issuer`, this._api.issuer), + ]); + + this._authorizationEndpoint = await this._api.authorizationEndpoint(p); + this._isBusy = false; + } + + setBusy(status) { + this._isBusy = status; + this.emitChange("isBusy"); + } +} diff --git a/src/domain/navigation/index.ts b/src/domain/navigation/index.ts index afba0d86..6cfa479c 100644 --- a/src/domain/navigation/index.ts +++ b/src/domain/navigation/index.ts @@ -33,6 +33,8 @@ export type SegmentType = { "details": true; "members": true; "member": string; + "oidc-callback": (string | null)[]; + "oidc-error": (string | null)[]; }; export function createNavigation(): Navigation { @@ -48,7 +50,7 @@ function allowsChild(parent: Segment | undefined, child: Segment | undefined, child: Segment(navigation: Navigation, defaultSessionId?: string): Segment[] { + const segments: Segment[] = []; + + // Special case for OIDC callback + if (urlPath.includes("state")) { + const params = new URLSearchParams(urlPath); + if (params.has("state")) { + // This is a proper OIDC callback + if (params.has("code")) { + segments.push(new Segment("oidc-callback", [ + params.get("state"), + params.get("code"), + ])); + return segments; + } else if (params.has("error")) { + segments.push(new Segment("oidc-error", [ + params.get("state"), + params.get("error"), + params.get("error_description"), + params.get("error_uri"), + ])); + return segments; + } + } + } + // substring(1) to take of initial / const parts = urlPath.substring(1).split("/"); const iterator = parts[Symbol.iterator](); - const segments: Segment[] = []; let next; while (!(next = iterator.next()).done) { const type = next.value; @@ -210,6 +236,8 @@ export function stringifyPath(path: Path): string { break; case "right-panel": case "sso": + case "oidc-callback": + case "oidc-error": // Do not put these segments in URL continue; default: @@ -485,6 +513,23 @@ export function tests() { assert.equal(newPath?.segments[1].type, "room"); assert.equal(newPath?.segments[1].value, "b"); }, - + "Parse OIDC callback": assert => { + const segments = parseUrlPath("state=tc9CnLU7&code=cnmUnwIYtY7V8RrWUyhJa4yvX72jJ5Yx"); + assert.equal(segments.length, 1); + assert.equal(segments[0].type, "oidc-callback"); + assert.deepEqual(segments[0].value, ["tc9CnLU7", "cnmUnwIYtY7V8RrWUyhJa4yvX72jJ5Yx"]); + }, + "Parse OIDC error": assert => { + const segments = parseUrlPath("state=tc9CnLU7&error=invalid_request"); + assert.equal(segments.length, 1); + assert.equal(segments[0].type, "oidc-error"); + assert.deepEqual(segments[0].value, ["tc9CnLU7", "invalid_request", null, null]); + }, + "Parse OIDC error with description": assert => { + const segments = parseUrlPath("state=tc9CnLU7&error=invalid_request&error_description=Unsupported%20response_type%20value"); + assert.equal(segments.length, 1); + assert.equal(segments[0].type, "oidc-error"); + assert.deepEqual(segments[0].value, ["tc9CnLU7", "invalid_request", "Unsupported response_type value", null]); + }, } } diff --git a/src/matrix/Client.js b/src/matrix/Client.js index 44643cc1..64daf727 100644 --- a/src/matrix/Client.js +++ b/src/matrix/Client.js @@ -20,6 +20,8 @@ import {lookupHomeserver} from "./well-known.js"; import {AbortableOperation} from "../utils/AbortableOperation"; import {ObservableValue} from "../observable/ObservableValue"; import {HomeServerApi} from "./net/HomeServerApi"; +import {OidcApi} from "./net/OidcApi"; +import {TokenRefresher} from "./net/TokenRefresher"; import {Reconnector, ConnectionStatus} from "./net/Reconnector"; import {ExponentialRetryDelay} from "./net/ExponentialRetryDelay"; import {MediaRepository} from "./net/MediaRepository"; @@ -123,11 +125,29 @@ export class Client { return result; } - queryLogin(homeserver) { + queryLogin(initialHomeserver) { return new AbortableOperation(async setAbortable => { - homeserver = await lookupHomeserver(homeserver, (url, options) => { + const { homeserver, issuer } = await lookupHomeserver(initialHomeserver, (url, options) => { return setAbortable(this._platform.request(url, options)); }); + if (issuer) { + try { + const oidcApi = new OidcApi({ + issuer, + clientId: "hydrogen-web", + request: this._platform.request, + encoding: this._platform.encoding, + }); + await oidcApi.validate(); + + return { + homeserver, + oidc: { issuer }, + }; + } catch (e) { + console.log(e); + } + } const hsApi = new HomeServerApi({homeserver, request: this._platform.request}); const response = await setAbortable(hsApi.getLoginFlows()).response(); return this._parseLoginOptions(response, homeserver); @@ -172,6 +192,19 @@ export class Client { accessToken: loginData.access_token, lastUsed: clock.now() }; + + if (loginData.refresh_token) { + sessionInfo.refreshToken = loginData.refresh_token; + } + + if (loginData.expires_in) { + sessionInfo.accessTokenExpiresAt = clock.now() + loginData.expires_in * 1000; + } + + if (loginData.oidc_issuer) { + sessionInfo.oidcIssuer = loginData.oidc_issuer; + } + log.set("id", sessionId); } catch (err) { this._error = err; @@ -225,9 +258,41 @@ export class Client { retryDelay: new ExponentialRetryDelay(clock.createTimeout), createMeasure: clock.createMeasure }); + + let accessToken; + + if (sessionInfo.oidcIssuer) { + const oidcApi = new OidcApi({ + issuer: sessionInfo.oidcIssuer, + clientId: "hydrogen-web", + request: this._platform.request, + encoding: this._platform.encoding, + }); + + // TODO: stop/pause the refresher? + const tokenRefresher = new TokenRefresher({ + oidcApi, + clock: this._platform.clock, + accessToken: sessionInfo.accessToken, + accessTokenExpiresAt: sessionInfo.accessTokenExpiresAt, + refreshToken: sessionInfo.refreshToken, + anticipation: 30 * 1000, + }); + + tokenRefresher.token.subscribe(t => { + this._platform.sessionInfoStorage.updateToken(sessionInfo.id, t.accessToken, t.accessTokenExpiresAt, t.refreshToken); + }); + + await tokenRefresher.start(); + + accessToken = tokenRefresher.accessToken; + } else { + accessToken = new ObservableValue(sessionInfo.accessToken); + } + const hsApi = new HomeServerApi({ homeserver: sessionInfo.homeServer, - accessToken: sessionInfo.accessToken, + accessToken, request: this._platform.request, reconnector: this._reconnector, }); diff --git a/src/matrix/login/OIDCLoginMethod.ts b/src/matrix/login/OIDCLoginMethod.ts new file mode 100644 index 00000000..0226877a --- /dev/null +++ b/src/matrix/login/OIDCLoginMethod.ts @@ -0,0 +1,67 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {ILogItem} from "../../logging/types"; +import {ILoginMethod} from "./LoginMethod"; +import {HomeServerApi} from "../net/HomeServerApi.js"; +import {OidcApi} from "../net/OidcApi"; + +export class OIDCLoginMethod implements ILoginMethod { + private readonly _code: string; + private readonly _codeVerifier: string; + private readonly _nonce: string; + private readonly _oidcApi: OidcApi; + public readonly homeserver: string; + + constructor({ + nonce, + codeVerifier, + code, + homeserver, + oidcApi, + }: { + nonce: string, + code: string, + codeVerifier: string, + homeserver: string, + oidcApi: OidcApi, + }) { + this._oidcApi = oidcApi; + this._code = code; + this._codeVerifier = codeVerifier; + this._nonce = nonce; + this.homeserver = homeserver; + } + + async login(hsApi: HomeServerApi, _deviceName: string, log: ILogItem): Promise> { + const { access_token, refresh_token, expires_in } = await this._oidcApi.completeAuthorizationCodeGrant({ + code: this._code, + codeVerifier: this._codeVerifier, + }); + + // TODO: validate the id_token and the nonce claim + + // Do a "whoami" request to find out the user_id and device_id + const { user_id, device_id } = await hsApi.whoami({ + log, + accessTokenOverride: access_token, + }).response(); + + const oidc_issuer = this._oidcApi.issuer; + + return { oidc_issuer, access_token, refresh_token, expires_in, user_id, device_id }; + } +} diff --git a/src/matrix/net/HomeServerApi.ts b/src/matrix/net/HomeServerApi.ts index e9902ef8..d97d9eae 100644 --- a/src/matrix/net/HomeServerApi.ts +++ b/src/matrix/net/HomeServerApi.ts @@ -31,7 +31,7 @@ const DEHYDRATION_PREFIX = "/_matrix/client/unstable/org.matrix.msc2697.v2"; type Options = { homeserver: string; - accessToken: string; + accessToken: BaseObservableValue; request: RequestFunction; reconnector: Reconnector; }; @@ -42,11 +42,12 @@ type BaseRequestOptions = { uploadProgress?: (loadedBytes: number) => void; timeout?: number; prefix?: string; + accessTokenOverride?: string; }; export class HomeServerApi { private readonly _homeserver: string; - private readonly _accessToken: string; + private readonly _accessToken: BaseObservableValue; private readonly _requestFn: RequestFunction; private readonly _reconnector: Reconnector; @@ -63,11 +64,19 @@ export class HomeServerApi { return this._homeserver + prefix + csPath; } - private _baseRequest(method: RequestMethod, url: string, queryParams?: Record, body?: Record, options?: BaseRequestOptions, accessToken?: string): IHomeServerRequest { + private _baseRequest(method: RequestMethod, url: string, queryParams?: Record, body?: Record, options?: BaseRequestOptions, accessTokenSource?: BaseObservableValue): IHomeServerRequest { const queryString = encodeQueryParams(queryParams); url = `${url}?${queryString}`; let encodedBody: EncodedBody["body"]; const headers: Map = new Map(); + + let accessToken: string | null = null; + if (options?.accessTokenOverride) { + accessToken = options.accessTokenOverride; + } else if (accessTokenSource) { + accessToken = accessTokenSource.get(); + } + if (accessToken) { headers.set("Authorization", `Bearer ${accessToken}`); } @@ -279,6 +288,10 @@ export class HomeServerApi { return this._post(`/logout`, {}, {}, options); } + whoami(options?: BaseRequestOptions): IHomeServerRequest { + return this._get(`/account/whoami`, undefined, undefined, options); + } + getDehydratedDevice(options: BaseRequestOptions = {}): IHomeServerRequest { options.prefix = DEHYDRATION_PREFIX; return this._get(`/dehydrated_device`, undefined, undefined, options); @@ -308,6 +321,7 @@ export class HomeServerApi { } import {Request as MockRequest} from "../../mocks/Request.js"; +import {BaseObservableValue} from "../../observable/ObservableValue"; export function tests() { return { diff --git a/src/matrix/net/OidcApi.ts b/src/matrix/net/OidcApi.ts new file mode 100644 index 00000000..3111d65f --- /dev/null +++ b/src/matrix/net/OidcApi.ts @@ -0,0 +1,221 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +const WELL_KNOWN = ".well-known/openid-configuration"; + +const RANDOM_CHARSET = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; +const randomChar = () => RANDOM_CHARSET.charAt(Math.floor(Math.random() * 1e10) % RANDOM_CHARSET.length); +const randomString = (length: number) => + Array.from({ length }, randomChar).join(""); + +type BearerToken = { + token_type: "Bearer", + access_token: string, + refresh_token?: string, + expires_in?: number, +} + +const isValidBearerToken = (t: any): t is BearerToken => + typeof t == "object" && + t["token_type"] === "Bearer" && + typeof t["access_token"] === "string" && + (!("refresh_token" in t) || typeof t["refresh_token"] === "string") && + (!("expires_in" in t) || typeof t["expires_in"] === "number"); + + +type AuthorizationParams = { + state: string, + scope: string, + nonce?: string, + codeVerifier?: string, +}; + +function assert(condition: any, message: string): asserts condition { + if (!condition) { + throw new Error(`Assertion failed: ${message}`); + } +}; + +export class OidcApi { + _issuer: string; + _clientId: string; + _requestFn: any; + _base64: any; + _metadataPromise: Promise; + + constructor({ issuer, clientId, request, encoding }) { + this._issuer = issuer; + this._clientId = clientId; + this._requestFn = request; + this._base64 = encoding.base64; + } + + get metadataUrl() { + return new URL(WELL_KNOWN, this._issuer).toString(); + } + + get issuer() { + return this._issuer; + } + + get redirectUri() { + return window.location.origin; + } + + metadata() { + if (!this._metadataPromise) { + this._metadataPromise = (async () => { + const headers = new Map(); + headers.set("Accept", "application/json"); + const req = this._requestFn(this.metadataUrl, { + method: "GET", + headers, + format: "json", + }); + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to request metadata"); + } + + return res.body; + })(); + } + return this._metadataPromise; + } + + async validate() { + const m = await this.metadata(); + assert(typeof m.authorization_endpoint === "string", "Has an authorization endpoint"); + assert(typeof m.token_endpoint === "string", "Has a token endpoint"); + assert(Array.isArray(m.response_types_supported) && m.response_types_supported.includes("code"), "Supports the code response type"); + assert(Array.isArray(m.response_modes_supported) && m.response_modes_supported.includes("fragment"), "Supports the fragment response mode"); + assert(Array.isArray(m.grant_types_supported) && m.grant_types_supported.includes("authorization_code"), "Supports the authorization_code grant type"); + assert(Array.isArray(m.code_challenge_methods_supported) && m.code_challenge_methods_supported.includes("S256"), "Supports the authorization_code grant type"); + } + + async _generateCodeChallenge( + codeVerifier: string + ): Promise { + const encoder = new TextEncoder(); + const data = encoder.encode(codeVerifier); + const digest = await window.crypto.subtle.digest("SHA-256", data); + const base64Digest = this._base64.encode(digest); + return base64Digest.replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, ""); + } + + async authorizationEndpoint({ + state, + scope, + nonce, + codeVerifier, + }: AuthorizationParams) { + const metadata = await this.metadata(); + const url = new URL(metadata["authorization_endpoint"]); + url.searchParams.append("response_mode", "fragment"); + url.searchParams.append("response_type", "code"); + url.searchParams.append("redirect_uri", this.redirectUri); + url.searchParams.append("client_id", this._clientId); + url.searchParams.append("state", state); + url.searchParams.append("scope", scope); + if (nonce) { + url.searchParams.append("nonce", nonce); + } + + if (codeVerifier) { + url.searchParams.append("code_challenge_method", "S256"); + url.searchParams.append("code_challenge", await this._generateCodeChallenge(codeVerifier)); + } + + return url.toString(); + } + + async tokenEndpoint() { + const metadata = await this.metadata(); + return metadata["token_endpoint"]; + } + + generateParams(scope: string): AuthorizationParams { + return { + scope, + state: randomString(8), + nonce: randomString(8), + codeVerifier: randomString(32), + }; + } + + async completeAuthorizationCodeGrant({ + codeVerifier, + code, + }: { codeVerifier: string, code: string }): Promise { + const params = new URLSearchParams(); + params.append("grant_type", "authorization_code"); + params.append("client_id", this._clientId); + params.append("code_verifier", codeVerifier); + params.append("redirect_uri", this.redirectUri); + params.append("code", code); + const body = params.toString(); + + const headers = new Map(); + headers.set("Content-Type", "application/x-www-form-urlencoded"); + + const req = this._requestFn(await this.tokenEndpoint(), { + method: "POST", + headers, + format: "json", + body, + }); + + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to exchange authorization code"); + } + + const token = res.body; + assert(isValidBearerToken(token), "Got back a valid bearer token"); + + return token; + } + + async refreshToken({ + refreshToken, + }: { refreshToken: string }): Promise { + const params = new URLSearchParams(); + params.append("grant_type", "refresh_token"); + params.append("client_id", this._clientId); + params.append("refresh_token", refreshToken); + const body = params.toString(); + + const headers = new Map(); + headers.set("Content-Type", "application/x-www-form-urlencoded"); + + const req = this._requestFn(await this.tokenEndpoint(), { + method: "POST", + headers, + format: "json", + body, + }); + + const res = await req.response(); + if (res.status >= 400) { + throw new Error("failed to use refresh token"); + } + + const token = res.body; + assert(isValidBearerToken(token), "Got back a valid bearer token"); + + return token; + } +} diff --git a/src/matrix/net/TokenRefresher.ts b/src/matrix/net/TokenRefresher.ts new file mode 100644 index 00000000..489dfb11 --- /dev/null +++ b/src/matrix/net/TokenRefresher.ts @@ -0,0 +1,125 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {BaseObservableValue, ObservableValue} from "../../observable/ObservableValue"; +import type {Clock, Timeout} from "../../platform/web/dom/Clock"; +import {OidcApi} from "./OidcApi"; + +type Token = { + accessToken: string, + accessTokenExpiresAt: number, + refreshToken: string, +}; + + +export class TokenRefresher { + private _token: ObservableValue; + private _accessToken: BaseObservableValue; + private _anticipation: number; + private _clock: Clock; + private _oidcApi: OidcApi; + private _timeout: Timeout + + constructor({ + oidcApi, + refreshToken, + accessToken, + accessTokenExpiresAt, + anticipation, + clock, + }: { + oidcApi: OidcApi, + refreshToken: string, + accessToken: string, + accessTokenExpiresAt: number, + anticipation: number, + clock: Clock, + }) { + this._token = new ObservableValue({ + accessToken, + accessTokenExpiresAt, + refreshToken, + }); + this._accessToken = this._token.map(t => t.accessToken); + + this._anticipation = anticipation; + this._oidcApi = oidcApi; + this._clock = clock; + } + + async start() { + if (this.needsRenewing) { + await this.renew(); + } + + this._renewingLoop(); + } + + stop() { + // TODO + } + + get needsRenewing() { + const remaining = this._token.get().accessTokenExpiresAt - this._clock.now(); + const anticipated = remaining - this._anticipation; + return anticipated < 0; + } + + async _renewingLoop() { + while (true) { + const remaining = + this._token.get().accessTokenExpiresAt - this._clock.now(); + const anticipated = remaining - this._anticipation; + + if (anticipated > 0) { + this._timeout = this._clock.createTimeout(anticipated); + await this._timeout.elapsed(); + } + + await this.renew(); + } + } + + async renew() { + let refreshToken = this._token.get().refreshToken; + const response = await this._oidcApi + .refreshToken({ + refreshToken, + }); + + if (typeof response.expires_in !== "number") { + throw new Error("Refreshed access token does not expire"); + } + + if (response.refresh_token) { + refreshToken = response.refresh_token; + } + + this._token.set({ + refreshToken, + accessToken: response.access_token, + accessTokenExpiresAt: this._clock.now() + response.expires_in * 1000, + }); + } + + get accessToken(): BaseObservableValue { + return this._accessToken; + } + + get token(): BaseObservableValue { + return this._token; + } +} diff --git a/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts b/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts index ebe575f6..80443e83 100644 --- a/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts +++ b/src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts @@ -21,6 +21,9 @@ interface ISessionInfo { homeserver: string; homeServer: string; // deprecate this over time accessToken: string; + accessTokenExpiresAt?: number; + refreshToken?: string; + oidcIssuer?: string; lastUsed: number; } @@ -28,6 +31,7 @@ interface ISessionInfo { interface ISessionInfoStorage { getAll(): Promise; updateLastUsed(id: string, timestamp: number): Promise; + updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise; get(id: string): Promise; add(sessionInfo: ISessionInfo): Promise; delete(sessionId: string): Promise; @@ -62,6 +66,19 @@ export class SessionInfoStorage implements ISessionInfoStorage { } } + async updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise { + const sessions = await this.getAll(); + if (sessions) { + const session = sessions.find(session => session.id === id); + if (session) { + session.accessToken = accessToken; + session.accessTokenExpiresAt = accessTokenExpiresAt; + session.refreshToken = refreshToken; + localStorage.setItem(this._name, JSON.stringify(sessions)); + } + } + } + async get(id: string): Promise { const sessions = await this.getAll(); if (sessions) { diff --git a/src/matrix/well-known.js b/src/matrix/well-known.js index 00c91f27..6e3bedbf 100644 --- a/src/matrix/well-known.js +++ b/src/matrix/well-known.js @@ -41,6 +41,7 @@ async function getWellKnownResponse(homeserver, request) { export async function lookupHomeserver(homeserver, request) { homeserver = normalizeHomeserver(homeserver); + let issuer = null; const wellKnownResponse = await getWellKnownResponse(homeserver, request); if (wellKnownResponse && wellKnownResponse.status === 200) { const {body} = wellKnownResponse; @@ -48,6 +49,11 @@ export async function lookupHomeserver(homeserver, request) { if (typeof wellKnownHomeserver === "string") { homeserver = normalizeHomeserver(wellKnownHomeserver); } + + const wellKnownIssuer = body["m.authentication"]?.["issuer"]; + if (typeof wellKnownIssuer === "string") { + issuer = wellKnownIssuer; + } } - return homeserver; + return {homeserver, issuer}; } diff --git a/src/observable/ObservableValue.ts b/src/observable/ObservableValue.ts index ad0a226d..8b9b3be6 100644 --- a/src/observable/ObservableValue.ts +++ b/src/observable/ObservableValue.ts @@ -39,6 +39,10 @@ export abstract class BaseObservableValue extends BaseObservable<(value: T) = flatMap(mapper: (value: T) => (BaseObservableValue | undefined)): BaseObservableValue { return new FlatMapObservableValue(this, mapper); } + + map(mapper: (value: T) => C): BaseObservableValue { + return new MappedObservableValue(this, mapper); + } } interface IWaitHandle { @@ -174,6 +178,34 @@ export class FlatMapObservableValue extends BaseObservableValue extends BaseObservableValue { + private sourceSubscription?: SubscriptionHandle; + + constructor( + private readonly source: BaseObservableValue

, + private readonly mapper: (value: P) => C + ) { + super(); + } + + onUnsubscribeLast() { + super.onUnsubscribeLast(); + this.sourceSubscription = this.sourceSubscription!(); + } + + onSubscribeFirst() { + super.onSubscribeFirst(); + this.sourceSubscription = this.source.subscribe(() => { + this.emit(this.get()); + }); + } + + get(): C { + const sourceValue = this.source.get(); + return this.mapper(sourceValue); + } +} + export function tests() { return { "set emits an update": assert => { diff --git a/src/platform/types/types.ts b/src/platform/types/types.ts index 1d359a09..6605f238 100644 --- a/src/platform/types/types.ts +++ b/src/platform/types/types.ts @@ -26,6 +26,7 @@ export interface IRequestOptions { cache?: boolean; method?: string; format?: string; + accessTokenOverride?: string; } export type RequestFunction = (url: string, options: IRequestOptions) => RequestResult; diff --git a/src/platform/web/ui/css/login.css b/src/platform/web/ui/css/login.css index deb16b02..ae706242 100644 --- a/src/platform/web/ui/css/login.css +++ b/src/platform/web/ui/css/login.css @@ -68,13 +68,13 @@ limitations under the License. --size: 20px; } -.StartSSOLoginView { +.StartSSOLoginView, .StartOIDCLoginView { display: flex; flex-direction: column; padding: 0 0.4em 0; } -.StartSSOLoginView_button { +.StartSSOLoginView_button, .StartOIDCLoginView_button { flex: 1; margin-top: 12px; } diff --git a/src/platform/web/ui/login/LoginView.js b/src/platform/web/ui/login/LoginView.js index 88002625..ee8bf169 100644 --- a/src/platform/web/ui/login/LoginView.js +++ b/src/platform/web/ui/login/LoginView.js @@ -57,6 +57,7 @@ export class LoginView extends TemplateView { t.mapView(vm => vm.passwordLoginViewModel, vm => vm ? new PasswordLoginView(vm): null), t.if(vm => vm.passwordLoginViewModel && vm.startSSOLoginViewModel, t => t.p({className: "LoginView_separator"}, vm.i18n`or`)), t.mapView(vm => vm.startSSOLoginViewModel, vm => vm ? new StartSSOLoginView(vm) : null), + t.mapView(vm => vm.startOIDCLoginViewModel, vm => vm ? new StartOIDCLoginView(vm) : null), t.mapView(vm => vm.loadViewModel, loadViewModel => loadViewModel ? new SessionLoadStatusView(loadViewModel) : null), // use t.mapView rather than t.if to create a new view when the view model changes too t.p(hydrogenGithubLink(t)) @@ -76,3 +77,14 @@ class StartSSOLoginView extends TemplateView { ); } } + +class StartOIDCLoginView extends TemplateView { + render(t, vm) { + return t.div({ className: "StartOIDCLoginView" }, + t.a({ + className: "StartOIDCLoginView_button button-action secondary", + href: vm => (vm.isBusy ? "#" : vm.authorizationEndpoint), + }, vm.i18n`Log in via OIDC`) + ); + } +}