diff --git a/src/matrix/DeviceMessageHandler.js b/src/matrix/DeviceMessageHandler.js index 9c81f11e..537b948d 100644 --- a/src/matrix/DeviceMessageHandler.js +++ b/src/matrix/DeviceMessageHandler.js @@ -64,6 +64,9 @@ export class DeviceMessageHandler { } const readTxn = await this._storage.readTxn([this._storage.storeNames.session]); const pendingEvents = await this._getPendingEvents(readTxn); + if (pendingEvents.length === 0) { + return; + } // only know olm for now const olmEvents = pendingEvents.filter(e => e.content?.algorithm === OLM_ALGORITHM); const decryptChanges = await this._olmDecryption.decryptAll(olmEvents); diff --git a/src/matrix/Session.js b/src/matrix/Session.js index 2752fca7..652ac18e 100644 --- a/src/matrix/Session.js +++ b/src/matrix/Session.js @@ -21,8 +21,11 @@ import {User} from "./User.js"; import {Account as E2EEAccount} from "./e2ee/Account.js"; import {DeviceMessageHandler} from "./DeviceMessageHandler.js"; import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js"; +import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js"; import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js"; import {DeviceTracker} from "./e2ee/DeviceTracker.js"; +import {LockMap} from "../utils/LockMap.js"; + const PICKLE_KEY = "DEFAULT_KEY"; export class Session { @@ -42,18 +45,23 @@ export class Session { this._olmUtil = null; this._e2eeAccount = null; this._deviceTracker = null; + this._olmEncryption = null; if (olm) { this._olmUtil = new olm.Utility(); this._deviceTracker = new DeviceTracker({ storage, getSyncToken: () => this.syncToken, olmUtil: this._olmUtil, + ownUserId: sessionInfo.userId, + ownDeviceId: sessionInfo.deviceId, }); } } // called once this._e2eeAccount is assigned _setupEncryption() { + console.log("loaded e2ee account with keys", this._e2eeAccount.identityKeys); + const senderKeyLock = new LockMap(); const olmDecryption = new OlmDecryption({ account: this._e2eeAccount, pickleKey: PICKLE_KEY, @@ -61,6 +69,17 @@ export class Session { ownUserId: this._user.id, storage: this._storage, olm: this._olm, + senderKeyLock + }); + this._olmEncryption = new OlmEncryption({ + account: this._e2eeAccount, + pickleKey: PICKLE_KEY, + now: this._clock.now, + ownUserId: this._user.id, + storage: this._storage, + olm: this._olm, + olmUtil: this._olmUtil, + senderKeyLock }); const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm}); this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption}); diff --git a/src/matrix/e2ee/Account.js b/src/matrix/e2ee/Account.js index 0478112b..9d83465c 100644 --- a/src/matrix/e2ee/Account.js +++ b/src/matrix/e2ee/Account.js @@ -126,8 +126,24 @@ export class Account { createInboundOlmSession(senderKey, body) { const newSession = new this._olm.Session(); - newSession.create_inbound_from(this._account, senderKey, body); - return newSession; + try { + newSession.create_inbound_from(this._account, senderKey, body); + return newSession; + } catch (err) { + newSession.free(); + throw err; + } + } + + createOutboundOlmSession(theirIdentityKey, theirOneTimeKey) { + const newSession = new this._olm.Session(); + try { + newSession.create_outbound(this._account, theirIdentityKey, theirOneTimeKey); + return newSession; + } catch (err) { + newSession.free(); + throw err; + } } writeRemoveOneTimeKey(session, txn) { diff --git a/src/matrix/e2ee/DeviceTracker.js b/src/matrix/e2ee/DeviceTracker.js index b085be80..84da2f37 100644 --- a/src/matrix/e2ee/DeviceTracker.js +++ b/src/matrix/e2ee/DeviceTracker.js @@ -14,13 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -import anotherjson from "../../../lib/another-json/index.js"; +import {verifyEd25519Signature, SIGNATURE_ALGORITHM} from "./common.js"; const TRACKING_STATUS_OUTDATED = 0; const TRACKING_STATUS_UPTODATE = 1; -const DEVICE_KEYS_SIGNATURE_ALGORITHM = "ed25519"; - // map 1 device from /keys/query response to DeviceIdentity function deviceKeysAsDeviceIdentity(deviceSection) { const deviceId = deviceSection["device_id"]; @@ -36,11 +34,13 @@ function deviceKeysAsDeviceIdentity(deviceSection) { } export class DeviceTracker { - constructor({storage, getSyncToken, olmUtil}) { + constructor({storage, getSyncToken, olmUtil, ownUserId, ownDeviceId}) { this._storage = storage; this._getSyncToken = getSyncToken; this._identityChangedForRoom = null; this._olmUtil = olmUtil; + this._ownUserId = ownUserId; + this._ownDeviceId = ownDeviceId; } async writeDeviceChanges(deviceLists, txn) { @@ -200,7 +200,11 @@ export class DeviceTracker { if (deviceIdOnKeys !== deviceId) { return false; } - return this._verifyUserDeviceKeys(deviceKeys); + // don't store our own device + if (userId === this._ownUserId && deviceId === this._ownDeviceId) { + return false; + } + return this._hasValidSignature(deviceKeys); }); const verifiedKeys = verifiedEntries.map(([, deviceKeys]) => deviceKeys); return {userId, verifiedKeys}; @@ -208,26 +212,11 @@ export class DeviceTracker { return verifiedKeys; } - _verifyUserDeviceKeys(deviceSection) { + _hasValidSignature(deviceSection) { const deviceId = deviceSection["device_id"]; const userId = deviceSection["user_id"]; - const clone = Object.assign({}, deviceSection); - delete clone.unsigned; - delete clone.signatures; - const canonicalJson = anotherjson.stringify(clone); - const key = deviceSection?.keys?.[`${DEVICE_KEYS_SIGNATURE_ALGORITHM}:${deviceId}`]; - const signature = deviceSection?.signatures?.[userId]?.[`${DEVICE_KEYS_SIGNATURE_ALGORITHM}:${deviceId}`]; - try { - if (!signature) { - throw new Error("no signature"); - } - // throws when signature is invalid - this._olmUtil.ed25519_verify(key, canonicalJson, signature); - return true; - } catch (err) { - console.warn("Invalid device signature, ignoring device.", key, canonicalJson, signature, err); - return false; - } + const ed25519Key = deviceSection?.keys?.[`${SIGNATURE_ALGORITHM}:${deviceId}`]; + return verifyEd25519Signature(this._olmUtil, userId, deviceId, ed25519Key, deviceSection); } /** @@ -275,6 +264,10 @@ export class DeviceTracker { if (queriedDevices && queriedDevices.length) { flattenedDevices = flattenedDevices.concat(queriedDevices); } - return flattenedDevices; + // filter out our own devices if it got in somehow (even though we should not store it) + const devices = flattenedDevices.filter(device => { + return !(device.userId === this._ownUserId && device.deviceId === this._ownDeviceId); + }); + return devices; } } diff --git a/src/matrix/e2ee/common.js b/src/matrix/e2ee/common.js index c5e7399f..3312032b 100644 --- a/src/matrix/e2ee/common.js +++ b/src/matrix/e2ee/common.js @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +import anotherjson from "../../../lib/another-json/index.js"; + // use common prefix so it's easy to clear properties that are not e2ee related during session clear export const SESSION_KEY_PREFIX = "e2ee:"; export const OLM_ALGORITHM = "m.olm.v1.curve25519-aes-sha2"; @@ -27,3 +29,24 @@ export class DecryptionError extends Error { this.details = detailsObj; } } + +export const SIGNATURE_ALGORITHM = "ed25519"; + +export function verifyEd25519Signature(olmUtil, userId, deviceOrKeyId, ed25519Key, value) { + const clone = Object.assign({}, value); + delete clone.unsigned; + delete clone.signatures; + const canonicalJson = anotherjson.stringify(clone); + const signature = value?.signatures?.[userId]?.[`${SIGNATURE_ALGORITHM}:${deviceOrKeyId}`]; + try { + if (!signature) { + throw new Error("no signature"); + } + // throws when signature is invalid + olmUtil.ed25519_verify(ed25519Key, canonicalJson, signature); + return true; + } catch (err) { + console.warn("Invalid signature, ignoring.", ed25519Key, canonicalJson, signature, err); + return false; + } +} diff --git a/src/matrix/e2ee/olm/Decryption.js b/src/matrix/e2ee/olm/Decryption.js index f701f4df..dfde7674 100644 --- a/src/matrix/e2ee/olm/Decryption.js +++ b/src/matrix/e2ee/olm/Decryption.js @@ -15,6 +15,8 @@ limitations under the License. */ import {DecryptionError} from "../common.js"; +import {groupBy} from "../../../utils/groupBy.js"; +import {Session} from "./Session.js"; const SESSION_LIMIT_PER_SENDER_KEY = 4; @@ -29,14 +31,14 @@ function sortSessions(sessions) { } export class Decryption { - constructor({account, pickleKey, now, ownUserId, storage, olm}) { + constructor({account, pickleKey, now, ownUserId, storage, olm, senderKeyLock}) { this._account = account; this._pickleKey = pickleKey; this._now = now; this._ownUserId = ownUserId; this._storage = storage; this._olm = olm; - this._createOutboundSessionPromise = null; + this._senderKeyLock = senderKeyLock; } // we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once @@ -49,26 +51,30 @@ export class Decryption { // // doing it one by one would be possible, but we would lose the opportunity for parallelization async decryptAll(events) { - const eventsPerSenderKey = events.reduce((map, event) => { - const senderKey = event.content?.["sender_key"]; - let list = map.get(senderKey); - if (!list) { - list = []; - map.set(senderKey, list); - } - list.push(event); - return map; - }, new Map()); + const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]); const timestamp = this._now(); - const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); - // decrypt events for different sender keys in parallel - const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { - return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn); + // take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen) + // don't modify the sessions at the same time + const locks = await Promise.all(Array.from(eventsPerSenderKey.keys()).map(senderKey => { + return this._senderKeyLock.takeLock(senderKey); })); - const payloads = results.reduce((all, r) => all.concat(r.payloads), []); - const errors = results.reduce((all, r) => all.concat(r.errors), []); - const senderKeyDecryptions = results.map(r => r.senderKeyDecryption); - return new DecryptionChanges(senderKeyDecryptions, payloads, errors); + try { + const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); + // decrypt events for different sender keys in parallel + const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { + return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn); + })); + const payloads = results.reduce((all, r) => all.concat(r.payloads), []); + const errors = results.reduce((all, r) => all.concat(r.errors), []); + const senderKeyDecryptions = results.map(r => r.senderKeyDecryption); + return new DecryptionChanges(senderKeyDecryptions, payloads, errors, this._account, locks); + } catch (err) { + // make sure the locks are release if something throws + for (const lock of locks) { + lock.release(); + } + throw err; + } } async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) { @@ -105,7 +111,12 @@ export class Decryption { plaintext = createResult.plaintext; } if (typeof plaintext === "string") { - const payload = JSON.parse(plaintext); + let payload; + try { + payload = JSON.parse(plaintext); + } catch (err) { + throw new DecryptionError("Could not JSON decode plaintext", event, {plaintext, err}); + } this._validatePayload(payload, event); return {event: payload, senderKey}; } else { @@ -177,44 +188,6 @@ export class Decryption { } } -class Session { - constructor(data, pickleKey, olm, isNew = false) { - this.data = data; - this._olm = olm; - this._pickleKey = pickleKey; - this.isNew = isNew; - this.isModified = isNew; - } - - static create(senderKey, olmSession, olm, pickleKey, timestamp) { - return new Session({ - session: olmSession.pickle(pickleKey), - sessionId: olmSession.session_id(), - senderKey, - lastUsed: timestamp, - }, pickleKey, olm, true); - } - - get id() { - return this.data.sessionId; - } - - load() { - const session = new this._olm.Session(); - session.unpickle(this._pickleKey, this.data.session); - return session; - } - - unload(olmSession) { - olmSession.free(); - } - - save(olmSession) { - this.data.session = olmSession.pickle(this._pickleKey); - this.isModified = true; - } -} - // decryption helper for a single senderKey class SenderKeyDecryption { constructor(senderKey, sessions, olm, timestamp) { @@ -280,11 +253,12 @@ class SenderKeyDecryption { } class DecryptionChanges { - constructor(senderKeyDecryptions, payloads, errors, account) { + constructor(senderKeyDecryptions, payloads, errors, account, locks) { this._senderKeyDecryptions = senderKeyDecryptions; this._account = account; this.payloads = payloads; this.errors = errors; + this._locks = locks; } get hasNewSessions() { @@ -292,25 +266,31 @@ class DecryptionChanges { } write(txn) { - for (const senderKeyDecryption of this._senderKeyDecryptions) { - for (const session of senderKeyDecryption.getModifiedSessions()) { - txn.olmSessions.set(session.data); - if (session.isNew) { - const olmSession = session.load(); - try { - this._account.writeRemoveOneTimeKey(olmSession, txn); - } finally { - session.unload(olmSession); + try { + for (const senderKeyDecryption of this._senderKeyDecryptions) { + for (const session of senderKeyDecryption.getModifiedSessions()) { + txn.olmSessions.set(session.data); + if (session.isNew) { + const olmSession = session.load(); + try { + this._account.writeRemoveOneTimeKey(olmSession, txn); + } finally { + session.unload(olmSession); + } + } + } + if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) { + const {senderKey, sessions} = senderKeyDecryption; + // >= because index is zero-based + for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) { + const session = sessions[i]; + txn.olmSessions.remove(senderKey, session.id); } } } - if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) { - const {senderKey, sessions} = senderKeyDecryption; - // >= because index is zero-based - for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) { - const session = sessions[i]; - txn.olmSessions.remove(senderKey, session.id); - } + } finally { + for (const lock of this._locks) { + lock.release(); } } } diff --git a/src/matrix/e2ee/olm/Encryption.js b/src/matrix/e2ee/olm/Encryption.js new file mode 100644 index 00000000..680ce154 --- /dev/null +++ b/src/matrix/e2ee/olm/Encryption.js @@ -0,0 +1,274 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {groupByWithCreator} from "../../../utils/groupBy.js"; +import {verifyEd25519Signature, OLM_ALGORITHM} from "../common.js"; +import {createSessionEntry} from "./Session.js"; + +function findFirstSessionId(sessionIds) { + return sessionIds.reduce((first, sessionId) => { + if (!first || sessionId < first) { + return sessionId; + } else { + return first; + } + }, null); +} + +const OTK_ALGORITHM = "signed_curve25519"; + +export class Encryption { + constructor({account, olm, olmUtil, ownUserId, storage, now, pickleKey, senderKeyLock}) { + this._account = account; + this._olm = olm; + this._olmUtil = olmUtil; + this._ownUserId = ownUserId; + this._storage = storage; + this._now = now; + this._pickleKey = pickleKey; + this._senderKeyLock = senderKeyLock; + } + + async encrypt(type, content, devices, hsApi) { + // TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed) + // take a lock on all senderKeys so decryption and other calls to encrypt (should not happen) + // don't modify the sessions at the same time + const locks = await Promise.all(devices.map(device => { + return this._senderKeyLock.takeLock(device.curve25519Key); + })); + try { + const { + devicesWithoutSession, + existingEncryptionTargets, + } = await this._findExistingSessions(devices); + + const timestamp = this._now(); + + let encryptionTargets = []; + try { + if (devicesWithoutSession.length) { + const newEncryptionTargets = await this._createNewSessions( + devicesWithoutSession, hsApi, timestamp); + encryptionTargets = encryptionTargets.concat(newEncryptionTargets); + } + await this._loadSessions(existingEncryptionTargets); + encryptionTargets = encryptionTargets.concat(existingEncryptionTargets); + const messages = encryptionTargets.map(target => { + const encryptedContent = this._encryptForDevice(type, content, target); + return new EncryptedMessage(encryptedContent, target.device); + }); + await this._storeSessions(encryptionTargets, timestamp); + return messages; + } finally { + for (const target of encryptionTargets) { + target.dispose(); + } + } + } finally { + for (const lock of locks) { + lock.release(); + } + } + } + + async _findExistingSessions(devices) { + const txn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); + const sessionIdsForDevice = await Promise.all(devices.map(async device => { + return await txn.olmSessions.getSessionIds(device.curve25519Key); + })); + const devicesWithoutSession = devices.filter((_, i) => { + const sessionIds = sessionIdsForDevice[i]; + return !(sessionIds?.length); + }); + + const existingEncryptionTargets = devices.map((device, i) => { + const sessionIds = sessionIdsForDevice[i]; + if (sessionIds?.length > 0) { + const sessionId = findFirstSessionId(sessionIds); + return EncryptionTarget.fromSessionId(device, sessionId); + } + }).filter(target => !!target); + + return {devicesWithoutSession, existingEncryptionTargets}; + } + + _encryptForDevice(type, content, target) { + const {session, device} = target; + const plaintext = JSON.stringify(this._buildPlainTextMessageForDevice(type, content, device)); + const message = session.encrypt(plaintext); + const encryptedContent = { + algorithm: OLM_ALGORITHM, + sender_key: this._account.identityKeys.curve25519, + ciphertext: { + [device.curve25519Key]: message + } + }; + return encryptedContent; + } + + _buildPlainTextMessageForDevice(type, content, device) { + return { + keys: { + "ed25519": this._account.identityKeys.ed25519 + }, + recipient_keys: { + "ed25519": device.ed25519Key + }, + recipient: device.userId, + sender: this._ownUserId, + content, + type + } + } + + async _createNewSessions(devicesWithoutSession, hsApi, timestamp) { + const newEncryptionTargets = await this._claimOneTimeKeys(hsApi, devicesWithoutSession); + try { + for (const target of newEncryptionTargets) { + const {device, oneTimeKey} = target; + target.session = this._account.createOutboundOlmSession(device.curve25519Key, oneTimeKey); + } + this._storeSessions(newEncryptionTargets, timestamp); + } catch (err) { + for (const target of newEncryptionTargets) { + target.dispose(); + } + throw err; + } + return newEncryptionTargets; + } + + async _claimOneTimeKeys(hsApi, deviceIdentities) { + // create a Map> + const devicesByUser = groupByWithCreator(deviceIdentities, + device => device.userId, + () => new Map(), + (deviceMap, device) => deviceMap.set(device.deviceId, device) + ); + const oneTimeKeys = Array.from(devicesByUser.entries()).reduce((usersObj, [userId, deviceMap]) => { + usersObj[userId] = Array.from(deviceMap.values()).reduce((devicesObj, device) => { + devicesObj[device.deviceId] = OTK_ALGORITHM; + return devicesObj; + }, {}); + return usersObj; + }, {}); + const claimResponse = await hsApi.claimKeys({ + timeout: 10000, + one_time_keys: oneTimeKeys + }).response(); + // TODO: log claimResponse.failures + const userKeyMap = claimResponse?.["one_time_keys"]; + return this._verifyAndCreateOTKTargets(userKeyMap, devicesByUser); + } + + _verifyAndCreateOTKTargets(userKeyMap, devicesByUser) { + const verifiedEncryptionTargets = []; + for (const [userId, userSection] of Object.entries(userKeyMap)) { + for (const [deviceId, deviceSection] of Object.entries(userSection)) { + const [firstPropName, keySection] = Object.entries(deviceSection)[0]; + const [keyAlgorithm] = firstPropName.split(":"); + if (keyAlgorithm === OTK_ALGORITHM) { + const device = devicesByUser.get(userId)?.get(deviceId); + if (device) { + const isValidSignature = verifyEd25519Signature( + this._olmUtil, userId, deviceId, device.ed25519Key, keySection); + if (isValidSignature) { + const target = EncryptionTarget.fromOTK(device, keySection.key); + verifiedEncryptionTargets.push(target); + } + } + } + } + } + return verifiedEncryptionTargets; + } + + async _loadSessions(encryptionTargets) { + const txn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); + // given we run loading in parallel, there might still be some + // storage requests that will finish later once one has failed. + // those should not allocate a session anymore. + let failed = false; + try { + await Promise.all(encryptionTargets.map(async encryptionTarget => { + const sessionEntry = await txn.olmSessions.get( + encryptionTarget.device.curve25519Key, encryptionTarget.sessionId); + if (sessionEntry && !failed) { + const olmSession = new this._olm.Session(); + olmSession.unpickle(this._pickleKey, sessionEntry.session); + encryptionTarget.session = olmSession; + } + })); + } catch (err) { + failed = true; + // clean up the sessions that did load + for (const target of encryptionTargets) { + target.dispose(); + } + throw err; + } + } + + async _storeSessions(encryptionTargets, timestamp) { + const txn = await this._storage.readWriteTxn([this._storage.storeNames.olmSessions]); + try { + for (const target of encryptionTargets) { + const sessionEntry = createSessionEntry( + target.session, target.device.curve25519Key, timestamp, this._pickleKey); + txn.olmSessions.set(sessionEntry); + } + } catch (err) { + txn.abort(); + throw err; + } + await txn.complete(); + } +} + +// just a container needed to encrypt a message for a recipient device +// it is constructed with either a oneTimeKey +// (and later converted to a session) in case of a new session +// or an existing session +class EncryptionTarget { + constructor(device, oneTimeKey, sessionId) { + this.device = device; + this.oneTimeKey = oneTimeKey; + this.sessionId = sessionId; + // an olmSession, should probably be called olmSession + this.session = null; + } + + static fromOTK(device, oneTimeKey) { + return new EncryptionTarget(device, oneTimeKey, null); + } + + static fromSessionId(device, sessionId) { + return new EncryptionTarget(device, null, sessionId); + } + + dispose() { + if (this.session) { + this.session.free(); + } + } +} + +class EncryptedMessage { + constructor(content, device) { + this.content = content; + this.device = device; + } +} diff --git a/src/matrix/e2ee/olm/Session.js b/src/matrix/e2ee/olm/Session.js new file mode 100644 index 00000000..9b5f4db0 --- /dev/null +++ b/src/matrix/e2ee/olm/Session.js @@ -0,0 +1,58 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +export function createSessionEntry(olmSession, senderKey, timestamp, pickleKey) { + return { + session: olmSession.pickle(pickleKey), + sessionId: olmSession.session_id(), + senderKey, + lastUsed: timestamp, + }; +} + +export class Session { + constructor(data, pickleKey, olm, isNew = false) { + this.data = data; + this._olm = olm; + this._pickleKey = pickleKey; + this.isNew = isNew; + this.isModified = isNew; + } + + static create(senderKey, olmSession, olm, pickleKey, timestamp) { + const data = createSessionEntry(olmSession, senderKey, timestamp, pickleKey); + return new Session(data, pickleKey, olm, true); + } + + get id() { + return this.data.sessionId; + } + + load() { + const session = new this._olm.Session(); + session.unpickle(this._pickleKey, this.data.session); + return session; + } + + unload(olmSession) { + olmSession.free(); + } + + save(olmSession) { + this.data.session = olmSession.pickle(this._pickleKey); + this.isModified = true; + } +} diff --git a/src/matrix/net/HomeServerApi.js b/src/matrix/net/HomeServerApi.js index 42d1b0e0..9ea7dc26 100644 --- a/src/matrix/net/HomeServerApi.js +++ b/src/matrix/net/HomeServerApi.js @@ -168,6 +168,10 @@ export class HomeServerApi { return this._post("/keys/query", null, queryRequest, options); } + claimKeys(payload, options = null) { + return this._post("/keys/claim", null, payload, options); + } + get mediaRepository() { return this._mediaRepository; } diff --git a/src/matrix/storage/idb/stores/OlmSessionStore.js b/src/matrix/storage/idb/stores/OlmSessionStore.js index c94b3bfd..4648f09c 100644 --- a/src/matrix/storage/idb/stores/OlmSessionStore.js +++ b/src/matrix/storage/idb/stores/OlmSessionStore.js @@ -18,11 +18,31 @@ function encodeKey(senderKey, sessionId) { return `${senderKey}|${sessionId}`; } +function decodeKey(key) { + const [senderKey, sessionId] = key.split("|"); + return {senderKey, sessionId}; +} + export class OlmSessionStore { constructor(store) { this._store = store; } + async getSessionIds(senderKey) { + const sessionIds = []; + const range = IDBKeyRange.lowerBound(encodeKey(senderKey, "")); + await this._store.iterateKeys(range, key => { + const decodedKey = decodeKey(key); + // prevent running into the next room + if (decodedKey.senderKey === senderKey) { + sessionIds.push(decodedKey.sessionId); + return false; // fetch more + } + return true; // done + }); + return sessionIds; + } + getAll(senderKey) { const range = IDBKeyRange.lowerBound(encodeKey(senderKey, "")); return this._store.selectWhile(range, session => { diff --git a/src/utils/Lock.js b/src/utils/Lock.js new file mode 100644 index 00000000..6a198097 --- /dev/null +++ b/src/utils/Lock.js @@ -0,0 +1,86 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +export class Lock { + constructor() { + this._promise = null; + this._resolve = null; + } + + take() { + if (!this._promise) { + this._promise = new Promise(resolve => { + this._resolve = resolve; + }); + return true; + } + return false; + } + + get isTaken() { + return !!this._promise; + } + + release() { + if (this._resolve) { + this._promise = null; + const resolve = this._resolve; + this._resolve = null; + resolve(); + } + } + + released() { + return this._promise; + } +} + +export function tests() { + return { + "taking a lock twice returns false": assert => { + const lock = new Lock(); + assert.equal(lock.take(), true); + assert.equal(lock.isTaken, true); + assert.equal(lock.take(), false); + }, + "can take a released lock again": assert => { + const lock = new Lock(); + lock.take(); + lock.release(); + assert.equal(lock.isTaken, false); + assert.equal(lock.take(), true); + }, + "2 waiting for lock, only first one gets it": async assert => { + const lock = new Lock(); + lock.take(); + + let first; + lock.released().then(() => first = lock.take()); + let second; + lock.released().then(() => second = lock.take()); + const promise = lock.released(); + lock.release(); + await promise; + assert.strictEqual(first, true); + assert.strictEqual(second, false); + }, + "await non-taken lock": async assert => { + const lock = new Lock(); + await lock.released(); + assert(true); + } + } +} diff --git a/src/utils/LockMap.js b/src/utils/LockMap.js new file mode 100644 index 00000000..f99776cc --- /dev/null +++ b/src/utils/LockMap.js @@ -0,0 +1,93 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {Lock} from "./Lock.js"; + +export class LockMap { + constructor() { + this._map = new Map(); + } + + async takeLock(key) { + let lock = this._map.get(key); + if (lock) { + while (!lock.take()) { + await lock.released(); + } + } else { + lock = new Lock(); + lock.take(); + this._map.set(key, lock); + } + // don't leave old locks lying around + lock.released().then(() => { + // give others a chance to take the lock first + Promise.resolve().then(() => { + if (!lock.isTaken) { + this._map.delete(key); + } + }); + }); + return lock; + } +} + +export function tests() { + return { + "taking a lock on the same key blocks": async assert => { + const lockMap = new LockMap(); + const lock = await lockMap.takeLock("foo"); + let second = false; + const prom = lockMap.takeLock("foo").then(() => { + second = true; + }); + assert.equal(second, false); + // do a delay to make sure prom does not resolve on its own + await Promise.resolve(); + lock.release(); + await prom; + assert.equal(second, true); + }, + "lock is not cleaned up with second request": async assert => { + const lockMap = new LockMap(); + const lock = await lockMap.takeLock("foo"); + let ranSecond = false; + const prom = lockMap.takeLock("foo").then(returnedLock => { + ranSecond = true; + assert.equal(returnedLock.isTaken, true); + // peek into internals, naughty + assert.equal(lockMap._map.get("foo"), returnedLock); + }); + lock.release(); + await prom; + // double delay to make sure cleanup logic ran + await Promise.resolve(); + await Promise.resolve(); + assert.equal(ranSecond, true); + }, + "lock is cleaned up without other request": async assert => { + const lockMap = new LockMap(); + const lock = await lockMap.takeLock("foo"); + await Promise.resolve(); + lock.release(); + // double delay to make sure cleanup logic ran + await Promise.resolve(); + await Promise.resolve(); + assert.equal(lockMap._map.has("foo"), false); + }, + + }; +} diff --git a/src/utils/groupBy.js b/src/utils/groupBy.js new file mode 100644 index 00000000..5df2f36d --- /dev/null +++ b/src/utils/groupBy.js @@ -0,0 +1,35 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +export function groupBy(array, groupFn) { + return groupByWithCreator(array, groupFn, + () => {return [];}, + (array, value) => array.push(value) + ); +} + +export function groupByWithCreator(array, groupFn, createCollectionFn, addCollectionFn) { + return array.reduce((map, value) => { + const key = groupFn(value); + let collection = map.get(key); + if (!collection) { + collection = createCollectionFn(); + map.set(key, collection); + } + addCollectionFn(collection, value); + return map; + }, new Map()); +}