use backup flag in key backup rather than separate store

This commit is contained in:
Bruno Windels 2022-01-27 16:07:18 +01:00
parent 48e72f9b69
commit dd2b41ff95
4 changed files with 35 additions and 49 deletions

View File

@ -57,7 +57,8 @@ export class DeviceMessageHandler {
async writeSync(prep, txn) { async writeSync(prep, txn) {
// write olm changes // write olm changes
prep.olmDecryptChanges.write(txn); prep.olmDecryptChanges.write(txn);
await Promise.all(prep.newRoomKeys.map(key => this._megolmDecryption.writeRoomKey(key, txn))); const didWriteValues = await Promise.all(prep.newRoomKeys.map(key => this._megolmDecryption.writeRoomKey(key, txn)));
return didWriteValues.some(didWrite => !!didWrite);
} }
} }

View File

@ -597,12 +597,7 @@ export class Session {
} }
if (preparation) { if (preparation) {
await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(preparation, txn, log)); changes.hasNewRoomKeys = await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(preparation, txn, log));
// this should come after the deviceMessageHandler, so the room keys are already written and their
// isBetter property has been checked
if (this._keyBackup) {
changes.shouldFlushKeyBackup = this._keyBackup.writeKeys(preparation.newRoomKeys, txn, log);
}
} }
// store account data // store account data
@ -641,7 +636,7 @@ export class Session {
} }
} }
// should flush and not already flushing // should flush and not already flushing
if (changes.shouldFlushKeyBackup && this._keyBackup && !this._keyBackupOperation.get()) { if (changes.hasNewRoomKeys && this._keyBackup && !this._keyBackupOperation.get()) {
log.wrapDetached("flush key backup", async log => { log.wrapDetached("flush key backup", async log => {
const operation = this._keyBackup.flush(log); const operation = this._keyBackup.flush(log);
this._keyBackupOperation.set(operation); this._keyBackupOperation.set(operation);

View File

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {BackupStatus} from "../../../storage/idb/stores/InboundGroupSessionStore";
import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore"; import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore";
import type {Transaction} from "../../../storage/idb/Transaction"; import type {Transaction} from "../../../storage/idb/Transaction";
import type {DecryptionResult} from "../../DecryptionResult"; import type {DecryptionResult} from "../../DecryptionResult";
@ -81,6 +82,7 @@ export abstract class IncomingRoomKey extends RoomKey {
senderKey: this.senderKey, senderKey: this.senderKey,
sessionId: this.sessionId, sessionId: this.sessionId,
session: pickledSession, session: pickledSession,
backup: this.backupStatus,
claimedKeys: {"ed25519": this.claimedEd25519Key}, claimedKeys: {"ed25519": this.claimedEd25519Key},
}; };
txn.inboundGroupSessions.set(sessionEntry); txn.inboundGroupSessions.set(sessionEntry);
@ -125,6 +127,10 @@ export abstract class IncomingRoomKey extends RoomKey {
} }
return this.isBetter!; return this.isBetter!;
} }
protected get backupStatus(): BackupStatus {
return BackupStatus.NotBackedUp;
}
} }
class DeviceMessageRoomKey extends IncomingRoomKey { class DeviceMessageRoomKey extends IncomingRoomKey {
@ -162,9 +168,13 @@ class BackupRoomKey extends IncomingRoomKey {
loadInto(session) { loadInto(session) {
session.import_session(this.serializationKey); session.import_session(this.serializationKey);
} }
protected get backupStatus(): BackupStatus {
return BackupStatus.BackedUp;
}
} }
class StoredRoomKey extends RoomKey { export class StoredRoomKey extends RoomKey {
private storageEntry: InboundGroupSessionEntry; private storageEntry: InboundGroupSessionEntry;
constructor(storageEntry: InboundGroupSessionEntry) { constructor(storageEntry: InboundGroupSessionEntry) {

View File

@ -15,7 +15,7 @@ limitations under the License.
*/ */
import {StoreNames} from "../../../storage/common"; import {StoreNames} from "../../../storage/common";
import {keyFromStorage, keyFromBackup} from "../decryption/RoomKey"; import {StoredRoomKey, keyFromBackup} from "../decryption/RoomKey";
import {MEGOLM_ALGORITHM} from "../../common"; import {MEGOLM_ALGORITHM} from "../../common";
import * as Curve25519 from "./Curve25519"; import * as Curve25519 from "./Curve25519";
import {AbortableOperation} from "../../../../utils/AbortableOperation"; import {AbortableOperation} from "../../../../utils/AbortableOperation";
@ -30,7 +30,6 @@ import type {Storage} from "../../../storage/idb/Storage";
import type {ILogItem} from "../../../../logging/types"; import type {ILogItem} from "../../../../logging/types";
import type {Platform} from "../../../../platform/web/Platform"; import type {Platform} from "../../../../platform/web/Platform";
import type {Transaction} from "../../../storage/idb/Transaction"; import type {Transaction} from "../../../storage/idb/Transaction";
import type {BackupEntry} from "../../../storage/idb/stores/SessionNeedingBackupStore";
import type * as OlmNamespace from "@matrix-org/olm"; import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace; type Olm = typeof OlmNamespace;
@ -57,17 +56,6 @@ export class KeyBackup {
} }
} }
writeKeys(roomKeys: IncomingRoomKey[], txn: Transaction): boolean {
let hasBetter = false;
for (const key of roomKeys) {
if (key.isBetter) {
txn.sessionsNeedingBackup.set(key.roomId, key.senderKey, key.sessionId);
hasBetter = true;
}
}
return hasBetter;
}
// TODO: protect against having multiple concurrent flushes // TODO: protect against having multiple concurrent flushes
flush(log: ILogItem): AbortableOperation<Promise<boolean>, Progress> { flush(log: ILogItem): AbortableOperation<Promise<boolean>, Progress> {
return new AbortableOperation(async (setAbortable, setProgress) => { return new AbortableOperation(async (setAbortable, setProgress) => {
@ -77,36 +65,30 @@ export class KeyBackup {
const timeout = this.platform.clock.createTimeout(this.platform.random() * 10000); const timeout = this.platform.clock.createTimeout(this.platform.random() * 10000);
setAbortable(timeout); setAbortable(timeout);
await timeout.elapsed(); await timeout.elapsed();
const txn = await this.storage.readTxn([ const txn = await this.storage.readTxn([StoreNames.inboundGroupSessions]);
StoreNames.sessionsNeedingBackup,
StoreNames.inboundGroupSessions,
]);
setAbortable(txn); setAbortable(txn);
// fetch total again on each iteration as while we are flushing, sync might be adding keys // fetch total again on each iteration as while we are flushing, sync might be adding keys
total = await txn.sessionsNeedingBackup.count(); total = await txn.inboundGroupSessions.countNonBackedUpSessions();
setProgress(new Progress(total, amountFinished)); setProgress(new Progress(total, amountFinished));
const keysNeedingBackup = await txn.sessionsNeedingBackup.getFirstEntries(20); const keysNeedingBackup = (await txn.inboundGroupSessions.getFirstNonBackedUpSessions(20))
.map(entry => new StoredRoomKey(entry));
if (keysNeedingBackup.length === 0) { if (keysNeedingBackup.length === 0) {
return true; return true;
} }
const roomKeysOrNotFound = await Promise.all(keysNeedingBackup.map(k => keyFromStorage(k.roomId, k.senderKey, k.sessionId, txn))); const payload = await this.encodeKeysForBackup(keysNeedingBackup);
const roomKeys = roomKeysOrNotFound.filter(k => !!k) as RoomKey[]; const uploadRequest = this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload, {log});
if (roomKeys.length) { setAbortable(uploadRequest);
const payload = await this.encodeKeysForBackup(roomKeys); try {
const uploadRequest = this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload, {log}); await uploadRequest.response();
setAbortable(uploadRequest); } catch (err) {
try { if (err.name === "HomeServerError" && err.errcode === "M_WRONG_ROOM_KEYS_VERSION") {
await uploadRequest.response(); log.set("wrong_version", true);
} catch (err) { return false;
if (err.name === "HomeServerError" && err.errcode === "M_WRONG_ROOM_KEYS_VERSION") { } else {
log.set("wrong_version", true); throw err;
return false;
} else {
throw err;
}
} }
} }
this.removeBackedUpKeys(keysNeedingBackup, setAbortable); this.markKeysAsBackedUp(keysNeedingBackup, setAbortable);
amountFinished += keysNeedingBackup.length; amountFinished += keysNeedingBackup.length;
setProgress(new Progress(total, amountFinished)); setProgress(new Progress(total, amountFinished));
} }
@ -126,15 +108,13 @@ export class KeyBackup {
return payload; return payload;
} }
private async removeBackedUpKeys(keysNeedingBackup: BackupEntry[], setAbortable: SetAbortableFn) { private async markKeysAsBackedUp(roomKeys: RoomKey[], setAbortable: SetAbortableFn) {
const txn = await this.storage.readWriteTxn([ const txn = await this.storage.readWriteTxn([
StoreNames.sessionsNeedingBackup, StoreNames.inboundGroupSessions,
]); ]);
setAbortable(txn); setAbortable(txn);
try { try {
for (const key of keysNeedingBackup) { await Promise.all(roomKeys.map(key => txn.inboundGroupSessions.markAsBackedUp(key.roomId, key.senderKey, key.sessionId)));
txn.sessionsNeedingBackup.remove(key.roomId, key.senderKey, key.sessionId);
}
} catch (err) { } catch (err) {
txn.abort(); txn.abort();
throw err; throw err;