diff --git a/src/matrix/Session.js b/src/matrix/Session.js index 5e4caa8c..99081cf6 100644 --- a/src/matrix/Session.js +++ b/src/matrix/Session.js @@ -101,9 +101,6 @@ export class Session { ownDeviceId: sessionInfo.deviceId, ownUserId: sessionInfo.userId, logger: this._platform.logger, - turnServers: [{ - urls: ["stun:turn.matrix.org"], - }], forceTURN: false, }); this._roomStateHandler = new RoomStateHandlerSet(); @@ -499,6 +496,8 @@ export class Session { this._megolmDecryption = undefined; this._e2eeAccount?.dispose(); this._e2eeAccount = undefined; + this._callHandler?.dispose(); + this._callHandler = undefined; for (const room of this._rooms.values()) { room.dispose(); } diff --git a/src/matrix/calls/CallHandler.ts b/src/matrix/calls/CallHandler.ts index f24d2ea5..ab37a01a 100644 --- a/src/matrix/calls/CallHandler.ts +++ b/src/matrix/calls/CallHandler.ts @@ -23,6 +23,7 @@ import {GroupCall} from "./group/GroupCall"; import {makeId} from "../common"; import {CALL_LOG_TYPE} from "./common"; import {EVENT_TYPE as MEMBER_EVENT_TYPE, RoomMember} from "../room/members/RoomMember"; +import {TurnServerSource} from "./TurnServerSource"; import type {LocalMedia} from "./LocalMedia"; import type {Room} from "../room/Room"; @@ -39,7 +40,7 @@ import type {Clock} from "../../platform/web/dom/Clock"; import type {RoomStateHandler} from "../room/state/types"; import type {MemberSync} from "../room/timeline/persistence/MemberWriter"; -export type Options = Omit & { +export type Options = Omit & { clock: Clock }; @@ -54,9 +55,12 @@ export class CallHandler implements RoomStateHandler { private roomMemberToCallIds: Map> = new Map(); private groupCallOptions: GroupCallOptions; private sessionId = makeId("s"); + private turnServerSource: TurnServerSource; constructor(private readonly options: Options) { + this.turnServerSource = new TurnServerSource(this.options.hsApi, this.options.clock); this.groupCallOptions = Object.assign({}, this.options, { + turnServerSource: this.turnServerSource, emitUpdate: (groupCall, params) => this._calls.update(groupCall.id, params), createTimeout: this.options.clock.createTimeout, sessionId: this.sessionId @@ -242,5 +246,11 @@ export class CallHandler implements RoomStateHandler { this.roomMemberToCallIds.set(roomMemberKey, newCallIdsMemberOf); } } + + dispose() { + this.turnServerSource.dispose(); + const joinedCalls = Array.from(this._calls.values()).filter(c => c.hasJoined); + Promise.all(joinedCalls.map(c => c.leave())).then(() => {}, () => {}); + } } diff --git a/src/matrix/calls/PeerCall.ts b/src/matrix/calls/PeerCall.ts index 625f1487..17941995 100644 --- a/src/matrix/calls/PeerCall.ts +++ b/src/matrix/calls/PeerCall.ts @@ -15,6 +15,7 @@ limitations under the License. */ import {ObservableMap} from "../../observable/map/ObservableMap"; +import {BaseObservableValue} from "../../observable/value/BaseObservableValue"; import {recursivelyAssign} from "../../utils/recursivelyAssign"; import {Disposables, Disposable, IDisposable} from "../../utils/Disposables"; import {WebRTC, PeerConnection, Transceiver, TransceiverDirection, Sender, Receiver, PeerConnectionEventMap} from "../../platform/types/WebRTC"; @@ -47,7 +48,7 @@ import type { export type Options = { webRTC: WebRTC, forceTURN: boolean, - turnServers: RTCIceServer[], + turnServer: BaseObservableValue, createTimeout: TimeoutCreator, emitUpdate: (peerCall: PeerCall, params: any, log: ILogItem) => void; sendSignallingMessage: (message: SignallingMessage, log: ILogItem) => Promise; @@ -114,8 +115,16 @@ export class PeerCall implements IDisposable { ) { logItem.log({l: "create PeerCall", id: callId}); this._remoteMedia = new RemoteMedia(); - this.peerConnection = options.webRTC.createPeerConnection(this.options.forceTURN, this.options.turnServers, 0); - + this.peerConnection = options.webRTC.createPeerConnection( + this.options.forceTURN, + [this.options.turnServer.get()], + 0 + ); + // update turn servers when they change (see TurnServerSource) + this.disposables.track(this.options.turnServer.subscribe(turnServer => { + this.logItem.log({l: "updating turn server", turnServer}) + this.peerConnection.setConfiguration({iceServers: [turnServer]}); + })); const listen = (type: K, listener: (this: PeerConnection, ev: PeerConnectionEventMap[K]) => any, options?: boolean | EventListenerOptions): void => { this.peerConnection.addEventListener(type, listener); const dispose = () => { diff --git a/src/matrix/calls/TurnServerSource.ts b/src/matrix/calls/TurnServerSource.ts new file mode 100644 index 00000000..1066f7a6 --- /dev/null +++ b/src/matrix/calls/TurnServerSource.ts @@ -0,0 +1,223 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {RetainedObservableValue} from "../../observable/value/RetainedObservableValue"; + +import type {HomeServerApi} from "../net/HomeServerApi"; +import type {IHomeServerRequest} from "../net/HomeServerRequest"; +import type {BaseObservableValue} from "../../observable/value/BaseObservableValue"; +import type {ObservableValue} from "../../observable/value/ObservableValue"; +import type {Clock, Timeout} from "../../platform/web/dom/Clock"; +import type {ILogItem} from "../../logging/types"; + +type TurnServerSettings = { + uris: string[], + username: string, + password: string, + ttl: number +}; + +const DEFAULT_TTL = 5 * 60; // 5min +const DEFAULT_SETTINGS: RTCIceServer = { + urls: ["stun:turn.matrix.org"], + username: "", + credential: "", +}; + +export class TurnServerSource { + private currentObservable?: ObservableValue; + private pollTimeout?: Timeout; + private pollRequest?: IHomeServerRequest; + private isPolling = false; + + constructor( + private hsApi: HomeServerApi, + private clock: Clock, + private defaultSettings: RTCIceServer = DEFAULT_SETTINGS + ) {} + + getSettings(log: ILogItem): Promise> { + return log.wrap("get turn server", async log => { + if (!this.isPolling) { + const settings = await this.doRequest(log); + const iceServer = settings ? toIceServer(settings) : this.defaultSettings; + log.set("iceServer", iceServer); + if (this.currentObservable) { + this.currentObservable.set(iceServer); + } else { + this.currentObservable = new RetainedObservableValue(iceServer, + () => { + this.stopPollLoop(); + }, + () => { + // start loop on first subscribe + this.runLoop(settings?.ttl ?? DEFAULT_TTL); + }); + } + } + return this.currentObservable!; + }); + } + + private async runLoop(initialTtl: number): Promise { + let ttl = initialTtl; + this.isPolling = true; + while(this.isPolling) { + try { + this.pollTimeout = this.clock.createTimeout(ttl * 1000); + await this.pollTimeout.elapsed(); + this.pollTimeout = undefined; + const settings = await this.doRequest(undefined); + if (settings) { + const iceServer = toIceServer(settings); + if (shouldUpdate(this.currentObservable!, iceServer)) { + this.currentObservable!.set(iceServer); + } + if (settings.ttl > 0) { + ttl = settings.ttl; + } else { + // stop polling is settings are good indefinitely + this.stopPollLoop(); + } + } else { + ttl = DEFAULT_TTL; + } + } catch (err) { + if (err.name === "AbortError") { + /* ignore, the loop will exit because isPolling is false */ + } else { + // TODO: log error + } + } + } + } + + private async doRequest(log: ILogItem | undefined): Promise { + try { + this.pollRequest = this.hsApi.getTurnServer({log}); + const settings = await this.pollRequest.response(); + return settings; + } catch (err) { + if (err.name === "HomeServerError") { + return undefined; + } + throw err; + } finally { + this.pollRequest = undefined; + } + } + + private stopPollLoop() { + this.isPolling = false; + this.currentObservable = undefined; + this.pollTimeout?.dispose(); + this.pollTimeout = undefined; + this.pollRequest?.abort(); + this.pollRequest = undefined; + } + + dispose() { + this.stopPollLoop(); + } +} + +function shouldUpdate(observable: BaseObservableValue, settings: RTCIceServer): boolean { + const currentSettings = observable.get(); + if (!currentSettings) { + return true; + } + // same length and new settings doesn't contain any uri the old settings don't contain + const currentUrls = Array.isArray(currentSettings.urls) ? currentSettings.urls : [currentSettings.urls]; + const newUrls = Array.isArray(settings.urls) ? settings.urls : [settings.urls]; + const arraysEqual = currentUrls.length === newUrls.length && + !newUrls.some(uri => !currentUrls.includes(uri)); + return !arraysEqual || settings.username !== currentSettings.username || + settings.credential !== currentSettings.credential; +} + +function toIceServer(settings: TurnServerSettings): RTCIceServer { + return { + urls: settings.uris, + username: settings.username, + credential: settings.password, + credentialType: "password" + } +} + +export function tests() { + return { + "shouldUpdate returns false for same object": assert => { + const observable = {get() { + return { + urls: ["a", "b"], + username: "alice", + credential: "f00", + }; + }}; + const same = { + urls: ["a", "b"], + username: "alice", + credential: "f00", + }; + assert.equal(false, shouldUpdate(observable as any as BaseObservableValue, same)); + }, + "shouldUpdate returns true for 1 different uri": assert => { + const observable = {get() { + return { + urls: ["a", "c"], + username: "alice", + credential: "f00", + }; + }}; + const same = { + urls: ["a", "b"], + username: "alice", + credential: "f00", + }; + assert.equal(true, shouldUpdate(observable as any as BaseObservableValue, same)); + }, + "shouldUpdate returns true for different user": assert => { + const observable = {get() { + return { + urls: ["a", "b"], + username: "alice", + credential: "f00", + }; + }}; + const same = { + urls: ["a", "b"], + username: "bob", + credential: "f00", + }; + assert.equal(true, shouldUpdate(observable as any as BaseObservableValue, same)); + }, + "shouldUpdate returns true for different password": assert => { + const observable = {get() { + return { + urls: ["a", "b"], + username: "alice", + credential: "f00", + }; + }}; + const same = { + urls: ["a", "b"], + username: "alice", + credential: "b4r", + }; + assert.equal(true, shouldUpdate(observable as any as BaseObservableValue, same)); + } + } +} diff --git a/src/matrix/calls/group/GroupCall.ts b/src/matrix/calls/group/GroupCall.ts index a04b9d49..ccdc2a3c 100644 --- a/src/matrix/calls/group/GroupCall.ts +++ b/src/matrix/calls/group/GroupCall.ts @@ -23,6 +23,7 @@ import {EventEmitter} from "../../../utils/EventEmitter"; import {EventType, CallIntent} from "../callEventTypes"; import type {Options as MemberOptions} from "./Member"; +import type {TurnServerSource} from "../TurnServerSource"; import type {BaseObservableMap} from "../../../observable/map/BaseObservableMap"; import type {Track} from "../../../platform/types/MediaDevices"; import type {SignallingMessage, MGroupCallBase, CallMembership} from "../callEventTypes"; @@ -32,6 +33,7 @@ import type {Platform} from "../../../platform/web/Platform"; import type {EncryptedMessage} from "../../e2ee/olm/Encryption"; import type {ILogItem, ILogger} from "../../../logging/types"; import type {Storage} from "../../storage/idb/Storage"; +import type {BaseObservableValue} from "../../../observable/value/BaseObservableValue"; export enum GroupCallState { Fledgling = "fledgling", @@ -53,11 +55,12 @@ function getDeviceFromMemberKey(key: string): string { return JSON.parse(`[${key}]`)[1]; } -export type Options = Omit & { +export type Options = Omit & { emitUpdate: (call: GroupCall, params?: any) => void; encryptDeviceMessage: (roomId: string, userId: string, deviceId: string, message: SignallingMessage, log: ILogItem) => Promise, storage: Storage, logger: ILogger, + turnServerSource: TurnServerSource }; class JoinedData { @@ -65,7 +68,8 @@ class JoinedData { public readonly logItem: ILogItem, public readonly membersLogItem: ILogItem, public localMedia: LocalMedia, - public localMuteSettings: MuteSettings + public localMuteSettings: MuteSettings, + public turnServer: BaseObservableValue ) {} dispose() { @@ -147,6 +151,7 @@ export class GroupCall extends EventEmitter<{change: never}> { id: this.id, ownSessionId: this.options.sessionId }); + const turnServer = await this.options.turnServerSource.getSettings(logItem); const membersLogItem = logItem.child("member connections"); const localMuteSettings = new MuteSettings(); localMuteSettings.updateTrackInfo(localMedia.userMedia); @@ -154,7 +159,8 @@ export class GroupCall extends EventEmitter<{change: never}> { logItem, membersLogItem, localMedia, - localMuteSettings + localMuteSettings, + turnServer ); this.joinedData = joinedData; await joinedData.logItem.wrap("join", async log => { @@ -529,7 +535,12 @@ export class GroupCall extends EventEmitter<{change: never}> { const logItem = joinedData.membersLogItem.child({l: "member", id: memberKey}); logItem.set("sessionId", member.sessionId); log.wrap({l: "connect", id: memberKey}, log => { - const connectItem = member.connect(joinedData.localMedia, joinedData.localMuteSettings, logItem); + const connectItem = member.connect( + joinedData.localMedia, + joinedData.localMuteSettings, + joinedData.turnServer, + logItem + ); if (connectItem) { log.refDetached(connectItem); } diff --git a/src/matrix/calls/group/Member.ts b/src/matrix/calls/group/Member.ts index 9d9f39e1..de868dd9 100644 --- a/src/matrix/calls/group/Member.ts +++ b/src/matrix/calls/group/Member.ts @@ -19,6 +19,7 @@ import {makeTxnId, makeId} from "../../common"; import {EventType, CallErrorCode} from "../callEventTypes"; import {formatToDeviceMessagesPayload} from "../../common"; import {sortedIndex} from "../../../utils/sortedIndex"; + import type {MuteSettings} from "../common"; import type {Options as PeerCallOptions, RemoteMedia} from "../PeerCall"; import type {LocalMedia} from "../LocalMedia"; @@ -28,8 +29,9 @@ import type {GroupCall} from "./GroupCall"; import type {RoomMember} from "../../room/members/RoomMember"; import type {EncryptedMessage} from "../../e2ee/olm/Encryption"; import type {ILogItem} from "../../../logging/types"; +import type {BaseObservableValue} from "../../../observable/value/BaseObservableValue"; -export type Options = Omit & { +export type Options = Omit & { confId: string, ownUserId: string, ownDeviceId: string, @@ -60,6 +62,7 @@ class MemberConnection { constructor( public localMedia: LocalMedia, public localMuteSettings: MuteSettings, + public turnServer: BaseObservableValue, public readonly logItem: ILogItem ) {} } @@ -122,12 +125,17 @@ export class Member { } /** @internal */ - connect(localMedia: LocalMedia, localMuteSettings: MuteSettings, memberLogItem: ILogItem): ILogItem | undefined { + connect(localMedia: LocalMedia, localMuteSettings: MuteSettings, turnServer: BaseObservableValue, memberLogItem: ILogItem): ILogItem | undefined { if (this.connection) { return; } // Safari can't send a MediaStream to multiple sources, so clone it - const connection = new MemberConnection(localMedia.clone(), localMuteSettings, memberLogItem); + const connection = new MemberConnection( + localMedia.clone(), + localMuteSettings, + turnServer, + memberLogItem + ); this.connection = connection; let connectLogItem; connection.logItem.wrap("connect", async log => { @@ -355,9 +363,11 @@ export class Member { } private _createPeerCall(callId: string): PeerCall { + const connection = this.connection!; return new PeerCall(callId, Object.assign({}, this.options, { emitUpdate: this.emitUpdateFromPeerCall, - sendSignallingMessage: this.sendSignallingMessage - }), this.connection!.logItem); + sendSignallingMessage: this.sendSignallingMessage, + turnServer: connection.turnServer + }), connection.logItem); } } diff --git a/src/matrix/net/HomeServerApi.ts b/src/matrix/net/HomeServerApi.ts index d9bd8f50..86ec999a 100644 --- a/src/matrix/net/HomeServerApi.ts +++ b/src/matrix/net/HomeServerApi.ts @@ -373,6 +373,10 @@ export class HomeServerApi { setAccountData(ownUserId: string, type: string, content: Record, options?: BaseRequestOptions): IHomeServerRequest { return this._put(`/user/${encodeURIComponent(ownUserId)}/account_data/${encodeURIComponent(type)}`, {}, content, options); } + + getTurnServer(options?: BaseRequestOptions): IHomeServerRequest { + return this._get(`/voip/turnServer`, undefined, undefined, options); + } } import {Request as MockRequest} from "../../mocks/Request.js"; diff --git a/src/observable/value/RetainedObservableValue.ts b/src/observable/value/RetainedObservableValue.ts index edfb6c15..16058f8e 100644 --- a/src/observable/value/RetainedObservableValue.ts +++ b/src/observable/value/RetainedObservableValue.ts @@ -17,15 +17,17 @@ limitations under the License. import {ObservableValue} from "./ObservableValue"; export class RetainedObservableValue extends ObservableValue { - private _freeCallback: () => void; - constructor(initialValue: T, freeCallback: () => void) { + constructor(initialValue: T, private freeCallback: () => void, private startCallback: () => void = () => {}) { super(initialValue); - this._freeCallback = freeCallback; + } + + onSubscribeFirst() { + this.startCallback(); } onUnsubscribeLast() { super.onUnsubscribeLast(); - this._freeCallback(); + this.freeCallback(); } } diff --git a/src/platform/types/WebRTC.ts b/src/platform/types/WebRTC.ts index 39ad49c5..236e8354 100644 --- a/src/platform/types/WebRTC.ts +++ b/src/platform/types/WebRTC.ts @@ -148,6 +148,7 @@ export interface PeerConnection { addEventListener(type: K, listener: (this: PeerConnection, ev: PeerConnectionEventMap[K]) => any, options?: boolean | AddEventListenerOptions): void; removeEventListener(type: K, listener: (this: PeerConnection, ev: PeerConnectionEventMap[K]) => any, options?: boolean | EventListenerOptions): void; getStats(selector?: Track | null): Promise; + setConfiguration(configuration?: RTCConfiguration): void; }