1
0
mirror of https://github.com/matrix-org/matrix-js-sdk.git synced 2026-01-03 23:22:30 +03:00

Decrypt and Import full backups in chunk with progress (#4005)

* Decrypt and Import full backups in chunk with progress

* backup chunk decryption jsdoc

* Review: fix capitalization

* review: better var name

* review: fix better iterate on object

* review: extract utility function

* review: Improve test, ensure mock calls

* review: Add more test for decryption or import failures

* Review: fix typo

Co-authored-by: Andy Balaam <andy.balaam@matrix.org>

---------

Co-authored-by: Andy Balaam <andy.balaam@matrix.org>
This commit is contained in:
Valere
2024-01-19 11:08:45 +01:00
committed by GitHub
parent 418b69914a
commit 4cddc7397d
6 changed files with 334 additions and 54 deletions

View File

@@ -17,8 +17,18 @@ limitations under the License.
import fetchMock from "fetch-mock-jest";
import "fake-indexeddb/auto";
import { IDBFactory } from "fake-indexeddb";
import { Mocked } from "jest-mock";
import { createClient, CryptoEvent, ICreateClientOpts, IEvent, MatrixClient, TypedEventEmitter } from "../../../src";
import {
createClient,
CryptoApi,
CryptoEvent,
ICreateClientOpts,
IEvent,
IMegolmSessionData,
MatrixClient,
TypedEventEmitter,
} from "../../../src";
import { SyncResponder } from "../../test-utils/SyncResponder";
import { E2EKeyReceiver } from "../../test-utils/E2EKeyReceiver";
import { E2EKeyResponder } from "../../test-utils/E2EKeyResponder";
@@ -31,7 +41,7 @@ import {
syncPromise,
} from "../../test-utils/test-utils";
import * as testData from "../../test-utils/test-data";
import { KeyBackupInfo } from "../../../src/crypto-api/keybackup";
import { KeyBackupInfo, KeyBackupSession } from "../../../src/crypto-api/keybackup";
import { IKeyBackup } from "../../../src/crypto/backup";
import { flushPromises } from "../../test-utils/flushPromises";
import { defer, IDeferred } from "../../../src/utils";
@@ -286,17 +296,21 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
});
describe("recover from backup", () => {
it("can restore from backup (Curve25519 version)", async function () {
let aliceCrypto: CryptoApi;
beforeEach(async () => {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);
aliceClient = await initTestClient();
const aliceCrypto = aliceClient.getCrypto()!;
aliceCrypto = aliceClient.getCrypto()!;
await aliceClient.startClient();
// tell Alice to trust the dummy device that signed the backup
await waitForDeviceList();
await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID);
});
it("can restore from backup (Curve25519 version)", async function () {
const fullBackup = {
rooms: {
[ROOM_ID]: {
@@ -340,17 +354,179 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
expect(afterCache.imported).toStrictEqual(1);
});
/**
* Creates a mock backup response of a GET `room_keys/keys` with a given number of keys per room.
* @param keysPerRoom The number of keys per room
*/
function createBackupDownloadResponse(keysPerRoom: number[]) {
const response: {
rooms: {
[roomId: string]: {
sessions: {
[sessionId: string]: KeyBackupSession;
};
};
};
} = { rooms: {} };
const expectedTotal = keysPerRoom.reduce((a, b) => a + b, 0);
for (let i = 0; i < keysPerRoom.length; i++) {
const roomId = `!room${i}:example.com`;
response.rooms[roomId] = { sessions: {} };
for (let j = 0; j < keysPerRoom[i]; j++) {
const sessionId = `session${j}`;
// Put the same fake session data, not important for that test
response.rooms[roomId].sessions[sessionId] = testData.CURVE25519_KEY_BACKUP_DATA;
}
}
return { response, expectedTotal };
}
it("Should import full backup in chunks", async function () {
const importMockImpl = jest.fn();
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = importMockImpl;
// We need several rooms with several sessions to test chunking
const { response, expectedTotal } = createBackupDownloadResponse([45, 300, 345, 12, 130]);
fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response);
const check = await aliceCrypto.checkKeyBackupAndEnable();
const progressCallback = jest.fn();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
undefined,
undefined,
check!.backupInfo!,
{
progressCallback,
},
);
expect(result.imported).toStrictEqual(expectedTotal);
// Should be called 5 times: 200*4 plus one chunk with the remaining 32
expect(importMockImpl).toHaveBeenCalledTimes(5);
for (let i = 0; i < 4; i++) {
expect(importMockImpl.mock.calls[i][0].length).toEqual(200);
}
expect(importMockImpl.mock.calls[4][0].length).toEqual(32);
expect(progressCallback).toHaveBeenCalledWith({
stage: "fetch",
});
// Should be called 4 times and report 200/400/600/800
for (let i = 0; i < 4; i++) {
expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: (i + 1) * 200,
stage: "load_keys",
failures: 0,
});
}
// The last chunk
expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 832,
stage: "load_keys",
failures: 0,
});
});
it("Should continue to process backup if a chunk import fails and report failures", async function () {
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest
.fn()
.mockImplementationOnce(() => {
// Fail to import first chunk
throw new Error("test error");
})
// Ok for other chunks
.mockResolvedValue(undefined);
const { response, expectedTotal } = createBackupDownloadResponse([100, 300]);
fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response);
const check = await aliceCrypto.checkKeyBackupAndEnable();
const progressCallback = jest.fn();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
undefined,
undefined,
check!.backupInfo!,
{
progressCallback,
},
);
expect(result.total).toStrictEqual(expectedTotal);
// A chunk failed to import
expect(result.imported).toStrictEqual(200);
expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 0,
stage: "load_keys",
failures: 200,
});
expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 200,
stage: "load_keys",
failures: 200,
});
});
it("Should continue if some keys fails to decrypt", async function () {
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest.fn();
const decryptionFailureCount = 2;
const mockDecryptor = {
// DecryptSessions does not reject on decryption failure, but just skip the key
decryptSessions: jest.fn().mockImplementation((sessions) => {
// simulate fail to decrypt 2 keys out of all
const decrypted = [];
const keys = Object.keys(sessions);
for (let i = 0; i < keys.length - decryptionFailureCount; i++) {
decrypted.push({
session_id: keys[i],
} as unknown as Mocked<IMegolmSessionData>);
}
return decrypted;
}),
free: jest.fn(),
};
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.getBackupDecryptor = jest.fn().mockResolvedValue(mockDecryptor);
const { response, expectedTotal } = createBackupDownloadResponse([100]);
fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response);
const check = await aliceCrypto.checkKeyBackupAndEnable();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
undefined,
undefined,
check!.backupInfo!,
);
expect(result.total).toStrictEqual(expectedTotal);
// A chunk failed to import
expect(result.imported).toStrictEqual(expectedTotal - decryptionFailureCount);
});
it("recover specific session from backup", async function () {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);
aliceClient = await initTestClient();
const aliceCrypto = aliceClient.getCrypto()!;
await aliceClient.startClient();
// tell Alice to trust the dummy device that signed the backup
await waitForDeviceList();
await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID);
fetchMock.get(
"express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id",
testData.CURVE25519_KEY_BACKUP_DATA,
@@ -371,16 +547,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
});
it("Fails on bad recovery key", async function () {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);
aliceClient = await initTestClient();
const aliceCrypto = aliceClient.getCrypto()!;
await aliceClient.startClient();
// tell Alice to trust the dummy device that signed the backup
await waitForDeviceList();
await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID);
const fullBackup = {
rooms: {
[ROOM_ID]: {

View File

@@ -341,7 +341,7 @@ describe("RustCrypto", () => {
let importTotal = 0;
const opt: ImportRoomKeysOpts = {
progressCallback: (stage) => {
importTotal = stage.total;
importTotal = stage.total ?? 0;
},
};
await rustCrypto.importRoomKeys(someRoomKeys, opt);

View File

@@ -209,7 +209,7 @@ import { IgnoredInvites } from "./models/invites-ignorer";
import { UIARequest, UIAResponse } from "./@types/uia";
import { LocalNotificationSettings } from "./@types/local_notifications";
import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature";
import { CryptoBackend } from "./common-crypto/CryptoBackend";
import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants";
import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api";
import { DeviceInfoMap } from "./crypto/DeviceList";
@@ -3905,7 +3905,8 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
}
let totalKeyCount = 0;
let keys: IMegolmSessionData[] = [];
let totalFailures = 0;
let totalImported = 0;
const path = this.makeKeyBackupPath(targetRoomId, targetSessionId, backupInfo.version);
@@ -3941,25 +3942,61 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
{ prefix: ClientPrefix.V3 },
);
if ((res as IRoomsKeysResponse).rooms) {
const rooms = (res as IRoomsKeysResponse).rooms;
for (const [roomId, roomData] of Object.entries(rooms)) {
if (!roomData.sessions) continue;
// We have finished fetching the backup, go to next step
if (progressCallback) {
progressCallback({
stage: "load_keys",
});
}
totalKeyCount += Object.keys(roomData.sessions).length;
const roomKeys = await backupDecryptor.decryptSessions(roomData.sessions);
for (const k of roomKeys) {
k.room_id = roomId;
keys.push(k);
}
}
if ((res as IRoomsKeysResponse).rooms) {
// We have a full backup here, it can get quite big, so we need to decrypt and import it in chunks.
// Get the total count as a first pass
totalKeyCount = this.getTotalKeyCount(res as IRoomsKeysResponse);
// Now decrypt and import the keys in chunks
await this.handleDecryptionOfAFullBackup(
res as IRoomsKeysResponse,
backupDecryptor,
200,
async (chunk) => {
// We have a chunk of decrypted keys: import them
try {
await this.cryptoBackend!.importBackedUpRoomKeys(chunk, {
untrusted,
});
totalImported += chunk.length;
} catch (e) {
totalFailures += chunk.length;
// We failed to import some keys, but we should still try to import the rest?
// Log the error and continue
logger.error("Error importing keys from backup", e);
}
if (progressCallback) {
progressCallback({
total: totalKeyCount,
successes: totalImported,
stage: "load_keys",
failures: totalFailures,
});
}
},
);
} else if ((res as IRoomKeysResponse).sessions) {
// For now we don't chunk for a single room backup, but we could in the future.
// Currently it is not used by the application.
const sessions = (res as IRoomKeysResponse).sessions;
totalKeyCount = Object.keys(sessions).length;
keys = await backupDecryptor.decryptSessions(sessions);
const keys = await backupDecryptor.decryptSessions(sessions);
for (const k of keys) {
k.room_id = targetRoomId!;
}
await this.cryptoBackend.importBackedUpRoomKeys(keys, {
progressCallback,
untrusted,
});
totalImported = keys.length;
} else {
totalKeyCount = 1;
try {
@@ -3968,7 +4005,12 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
});
key.room_id = targetRoomId!;
key.session_id = targetSessionId!;
keys.push(key);
await this.cryptoBackend.importBackedUpRoomKeys([key], {
progressCallback,
untrusted,
});
totalImported = 1;
} catch (e) {
this.logger.debug("Failed to decrypt megolm session from backup", e);
}
@@ -3977,15 +4019,88 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
backupDecryptor.free();
}
await this.cryptoBackend.importBackedUpRoomKeys(keys, {
progressCallback,
untrusted,
});
/// in case entering the passphrase would add a new signature?
await this.cryptoBackend.checkKeyBackupAndEnable();
return { total: totalKeyCount, imported: keys.length };
return { total: totalKeyCount, imported: totalImported };
}
/**
* This method calculates the total number of keys present in the response of a `/room_keys/keys` call.
*
* @param res - The response from the server containing the keys to be counted.
*
* @returns The total number of keys in the backup.
*/
private getTotalKeyCount(res: IRoomsKeysResponse): number {
const rooms = res.rooms;
let totalKeyCount = 0;
for (const roomData of Object.values(rooms)) {
if (!roomData.sessions) continue;
totalKeyCount += Object.keys(roomData.sessions).length;
}
return totalKeyCount;
}
/**
* This method handles the decryption of a full backup, i.e a call to `/room_keys/keys`.
* It will decrypt the keys in chunks and call the `block` callback for each chunk.
*
* @param res - The response from the server containing the keys to be decrypted.
* @param backupDecryptor - An instance of the BackupDecryptor class used to decrypt the keys.
* @param chunkSize - The size of the chunks to be processed at a time.
* @param block - A callback function that is called for each chunk of keys.
*
* @returns A promise that resolves when the decryption is complete.
*/
private async handleDecryptionOfAFullBackup(
res: IRoomsKeysResponse,
backupDecryptor: BackupDecryptor,
chunkSize: number,
block: (chunk: IMegolmSessionData[]) => Promise<void>,
): Promise<void> {
const rooms = (res as IRoomsKeysResponse).rooms;
let groupChunkCount = 0;
let chunkGroupByRoom: Map<string, IKeyBackupRoomSessions> = new Map();
const handleChunkCallback = async (roomChunks: Map<string, IKeyBackupRoomSessions>): Promise<void> => {
const currentChunk: IMegolmSessionData[] = [];
for (const roomId of roomChunks.keys()) {
const decryptedSessions = await backupDecryptor.decryptSessions(roomChunks.get(roomId)!);
for (const sessionId in decryptedSessions) {
const k = decryptedSessions[sessionId];
k.room_id = roomId;
currentChunk.push(k);
}
}
await block(currentChunk);
};
for (const [roomId, roomData] of Object.entries(rooms)) {
if (!roomData.sessions) continue;
chunkGroupByRoom.set(roomId, {});
for (const [sessionId, session] of Object.entries(roomData.sessions)) {
const sessionsForRoom = chunkGroupByRoom.get(roomId)!;
sessionsForRoom[sessionId] = session;
groupChunkCount += 1;
if (groupChunkCount >= chunkSize) {
// We have enough chunks to decrypt
await handleChunkCallback(chunkGroupByRoom);
chunkGroupByRoom = new Map();
// There might be remaining keys for that room, so add back an entry for the current room.
chunkGroupByRoom.set(roomId, {});
groupChunkCount = 0;
}
}
}
// Handle remaining chunk if needed
if (groupChunkCount > 0) {
await handleChunkCallback(chunkGroupByRoom);
}
}
public deleteKeysFromBackup(roomId: undefined, sessionId: undefined, version?: string): Promise<void>;

View File

@@ -586,9 +586,9 @@ export class DeviceVerificationStatus {
*/
export interface ImportRoomKeyProgressData {
stage: string; // TODO: Enum
successes: number;
failures: number;
total: number;
successes?: number;
failures?: number;
total?: number;
}
/**

View File

@@ -15,6 +15,8 @@ limitations under the License.
*/
// Export for backward compatibility
import { ImportRoomKeyProgressData } from "../crypto-api";
export type {
Curve25519AuthData as ICurve25519AuthData,
Aes256AuthData as IAes256AuthData,
@@ -41,5 +43,5 @@ export interface IKeyBackupRestoreResult {
export interface IKeyBackupRestoreOpts {
cacheCompleteCallback?: () => void;
progressCallback?: (progress: { stage: string }) => void;
progressCallback?: (progress: ImportRoomKeyProgressData) => void;
}

View File

@@ -29,7 +29,7 @@ import { logger } from "../logger";
import { ClientPrefix, IHttpOpts, MatrixError, MatrixHttpApi, Method } from "../http-api";
import { CryptoEvent, IMegolmSessionData } from "../crypto";
import { TypedEventEmitter } from "../models/typed-event-emitter";
import { encodeUri, immediate, logDuration } from "../utils";
import { encodeUri, logDuration } from "../utils";
import { OutgoingRequestProcessor } from "./OutgoingRequestProcessor";
import { sleep } from "../utils";
import { BackupDecryptor } from "../common-crypto/CryptoBackend";
@@ -534,9 +534,6 @@ export class RustBackupDecryptor implements BackupDecryptor {
);
decrypted.session_id = sessionId;
keys.push(decrypted);
// there might be lots of sessions, so don't hog the event loop
await immediate();
} catch (e) {
logger.log("Failed to decrypt megolm session from backup", e, sessionData);
}