diff --git a/src/matrix/e2ee/megolm/Decryption.js b/src/matrix/e2ee/megolm/Decryption.js index 68bca6fe..6d7941a0 100644 --- a/src/matrix/e2ee/megolm/Decryption.js +++ b/src/matrix/e2ee/megolm/Decryption.js @@ -14,12 +14,97 @@ See the License for the specific language governing permissions and limitations under the License. */ +import {DecryptionError} from "../common.js"; + +const CACHE_MAX_SIZE = 10; + export class Decryption { constructor({pickleKey, olm}) { this._pickleKey = pickleKey; this._olm = olm; } + createSessionCache() { + return new SessionCache(); + } + + async decryptNewEvent(roomId, event, sessionCache, txn) { + const {payload, messageIndex} = this._decrypt(roomId, event, sessionCache, txn); + const sessionId = event.content?.["session_id"]; + this._handleReplayAttacks(roomId, sessionId, messageIndex, event, txn); + return payload; + } + + async decryptStoredEvent(roomId, event, sessionCache, txn) { + const {payload} = this._decrypt(roomId, event, sessionCache, txn); + return payload; + } + + async _decrypt(roomId, event, sessionCache, txn) { + const senderKey = event.content?.["sender_key"]; + const sessionId = event.content?.["session_id"]; + const ciphertext = event.content?.ciphertext; + + if ( + typeof senderKey !== "string" || + typeof sessionId !== "string" || + typeof ciphertext !== "string" + ) { + throw new DecryptionError("MEGOLM_INVALID_EVENT", event); + } + + let session = sessionCache.get(roomId, senderKey, sessionId); + if (!session) { + const sessionEntry = await txn.inboundGroupSessions.get(roomId, senderKey, sessionId); + if (sessionEntry) { + session = new this._olm.InboundGroupSession(); + try { + session.unpickle(this._pickleKey, sessionEntry.session); + } catch (err) { + session.free(); + throw err; + } + sessionCache.add(roomId, senderKey, session); + } + } + if (!session) { + return; + } + const {plaintext, message_index: messageIndex} = session.decrypt(ciphertext); + let payload; + try { + payload = JSON.parse(plaintext); + } catch (err) { + throw new DecryptionError("NOT_JSON", event, {plaintext, err}); + } + if (payload.room_id !== roomId) { + throw new DecryptionError("MEGOLM_WRONG_ROOM", event, + {encryptedRoomId: payload.room_id, eventRoomId: roomId}); + } + return {payload, messageIndex}; + } + + async _handleReplayAttacks(roomId, sessionId, messageIndex, event, txn) { + const eventId = event.event_id; + const timestamp = event.origin_server_ts; + const decryption = await txn.groupSessionDecryptions.get(roomId, sessionId, messageIndex); + if (decryption && decryption.eventId !== eventId) { + // the one with the newest timestamp should be the attack + const decryptedEventIsBad = decryption.timestamp < timestamp; + const badEventId = decryptedEventIsBad ? eventId : decryption.eventId; + throw new DecryptionError("MEGOLM_REPLAY_ATTACK", event, {badEventId, otherEventId: decryption.eventId}); + } + if (!decryption) { + txn.groupSessionDecryptions.set({ + roomId, + sessionId, + messageIndex, + eventId, + timestamp + }); + } + } + async addRoomKeys(payloads, txn) { const newSessions = []; for (const {senderKey, event} of payloads) { @@ -56,13 +141,49 @@ export class Decryption { } } + // this will be passed to the Room in notifyRoomKeys return newSessions; } +} - applyRoomKeyChanges(newSessions) { - // retry decryption with the new sessions - if (newSessions.length) { - console.log(`I have ${newSessions.length} new inbound group sessions`, newSessions) +class SessionCache { + constructor() { + this._sessions = []; + } + + get(roomId, senderKey, sessionId) { + const idx = this._sessions.findIndex(s => { + return s.roomId === roomId && + s.senderKey === senderKey && + sessionId === s.session.session_id(); + }); + if (idx !== -1) { + const entry = this._sessions[idx]; + // move to top + if (idx > 0) { + this._sessions.splice(idx, 1); + this._sessions.unshift(entry); + } + return entry.session; } } + + add(roomId, senderKey, session) { + // add new at top + this._sessions.unshift({roomId, senderKey, session}); + if (this._sessions.length > CACHE_MAX_SIZE) { + // free sessions we're about to remove + for (let i = CACHE_MAX_SIZE; i < this._sessions.length; i += 1) { + this._sessions[i].session.free(); + } + this._sessions = this._sessions.slice(0, CACHE_MAX_SIZE); + } + } + + dispose() { + for (const entry of this._sessions) { + entry.session.free(); + } + + } }