offload olm account creation in worker

This commit is contained in:
Bruno Windels 2020-09-11 10:43:17 +02:00
parent 0b26e6f53a
commit e0d9d703b7
8 changed files with 134 additions and 61 deletions

View File

@ -26,6 +26,7 @@ import {BrawlView} from "./ui/web/BrawlView.js";
import {Clock} from "./ui/web/dom/Clock.js"; import {Clock} from "./ui/web/dom/Clock.js";
import {OnlineStatus} from "./ui/web/dom/OnlineStatus.js"; import {OnlineStatus} from "./ui/web/dom/OnlineStatus.js";
import {WorkerPool} from "./utils/WorkerPool.js"; import {WorkerPool} from "./utils/WorkerPool.js";
import {OlmWorker} from "./matrix/e2ee/OlmWorker.js";
function addScript(src) { function addScript(src) {
return new Promise(function (resolve, reject) { return new Promise(function (resolve, reject) {
@ -65,12 +66,13 @@ function relPath(path, basePath) {
return "../".repeat(dirCount) + path; return "../".repeat(dirCount) + path;
} }
async function loadWorker(paths) { async function loadOlmWorker(paths) {
const workerPool = new WorkerPool(paths.worker, 4); const workerPool = new WorkerPool(paths.worker, 4);
await workerPool.init(); await workerPool.init();
const path = relPath(paths.olm.legacyBundle, paths.worker); const path = relPath(paths.olm.legacyBundle, paths.worker);
await workerPool.sendAll({type: "load_olm", path}); await workerPool.sendAll({type: "load_olm", path});
return workerPool; const olmWorker = new OlmWorker(workerPool);
return olmWorker;
} }
// Don't use a default export here, as we use multiple entries during legacy build, // Don't use a default export here, as we use multiple entries during legacy build,
@ -100,9 +102,9 @@ export async function main(container, paths) {
// if wasm is not supported, we'll want // if wasm is not supported, we'll want
// to run some olm operations in a worker (mainly for IE11) // to run some olm operations in a worker (mainly for IE11)
let workerPromise; let workerPromise;
if (!window.WebAssembly) { // if (!window.WebAssembly) {
workerPromise = loadWorker(paths); workerPromise = loadOlmWorker(paths);
} // }
const vm = new BrawlViewModel({ const vm = new BrawlViewModel({
createSessionContainer: () => { createSessionContainer: () => {

View File

@ -33,7 +33,7 @@ const PICKLE_KEY = "DEFAULT_KEY";
export class Session { export class Session {
// sessionInfo contains deviceId, userId and homeServer // sessionInfo contains deviceId, userId and homeServer
constructor({clock, storage, hsApi, sessionInfo, olm, workerPool}) { constructor({clock, storage, hsApi, sessionInfo, olm, olmWorker}) {
this._clock = clock; this._clock = clock;
this._storage = storage; this._storage = storage;
this._hsApi = hsApi; this._hsApi = hsApi;
@ -52,7 +52,7 @@ export class Session {
this._megolmEncryption = null; this._megolmEncryption = null;
this._megolmDecryption = null; this._megolmDecryption = null;
this._getSyncToken = () => this.syncToken; this._getSyncToken = () => this.syncToken;
this._workerPool = workerPool; this._olmWorker = olmWorker;
if (olm) { if (olm) {
this._olmUtil = new olm.Utility(); this._olmUtil = new olm.Utility();
@ -101,7 +101,7 @@ export class Session {
this._megolmDecryption = new MegOlmDecryption({ this._megolmDecryption = new MegOlmDecryption({
pickleKey: PICKLE_KEY, pickleKey: PICKLE_KEY,
olm: this._olm, olm: this._olm,
workerPool: this._workerPool, olmWorker: this._olmWorker,
}); });
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption}); this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption});
} }
@ -140,23 +140,15 @@ export class Session {
throw new Error("there should not be an e2ee account already on a fresh login"); throw new Error("there should not be an e2ee account already on a fresh login");
} }
if (!this._e2eeAccount) { if (!this._e2eeAccount) {
const txn = await this._storage.readWriteTxn([
this._storage.storeNames.session
]);
try {
this._e2eeAccount = await E2EEAccount.create({ this._e2eeAccount = await E2EEAccount.create({
hsApi: this._hsApi, hsApi: this._hsApi,
olm: this._olm, olm: this._olm,
pickleKey: PICKLE_KEY, pickleKey: PICKLE_KEY,
userId: this._sessionInfo.userId, userId: this._sessionInfo.userId,
deviceId: this._sessionInfo.deviceId, deviceId: this._sessionInfo.deviceId,
txn olmWorker: this._olmWorker,
storage: this._storage,
}); });
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
this._setupEncryption(); this._setupEncryption();
} }
await this._e2eeAccount.generateOTKsIfNeeded(this._storage); await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
@ -184,6 +176,7 @@ export class Session {
pickleKey: PICKLE_KEY, pickleKey: PICKLE_KEY,
userId: this._sessionInfo.userId, userId: this._sessionInfo.userId,
deviceId: this._sessionInfo.deviceId, deviceId: this._sessionInfo.deviceId,
olmWorker: this._olmWorker,
txn txn
}); });
if (this._e2eeAccount) { if (this._e2eeAccount) {
@ -204,7 +197,7 @@ export class Session {
} }
stop() { stop() {
this._workerPool?.dispose(); this._olmWorker?.dispose();
this._sendScheduler.stop(); this._sendScheduler.stop();
} }

View File

@ -153,13 +153,13 @@ export class SessionContainer {
homeServer: sessionInfo.homeServer, homeServer: sessionInfo.homeServer,
}; };
const olm = await this._olmPromise; const olm = await this._olmPromise;
let workerPool = null; let olmWorker = null;
if (this._workerPromise) { if (this._workerPromise) {
workerPool = await this._workerPromise; olmWorker = await this._workerPromise;
} }
this._session = new Session({storage: this._storage, this._session = new Session({storage: this._storage,
sessionInfo: filteredSessionInfo, hsApi, olm, sessionInfo: filteredSessionInfo, hsApi, olm,
clock: this._clock, workerPool}); clock: this._clock, olmWorker});
await this._session.load(); await this._session.load();
this._status.set(LoadStatus.SessionSetup); this._status.set(LoadStatus.SessionSetup);
await this._session.beforeFirstSync(isNewLogin); await this._session.beforeFirstSync(isNewLogin);

View File

@ -23,7 +23,7 @@ const DEVICE_KEY_FLAG_SESSION_KEY = SESSION_KEY_PREFIX + "areDeviceKeysUploaded"
const SERVER_OTK_COUNT_SESSION_KEY = SESSION_KEY_PREFIX + "serverOTKCount"; const SERVER_OTK_COUNT_SESSION_KEY = SESSION_KEY_PREFIX + "serverOTKCount";
export class Account { export class Account {
static async load({olm, pickleKey, hsApi, userId, deviceId, txn}) { static async load({olm, pickleKey, hsApi, userId, deviceId, olmWorker, txn}) {
const pickledAccount = await txn.session.get(ACCOUNT_SESSION_KEY); const pickledAccount = await txn.session.get(ACCOUNT_SESSION_KEY);
if (pickledAccount) { if (pickledAccount) {
const account = new olm.Account(); const account = new olm.Account();
@ -31,26 +31,39 @@ export class Account {
account.unpickle(pickleKey, pickledAccount); account.unpickle(pickleKey, pickledAccount);
const serverOTKCount = await txn.session.get(SERVER_OTK_COUNT_SESSION_KEY); const serverOTKCount = await txn.session.get(SERVER_OTK_COUNT_SESSION_KEY);
return new Account({pickleKey, hsApi, account, userId, return new Account({pickleKey, hsApi, account, userId,
deviceId, areDeviceKeysUploaded, serverOTKCount, olm}); deviceId, areDeviceKeysUploaded, serverOTKCount, olm, olmWorker});
} }
} }
static async create({olm, pickleKey, hsApi, userId, deviceId, txn}) { static async create({olm, pickleKey, hsApi, userId, deviceId, olmWorker, storage}) {
const account = new olm.Account(); const account = new olm.Account();
if (olmWorker) {
await olmWorker.createAccountAndOTKs(account, account.max_number_of_one_time_keys());
} else {
account.create(); account.create();
account.generate_one_time_keys(account.max_number_of_one_time_keys()); account.generate_one_time_keys(account.max_number_of_one_time_keys());
}
const pickledAccount = account.pickle(pickleKey); const pickledAccount = account.pickle(pickleKey);
const areDeviceKeysUploaded = false;
const txn = await storage.readWriteTxn([
storage.storeNames.session
]);
try {
// add will throw if the key already exists // add will throw if the key already exists
// we would not want to overwrite olmAccount here // we would not want to overwrite olmAccount here
const areDeviceKeysUploaded = false; txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount);
await txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount); txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded);
await txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded); txn.session.add(SERVER_OTK_COUNT_SESSION_KEY, 0);
await txn.session.add(SERVER_OTK_COUNT_SESSION_KEY, 0); } catch (err) {
txn.abort();
throw err;
}
await txn.complete();
return new Account({pickleKey, hsApi, account, userId, return new Account({pickleKey, hsApi, account, userId,
deviceId, areDeviceKeysUploaded, serverOTKCount: 0, olm}); deviceId, areDeviceKeysUploaded, serverOTKCount: 0, olm, olmWorker});
} }
constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded, serverOTKCount, olm}) { constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded, serverOTKCount, olm, olmWorker}) {
this._olm = olm; this._olm = olm;
this._pickleKey = pickleKey; this._pickleKey = pickleKey;
this._hsApi = hsApi; this._hsApi = hsApi;
@ -59,6 +72,7 @@ export class Account {
this._deviceId = deviceId; this._deviceId = deviceId;
this._areDeviceKeysUploaded = areDeviceKeysUploaded; this._areDeviceKeysUploaded = areDeviceKeysUploaded;
this._serverOTKCount = serverOTKCount; this._serverOTKCount = serverOTKCount;
this._olmWorker = olmWorker;
this._identityKeys = JSON.parse(this._account.identity_keys()); this._identityKeys = JSON.parse(this._account.identity_keys());
} }

View File

@ -14,13 +14,30 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
export class DecryptionWorker { export class OlmWorker {
constructor(workerPool) { constructor(workerPool) {
this._workerPool = workerPool; this._workerPool = workerPool;
} }
decrypt(session, ciphertext) { megolmDecrypt(session, ciphertext) {
const sessionKey = session.export_session(session.first_known_index()); const sessionKey = session.export_session(session.first_known_index());
return this._workerPool.send({type: "megolm_decrypt", ciphertext, sessionKey}); return this._workerPool.send({type: "megolm_decrypt", ciphertext, sessionKey});
} }
async createAccountAndOTKs(account, otkAmount) {
// IE11 does not support getRandomValues in a worker, so we have to generate the values upfront.
let randomValues;
if (window.msCrypto) {
randomValues = [
window.msCrypto.getRandomValues(new Uint8Array(64)),
window.msCrypto.getRandomValues(new Uint8Array(otkAmount * 32)),
];
}
const pickle = await this._workerPool.send({type: "olm_create_account_otks", randomValues, otkAmount}).response();
account.unpickle("", pickle);
}
dispose() {
this._workerPool.dispose();
}
} }

View File

@ -21,7 +21,6 @@ import {SessionInfo} from "./decryption/SessionInfo.js";
import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js"; import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js";
import {SessionDecryption} from "./decryption/SessionDecryption.js"; import {SessionDecryption} from "./decryption/SessionDecryption.js";
import {SessionCache} from "./decryption/SessionCache.js"; import {SessionCache} from "./decryption/SessionCache.js";
import {DecryptionWorker} from "./decryption/DecryptionWorker.js";
function getSenderKey(event) { function getSenderKey(event) {
return event.content?.["sender_key"]; return event.content?.["sender_key"];
@ -36,10 +35,10 @@ function getCiphertext(event) {
} }
export class Decryption { export class Decryption {
constructor({pickleKey, olm, workerPool}) { constructor({pickleKey, olm, olmWorker}) {
this._pickleKey = pickleKey; this._pickleKey = pickleKey;
this._olm = olm; this._olm = olm;
this._decryptor = workerPool ? new DecryptionWorker(workerPool) : null; this._olmWorker = olmWorker;
} }
createSessionCache(fallback) { createSessionCache(fallback) {
@ -86,7 +85,7 @@ export class Decryption {
errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event)); errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event));
} }
} else { } else {
sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._decryptor)); sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._olmWorker));
} }
})); }));

View File

@ -22,12 +22,12 @@ import {ReplayDetectionEntry} from "./ReplayDetectionEntry.js";
* Does the actual decryption of all events for a given megolm session in a batch * Does the actual decryption of all events for a given megolm session in a batch
*/ */
export class SessionDecryption { export class SessionDecryption {
constructor(sessionInfo, events, decryptor) { constructor(sessionInfo, events, olmWorker) {
sessionInfo.retain(); sessionInfo.retain();
this._sessionInfo = sessionInfo; this._sessionInfo = sessionInfo;
this._events = events; this._events = events;
this._decryptor = decryptor; this._olmWorker = olmWorker;
this._decryptionRequests = decryptor ? [] : null; this._decryptionRequests = olmWorker ? [] : null;
} }
async decryptAll() { async decryptAll() {
@ -41,8 +41,8 @@ export class SessionDecryption {
const {session} = this._sessionInfo; const {session} = this._sessionInfo;
const ciphertext = event.content.ciphertext; const ciphertext = event.content.ciphertext;
let decryptionResult; let decryptionResult;
if (this._decryptor) { if (this._olmWorker) {
const request = this._decryptor.decrypt(session, ciphertext); const request = this._olmWorker.megolmDecrypt(session, ciphertext);
this._decryptionRequests.push(request); this._decryptionRequests.push(request);
decryptionResult = await request.response(); decryptionResult = await request.response();
} else { } else {

View File

@ -32,6 +32,44 @@ function asSuccessMessage(payload) {
class MessageHandler { class MessageHandler {
constructor() { constructor() {
this._olm = null; this._olm = null;
this._randomValues = self.crypto ? null : [];
}
_feedRandomValues(randomValues) {
if (this._randomValues) {
this._randomValues.push(...randomValues);
}
}
_checkRandomValuesUsed() {
if (this._randomValues && this._randomValues.length !== 0) {
throw new Error(`${this._randomValues.length} random values left`);
}
}
_getRandomValues(typedArray) {
if (!(typedArray instanceof Uint8Array)) {
throw new Error("only Uint8Array is supported: " + JSON.stringify({
Int8Array: typedArray instanceof Int8Array,
Uint8Array: typedArray instanceof Uint8Array,
Int16Array: typedArray instanceof Int16Array,
Uint16Array: typedArray instanceof Uint16Array,
Int32Array: typedArray instanceof Int32Array,
Uint32Array: typedArray instanceof Uint32Array,
}));
}
if (this._randomValues.length === 0) {
throw new Error("no more random values, needed one of length " + typedArray.length);
}
const precalculated = this._randomValues.shift();
if (precalculated.length !== typedArray.length) {
throw new Error(`typedArray length (${typedArray.length}) does not match precalculated length (${precalculated.length})`);
}
// copy values
for (let i = 0; i < typedArray.length; ++i) {
typedArray[i] = precalculated[i];
}
return typedArray;
} }
handleEvent(e) { handleEvent(e) {
@ -47,7 +85,7 @@ class MessageHandler {
_toMessage(fn) { _toMessage(fn) {
try { try {
let payload = fn(); const payload = fn();
if (payload instanceof Promise) { if (payload instanceof Promise) {
return payload.then( return payload.then(
payload => asSuccessMessage(payload), payload => asSuccessMessage(payload),
@ -63,18 +101,15 @@ class MessageHandler {
_loadOlm(path) { _loadOlm(path) {
return this._toMessage(async () => { return this._toMessage(async () => {
// might have some problems here with window vs self as global object? if (!self.crypto) {
if (self.msCrypto && !self.crypto) { self.crypto = {getRandomValues: this._getRandomValues.bind(this)};
self.crypto = self.msCrypto;
} }
self.importScripts(path); // mangle the globals enough to make olm believe it is running in a browser
const olm = self.olm_exports;
// mangle the globals enough to make olm load believe it is running in a browser
self.window = self; self.window = self;
self.document = {}; self.document = {};
self.importScripts(path);
const olm = self.olm_exports;
await olm.init(); await olm.init();
delete self.document;
delete self.window;
this._olm = olm; this._olm = olm;
}); });
} }
@ -93,6 +128,17 @@ class MessageHandler {
}); });
} }
_olmCreateAccountAndOTKs(randomValues, otkAmount) {
return this._toMessage(() => {
this._feedRandomValues(randomValues);
const account = new this._olm.Account();
account.create();
account.generate_one_time_keys(otkAmount);
this._checkRandomValuesUsed();
return account.pickle("");
});
}
async _handleMessage(message) { async _handleMessage(message) {
const {type} = message; const {type} = message;
if (type === "ping") { if (type === "ping") {
@ -101,6 +147,8 @@ class MessageHandler {
this._sendReply(message, await this._loadOlm(message.path)); this._sendReply(message, await this._loadOlm(message.path));
} else if (type === "megolm_decrypt") { } else if (type === "megolm_decrypt") {
this._sendReply(message, this._megolmDecrypt(message.sessionKey, message.ciphertext)); this._sendReply(message, this._megolmDecrypt(message.sessionKey, message.ciphertext));
} else if (type === "olm_create_account_otks") {
this._sendReply(message, this._olmCreateAccountAndOTKs(message.randomValues, message.otkAmount));
} }
} }
} }