1
0
mirror of https://github.com/matrix-org/matrix-js-sdk.git synced 2025-07-31 15:24:23 +03:00

Merge branch 'master' into develop

This commit is contained in:
RiotRobot
2023-03-28 14:15:09 +01:00
42 changed files with 718 additions and 559 deletions

View File

@ -1,3 +1,9 @@
Changes in [24.0.0](https://github.com/matrix-org/matrix-js-sdk/releases/tag/v24.0.0) (2023-03-28)
==================================================================================================
## 🐛 Bug Fixes
* Changes for matrix-js-sdk v24.0.0
Changes in [23.5.0](https://github.com/matrix-org/matrix-js-sdk/releases/tag/v23.5.0) (2023-03-15) Changes in [23.5.0](https://github.com/matrix-org/matrix-js-sdk/releases/tag/v23.5.0) (2023-03-15)
================================================================================================== ==================================================================================================

View File

@ -1,6 +1,6 @@
{ {
"name": "matrix-js-sdk", "name": "matrix-js-sdk",
"version": "23.5.0", "version": "24.0.0",
"description": "Matrix Client-Server SDK for Javascript", "description": "Matrix Client-Server SDK for Javascript",
"engines": { "engines": {
"node": ">=16.0.0" "node": ">=16.0.0"

View File

@ -560,7 +560,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager. // if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) { if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
} }
@ -620,7 +620,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager. // if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) { if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
} }
@ -688,7 +688,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager. // if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) { if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
} }
@ -1071,8 +1071,8 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
throw new Error("sendTextMessage succeeded on an unknown device"); throw new Error("sendTextMessage succeeded on an unknown device");
} catch (e) { } catch (e) {
expect((e as any).name).toEqual("UnknownDeviceError"); expect((e as any).name).toEqual("UnknownDeviceError");
expect(Object.keys((e as any).devices)).toEqual([aliceClient.getUserId()!]); expect([...(e as any).devices.keys()]).toEqual([aliceClient.getUserId()!]);
expect(Object.keys((e as any)?.devices[aliceClient.getUserId()!])).toEqual(["DEVICE_ID"]); expect((e as any).devices.get(aliceClient.getUserId()!).has("DEVICE_ID"));
} }
// mark the device as known, and resend. // mark the device as known, and resend.
@ -1140,7 +1140,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager. // if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) { if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
} }
@ -1296,7 +1296,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager. // if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) { if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
} }
@ -1363,7 +1363,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager. // if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) { if (aliceClient.crypto) {
aliceClient.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); aliceClient.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto!.deviceList.getDeviceByIdentityKey = () => device; aliceClient.crypto!.deviceList.getDeviceByIdentityKey = () => device;
aliceClient.crypto!.deviceList.getUserByIdentityKey = () => beccaTestClient.client.getUserId()!; aliceClient.crypto!.deviceList.getUserByIdentityKey = () => beccaTestClient.client.getUserId()!;
} }

View File

@ -603,14 +603,14 @@ describe("MatrixClient", function () {
}); });
const prom = client!.downloadKeys(["boris", "chaz"]).then(function (res) { const prom = client!.downloadKeys(["boris", "chaz"]).then(function (res) {
assertObjectContains(res.boris.dev1, { assertObjectContains(res.get("boris")!.get("dev1")!, {
verified: 0, // DeviceVerification.UNVERIFIED verified: 0, // DeviceVerification.UNVERIFIED
keys: { "ed25519:dev1": ed25519key }, keys: { "ed25519:dev1": ed25519key },
algorithms: ["1"], algorithms: ["1"],
unsigned: { abc: "def" }, unsigned: { abc: "def" },
}); });
assertObjectContains(res.chaz.dev2, { assertObjectContains(res.get("chaz")!.get("dev2")!, {
verified: 0, // DeviceVerification.UNVERIFIED verified: 0, // DeviceVerification.UNVERIFIED
keys: { "ed25519:dev2": ed25519key }, keys: { "ed25519:dev2": ed25519key },
algorithms: ["2"], algorithms: ["2"],

View File

@ -472,7 +472,7 @@ describe("MatrixClient crypto", () => {
aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} }); aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} });
await aliTestClient.start(); await aliTestClient.start();
await bobTestClient.start(); await bobTestClient.start();
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
await firstSync(aliTestClient); await firstSync(aliTestClient);
await aliEnablesEncryption(); await aliEnablesEncryption();
await aliSendsFirstMessage(); await aliSendsFirstMessage();
@ -483,7 +483,7 @@ describe("MatrixClient crypto", () => {
aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} }); aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} });
await aliTestClient.start(); await aliTestClient.start();
await bobTestClient.start(); await bobTestClient.start();
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
await firstSync(aliTestClient); await firstSync(aliTestClient);
await aliEnablesEncryption(); await aliEnablesEncryption();
await aliSendsFirstMessage(); await aliSendsFirstMessage();
@ -545,7 +545,7 @@ describe("MatrixClient crypto", () => {
aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} }); aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} });
await aliTestClient.start(); await aliTestClient.start();
await bobTestClient.start(); await bobTestClient.start();
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
await firstSync(aliTestClient); await firstSync(aliTestClient);
await aliEnablesEncryption(); await aliEnablesEncryption();
await aliSendsFirstMessage(); await aliSendsFirstMessage();

View File

@ -30,6 +30,7 @@ import {
RoomState, RoomState,
RoomStateEvent, RoomStateEvent,
RoomStateEventHandlerMap, RoomStateEventHandlerMap,
SendToDeviceContentMap,
} from "../../src"; } from "../../src";
import { TypedEventEmitter } from "../../src/models/typed-event-emitter"; import { TypedEventEmitter } from "../../src/models/typed-event-emitter";
import { ReEmitter } from "../../src/ReEmitter"; import { ReEmitter } from "../../src/ReEmitter";
@ -443,11 +444,7 @@ export class MockCallMatrixClient extends TypedEventEmitter<EmittedEvents, Emitt
>(); >();
public sendToDevice = jest.fn< public sendToDevice = jest.fn<
Promise<{}>, Promise<{}>,
[ [eventType: string, contentMap: SendToDeviceContentMap, txnId?: string]
eventType: string,
contentMap: { [userId: string]: { [deviceId: string]: Record<string, any> } },
txnId?: string,
]
>(); >();
public isInitialSyncComplete(): boolean { public isInitialSyncComplete(): boolean {

View File

@ -405,7 +405,7 @@ describe("Crypto", function () {
// the first message can't be decrypted yet, but the second one // the first message can't be decrypted yet, but the second one
// can // can
let ksEvent = await keyshareEventForEvent(aliceClient, events[1], 1); let ksEvent = await keyshareEventForEvent(aliceClient, events[1], 1);
bobClient.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); bobClient.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
bobClient.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com"; bobClient.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com";
await bobDecryptor.onRoomKeyEvent(ksEvent); await bobDecryptor.onRoomKeyEvent(ksEvent);
await decryptEventsPromise; await decryptEventsPromise;
@ -1039,7 +1039,7 @@ describe("Crypto", function () {
beforeEach(async () => { beforeEach(async () => {
ensureOlmSessionsForDevices = jest.spyOn(olmlib, "ensureOlmSessionsForDevices"); ensureOlmSessionsForDevices = jest.spyOn(olmlib, "ensureOlmSessionsForDevices");
ensureOlmSessionsForDevices.mockResolvedValue({}); ensureOlmSessionsForDevices.mockResolvedValue(new Map());
encryptMessageForDevice = jest.spyOn(olmlib, "encryptMessageForDevice"); encryptMessageForDevice = jest.spyOn(olmlib, "encryptMessageForDevice");
encryptMessageForDevice.mockImplementation(async (...[result, , , , , , payload]) => { encryptMessageForDevice.mockImplementation(async (...[result, , , , , , payload]) => {
result.plaintext = { type: 0, body: JSON.stringify(payload) }; result.plaintext = { type: 0, body: JSON.stringify(payload) };

View File

@ -34,6 +34,7 @@ import { ClientEvent, MatrixClient, RoomMember } from "../../../../src";
import { DeviceInfo, IDevice } from "../../../../src/crypto/deviceinfo"; import { DeviceInfo, IDevice } from "../../../../src/crypto/deviceinfo";
import { DeviceTrustLevel } from "../../../../src/crypto/CrossSigning"; import { DeviceTrustLevel } from "../../../../src/crypto/CrossSigning";
import { MegolmEncryption as MegolmEncryptionClass } from "../../../../src/crypto/algorithms/megolm"; import { MegolmEncryption as MegolmEncryptionClass } from "../../../../src/crypto/algorithms/megolm";
import { recursiveMapToObject } from "../../../../src/utils";
import { sleep } from "../../../../src/utils"; import { sleep } from "../../../../src/utils";
const MegolmDecryption = algorithms.DECRYPTION_CLASSES.get("m.megolm.v1.aes-sha2")!; const MegolmDecryption = algorithms.DECRYPTION_CLASSES.get("m.megolm.v1.aes-sha2")!;
@ -183,14 +184,22 @@ describe("MegolmDecryption", function () {
const deviceInfo = {} as DeviceInfo; const deviceInfo = {} as DeviceInfo;
mockCrypto.getStoredDevice.mockReturnValue(deviceInfo); mockCrypto.getStoredDevice.mockReturnValue(deviceInfo);
mockOlmLib.ensureOlmSessionsForDevices.mockResolvedValue({ mockOlmLib.ensureOlmSessionsForDevices.mockResolvedValue(
"@alice:foo": { new Map([
alidevice: { [
sessionId: "alisession", "@alice:foo",
device: new DeviceInfo("alidevice"), new Map([
}, [
}, "alidevice",
}); {
sessionId: "alisession",
device: new DeviceInfo("alidevice"),
},
],
]),
],
]),
);
const awaitEncryptForDevice = new Promise<void>((res, rej) => { const awaitEncryptForDevice = new Promise<void>((res, rej) => {
mockOlmLib.encryptMessageForDevice.mockImplementation(() => { mockOlmLib.encryptMessageForDevice.mockImplementation(() => {
@ -357,11 +366,7 @@ describe("MegolmDecryption", function () {
} as unknown as DeviceInfo; } as unknown as DeviceInfo;
mockCrypto.downloadKeys.mockReturnValue( mockCrypto.downloadKeys.mockReturnValue(
Promise.resolve({ Promise.resolve(new Map([["@alice:home.server", new Map([["aliceDevice", aliceDeviceInfo]])]])),
"@alice:home.server": {
aliceDevice: aliceDeviceInfo,
},
}),
); );
mockCrypto.checkDeviceTrust.mockReturnValue({ mockCrypto.checkDeviceTrust.mockReturnValue({
@ -523,23 +528,32 @@ describe("MegolmDecryption", function () {
let megolm: MegolmEncryptionClass; let megolm: MegolmEncryptionClass;
let room: jest.Mocked<Room>; let room: jest.Mocked<Room>;
const deviceMap: DeviceInfoMap = { const deviceMap: DeviceInfoMap = new Map([
"user-a": { [
"device-a": new DeviceInfo("device-a"), "user-a",
"device-b": new DeviceInfo("device-b"), new Map([
"device-c": new DeviceInfo("device-c"), ["device-a", new DeviceInfo("device-a")],
}, ["device-b", new DeviceInfo("device-b")],
"user-b": { ["device-c", new DeviceInfo("device-c")],
"device-d": new DeviceInfo("device-d"), ]),
"device-e": new DeviceInfo("device-e"), ],
"device-f": new DeviceInfo("device-f"), [
}, "user-b",
"user-c": { new Map([
"device-g": new DeviceInfo("device-g"), ["device-d", new DeviceInfo("device-d")],
"device-h": new DeviceInfo("device-h"), ["device-e", new DeviceInfo("device-e")],
"device-i": new DeviceInfo("device-i"), ["device-f", new DeviceInfo("device-f")],
}, ]),
}; ],
[
"user-c",
new Map([
["device-g", new DeviceInfo("device-g")],
["device-h", new DeviceInfo("device-h")],
["device-i", new DeviceInfo("device-i")],
]),
],
]);
beforeEach(() => { beforeEach(() => {
room = testUtils.mock(Room, "Room") as jest.Mocked<Room>; room = testUtils.mock(Room, "Room") as jest.Mocked<Room>;
@ -572,8 +586,8 @@ describe("MegolmDecryption", function () {
//@ts-ignore private member access, gross //@ts-ignore private member access, gross
await megolm.encryptionPreparation?.promise; await megolm.encryptionPreparation?.promise;
for (const userId in deviceMap) { for (const [userId, devices] of deviceMap) {
for (const deviceId in deviceMap[userId]) { for (const deviceId of devices.keys()) {
expect(mockCrypto.checkDeviceTrust).toHaveBeenCalledWith(userId, deviceId); expect(mockCrypto.checkDeviceTrust).toHaveBeenCalledWith(userId, deviceId);
} }
} }
@ -658,20 +672,20 @@ describe("MegolmDecryption", function () {
expect(aliceClient.sendToDevice).toHaveBeenCalled(); expect(aliceClient.sendToDevice).toHaveBeenCalled();
const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0]; const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0];
expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/); expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/);
delete contentMap["@bob:example.com"].bobdevice1.session_id; delete contentMap.get("@bob:example.com")?.get("bobdevice1")?.["session_id"];
delete contentMap["@bob:example.com"].bobdevice1["org.matrix.msgid"]; delete contentMap.get("@bob:example.com")?.get("bobdevice1")?.["org.matrix.msgid"];
delete contentMap["@bob:example.com"].bobdevice2.session_id; delete contentMap.get("@bob:example.com")?.get("bobdevice2")?.["session_id"];
delete contentMap["@bob:example.com"].bobdevice2["org.matrix.msgid"]; delete contentMap.get("@bob:example.com")?.get("bobdevice2")?.["org.matrix.msgid"];
expect(contentMap).toStrictEqual({ expect(recursiveMapToObject(contentMap)).toStrictEqual({
"@bob:example.com": { ["@bob:example.com"]: {
bobdevice1: { ["bobdevice1"]: {
algorithm: "m.megolm.v1.aes-sha2", algorithm: "m.megolm.v1.aes-sha2",
room_id: roomId, room_id: roomId,
code: "m.unverified", code: "m.unverified",
reason: "The sender has disabled encrypting to unverified devices.", reason: "The sender has disabled encrypting to unverified devices.",
sender_key: aliceDevice.deviceCurve25519Key, sender_key: aliceDevice.deviceCurve25519Key,
}, },
bobdevice2: { ["bobdevice2"]: {
algorithm: "m.megolm.v1.aes-sha2", algorithm: "m.megolm.v1.aes-sha2",
room_id: roomId, room_id: roomId,
code: "m.blacklisted", code: "m.blacklisted",
@ -839,10 +853,10 @@ describe("MegolmDecryption", function () {
expect(aliceClient.sendToDevice).toHaveBeenCalled(); expect(aliceClient.sendToDevice).toHaveBeenCalled();
const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0]; const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0];
expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/); expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/);
delete contentMap["@bob:example.com"]["bobdevice"]["org.matrix.msgid"]; delete contentMap.get("@bob:example.com")?.get("bobdevice")?.["org.matrix.msgid"];
expect(contentMap).toStrictEqual({ expect(recursiveMapToObject(contentMap)).toStrictEqual({
"@bob:example.com": { ["@bob:example.com"]: {
bobdevice: { ["bobdevice"]: {
algorithm: "m.megolm.v1.aes-sha2", algorithm: "m.megolm.v1.aes-sha2",
code: "m.no_olm", code: "m.no_olm",
reason: "Unable to establish a secure channel.", reason: "Unable to establish a secure channel.",

View File

@ -146,18 +146,21 @@ describe("OlmDevice", function () {
}); });
}, },
} as unknown as MockedObject<MatrixClient>; } as unknown as MockedObject<MatrixClient>;
const devicesByUser = { const devicesByUser = new Map([
"@bob:example.com": [ [
DeviceInfo.fromStorage( "@bob:example.com",
{ [
keys: { DeviceInfo.fromStorage(
"curve25519:ABCDEFG": "akey", {
keys: {
"curve25519:ABCDEFG": "akey",
},
}, },
}, "ABCDEFG",
"ABCDEFG", ),
), ],
], ],
}; ]);
// start two tasks that try to ensure that there's an olm session // start two tasks that try to ensure that there's an olm session
const promises = Promise.all([ const promises = Promise.all([
@ -218,12 +221,8 @@ describe("OlmDevice", function () {
// There's no required ordering of devices per user, so here we // There's no required ordering of devices per user, so here we
// create two different orderings so that each task reserves a // create two different orderings so that each task reserves a
// device the other task needs before continuing. // device the other task needs before continuing.
const devicesByUserAB = { const devicesByUserAB = new Map([["@bob:example.com", [deviceBobA, deviceBobB]]]);
"@bob:example.com": [deviceBobA, deviceBobB], const devicesByUserBA = new Map([["@bob:example.com", [deviceBobB, deviceBobA]]]);
};
const devicesByUserBA = {
"@bob:example.com": [deviceBobB, deviceBobA],
};
const task1 = alwaysSucceed(olmlib.ensureOlmSessionsForDevices(aliceOlmDevice, baseApis, devicesByUserAB)); const task1 = alwaysSucceed(olmlib.ensureOlmSessionsForDevices(aliceOlmDevice, baseApis, devicesByUserAB));

View File

@ -45,7 +45,7 @@ async function makeTestClient(
await client.initCrypto(); await client.initCrypto();
// No need to download keys for these tests // No need to download keys for these tests
jest.spyOn(client.crypto!, "downloadKeys").mockResolvedValue({}); jest.spyOn(client.crypto!, "downloadKeys").mockResolvedValue(new Map());
return client; return client;
} }
@ -274,7 +274,7 @@ describe("Secrets", function () {
Object.values(otks)[0], Object.values(otks)[0],
); );
osborne2.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); osborne2.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
osborne2.client.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com"; osborne2.client.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com";
const request = await secretStorage.request("foo", ["VAX"]); const request = await secretStorage.request("foo", ["VAX"]);

View File

@ -121,12 +121,12 @@ describe("SAS verification", function () {
alice.client.crypto!.deviceList.storeDevicesForUser("@bob:example.com", BOB_DEVICES); alice.client.crypto!.deviceList.storeDevicesForUser("@bob:example.com", BOB_DEVICES);
alice.client.downloadKeys = () => { alice.client.downloadKeys = () => {
return Promise.resolve({}); return Promise.resolve(new Map());
}; };
bob.client.crypto!.deviceList.storeDevicesForUser("@alice:example.com", ALICE_DEVICES); bob.client.crypto!.deviceList.storeDevicesForUser("@alice:example.com", ALICE_DEVICES);
bob.client.downloadKeys = () => { bob.client.downloadKeys = () => {
return Promise.resolve({}); return Promise.resolve(new Map());
}; };
aliceSasEvent = null; aliceSasEvent = null;
@ -176,6 +176,7 @@ describe("SAS verification", function () {
} }
}); });
}); });
afterEach(async () => { afterEach(async () => {
await Promise.all([alice.stop(), bob.stop()]); await Promise.all([alice.stop(), bob.stop()]);
@ -186,10 +187,14 @@ describe("SAS verification", function () {
let macMethod; let macMethod;
let keyAgreement; let keyAgreement;
const origSendToDevice = bob.client.sendToDevice.bind(bob.client); const origSendToDevice = bob.client.sendToDevice.bind(bob.client);
bob.client.sendToDevice = function (type, map) { bob.client.sendToDevice = async (type, map) => {
if (type === "m.key.verification.accept") { if (type === "m.key.verification.accept") {
macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code; macMethod = map
keyAgreement = map[alice.client.getUserId()!][alice.client.deviceId!].key_agreement_protocol; .get(alice.client.getUserId()!)
?.get(alice.client.deviceId!)?.message_authentication_code;
keyAgreement = map
.get(alice.client.getUserId()!)
?.get(alice.client.deviceId!)?.key_agreement_protocol;
} }
return origSendToDevice(type, map); return origSendToDevice(type, map);
}; };
@ -237,7 +242,7 @@ describe("SAS verification", function () {
// has, since it is the same object. If this does not // has, since it is the same object. If this does not
// happen, the verification will fail due to a hash // happen, the verification will fail due to a hash
// commitment mismatch. // commitment mismatch.
map[bob.client.getUserId()!][bob.client.deviceId!].message_authentication_codes = [ map.get(bob.client.getUserId()!)!.get(bob.client.deviceId!)!.message_authentication_codes = [
"hkdf-hmac-sha256", "hkdf-hmac-sha256",
]; ];
} }
@ -246,7 +251,9 @@ describe("SAS verification", function () {
const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client); const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client);
bob.client.sendToDevice = (type, map) => { bob.client.sendToDevice = (type, map) => {
if (type === "m.key.verification.accept") { if (type === "m.key.verification.accept") {
macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code; macMethod = map
.get(alice.client.getUserId()!)!
.get(alice.client.deviceId!)!.message_authentication_code;
} }
return bobOrigSendToDevice(type, map); return bobOrigSendToDevice(type, map);
}; };
@ -291,14 +298,18 @@ describe("SAS verification", function () {
// has, since it is the same object. If this does not // has, since it is the same object. If this does not
// happen, the verification will fail due to a hash // happen, the verification will fail due to a hash
// commitment mismatch. // commitment mismatch.
map[bob.client.getUserId()!][bob.client.deviceId!].message_authentication_codes = ["hmac-sha256"]; map.get(bob.client.getUserId()!)!.get(bob.client.deviceId!)!.message_authentication_codes = [
"hmac-sha256",
];
} }
return aliceOrigSendToDevice(type, map); return aliceOrigSendToDevice(type, map);
}; };
const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client); const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client);
bob.client.sendToDevice = (type, map) => { bob.client.sendToDevice = (type, map) => {
if (type === "m.key.verification.accept") { if (type === "m.key.verification.accept") {
macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code; macMethod = map
.get(alice.client.getUserId()!)!
.get(alice.client.deviceId!)!.message_authentication_code;
} }
return bobOrigSendToDevice(type, map); return bobOrigSendToDevice(type, map);
}; };
@ -454,7 +465,7 @@ describe("SAS verification", function () {
); );
}; };
alice.client.downloadKeys = () => { alice.client.downloadKeys = () => {
return Promise.resolve({}); return Promise.resolve(new Map());
}; };
bob.client.crypto!.setDeviceVerification = jest.fn(); bob.client.crypto!.setDeviceVerification = jest.fn();
@ -472,7 +483,7 @@ describe("SAS verification", function () {
return "bob+base64+ed25519+key"; return "bob+base64+ed25519+key";
}; };
bob.client.downloadKeys = () => { bob.client.downloadKeys = () => {
return Promise.resolve({}); return Promise.resolve(new Map());
}; };
aliceSasEvent = null; aliceSasEvent = null;

View File

@ -20,7 +20,7 @@ import { IContent, MatrixEvent } from "../../../../src/models/event";
import { IRoomTimelineData } from "../../../../src/models/event-timeline-set"; import { IRoomTimelineData } from "../../../../src/models/event-timeline-set";
import { Room, RoomEvent } from "../../../../src/models/room"; import { Room, RoomEvent } from "../../../../src/models/room";
import { logger } from "../../../../src/logger"; import { logger } from "../../../../src/logger";
import { MatrixClient, ClientEvent, ICreateClientOpts } from "../../../../src/client"; import { MatrixClient, ClientEvent, ICreateClientOpts, SendToDeviceContentMap } from "../../../../src/client";
interface UserInfo { interface UserInfo {
userId: string; userId: string;
@ -36,16 +36,16 @@ export async function makeTestClients(
const clientMap: Record<string, Record<string, MatrixClient>> = {}; const clientMap: Record<string, Record<string, MatrixClient>> = {};
const makeSendToDevice = const makeSendToDevice =
(matrixClient: MatrixClient): MatrixClient["sendToDevice"] => (matrixClient: MatrixClient): MatrixClient["sendToDevice"] =>
async (type, map) => { async (type: string, contentMap: SendToDeviceContentMap) => {
// logger.log(this.getUserId(), "sends", type, map); // logger.log(this.getUserId(), "sends", type, map);
for (const [userId, devMap] of Object.entries(map)) { for (const [userId, deviceMessages] of contentMap) {
if (userId in clientMap) { if (userId in clientMap) {
for (const [deviceId, msg] of Object.entries(devMap)) { for (const [deviceId, message] of deviceMessages) {
if (deviceId in clientMap[userId]) { if (deviceId in clientMap[userId]) {
const event = new MatrixEvent({ const event = new MatrixEvent({
sender: matrixClient.getUserId()!, sender: matrixClient.getUserId()!,
type: type, type: type,
content: msg, content: message,
}); });
const client = clientMap[userId][deviceId]; const client = clientMap[userId][deviceId];
const decryptionPromise = event.isEncrypted() const decryptionPromise = event.isEncrypted()

View File

@ -25,6 +25,7 @@ import { IContent, MatrixEvent } from "../../../../src/models/event";
import { MatrixClient } from "../../../../src/client"; import { MatrixClient } from "../../../../src/client";
import { IVerificationChannel } from "../../../../src/crypto/verification/request/Channel"; import { IVerificationChannel } from "../../../../src/crypto/verification/request/Channel";
import { VerificationBase } from "../../../../src/crypto/verification/Base"; import { VerificationBase } from "../../../../src/crypto/verification/Base";
import { MapWithDefault } from "../../../../src/utils";
type MockClient = MatrixClient & { type MockClient = MatrixClient & {
popEvents: () => MatrixEvent[]; popEvents: () => MatrixEvent[];
@ -33,7 +34,9 @@ type MockClient = MatrixClient & {
function makeMockClient(userId: string, deviceId: string): MockClient { function makeMockClient(userId: string, deviceId: string): MockClient {
let counter = 1; let counter = 1;
let events: MatrixEvent[] = []; let events: MatrixEvent[] = [];
const deviceEvents: Record<string, Record<string, MatrixEvent[]>> = {}; const deviceEvents: MapWithDefault<string, MapWithDefault<string, MatrixEvent[]>> = new MapWithDefault(
() => new MapWithDefault(() => []),
);
return { return {
getUserId() { getUserId() {
return userId; return userId;
@ -58,15 +61,11 @@ function makeMockClient(userId: string, deviceId: string): MockClient {
return Promise.resolve({ event_id: eventId }); return Promise.resolve({ event_id: eventId });
}, },
sendToDevice(type: string, msgMap: Record<string, Record<string, IContent>>) { sendToDevice(type: string, msgMap: Map<string, Map<string, IContent>>) {
for (const userId of Object.keys(msgMap)) { for (const [userId, deviceMessages] of msgMap) {
const deviceMap = msgMap[userId]; for (const [deviceId, content] of deviceMessages) {
for (const deviceId of Object.keys(deviceMap)) {
const content = deviceMap[deviceId];
const event = new MatrixEvent({ content, type }); const event = new MatrixEvent({ content, type });
deviceEvents[userId] = deviceEvents[userId] || {}; deviceEvents.getOrCreate(userId).getOrCreate(deviceId).push(event);
deviceEvents[userId][deviceId] = deviceEvents[userId][deviceId] || [];
deviceEvents[userId][deviceId].push(event);
} }
} }
return Promise.resolve({}); return Promise.resolve({});
@ -79,14 +78,9 @@ function makeMockClient(userId: string, deviceId: string): MockClient {
return e; return e;
}, },
// @ts-ignore special testing fn
popDeviceEvents(userId: string, deviceId: string): MatrixEvent[] { popDeviceEvents(userId: string, deviceId: string): MatrixEvent[] {
const forDevice = deviceEvents[userId]; const result = deviceEvents.get(userId)?.get(deviceId) || [];
const events = forDevice && forDevice[deviceId]; deviceEvents?.get(userId)?.delete(deviceId);
const result = events || [];
if (events) {
delete forDevice[deviceId];
}
return result; return result;
}, },
} as unknown as MockClient; } as unknown as MockClient;

View File

@ -204,9 +204,14 @@ describe("RoomWidgetClient", () => {
}); });
describe("to-device messages", () => { describe("to-device messages", () => {
const unencryptedContentMap = { const unencryptedContentMap = new Map([
"@alice:example.org": { "*": { hello: "alice!" } }, ["@alice:example.org", new Map([["*", { hello: "alice!" }]])],
"@bob:example.org": { bobDesktop: { hello: "bob!" } }, ["@bob:example.org", new Map([["bobDesktop", { hello: "bob!" }]])],
]);
const expectedRequestData = {
["@alice:example.org"]: { ["*"]: { hello: "alice!" } },
["@bob:example.org"]: { ["bobDesktop"]: { hello: "bob!" } },
}; };
it("sends unencrypted (sendToDevice)", async () => { it("sends unencrypted (sendToDevice)", async () => {
@ -214,7 +219,7 @@ describe("RoomWidgetClient", () => {
expect(widgetApi.requestCapabilityToSendToDevice).toHaveBeenCalledWith("org.example.foo"); expect(widgetApi.requestCapabilityToSendToDevice).toHaveBeenCalledWith("org.example.foo");
await client.sendToDevice("org.example.foo", unencryptedContentMap); await client.sendToDevice("org.example.foo", unencryptedContentMap);
expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, unencryptedContentMap); expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, expectedRequestData);
}); });
it("sends unencrypted (queueToDevice)", async () => { it("sends unencrypted (queueToDevice)", async () => {
@ -229,7 +234,7 @@ describe("RoomWidgetClient", () => {
], ],
}; };
await client.queueToDevice(batch); await client.queueToDevice(batch);
expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, unencryptedContentMap); expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, expectedRequestData);
}); });
it("sends encrypted (encryptAndSendToDevices)", async () => { it("sends encrypted (encryptAndSendToDevices)", async () => {

View File

@ -59,7 +59,7 @@ describe("MemoryStore", () => {
await store.deleteAllData(); await store.deleteAllData();
// empty object // empty object
expect(store.accountData).toEqual({}); expect(store.accountData).toEqual(new Map());
}); });
}); });
}); });

View File

@ -24,9 +24,12 @@ import {
lexicographicCompare, lexicographicCompare,
nextString, nextString,
prevString, prevString,
recursiveMapToObject,
simpleRetryOperation, simpleRetryOperation,
stringToBase, stringToBase,
sortEventsByLatestContentTimestamp, sortEventsByLatestContentTimestamp,
safeSet,
MapWithDefault,
} from "../../src/utils"; } from "../../src/utils";
import { logger } from "../../src/logger"; import { logger } from "../../src/logger";
import { mkMessage } from "../test-utils/test-utils"; import { mkMessage } from "../test-utils/test-utils";
@ -606,6 +609,105 @@ describe("utils", function () {
}); });
}); });
describe("recursiveMapToObject", () => {
it.each([
// empty map
{
map: new Map(),
expected: {},
},
// one level map
{
map: new Map<any, any>([
["key1", "value 1"],
["key2", 23],
["key3", undefined],
["key4", null],
["key5", [1, 2, 3]],
]),
expected: { key1: "value 1", key2: 23, key3: undefined, key4: null, key5: [1, 2, 3] },
},
// two level map
{
map: new Map<any, any>([
[
"key1",
new Map<any, any>([
["key1_1", "value 1"],
["key1_2", "value 1.2"],
]),
],
["key2", "value 2"],
]),
expected: { key1: { key1_1: "value 1", key1_2: "value 1.2" }, key2: "value 2" },
},
// multi level map
{
map: new Map<any, any>([
["key1", new Map<any, any>([["key1_1", new Map<any, any>([["key1_1_1", "value 1.1.1"]])]])],
]),
expected: { key1: { key1_1: { key1_1_1: "value 1.1.1" } } },
},
// list of maps
{
map: new Map<any, any>([
[
"key1",
[new Map<any, any>([["key1_1", "value 1.1"]]), new Map<any, any>([["key1_2", "value 1.2"]])],
],
]),
expected: { key1: [{ key1_1: "value 1.1" }, { key1_2: "value 1.2" }] },
},
// map → array → array → map
{
map: new Map<any, any>([["key1", [[new Map<any, any>([["key2", "value 2"]])]]]]),
expected: {
key1: [
[
{
key2: "value 2",
},
],
],
},
},
])("%# should convert the value", ({ map, expected }) => {
expect(recursiveMapToObject(map)).toStrictEqual(expected);
});
});
describe("safeSet", () => {
it("should set a value", () => {
const obj = {};
safeSet(obj, "testProp", "test value");
expect(obj).toEqual({ testProp: "test value" });
});
it.each(["__proto__", "prototype", "constructor"])("should raise an error when setting »%s«", (prop) => {
expect(() => {
safeSet({}, prop, "teset value");
}).toThrow("Trying to modify prototype or constructor");
});
});
describe("MapWithDefault", () => {
it("getOrCreate should create the value if it does not exist", () => {
const newValue = {};
const map = new MapWithDefault(() => newValue);
// undefined before getOrCreate
expect(map.get("test")).toBeUndefined();
expect(map.getOrCreate("test")).toBe(newValue);
// default value after getOrCreate
expect(map.get("test")).toBe(newValue);
// test that it always returns the same value
expect(map.getOrCreate("test")).toBe(newValue);
});
});
describe("sleep", () => { describe("sleep", () => {
it("resolves", async () => { it("resolves", async () => {
await utils.sleep(0); await utils.sleep(0);

View File

@ -688,15 +688,15 @@ describe("Group Call", function () {
expect(client1.sendToDevice.mock.calls[0][0]).toBe("m.call.invite"); expect(client1.sendToDevice.mock.calls[0][0]).toBe("m.call.invite");
const toDeviceCallContent = client1.sendToDevice.mock.calls[0][1]; const toDeviceCallContent = client1.sendToDevice.mock.calls[0][1];
expect(Object.keys(toDeviceCallContent).length).toBe(1); expect(toDeviceCallContent.size).toBe(1);
expect(Object.keys(toDeviceCallContent)[0]).toBe(FAKE_USER_ID_2); expect(toDeviceCallContent.has(FAKE_USER_ID_2)).toBe(true);
const toDeviceBobDevices = toDeviceCallContent[FAKE_USER_ID_2]; const toDeviceBobDevices = toDeviceCallContent.get(FAKE_USER_ID_2);
expect(Object.keys(toDeviceBobDevices).length).toBe(1); expect(toDeviceBobDevices?.size).toBe(1);
expect(Object.keys(toDeviceBobDevices)[0]).toBe(FAKE_DEVICE_ID_2); expect(toDeviceBobDevices?.has(FAKE_DEVICE_ID_2)).toBe(true);
const bobDeviceMessage = toDeviceBobDevices[FAKE_DEVICE_ID_2]; const bobDeviceMessage = toDeviceBobDevices?.get(FAKE_DEVICE_ID_2);
expect(bobDeviceMessage.conf_id).toBe(FAKE_CONF_ID); expect(bobDeviceMessage?.conf_id).toBe(FAKE_CONF_ID);
} finally { } finally {
await Promise.all([groupCall1.leave(), groupCall2.leave()]); await Promise.all([groupCall1.leave(), groupCall2.leave()]);
} }

View File

@ -38,7 +38,7 @@ export interface CachedReceipt {
data: Receipt; data: Receipt;
} }
export type ReceiptCache = { [eventId: string]: CachedReceipt[] }; export type ReceiptCache = Map<string, CachedReceipt[]>;
export interface ReceiptContent { export interface ReceiptContent {
[eventId: string]: { [eventId: string]: {
@ -49,11 +49,8 @@ export interface ReceiptContent {
} }
// We will only hold a synthetic receipt if we do not have a real receipt or the synthetic is newer. // We will only hold a synthetic receipt if we do not have a real receipt or the synthetic is newer.
export type Receipts = { // map: receipt type → user Id → receipt
[receiptType: string]: { export type Receipts = Map<string, Map<string, [real: WrappedReceipt | null, synthetic: WrappedReceipt | null]>>;
[userId: string]: [WrappedReceipt | null, WrappedReceipt | null]; // Pair<real receipt, synthetic receipt> (both nullable)
};
};
export type CachedReceiptStructure = { export type CachedReceiptStructure = {
eventId: string; eventId: string;

View File

@ -21,6 +21,7 @@ import { MatrixError } from "./http-api";
import { IndexedToDeviceBatch, ToDeviceBatch, ToDeviceBatchWithTxnId, ToDevicePayload } from "./models/ToDeviceMessage"; import { IndexedToDeviceBatch, ToDeviceBatch, ToDeviceBatchWithTxnId, ToDevicePayload } from "./models/ToDeviceMessage";
import { MatrixScheduler } from "./scheduler"; import { MatrixScheduler } from "./scheduler";
import { SyncState } from "./sync"; import { SyncState } from "./sync";
import { MapWithDefault } from "./utils";
const MAX_BATCH_SIZE = 20; const MAX_BATCH_SIZE = 20;
@ -122,12 +123,9 @@ export class ToDeviceMessageQueue {
* Attempts to send a batch of to-device messages. * Attempts to send a batch of to-device messages.
*/ */
private async sendBatch(batch: IndexedToDeviceBatch): Promise<void> { private async sendBatch(batch: IndexedToDeviceBatch): Promise<void> {
const contentMap: Record<string, Record<string, ToDevicePayload>> = {}; const contentMap: MapWithDefault<string, Map<string, ToDevicePayload>> = new MapWithDefault(() => new Map());
for (const item of batch.batch) { for (const item of batch.batch) {
if (!contentMap[item.userId]) { contentMap.getOrCreate(item.userId).set(item.deviceId, item.payload);
contentMap[item.userId] = {};
}
contentMap[item.userId][item.deviceId] = item.payload;
} }
logger.info( logger.info(

View File

@ -37,7 +37,7 @@ import { Filter, IFilterDefinition, IRoomEventFilter } from "./filter";
import { CallEventHandlerEvent, CallEventHandler, CallEventHandlerEventHandlerMap } from "./webrtc/callEventHandler"; import { CallEventHandlerEvent, CallEventHandler, CallEventHandlerEventHandlerMap } from "./webrtc/callEventHandler";
import { GroupCallEventHandlerEvent, GroupCallEventHandlerEventHandlerMap } from "./webrtc/groupCallEventHandler"; import { GroupCallEventHandlerEvent, GroupCallEventHandlerEventHandlerMap } from "./webrtc/groupCallEventHandler";
import * as utils from "./utils"; import * as utils from "./utils";
import { replaceParam, QueryDict, sleep } from "./utils"; import { replaceParam, QueryDict, sleep, noUnsafeEventProps } from "./utils";
import { Direction, EventTimeline } from "./models/event-timeline"; import { Direction, EventTimeline } from "./models/event-timeline";
import { IActionsObject, PushProcessor } from "./pushprocessor"; import { IActionsObject, PushProcessor } from "./pushprocessor";
import { AutoDiscovery, AutoDiscoveryAction } from "./autodiscovery"; import { AutoDiscovery, AutoDiscoveryAction } from "./autodiscovery";
@ -79,7 +79,7 @@ import {
VerificationMethod, VerificationMethod,
IRoomKeyRequestBody, IRoomKeyRequestBody,
} from "./crypto"; } from "./crypto";
import { DeviceInfo, IDevice } from "./crypto/deviceinfo"; import { DeviceInfo } from "./crypto/deviceinfo";
import { decodeRecoveryKey } from "./crypto/recoverykey"; import { decodeRecoveryKey } from "./crypto/recoverykey";
import { keyFromAuthData } from "./crypto/key_passphrase"; import { keyFromAuthData } from "./crypto/key_passphrase";
import { User, UserEvent, UserEventHandlerMap } from "./models/user"; import { User, UserEvent, UserEventHandlerMap } from "./models/user";
@ -207,6 +207,7 @@ import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature";
import { CryptoBackend } from "./common-crypto/CryptoBackend"; import { CryptoBackend } from "./common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants"; import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants";
import { CryptoApi } from "./crypto-api"; import { CryptoApi } from "./crypto-api";
import { DeviceInfoMap } from "./crypto/DeviceList";
export type Store = IStore; export type Store = IStore;
@ -511,6 +512,8 @@ enum CrossSigningKeyType {
export type CrossSigningKeys = Record<CrossSigningKeyType, ICrossSigningKey>; export type CrossSigningKeys = Record<CrossSigningKeyType, ICrossSigningKey>;
export type SendToDeviceContentMap = Map<string, Map<string, Record<string, any>>>;
export interface ISignedKey { export interface ISignedKey {
keys: Record<string, string>; keys: Record<string, string>;
signatures: ISignatures; signatures: ISignatures;
@ -2268,7 +2271,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
* *
* @returns A promise which resolves to a map userId-\>deviceId-\>{@link DeviceInfo} * @returns A promise which resolves to a map userId-\>deviceId-\>{@link DeviceInfo}
*/ */
public downloadKeys(userIds: string[], forceDownload?: boolean): Promise<Record<string, Record<string, IDevice>>> { public downloadKeys(userIds: string[], forceDownload?: boolean): Promise<DeviceInfoMap> {
if (!this.crypto) { if (!this.crypto) {
return Promise.reject(new Error("End-to-end encryption disabled")); return Promise.reject(new Error("End-to-end encryption disabled"));
} }
@ -3818,9 +3821,9 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
} }
const deviceInfos = await this.crypto.downloadKeys(userIds); const deviceInfos = await this.crypto.downloadKeys(userIds);
const devicesByUser: Record<string, DeviceInfo[]> = {}; const devicesByUser: Map<string, DeviceInfo[]> = new Map();
for (const [userId, devices] of Object.entries(deviceInfos)) { for (const [userId, devices] of deviceInfos) {
devicesByUser[userId] = Object.values(devices); devicesByUser.set(userId, Array.from(devices.values()));
} }
// XXX: Private member access // XXX: Private member access
@ -6035,6 +6038,8 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
const token = res.next_token; const token = res.next_token;
const matrixEvents: MatrixEvent[] = []; const matrixEvents: MatrixEvent[] = [];
res.notifications = res.notifications.filter(noUnsafeEventProps);
for (let i = 0; i < res.notifications.length; i++) { for (let i = 0; i < res.notifications.length; i++) {
const notification = res.notifications[i]; const notification = res.notifications[i];
const event = this.getEventMapper()(notification.event); const event = this.getEventMapper()(notification.event);
@ -6081,11 +6086,11 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
.then((res) => { .then((res) => {
if (res.state) { if (res.state) {
const roomState = eventTimeline.getState(dir)!; const roomState = eventTimeline.getState(dir)!;
const stateEvents = res.state.map(this.getEventMapper()); const stateEvents = res.state.filter(noUnsafeEventProps).map(this.getEventMapper());
roomState.setUnknownStateEvents(stateEvents); roomState.setUnknownStateEvents(stateEvents);
} }
const token = res.end; const token = res.end;
const matrixEvents = res.chunk.map(this.getEventMapper()); const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(this.getEventMapper());
const timelineSet = eventTimeline.getTimelineSet(); const timelineSet = eventTimeline.getTimelineSet();
timelineSet.addEventsToTimeline(matrixEvents, backwards, eventTimeline, token); timelineSet.addEventsToTimeline(matrixEvents, backwards, eventTimeline, token);
@ -6117,7 +6122,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
}) })
.then(async (res) => { .then(async (res) => {
const mapper = this.getEventMapper(); const mapper = this.getEventMapper();
const matrixEvents = res.chunk.map(mapper); const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(mapper);
// Process latest events first // Process latest events first
for (const event of matrixEvents.slice().reverse()) { for (const event of matrixEvents.slice().reverse()) {
@ -6165,11 +6170,11 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
.then((res) => { .then((res) => {
if (res.state) { if (res.state) {
const roomState = eventTimeline.getState(dir)!; const roomState = eventTimeline.getState(dir)!;
const stateEvents = res.state.map(this.getEventMapper()); const stateEvents = res.state.filter(noUnsafeEventProps).map(this.getEventMapper());
roomState.setUnknownStateEvents(stateEvents); roomState.setUnknownStateEvents(stateEvents);
} }
const token = res.end; const token = res.end;
const matrixEvents = res.chunk.map(this.getEventMapper()); const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(this.getEventMapper());
const timelineSet = eventTimeline.getTimelineSet(); const timelineSet = eventTimeline.getTimelineSet();
const [timelineEvents] = room.partitionThreadedEvents(matrixEvents); const [timelineEvents] = room.partitionThreadedEvents(matrixEvents);
@ -9187,24 +9192,22 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
* supplied. * supplied.
* @returns Promise which resolves: to an empty object `{}` * @returns Promise which resolves: to an empty object `{}`
*/ */
public sendToDevice( public sendToDevice(eventType: string, contentMap: SendToDeviceContentMap, txnId?: string): Promise<{}> {
eventType: string,
contentMap: { [userId: string]: { [deviceId: string]: Record<string, any> } },
txnId?: string,
): Promise<{}> {
const path = utils.encodeUri("/sendToDevice/$eventType/$txnId", { const path = utils.encodeUri("/sendToDevice/$eventType/$txnId", {
$eventType: eventType, $eventType: eventType,
$txnId: txnId ? txnId : this.makeTxnId(), $txnId: txnId ? txnId : this.makeTxnId(),
}); });
const body = { const body = {
messages: contentMap, messages: utils.recursiveMapToObject(contentMap),
}; };
const targets = Object.keys(contentMap).reduce<Record<string, string[]>>((obj, key) => { const targets = new Map<string, string[]>();
obj[key] = Object.keys(contentMap[key]);
return obj; for (const [userId, deviceMessages] of contentMap) {
}, {}); targets.set(userId, Array.from(deviceMessages.keys()));
}
logger.log(`PUT ${path}`, targets); logger.log(`PUT ${path}`, targets);
return this.http.authedRequest(Method.Put, path, undefined, body); return this.http.authedRequest(Method.Put, path, undefined, body);

View File

@ -58,7 +58,8 @@ export enum TrackingStatus {
UpToDate, UpToDate,
} }
export type DeviceInfoMap = Record<string, Record<string, DeviceInfo>>; // user-Id → device-Id → DeviceInfo
export type DeviceInfoMap = Map<string, Map<string, DeviceInfo>>;
type EmittedEvents = CryptoEvent.WillUpdateDevices | CryptoEvent.DevicesUpdated | CryptoEvent.UserCrossSigningUpdated; type EmittedEvents = CryptoEvent.WillUpdateDevices | CryptoEvent.DevicesUpdated | CryptoEvent.UserCrossSigningUpdated;
@ -301,13 +302,13 @@ export class DeviceList extends TypedEventEmitter<EmittedEvents, CryptoEventHand
* @returns userId-\>deviceId-\>{@link DeviceInfo}. * @returns userId-\>deviceId-\>{@link DeviceInfo}.
*/ */
private getDevicesFromStore(userIds: string[]): DeviceInfoMap { private getDevicesFromStore(userIds: string[]): DeviceInfoMap {
const stored: DeviceInfoMap = {}; const stored: DeviceInfoMap = new Map();
userIds.forEach((u) => { userIds.forEach((userId) => {
stored[u] = {}; const deviceMap = new Map();
const devices = this.getStoredDevicesForUser(u) || []; this.getStoredDevicesForUser(userId)?.forEach(function (device) {
devices.forEach(function (dev) { deviceMap.set(device.deviceId, device);
stored[u][dev.deviceId] = dev;
}); });
stored.set(userId, deviceMap);
}); });
return stored; return stored;
} }

View File

@ -61,7 +61,7 @@ export class EncryptionSetupBuilder {
* @param accountData - pre-existing account data, will only be read, not written. * @param accountData - pre-existing account data, will only be read, not written.
* @param delegateCryptoCallbacks - crypto callbacks to delegate to if the key isn't in cache yet * @param delegateCryptoCallbacks - crypto callbacks to delegate to if the key isn't in cache yet
*/ */
public constructor(accountData: Record<string, MatrixEvent>, delegateCryptoCallbacks?: ICryptoCallbacks) { public constructor(accountData: Map<string, MatrixEvent>, delegateCryptoCallbacks?: ICryptoCallbacks) {
this.accountDataClientAdapter = new AccountDataClientAdapter(accountData); this.accountDataClientAdapter = new AccountDataClientAdapter(accountData);
this.crossSigningCallbacks = new CrossSigningCallbacks(); this.crossSigningCallbacks = new CrossSigningCallbacks();
this.ssssCryptoCallbacks = new SSSSCryptoCallbacks(delegateCryptoCallbacks); this.ssssCryptoCallbacks = new SSSSCryptoCallbacks(delegateCryptoCallbacks);
@ -246,7 +246,7 @@ class AccountDataClientAdapter
/** /**
* @param existingValues - existing account data * @param existingValues - existing account data
*/ */
public constructor(private readonly existingValues: Record<string, MatrixEvent>) { public constructor(private readonly existingValues: Map<string, MatrixEvent>) {
super(); super();
} }
@ -265,7 +265,7 @@ class AccountDataClientAdapter
if (modifiedValue) { if (modifiedValue) {
return modifiedValue; return modifiedValue;
} }
const existingValue = this.existingValues[type]; const existingValue = this.existingValues.get(type);
if (existingValue) { if (existingValue) {
return existingValue.getContent(); return existingValue.getContent();
} }

View File

@ -21,6 +21,7 @@ import { MatrixClient } from "../client";
import { IRoomKeyRequestBody, IRoomKeyRequestRecipient } from "./index"; import { IRoomKeyRequestBody, IRoomKeyRequestRecipient } from "./index";
import { CryptoStore, OutgoingRoomKeyRequest } from "./store/base"; import { CryptoStore, OutgoingRoomKeyRequest } from "./store/base";
import { EventType, ToDeviceMessageId } from "../@types/event"; import { EventType, ToDeviceMessageId } from "../@types/event";
import { MapWithDefault } from "../utils";
/** /**
* Internal module. Management of outgoing room key requests. * Internal module. Management of outgoing room key requests.
@ -460,15 +461,13 @@ export class OutgoingRoomKeyRequestManager {
recipients: IRoomKeyRequestRecipient[], recipients: IRoomKeyRequestRecipient[],
txnId?: string, txnId?: string,
): Promise<{}> { ): Promise<{}> {
const contentMap: Record<string, Record<string, Record<string, any>>> = {}; const contentMap = new MapWithDefault<string, Map<string, Record<string, any>>>(() => new Map());
for (const recip of recipients) { for (const recip of recipients) {
if (!contentMap[recip.userId]) { const userDeviceMap = contentMap.getOrCreate(recip.userId);
contentMap[recip.userId] = {}; userDeviceMap.set(recip.deviceId, {
}
contentMap[recip.userId][recip.deviceId] = {
...message, ...message,
[ToDeviceMessageId]: uuidv4(), [ToDeviceMessageId]: uuidv4(),
}; });
} }
return this.baseApis.sendToDevice(EventType.RoomKeyRequest, contentMap, txnId); return this.baseApis.sendToDevice(EventType.RoomKeyRequest, contentMap, txnId);

View File

@ -367,13 +367,11 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
requesting_device_id: this.baseApis.deviceId, requesting_device_id: this.baseApis.deviceId,
request_id: requestId, request_id: requestId,
}; };
const toDevice: Record<string, typeof cancelData> = {}; const toDevice: Map<string, typeof cancelData> = new Map();
for (const device of devices) { for (const device of devices) {
toDevice[device] = cancelData; toDevice.set(device, cancelData);
} }
this.baseApis.sendToDevice("m.secret.request", { this.baseApis.sendToDevice("m.secret.request", new Map([[this.baseApis.getUserId()!, toDevice]]));
[this.baseApis.getUserId()!]: toDevice,
});
// and reject the promise so that anyone waiting on it will be // and reject the promise so that anyone waiting on it will be
// notified // notified
@ -388,14 +386,12 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
request_id: requestId, request_id: requestId,
[ToDeviceMessageId]: uuidv4(), [ToDeviceMessageId]: uuidv4(),
}; };
const toDevice: Record<string, typeof requestData> = {}; const toDevice: Map<string, typeof requestData> = new Map();
for (const device of devices) { for (const device of devices) {
toDevice[device] = requestData; toDevice.set(device, requestData);
} }
logger.info(`Request secret ${name} from ${devices}, id ${requestId}`); logger.info(`Request secret ${name} from ${devices}, id ${requestId}`);
this.baseApis.sendToDevice("m.secret.request", { this.baseApis.sendToDevice("m.secret.request", new Map([[this.baseApis.getUserId()!, toDevice]]));
[this.baseApis.getUserId()!]: toDevice,
});
return { return {
requestId, requestId,
@ -469,9 +465,11 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
ciphertext: {}, ciphertext: {},
[ToDeviceMessageId]: uuidv4(), [ToDeviceMessageId]: uuidv4(),
}; };
await olmlib.ensureOlmSessionsForDevices(this.baseApis.crypto!.olmDevice, this.baseApis, { await olmlib.ensureOlmSessionsForDevices(
[sender]: [this.baseApis.getStoredDevice(sender, deviceId)!], this.baseApis.crypto!.olmDevice,
}); this.baseApis,
new Map([[sender, [this.baseApis.getStoredDevice(sender, deviceId)!]]]),
);
await olmlib.encryptMessageForDevice( await olmlib.encryptMessageForDevice(
encryptedContent.ciphertext, encryptedContent.ciphertext,
this.baseApis.getUserId()!, this.baseApis.getUserId()!,
@ -481,11 +479,7 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
this.baseApis.getStoredDevice(sender, deviceId)!, this.baseApis.getStoredDevice(sender, deviceId)!,
payload, payload,
); );
const contentMap = { const contentMap = new Map([[sender, new Map([[deviceId, encryptedContent]])]]);
[sender]: {
[deviceId]: encryptedContent,
},
};
logger.info(`Sending ${content.name} secret for ${deviceId}`); logger.info(`Sending ${content.name} secret for ${deviceId}`);
this.baseApis.sendToDevice("m.room.encrypted", contentMap); this.baseApis.sendToDevice("m.room.encrypted", contentMap);

View File

@ -26,6 +26,7 @@ import { IContent, MatrixEvent, RoomMember } from "../../matrix";
import { Crypto, IEncryptedContent, IEventDecryptionResult, IncomingRoomKeyRequest } from ".."; import { Crypto, IEncryptedContent, IEventDecryptionResult, IncomingRoomKeyRequest } from "..";
import { DeviceInfo } from "../deviceinfo"; import { DeviceInfo } from "../deviceinfo";
import { IRoomEncryption } from "../RoomList"; import { IRoomEncryption } from "../RoomList";
import { DeviceInfoMap } from "../DeviceList";
/** /**
* Map of registered encryption algorithm classes. A map from string to {@link EncryptionAlgorithm} class * Map of registered encryption algorithm classes. A map from string to {@link EncryptionAlgorithm} class
@ -195,7 +196,7 @@ export abstract class DecryptionAlgorithm {
} }
public onRoomKeyWithheldEvent?(event: MatrixEvent): Promise<void>; public onRoomKeyWithheldEvent?(event: MatrixEvent): Promise<void>;
public sendSharedHistoryInboundSessions?(devicesByUser: Record<string, DeviceInfo[]>): Promise<void>; public sendSharedHistoryInboundSessions?(devicesByUser: Map<string, DeviceInfo[]>): Promise<void>;
} }
/** /**
@ -241,11 +242,7 @@ export class UnknownDeviceError extends Error {
* @param msg - message describing the problem * @param msg - message describing the problem
* @param devices - set of unknown devices per user we're warning about * @param devices - set of unknown devices per user we're warning about
*/ */
public constructor( public constructor(msg: string, public readonly devices: DeviceInfoMap, public event?: MatrixEvent) {
msg: string,
public readonly devices: Record<string, Record<string, object>>,
public event?: MatrixEvent,
) {
super(msg); super(msg);
this.name = "UnknownDeviceError"; this.name = "UnknownDeviceError";
this.devices = devices; this.devices = devices;

View File

@ -43,7 +43,7 @@ import { IMegolmEncryptedContent, IncomingRoomKeyRequest, IEncryptedContent } fr
import { RoomKeyRequestState } from "../OutgoingRoomKeyRequestManager"; import { RoomKeyRequestState } from "../OutgoingRoomKeyRequestManager";
import { OlmGroupSessionExtraData } from "../../@types/crypto"; import { OlmGroupSessionExtraData } from "../../@types/crypto";
import { MatrixError } from "../../http-api"; import { MatrixError } from "../../http-api";
import { immediate } from "../../utils"; import { immediate, MapWithDefault } from "../../utils";
// determine whether the key can be shared with invitees // determine whether the key can be shared with invitees
export function isRoomSharedHistory(room: Room): boolean { export function isRoomSharedHistory(room: Room): boolean {
@ -63,17 +63,27 @@ interface IBlockedDevice {
deviceInfo: DeviceInfo; deviceInfo: DeviceInfo;
} }
interface IBlockedMap { // map user Id → device Id → IBlockedDevice
[userId: string]: { type BlockedMap = Map<string, Map<string, IBlockedDevice>>;
[deviceId: string]: IBlockedDevice;
};
}
export interface IOlmDevice<T = DeviceInfo> { export interface IOlmDevice<T = DeviceInfo> {
userId: string; userId: string;
deviceInfo: T; deviceInfo: T;
} }
/**
* Tests whether an encrypted content has a ciphertext.
* Ciphertext can be a string or object depending on the content type {@link IEncryptedContent}.
*
* @param content - Encrypted content
* @returns true: has ciphertext, else false
*/
const hasCiphertext = (content: IEncryptedContent): boolean => {
return typeof content.ciphertext === "string"
? !!content.ciphertext.length
: !!Object.keys(content.ciphertext).length;
};
/** The result of parsing the an `m.room_key` or `m.forwarded_room_key` to-device event */ /** The result of parsing the an `m.room_key` or `m.forwarded_room_key` to-device event */
interface RoomKey { interface RoomKey {
/** /**
@ -147,8 +157,8 @@ class OutboundSessionInfo {
/** when the session was created (ms since the epoch) */ /** when the session was created (ms since the epoch) */
public creationTime: number; public creationTime: number;
/** devices with which we have shared the session key `userId -> {deviceId -> SharedWithData}` */ /** devices with which we have shared the session key `userId -> {deviceId -> SharedWithData}` */
public sharedWithDevices: Record<string, Record<string, SharedWithData>> = {}; public sharedWithDevices: MapWithDefault<string, Map<string, SharedWithData>> = new MapWithDefault(() => new Map());
public blockedDevicesNotified: Record<string, Record<string, boolean>> = {}; public blockedDevicesNotified: MapWithDefault<string, Map<string, boolean>> = new MapWithDefault(() => new Map());
/** /**
* @param sharedHistory - whether the session can be freely shared with * @param sharedHistory - whether the session can be freely shared with
@ -173,17 +183,11 @@ class OutboundSessionInfo {
} }
public markSharedWithDevice(userId: string, deviceId: string, deviceKey: string, chainIndex: number): void { public markSharedWithDevice(userId: string, deviceId: string, deviceKey: string, chainIndex: number): void {
if (!this.sharedWithDevices[userId]) { this.sharedWithDevices.getOrCreate(userId).set(deviceId, { deviceKey, messageIndex: chainIndex });
this.sharedWithDevices[userId] = {};
}
this.sharedWithDevices[userId][deviceId] = { deviceKey, messageIndex: chainIndex };
} }
public markNotifiedBlockedDevice(userId: string, deviceId: string): void { public markNotifiedBlockedDevice(userId: string, deviceId: string): void {
if (!this.blockedDevicesNotified[userId]) { this.blockedDevicesNotified.getOrCreate(userId).set(deviceId, true);
this.blockedDevicesNotified[userId] = {};
}
this.blockedDevicesNotified[userId][deviceId] = true;
} }
/** /**
@ -196,23 +200,15 @@ class OutboundSessionInfo {
* @returns true if we have shared the session with devices which aren't * @returns true if we have shared the session with devices which aren't
* in devicesInRoom. * in devicesInRoom.
*/ */
public sharedWithTooManyDevices(devicesInRoom: Record<string, Record<string, object>>): boolean { public sharedWithTooManyDevices(devicesInRoom: DeviceInfoMap): boolean {
for (const userId in this.sharedWithDevices) { for (const [userId, devices] of this.sharedWithDevices) {
if (!this.sharedWithDevices.hasOwnProperty(userId)) { if (!devicesInRoom.has(userId)) {
continue;
}
if (!devicesInRoom.hasOwnProperty(userId)) {
logger.log("Starting new megolm session because we shared with " + userId); logger.log("Starting new megolm session because we shared with " + userId);
return true; return true;
} }
for (const deviceId in this.sharedWithDevices[userId]) { for (const [deviceId] of devices) {
if (!this.sharedWithDevices[userId].hasOwnProperty(deviceId)) { if (!devicesInRoom.get(userId)?.get(deviceId)) {
continue;
}
if (!devicesInRoom[userId].hasOwnProperty(deviceId)) {
logger.log("Starting new megolm session because we shared with " + userId + ":" + deviceId); logger.log("Starting new megolm session because we shared with " + userId + ":" + deviceId);
return true; return true;
} }
@ -292,7 +288,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
private async ensureOutboundSession( private async ensureOutboundSession(
room: Room, room: Room,
devicesInRoom: DeviceInfoMap, devicesInRoom: DeviceInfoMap,
blocked: IBlockedMap, blocked: BlockedMap,
singleOlmCreationPhase = false, singleOlmCreationPhase = false,
): Promise<OutboundSessionInfo> { ): Promise<OutboundSessionInfo> {
// takes the previous OutboundSessionInfo, and considers whether to create // takes the previous OutboundSessionInfo, and considers whether to create
@ -360,21 +356,21 @@ export class MegolmEncryption extends EncryptionAlgorithm {
devicesInRoom: DeviceInfoMap, devicesInRoom: DeviceInfoMap,
sharedHistory: boolean, sharedHistory: boolean,
singleOlmCreationPhase: boolean, singleOlmCreationPhase: boolean,
blocked: IBlockedMap, blocked: BlockedMap,
session: OutboundSessionInfo, session: OutboundSessionInfo,
): Promise<void> { ): Promise<void> {
// now check if we need to share with any devices // now check if we need to share with any devices
const shareMap: Record<string, DeviceInfo[]> = {}; const shareMap: Record<string, DeviceInfo[]> = {};
for (const [userId, userDevices] of Object.entries(devicesInRoom)) { for (const [userId, userDevices] of devicesInRoom) {
for (const [deviceId, deviceInfo] of Object.entries(userDevices)) { for (const [deviceId, deviceInfo] of userDevices) {
const key = deviceInfo.getIdentityKey(); const key = deviceInfo.getIdentityKey();
if (key == this.olmDevice.deviceCurve25519Key) { if (key == this.olmDevice.deviceCurve25519Key) {
// don't bother sending to ourself // don't bother sending to ourself
continue; continue;
} }
if (!session.sharedWithDevices[userId] || session.sharedWithDevices[userId][deviceId] === undefined) { if (!session.sharedWithDevices.get(userId)?.get(deviceId)) {
shareMap[userId] = shareMap[userId] || []; shareMap[userId] = shareMap[userId] || [];
shareMap[userId].push(deviceInfo); shareMap[userId].push(deviceInfo);
} }
@ -402,9 +398,9 @@ export class MegolmEncryption extends EncryptionAlgorithm {
await Promise.all([ await Promise.all([
(async (): Promise<void> => { (async (): Promise<void> => {
// share keys with devices that we already have a session for // share keys with devices that we already have a session for
const olmSessionList = Object.entries(olmSessions) const olmSessionList = Array.from(olmSessions.entries())
.map(([userId, sessionsByUser]) => .map(([userId, sessionsByUser]) =>
Object.entries(sessionsByUser).map( Array.from(sessionsByUser.entries()).map(
([deviceId, session]) => `${userId}/${deviceId}: ${session.sessionId}`, ([deviceId, session]) => `${userId}/${deviceId}: ${session.sessionId}`,
), ),
) )
@ -414,7 +410,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
this.prefixedLogger.debug("Shared keys with existing Olm sessions"); this.prefixedLogger.debug("Shared keys with existing Olm sessions");
})(), })(),
(async (): Promise<void> => { (async (): Promise<void> => {
const deviceList = Object.entries(devicesWithoutSession) const deviceList = Array.from(devicesWithoutSession.entries())
.map(([userId, devicesByUser]) => devicesByUser.map((device) => `${userId}/${device.deviceId}`)) .map(([userId, devicesByUser]) => devicesByUser.map((device) => `${userId}/${device.deviceId}`))
.flat(1); .flat(1);
this.prefixedLogger.debug( this.prefixedLogger.debug(
@ -450,7 +446,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
// do this in the background and don't block anything else while we // do this in the background and don't block anything else while we
// do this. We only need to retry users from servers that didn't // do this. We only need to retry users from servers that didn't
// respond the first time. // respond the first time.
const retryDevices: Record<string, DeviceInfo[]> = {}; const retryDevices: MapWithDefault<string, DeviceInfo[]> = new MapWithDefault(() => []);
const failedServerMap = new Set(); const failedServerMap = new Set();
for (const server of failedServers) { for (const server of failedServers) {
failedServerMap.add(server); failedServerMap.add(server);
@ -459,8 +455,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
for (const { userId, deviceInfo } of errorDevices) { for (const { userId, deviceInfo } of errorDevices) {
const userHS = userId.slice(userId.indexOf(":") + 1); const userHS = userId.slice(userId.indexOf(":") + 1);
if (failedServerMap.has(userHS)) { if (failedServerMap.has(userHS)) {
retryDevices[userId] = retryDevices[userId] || []; retryDevices.getOrCreate(userId).push(deviceInfo);
retryDevices[userId].push(deviceInfo);
} else { } else {
// if we aren't going to retry, then handle it // if we aren't going to retry, then handle it
// as a failed device // as a failed device
@ -468,7 +463,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
} }
} }
const retryDeviceList = Object.entries(retryDevices) const retryDeviceList = Array.from(retryDevices.entries())
.map(([userId, devicesByUser]) => .map(([userId, devicesByUser]) =>
devicesByUser.map((device) => `${userId}/${device.deviceId}`), devicesByUser.map((device) => `${userId}/${device.deviceId}`),
) )
@ -493,25 +488,25 @@ export class MegolmEncryption extends EncryptionAlgorithm {
})(), })(),
(async (): Promise<void> => { (async (): Promise<void> => {
this.prefixedLogger.debug( this.prefixedLogger.debug(
`There are ${Object.entries(blocked).length} blocked devices:`, `There are ${blocked.size} blocked devices:`,
Object.entries(blocked) Array.from(blocked.entries())
.map(([userId, blockedByUser]) => .map(([userId, blockedByUser]) =>
Object.entries(blockedByUser).map(([deviceId, _deviceInfo]) => `${userId}/${deviceId}`), Array.from(blockedByUser.entries()).map(
([deviceId, _deviceInfo]) => `${userId}/${deviceId}`,
),
) )
.flat(1), .flat(1),
); );
// also, notify newly blocked devices that they're blocked // also, notify newly blocked devices that they're blocked
const blockedMap: Record<string, Record<string, { device: IBlockedDevice }>> = {}; const blockedMap: MapWithDefault<string, Map<string, { device: IBlockedDevice }>> = new MapWithDefault(
() => new Map(),
);
let blockedCount = 0; let blockedCount = 0;
for (const [userId, userBlockedDevices] of Object.entries(blocked)) { for (const [userId, userBlockedDevices] of blocked) {
for (const [deviceId, device] of Object.entries(userBlockedDevices)) { for (const [deviceId, device] of userBlockedDevices) {
if ( if (session.blockedDevicesNotified.get(userId)?.get(deviceId) === undefined) {
!session.blockedDevicesNotified[userId] || blockedMap.getOrCreate(userId).set(deviceId, { device });
session.blockedDevicesNotified[userId][deviceId] === undefined
) {
blockedMap[userId] = blockedMap[userId] || {};
blockedMap[userId][deviceId] = { device };
blockedCount++; blockedCount++;
} }
} }
@ -520,7 +515,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
if (blockedCount) { if (blockedCount) {
this.prefixedLogger.debug( this.prefixedLogger.debug(
`Notifying ${blockedCount} newly blocked devices:`, `Notifying ${blockedCount} newly blocked devices:`,
Object.entries(blockedMap) Array.from(blockedMap.entries())
.map(([userId, blockedByUser]) => .map(([userId, blockedByUser]) =>
Object.entries(blockedByUser).map(([deviceId, _deviceInfo]) => `${userId}/${deviceId}`), Object.entries(blockedByUser).map(([deviceId, _deviceInfo]) => `${userId}/${deviceId}`),
) )
@ -566,7 +561,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* *
* @internal * @internal
* *
* @param devicemap - the devices that have olm sessions, as returned by * @param deviceMap - the devices that have olm sessions, as returned by
* olmlib.ensureOlmSessionsForDevices. * olmlib.ensureOlmSessionsForDevices.
* @param devicesByUser - a map of user IDs to array of deviceInfo * @param devicesByUser - a map of user IDs to array of deviceInfo
* @param noOlmDevices - an array to fill with devices that don't have * @param noOlmDevices - an array to fill with devices that don't have
@ -576,23 +571,23 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* noOlmDevices is specified, then noOlmDevices will be returned. * noOlmDevices is specified, then noOlmDevices will be returned.
*/ */
private getDevicesWithoutSessions( private getDevicesWithoutSessions(
devicemap: Record<string, Record<string, IOlmSessionResult>>, deviceMap: Map<string, Map<string, IOlmSessionResult>>,
devicesByUser: Record<string, DeviceInfo[]>, devicesByUser: Map<string, DeviceInfo[]>,
noOlmDevices: IOlmDevice[] = [], noOlmDevices: IOlmDevice[] = [],
): IOlmDevice[] { ): IOlmDevice[] {
for (const [userId, devicesToShareWith] of Object.entries(devicesByUser)) { for (const [userId, devicesToShareWith] of devicesByUser) {
const sessionResults = devicemap[userId]; const sessionResults = deviceMap.get(userId);
for (const deviceInfo of devicesToShareWith) { for (const deviceInfo of devicesToShareWith) {
const deviceId = deviceInfo.deviceId; const deviceId = deviceInfo.deviceId;
const sessionResult = sessionResults[deviceId]; const sessionResult = sessionResults?.get(deviceId);
if (!sessionResult.sessionId) { if (!sessionResult?.sessionId) {
// no session with this device, probably because there // no session with this device, probably because there
// were no one-time keys. // were no one-time keys.
noOlmDevices.push({ userId, deviceInfo }); noOlmDevices.push({ userId, deviceInfo });
delete sessionResults[deviceId]; sessionResults?.delete(deviceId);
// ensureOlmSessionsForUsers has already done the logging, // ensureOlmSessionsForUsers has already done the logging,
// so just skip it. // so just skip it.
@ -615,7 +610,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* @returns the blocked devices, split into chunks * @returns the blocked devices, split into chunks
*/ */
private splitDevices<T extends DeviceInfo | IBlockedDevice>( private splitDevices<T extends DeviceInfo | IBlockedDevice>(
devicesByUser: Record<string, Record<string, { device: T }>>, devicesByUser: Map<string, Map<string, { device: T }>>,
): IOlmDevice<T>[][] { ): IOlmDevice<T>[][] {
const maxDevicesPerRequest = 20; const maxDevicesPerRequest = 20;
@ -623,8 +618,8 @@ export class MegolmEncryption extends EncryptionAlgorithm {
let currentSlice: IOlmDevice<T>[] = []; let currentSlice: IOlmDevice<T>[] = [];
const mapSlices = [currentSlice]; const mapSlices = [currentSlice];
for (const [userId, userDevices] of Object.entries(devicesByUser)) { for (const [userId, userDevices] of devicesByUser) {
for (const deviceInfo of Object.values(userDevices)) { for (const deviceInfo of userDevices.values()) {
currentSlice.push({ currentSlice.push({
userId: userId, userId: userId,
deviceInfo: deviceInfo.device, deviceInfo: deviceInfo.device,
@ -702,7 +697,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
userDeviceMap: IOlmDevice<IBlockedDevice>[], userDeviceMap: IOlmDevice<IBlockedDevice>[],
payload: IPayload, payload: IPayload,
): Promise<void> { ): Promise<void> {
const contentMap: Record<string, Record<string, IPayload>> = {}; const contentMap: MapWithDefault<string, Map<string, IPayload>> = new MapWithDefault(() => new Map());
for (const val of userDeviceMap) { for (const val of userDeviceMap) {
const userId = val.userId; const userId = val.userId;
@ -722,17 +717,14 @@ export class MegolmEncryption extends EncryptionAlgorithm {
delete message.session_id; delete message.session_id;
} }
if (!contentMap[userId]) { contentMap.getOrCreate(userId).set(deviceId, message);
contentMap[userId] = {};
}
contentMap[userId][deviceId] = message;
} }
await this.baseApis.sendToDevice("m.room_key.withheld", contentMap); await this.baseApis.sendToDevice("m.room_key.withheld", contentMap);
// record the fact that we notified these blocked devices // record the fact that we notified these blocked devices
for (const userId of Object.keys(contentMap)) { for (const [userId, userDeviceMap] of contentMap) {
for (const deviceId of Object.keys(contentMap[userId])) { for (const deviceId of userDeviceMap.keys()) {
session.markNotifiedBlockedDevice(userId, deviceId); session.markNotifiedBlockedDevice(userId, deviceId);
} }
} }
@ -760,11 +752,11 @@ export class MegolmEncryption extends EncryptionAlgorithm {
} }
// The chain index of the key we previously sent this device // The chain index of the key we previously sent this device
if (obSessionInfo.sharedWithDevices[userId] === undefined) { if (!obSessionInfo.sharedWithDevices.has(userId)) {
this.prefixedLogger.debug(`megolm session ${senderKey}|${sessionId} never shared with user ${userId}`); this.prefixedLogger.debug(`megolm session ${senderKey}|${sessionId} never shared with user ${userId}`);
return; return;
} }
const sessionSharedData = obSessionInfo.sharedWithDevices[userId][device.deviceId]; const sessionSharedData = obSessionInfo.sharedWithDevices.get(userId)?.get(device.deviceId);
if (sessionSharedData === undefined) { if (sessionSharedData === undefined) {
this.prefixedLogger.debug( this.prefixedLogger.debug(
`megolm session ${senderKey}|${sessionId} never shared with device ${userId}:${device.deviceId}`, `megolm session ${senderKey}|${sessionId} never shared with device ${userId}:${device.deviceId}`,
@ -796,9 +788,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
return; return;
} }
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[userId, [device]]]));
[userId]: [device],
});
const payload = { const payload = {
type: "m.forwarded_room_key", type: "m.forwarded_room_key",
@ -831,11 +821,10 @@ export class MegolmEncryption extends EncryptionAlgorithm {
payload, payload,
); );
await this.baseApis.sendToDevice("m.room.encrypted", { await this.baseApis.sendToDevice(
[userId]: { "m.room.encrypted",
[device.deviceId]: encryptedContent, new Map([[userId, new Map([[device.deviceId, encryptedContent]])]]),
}, );
});
this.prefixedLogger.debug( this.prefixedLogger.debug(
`Re-shared key for megolm session ${senderKey}|${sessionId} with ${userId}:${device.deviceId}`, `Re-shared key for megolm session ${senderKey}|${sessionId} with ${userId}:${device.deviceId}`,
); );
@ -865,7 +854,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
session: OutboundSessionInfo, session: OutboundSessionInfo,
key: IOutboundGroupSessionKey, key: IOutboundGroupSessionKey,
payload: IPayload, payload: IPayload,
devicesByUser: Record<string, DeviceInfo[]>, devicesByUser: Map<string, DeviceInfo[]>,
errorDevices: IOlmDevice[], errorDevices: IOlmDevice[],
otkTimeout: number, otkTimeout: number,
failedServers?: string[], failedServers?: string[],
@ -887,9 +876,9 @@ export class MegolmEncryption extends EncryptionAlgorithm {
session: OutboundSessionInfo, session: OutboundSessionInfo,
key: IOutboundGroupSessionKey, key: IOutboundGroupSessionKey,
payload: IPayload, payload: IPayload,
devicemap: Record<string, Record<string, IOlmSessionResult>>, deviceMap: Map<string, Map<string, IOlmSessionResult>>,
): Promise<void> { ): Promise<void> {
const userDeviceMaps = this.splitDevices(devicemap); const userDeviceMaps = this.splitDevices(deviceMap);
for (let i = 0; i < userDeviceMaps.length; i++) { for (let i = 0; i < userDeviceMaps.length; i++) {
const taskDetail = `megolm keys for ${session.sessionId} (slice ${i + 1}/${userDeviceMaps.length})`; const taskDetail = `megolm keys for ${session.sessionId} (slice ${i + 1}/${userDeviceMaps.length})`;
@ -934,19 +923,20 @@ export class MegolmEncryption extends EncryptionAlgorithm {
this.prefixedLogger.debug( this.prefixedLogger.debug(
`Need to notify ${unnotifiedFailedDevices.length} failed devices which haven't been notified before`, `Need to notify ${unnotifiedFailedDevices.length} failed devices which haven't been notified before`,
); );
const blockedMap: Record<string, Record<string, { device: IBlockedDevice }>> = {}; const blockedMap: MapWithDefault<string, Map<string, { device: IBlockedDevice }>> = new MapWithDefault(
() => new Map(),
);
for (const { userId, deviceInfo } of unnotifiedFailedDevices) { for (const { userId, deviceInfo } of unnotifiedFailedDevices) {
blockedMap[userId] = blockedMap[userId] || {};
// we use a similar format to what // we use a similar format to what
// olmlib.ensureOlmSessionsForDevices returns, so that // olmlib.ensureOlmSessionsForDevices returns, so that
// we can use the same function to split // we can use the same function to split
blockedMap[userId][deviceInfo.deviceId] = { blockedMap.getOrCreate(userId).set(deviceInfo.deviceId, {
device: { device: {
code: "m.no_olm", code: "m.no_olm",
reason: WITHHELD_MESSAGES["m.no_olm"], reason: WITHHELD_MESSAGES["m.no_olm"],
deviceInfo, deviceInfo,
}, },
}; });
} }
// send the notifications // send the notifications
@ -964,7 +954,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
*/ */
private async notifyBlockedDevices( private async notifyBlockedDevices(
session: OutboundSessionInfo, session: OutboundSessionInfo,
devicesByUser: Record<string, Record<string, { device: IBlockedDevice }>>, devicesByUser: Map<string, Map<string, { device: IBlockedDevice }>>,
): Promise<void> { ): Promise<void> {
const payload: IPayload = { const payload: IPayload = {
room_id: this.roomId, room_id: this.roomId,
@ -1154,21 +1144,17 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* devices we should shared the session with. * devices we should shared the session with.
*/ */
private checkForUnknownDevices(devicesInRoom: DeviceInfoMap): void { private checkForUnknownDevices(devicesInRoom: DeviceInfoMap): void {
const unknownDevices: Record<string, Record<string, DeviceInfo>> = {}; const unknownDevices: MapWithDefault<string, Map<string, DeviceInfo>> = new MapWithDefault(() => new Map());
Object.keys(devicesInRoom).forEach((userId) => { for (const [userId, userDevices] of devicesInRoom) {
Object.keys(devicesInRoom[userId]).forEach((deviceId) => { for (const [deviceId, device] of userDevices) {
const device = devicesInRoom[userId][deviceId];
if (device.isUnverified() && !device.isKnown()) { if (device.isUnverified() && !device.isKnown()) {
if (!unknownDevices[userId]) { unknownDevices.getOrCreate(userId).set(deviceId, device);
unknownDevices[userId] = {};
}
unknownDevices[userId][deviceId] = device;
} }
}); }
}); }
if (Object.keys(unknownDevices).length) { if (unknownDevices.size) {
// it'd be kind to pass unknownDevices up to the user in this error // it'd be kind to pass unknownDevices up to the user in this error
throw new UnknownDeviceError( throw new UnknownDeviceError(
"This room contains unknown devices which have not been verified. " + "This room contains unknown devices which have not been verified. " +
@ -1186,15 +1172,15 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* devices we should shared the session with. * devices we should shared the session with.
*/ */
private removeUnknownDevices(devicesInRoom: DeviceInfoMap): void { private removeUnknownDevices(devicesInRoom: DeviceInfoMap): void {
for (const [userId, userDevices] of Object.entries(devicesInRoom)) { for (const [userId, userDevices] of devicesInRoom) {
for (const [deviceId, device] of Object.entries(userDevices)) { for (const [deviceId, device] of userDevices) {
if (device.isUnverified() && !device.isKnown()) { if (device.isUnverified() && !device.isKnown()) {
delete userDevices[deviceId]; userDevices.delete(deviceId);
} }
} }
if (Object.keys(userDevices).length === 0) { if (userDevices.size === 0) {
delete devicesInRoom[userId]; devicesInRoom.delete(userId);
} }
} }
} }
@ -1219,17 +1205,17 @@ export class MegolmEncryption extends EncryptionAlgorithm {
private async getDevicesInRoom( private async getDevicesInRoom(
room: Room, room: Room,
forceDistributeToUnverified?: boolean, forceDistributeToUnverified?: boolean,
): Promise<[DeviceInfoMap, IBlockedMap]>; ): Promise<[DeviceInfoMap, BlockedMap]>;
private async getDevicesInRoom( private async getDevicesInRoom(
room: Room, room: Room,
forceDistributeToUnverified?: boolean, forceDistributeToUnverified?: boolean,
isCancelled?: () => boolean, isCancelled?: () => boolean,
): Promise<null | [DeviceInfoMap, IBlockedMap]>; ): Promise<null | [DeviceInfoMap, BlockedMap]>;
private async getDevicesInRoom( private async getDevicesInRoom(
room: Room, room: Room,
forceDistributeToUnverified = false, forceDistributeToUnverified = false,
isCancelled?: () => boolean, isCancelled?: () => boolean,
): Promise<null | [DeviceInfoMap, IBlockedMap]> { ): Promise<null | [DeviceInfoMap, BlockedMap]> {
const members = await room.getEncryptionTargetMembers(); const members = await room.getEncryptionTargetMembers();
this.prefixedLogger.debug( this.prefixedLogger.debug(
`Encrypting for users (shouldEncryptForInvitedMembers: ${room.shouldEncryptForInvitedMembers()}):`, `Encrypting for users (shouldEncryptForInvitedMembers: ${room.shouldEncryptForInvitedMembers()}):`,
@ -1254,24 +1240,15 @@ export class MegolmEncryption extends EncryptionAlgorithm {
// using all the device_lists changes and left fields. // using all the device_lists changes and left fields.
// See https://github.com/vector-im/element-web/issues/2305 for details. // See https://github.com/vector-im/element-web/issues/2305 for details.
const devices = await this.crypto.downloadKeys(roomMembers, false); const devices = await this.crypto.downloadKeys(roomMembers, false);
const blocked: IBlockedMap = {};
if (isCancelled?.() === true) { if (isCancelled?.() === true) {
return null; return null;
} }
const blocked = new MapWithDefault<string, Map<string, IBlockedDevice>>(() => new Map());
// remove any blocked devices // remove any blocked devices
for (const userId in devices) { for (const [userId, userDevices] of devices) {
if (!devices.hasOwnProperty(userId)) { for (const [deviceId, userDevice] of userDevices) {
continue;
}
const userDevices = devices[userId];
for (const deviceId in userDevices) {
if (!userDevices.hasOwnProperty(deviceId)) {
continue;
}
// Yield prior to checking each device so that we don't block // Yield prior to checking each device so that we don't block
// updating/rendering for too long. // updating/rendering for too long.
// See https://github.com/vector-im/element-web/issues/21612 // See https://github.com/vector-im/element-web/issues/21612
@ -1280,19 +1257,17 @@ export class MegolmEncryption extends EncryptionAlgorithm {
const deviceTrust = this.crypto.checkDeviceTrust(userId, deviceId); const deviceTrust = this.crypto.checkDeviceTrust(userId, deviceId);
if ( if (
userDevices[deviceId].isBlocked() || userDevice.isBlocked() ||
(!deviceTrust.isVerified() && isBlacklisting && !forceDistributeToUnverified) (!deviceTrust.isVerified() && isBlacklisting && !forceDistributeToUnverified)
) { ) {
if (!blocked[userId]) { const blockedDevices = blocked.getOrCreate(userId);
blocked[userId] = {}; const isBlocked = userDevice.isBlocked();
} blockedDevices.set(deviceId, {
const isBlocked = userDevices[deviceId].isBlocked();
blocked[userId][deviceId] = {
code: isBlocked ? "m.blacklisted" : "m.unverified", code: isBlocked ? "m.blacklisted" : "m.unverified",
reason: WITHHELD_MESSAGES[isBlocked ? "m.blacklisted" : "m.unverified"], reason: WITHHELD_MESSAGES[isBlocked ? "m.blacklisted" : "m.unverified"],
deviceInfo: userDevices[deviceId], deviceInfo: userDevice,
}; });
delete userDevices[deviceId]; userDevices.delete(deviceId);
} }
} }
} }
@ -1923,7 +1898,7 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// XXX: switch this to use encryptAndSendToDevices() rather than duplicating it? // XXX: switch this to use encryptAndSendToDevices() rather than duplicating it?
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { [sender]: [device] }, false); await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[sender, [device]]]), false);
const encryptedContent: IEncryptedContent = { const encryptedContent: IEncryptedContent = {
algorithm: olmlib.OLM_ALGORITHM, algorithm: olmlib.OLM_ALGORITHM,
sender_key: this.olmDevice.deviceCurve25519Key!, sender_key: this.olmDevice.deviceCurve25519Key!,
@ -1942,11 +1917,10 @@ export class MegolmDecryption extends DecryptionAlgorithm {
await this.olmDevice.recordSessionProblem(senderKey, "no_olm", true); await this.olmDevice.recordSessionProblem(senderKey, "no_olm", true);
await this.baseApis.sendToDevice("m.room.encrypted", { await this.baseApis.sendToDevice(
[sender]: { "m.room.encrypted",
[device.deviceId]: encryptedContent, new Map([[sender, new Map([[device.deviceId, encryptedContent]])]]),
}, );
});
} }
public hasKeysForKeyRequest(keyRequest: IncomingRoomKeyRequest): Promise<boolean> { public hasKeysForKeyRequest(keyRequest: IncomingRoomKeyRequest): Promise<boolean> {
@ -1969,12 +1943,10 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// XXX: switch this to use encryptAndSendToDevices()? // XXX: switch this to use encryptAndSendToDevices()?
this.olmlib this.olmlib
.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { .ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[userId, [deviceInfo]]]))
[userId]: [deviceInfo],
})
.then((devicemap) => { .then((devicemap) => {
const olmSessionResult = devicemap[userId][deviceId]; const olmSessionResult = devicemap.get(userId)?.get(deviceId);
if (!olmSessionResult.sessionId) { if (!olmSessionResult?.sessionId) {
// no session with this device, probably because there // no session with this device, probably because there
// were no one-time keys. // were no one-time keys.
// //
@ -2015,14 +1987,11 @@ export class MegolmDecryption extends DecryptionAlgorithm {
payload!, payload!,
) )
.then(() => { .then(() => {
const contentMap = {
[userId]: {
[deviceId]: encryptedContent,
},
};
// TODO: retries // TODO: retries
return this.baseApis.sendToDevice("m.room.encrypted", contentMap); return this.baseApis.sendToDevice(
"m.room.encrypted",
new Map([[userId, new Map([[deviceId, encryptedContent]])]]),
);
}); });
}); });
} }
@ -2162,12 +2131,12 @@ export class MegolmDecryption extends DecryptionAlgorithm {
return !this.pendingEvents.has(senderKey); return !this.pendingEvents.has(senderKey);
} }
public async sendSharedHistoryInboundSessions(devicesByUser: Record<string, DeviceInfo[]>): Promise<void> { public async sendSharedHistoryInboundSessions(devicesByUser: Map<string, DeviceInfo[]>): Promise<void> {
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser); await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser);
const sharedHistorySessions = await this.olmDevice.getSharedHistoryInboundGroupSessions(this.roomId); const sharedHistorySessions = await this.olmDevice.getSharedHistoryInboundGroupSessions(this.roomId);
this.prefixedLogger.log( this.prefixedLogger.log(
`Sharing history in with users ${Object.keys(devicesByUser)}`, `Sharing history in with users ${Array.from(devicesByUser.keys())}`,
sharedHistorySessions.map(([senderKey, sessionId]) => `${senderKey}|${sessionId}`), sharedHistorySessions.map(([senderKey, sessionId]) => `${senderKey}|${sessionId}`),
); );
for (const [senderKey, sessionId] of sharedHistorySessions) { for (const [senderKey, sessionId] of sharedHistorySessions) {
@ -2175,9 +2144,10 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// FIXME: use encryptAndSendToDevices() rather than duplicating it here. // FIXME: use encryptAndSendToDevices() rather than duplicating it here.
const promises: Promise<unknown>[] = []; const promises: Promise<unknown>[] = [];
const contentMap: Record<string, Record<string, IEncryptedContent>> = {}; const contentMap: Map<string, Map<string, IEncryptedContent>> = new Map();
for (const [userId, devices] of Object.entries(devicesByUser)) { for (const [userId, devices] of devicesByUser) {
contentMap[userId] = {}; const deviceMessages = new Map();
contentMap.set(userId, deviceMessages);
for (const deviceInfo of devices) { for (const deviceInfo of devices) {
const encryptedContent: IEncryptedContent = { const encryptedContent: IEncryptedContent = {
algorithm: olmlib.OLM_ALGORITHM, algorithm: olmlib.OLM_ALGORITHM,
@ -2185,7 +2155,7 @@ export class MegolmDecryption extends DecryptionAlgorithm {
ciphertext: {}, ciphertext: {},
[ToDeviceMessageId]: uuidv4(), [ToDeviceMessageId]: uuidv4(),
}; };
contentMap[userId][deviceInfo.deviceId] = encryptedContent; deviceMessages.set(deviceInfo.deviceId, encryptedContent);
promises.push( promises.push(
olmlib.encryptMessageForDevice( olmlib.encryptMessageForDevice(
encryptedContent.ciphertext, encryptedContent.ciphertext,
@ -2205,22 +2175,22 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// in which case it will have just not added anything to the ciphertext object. // in which case it will have just not added anything to the ciphertext object.
// There's no point sending messages to devices if we couldn't encrypt to them, // There's no point sending messages to devices if we couldn't encrypt to them,
// since that's effectively a blank message. // since that's effectively a blank message.
for (const userId of Object.keys(contentMap)) { for (const [userId, deviceMessages] of contentMap) {
for (const deviceId of Object.keys(contentMap[userId])) { for (const [deviceId, content] of deviceMessages) {
if (Object.keys(contentMap[userId][deviceId].ciphertext).length === 0) { if (!hasCiphertext(content)) {
this.prefixedLogger.log("No ciphertext for device " + userId + ":" + deviceId + ": pruning"); this.prefixedLogger.log("No ciphertext for device " + userId + ":" + deviceId + ": pruning");
delete contentMap[userId][deviceId]; deviceMessages.delete(deviceId);
} }
} }
// No devices left for that user? Strip that too. // No devices left for that user? Strip that too.
if (Object.keys(contentMap[userId]).length === 0) { if (deviceMessages.size === 0) {
this.prefixedLogger.log("Pruned all devices for user " + userId); this.prefixedLogger.log("Pruned all devices for user " + userId);
delete contentMap[userId]; contentMap.delete(userId);
} }
} }
// Is there anything left? // Is there anything left?
if (Object.keys(contentMap).length === 0) { if (contentMap.size === 0) {
this.prefixedLogger.log("No users left to send to: aborting"); this.prefixedLogger.log("No users left to send to: aborting");
return; return;
} }

View File

@ -25,7 +25,7 @@ import { MEGOLM_ALGORITHM, verifySignature } from "./olmlib";
import { DeviceInfo } from "./deviceinfo"; import { DeviceInfo } from "./deviceinfo";
import { DeviceTrustLevel } from "./CrossSigning"; import { DeviceTrustLevel } from "./CrossSigning";
import { keyFromPassphrase } from "./key_passphrase"; import { keyFromPassphrase } from "./key_passphrase";
import { sleep } from "../utils"; import { safeSet, sleep } from "../utils";
import { IndexedDBCryptoStore } from "./store/indexeddb-crypto-store"; import { IndexedDBCryptoStore } from "./store/indexeddb-crypto-store";
import { encodeRecoveryKey } from "./recoverykey"; import { encodeRecoveryKey } from "./recoverykey";
import { calculateKeyCheck, decryptAES, encryptAES, IEncryptedPayload } from "./aes"; import { calculateKeyCheck, decryptAES, encryptAES, IEncryptedPayload } from "./aes";
@ -498,9 +498,7 @@ export class BackupManager {
const rooms: IKeyBackup["rooms"] = {}; const rooms: IKeyBackup["rooms"] = {};
for (const session of sessions) { for (const session of sessions) {
const roomId = session.sessionData!.room_id; const roomId = session.sessionData!.room_id;
if (rooms[roomId] === undefined) { safeSet(rooms, roomId, rooms[roomId] || { sessions: {} });
rooms[roomId] = { sessions: {} };
}
const sessionData = this.baseApis.crypto!.olmDevice.exportInboundGroupSession( const sessionData = this.baseApis.crypto!.olmDevice.exportInboundGroupSession(
session.senderKey, session.senderKey,
@ -517,12 +515,12 @@ export class BackupManager {
undefined; undefined;
const verified = this.baseApis.crypto!.checkDeviceInfoTrust(userId!, device).isVerified(); const verified = this.baseApis.crypto!.checkDeviceInfoTrust(userId!, device).isVerified();
rooms[roomId]["sessions"][session.sessionId] = { safeSet(rooms[roomId]["sessions"], session.sessionId, {
first_message_index: sessionData.first_known_index, first_message_index: sessionData.first_known_index,
forwarded_count: forwardedCount, forwarded_count: forwardedCount,
is_verified: verified, is_verified: verified,
session_data: await this.algorithm!.encryptSession(sessionData), session_data: await this.algorithm!.encryptSession(sessionData),
}; });
} }
await this.baseApis.sendKeyBackup(undefined, undefined, this.backupInfo!.version, { rooms }); await this.baseApis.sendKeyBackup(undefined, undefined, this.backupInfo!.version, { rooms });

View File

@ -90,6 +90,7 @@ import { ISignatures } from "../@types/signed";
import { IMessage } from "./algorithms/olm"; import { IMessage } from "./algorithms/olm";
import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend"; import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend";
import { RoomState, RoomStateEvent } from "../models/room-state"; import { RoomState, RoomStateEvent } from "../models/room-state";
import { MapWithDefault, recursiveMapToObject } from "../utils";
const DeviceVerification = DeviceInfo.DeviceVerification; const DeviceVerification = DeviceInfo.DeviceVerification;
@ -399,7 +400,10 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
// deviceId: 1234567890000, // deviceId: 1234567890000,
// }, // },
// } // }
private lastNewSessionForced: Record<string, Record<string, number>> = {}; // Map: user Id → device Id → timestamp
private lastNewSessionForced: MapWithDefault<string, MapWithDefault<string, number>> = new MapWithDefault(
() => new MapWithDefault(() => 0),
);
// This flag will be unset whilst the client processes a sync response // This flag will be unset whilst the client processes a sync response
// so that we don't start requesting keys until we've actually finished // so that we don't start requesting keys until we've actually finished
@ -2690,11 +2694,13 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
public ensureOlmSessionsForUsers( public ensureOlmSessionsForUsers(
users: string[], users: string[],
force?: boolean, force?: boolean,
): Promise<Record<string, Record<string, olmlib.IOlmSessionResult>>> { ): Promise<Map<string, Map<string, olmlib.IOlmSessionResult>>> {
const devicesByUser: Record<string, DeviceInfo[]> = {}; // map user Id → DeviceInfo[]
const devicesByUser: Map<string, DeviceInfo[]> = new Map();
for (const userId of users) { for (const userId of users) {
devicesByUser[userId] = []; const userDevices: DeviceInfo[] = [];
devicesByUser.set(userId, userDevices);
const devices = this.getStoredDevicesForUser(userId) || []; const devices = this.getStoredDevicesForUser(userId) || [];
for (const deviceInfo of devices) { for (const deviceInfo of devices) {
@ -2708,7 +2714,7 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
continue; continue;
} }
devicesByUser[userId].push(deviceInfo); userDevices.push(deviceInfo);
} }
} }
@ -3146,7 +3152,11 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
payload: encryptedContent, payload: encryptedContent,
}); });
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { [userId]: [deviceInfo] }); await olmlib.ensureOlmSessionsForDevices(
this.olmDevice,
this.baseApis,
new Map([[userId, [deviceInfo]]]),
);
await olmlib.encryptMessageForDevice( await olmlib.encryptMessageForDevice(
encryptedContent.ciphertext, encryptedContent.ciphertext,
this.userId, this.userId,
@ -3459,8 +3469,8 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
// check when we last forced a new session with this device: if we've already done so // check when we last forced a new session with this device: if we've already done so
// recently, don't do it again. // recently, don't do it again.
this.lastNewSessionForced[sender] = this.lastNewSessionForced[sender] || {}; const lastNewSessionDevices = this.lastNewSessionForced.getOrCreate(sender);
const lastNewSessionForced = this.lastNewSessionForced[sender][deviceKey] || 0; const lastNewSessionForced = lastNewSessionDevices.getOrCreate(deviceKey);
if (lastNewSessionForced + MIN_FORCE_SESSION_INTERVAL_MS > Date.now()) { if (lastNewSessionForced + MIN_FORCE_SESSION_INTERVAL_MS > Date.now()) {
logger.debug( logger.debug(
"New session already forced with device " + "New session already forced with device " +
@ -3493,11 +3503,10 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
return; return;
} }
} }
const devicesByUser: Record<string, DeviceInfo[]> = {}; const devicesByUser = new Map([[sender, [device]]]);
devicesByUser[sender] = [device];
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser, true); await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser, true);
this.lastNewSessionForced[sender][deviceKey] = Date.now(); lastNewSessionDevices.set(deviceKey, Date.now());
// Now send a blank message on that session so the other side knows about it. // Now send a blank message on that session so the other side knows about it.
// (The keyshare request is sent in the clear so that won't do) // (The keyshare request is sent in the clear so that won't do)
@ -3524,11 +3533,10 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
await this.olmDevice.recordSessionProblem(deviceKey, "wedged", true); await this.olmDevice.recordSessionProblem(deviceKey, "wedged", true);
retryDecryption(); retryDecryption();
await this.baseApis.sendToDevice("m.room.encrypted", { await this.baseApis.sendToDevice(
[sender]: { "m.room.encrypted",
[device.deviceId]: encryptedContent, new Map([[sender, new Map([[device.deviceId, encryptedContent]])]]),
}, );
});
// Most of the time this probably won't be necessary since we'll have queued up a key request when // Most of the time this probably won't be necessary since we'll have queued up a key request when
// we failed to decrypt the message and will be waiting a bit for the key to arrive before sending // we failed to decrypt the message and will be waiting a bit for the key to arrive before sending
@ -3835,15 +3843,16 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
* @param obj - Object to which we will add a 'signatures' property * @param obj - Object to which we will add a 'signatures' property
*/ */
public async signObject<T extends ISignableObject & object>(obj: T): Promise<void> { public async signObject<T extends ISignableObject & object>(obj: T): Promise<void> {
const sigs = obj.signatures || {}; const sigs = new Map(Object.entries(obj.signatures || {}));
const unsigned = obj.unsigned; const unsigned = obj.unsigned;
delete obj.signatures; delete obj.signatures;
delete obj.unsigned; delete obj.unsigned;
sigs[this.userId] = sigs[this.userId] || {}; const userSignatures = sigs.get(this.userId) || {};
sigs[this.userId]["ed25519:" + this.deviceId] = await this.olmDevice.sign(anotherjson.stringify(obj)); sigs.set(this.userId, userSignatures);
obj.signatures = sigs; userSignatures["ed25519:" + this.deviceId] = await this.olmDevice.sign(anotherjson.stringify(obj));
obj.signatures = recursiveMapToObject(sigs);
if (unsigned !== undefined) obj.unsigned = unsigned; if (unsigned !== undefined) obj.unsigned = unsigned;
} }
} }

View File

@ -30,6 +30,7 @@ import { ISignatures } from "../@types/signed";
import { MatrixEvent } from "../models/event"; import { MatrixEvent } from "../models/event";
import { EventType } from "../@types/event"; import { EventType } from "../@types/event";
import { IMessage } from "./algorithms/olm"; import { IMessage } from "./algorithms/olm";
import { MapWithDefault } from "../utils";
enum Algorithm { enum Algorithm {
Olm = "m.olm.v1.curve25519-aes-sha2", Olm = "m.olm.v1.curve25519-aes-sha2",
@ -154,9 +155,11 @@ export async function getExistingOlmSessions(
olmDevice: OlmDevice, olmDevice: OlmDevice,
baseApis: MatrixClient, baseApis: MatrixClient,
devicesByUser: Record<string, DeviceInfo[]>, devicesByUser: Record<string, DeviceInfo[]>,
): Promise<[Record<string, DeviceInfo[]>, Record<string, Record<string, IExistingOlmSession>>]> { ): Promise<[Map<string, DeviceInfo[]>, Map<string, Map<string, IExistingOlmSession>>]> {
const devicesWithoutSession: { [userId: string]: DeviceInfo[] } = {}; // map user Id → DeviceInfo[]
const sessions: { [userId: string]: { [deviceId: string]: IExistingOlmSession } } = {}; const devicesWithoutSession: MapWithDefault<string, DeviceInfo[]> = new MapWithDefault(() => []);
// map user Id → device Id → IExistingOlmSession
const sessions: MapWithDefault<string, Map<string, IExistingOlmSession>> = new MapWithDefault(() => new Map());
const promises: Promise<void>[] = []; const promises: Promise<void>[] = [];
@ -168,14 +171,12 @@ export async function getExistingOlmSessions(
(async (): Promise<void> => { (async (): Promise<void> => {
const sessionId = await olmDevice.getSessionIdForDevice(key, true); const sessionId = await olmDevice.getSessionIdForDevice(key, true);
if (sessionId === null) { if (sessionId === null) {
devicesWithoutSession[userId] = devicesWithoutSession[userId] || []; devicesWithoutSession.getOrCreate(userId).push(deviceInfo);
devicesWithoutSession[userId].push(deviceInfo);
} else { } else {
sessions[userId] = sessions[userId] || {}; sessions.getOrCreate(userId).set(deviceId, {
sessions[userId][deviceId] = {
device: deviceInfo, device: deviceInfo,
sessionId: sessionId, sessionId: sessionId,
}; });
} }
})(), })(),
); );
@ -210,24 +211,26 @@ export async function getExistingOlmSessions(
export async function ensureOlmSessionsForDevices( export async function ensureOlmSessionsForDevices(
olmDevice: OlmDevice, olmDevice: OlmDevice,
baseApis: MatrixClient, baseApis: MatrixClient,
devicesByUser: Record<string, DeviceInfo[]>, devicesByUser: Map<string, DeviceInfo[]>,
force = false, force = false,
otkTimeout?: number, otkTimeout?: number,
failedServers?: string[], failedServers?: string[],
log = logger, log = logger,
): Promise<Record<string, Record<string, IOlmSessionResult>>> { ): Promise<Map<string, Map<string, IOlmSessionResult>>> {
const devicesWithoutSession: [string, string][] = [ const devicesWithoutSession: [string, string][] = [
// [userId, deviceId], ... // [userId, deviceId], ...
]; ];
const result: { [userId: string]: { [deviceId: string]: IExistingOlmSession } } = {}; // map user Iddevice Id → IExistingOlmSession
const resolveSession: Record<string, (sessionId?: string) => void> = {}; const result: Map<string, Map<string, IExistingOlmSession>> = new Map();
// map device key → resolve session fn
const resolveSession: Map<string, (sessionId?: string) => void> = new Map();
// Mark all sessions this task intends to update as in progress. It is // Mark all sessions this task intends to update as in progress. It is
// important to do this for all devices this task cares about in a single // important to do this for all devices this task cares about in a single
// synchronous operation, as otherwise it is possible to have deadlocks // synchronous operation, as otherwise it is possible to have deadlocks
// where multiple tasks wait indefinitely on another task to update some set // where multiple tasks wait indefinitely on another task to update some set
// of common devices. // of common devices.
for (const [, devices] of Object.entries(devicesByUser)) { for (const devices of devicesByUser.values()) {
for (const deviceInfo of devices) { for (const deviceInfo of devices) {
const key = deviceInfo.getIdentityKey(); const key = deviceInfo.getIdentityKey();
@ -242,17 +245,19 @@ export async function ensureOlmSessionsForDevices(
// conditions. If we find that we already have a session, then // conditions. If we find that we already have a session, then
// we'll resolve // we'll resolve
olmDevice.sessionsInProgress[key] = new Promise((resolve) => { olmDevice.sessionsInProgress[key] = new Promise((resolve) => {
resolveSession[key] = (v: any): void => { resolveSession.set(key, (v: any): void => {
delete olmDevice.sessionsInProgress[key]; delete olmDevice.sessionsInProgress[key];
resolve(v); resolve(v);
}; });
}); });
} }
} }
} }
for (const [userId, devices] of Object.entries(devicesByUser)) { for (const [userId, devices] of devicesByUser) {
result[userId] = {}; const resultDevices = new Map();
result.set(userId, resultDevices);
for (const deviceInfo of devices) { for (const deviceInfo of devices) {
const deviceId = deviceInfo.deviceId; const deviceId = deviceInfo.deviceId;
const key = deviceInfo.getIdentityKey(); const key = deviceInfo.getIdentityKey();
@ -268,20 +273,21 @@ export async function ensureOlmSessionsForDevices(
log.info("Attempted to start session with ourself! Ignoring"); log.info("Attempted to start session with ourself! Ignoring");
// We must fill in the section in the return value though, as callers // We must fill in the section in the return value though, as callers
// expect it to be there. // expect it to be there.
result[userId][deviceId] = { resultDevices.set(deviceId, {
device: deviceInfo, device: deviceInfo,
sessionId: null, sessionId: null,
}; });
continue; continue;
} }
const forWhom = `for ${key} (${userId}:${deviceId})`; const forWhom = `for ${key} (${userId}:${deviceId})`;
const sessionId = await olmDevice.getSessionIdForDevice(key, !!resolveSession[key], log); const sessionId = await olmDevice.getSessionIdForDevice(key, !!resolveSession.get(key), log);
if (sessionId !== null && resolveSession[key]) { const resolveSessionFn = resolveSession.get(key);
if (sessionId !== null && resolveSessionFn) {
// we found a session, but we had marked the session as // we found a session, but we had marked the session as
// in-progress, so resolve it now, which will unmark it and // in-progress, so resolve it now, which will unmark it and
// unblock anything that was waiting // unblock anything that was waiting
resolveSession[key](); resolveSessionFn();
} }
if (sessionId === null || force) { if (sessionId === null || force) {
if (force) { if (force) {
@ -291,10 +297,10 @@ export async function ensureOlmSessionsForDevices(
} }
devicesWithoutSession.push([userId, deviceId]); devicesWithoutSession.push([userId, deviceId]);
} }
result[userId][deviceId] = { resultDevices.set(deviceId, {
device: deviceInfo, device: deviceInfo,
sessionId: sessionId, sessionId: sessionId,
}; });
} }
} }
@ -310,7 +316,7 @@ export async function ensureOlmSessionsForDevices(
res = await baseApis.claimOneTimeKeys(devicesWithoutSession, oneTimeKeyAlgorithm, otkTimeout); res = await baseApis.claimOneTimeKeys(devicesWithoutSession, oneTimeKeyAlgorithm, otkTimeout);
log.debug(`Claimed ${taskDetail}`); log.debug(`Claimed ${taskDetail}`);
} catch (e) { } catch (e) {
for (const resolver of Object.values(resolveSession)) { for (const resolver of resolveSession.values()) {
resolver(); resolver();
} }
log.log(`Failed to claim ${taskDetail}`, e, devicesWithoutSession); log.log(`Failed to claim ${taskDetail}`, e, devicesWithoutSession);
@ -323,7 +329,7 @@ export async function ensureOlmSessionsForDevices(
const otkResult = res.one_time_keys || ({} as IClaimOTKsResult["one_time_keys"]); const otkResult = res.one_time_keys || ({} as IClaimOTKsResult["one_time_keys"]);
const promises: Promise<void>[] = []; const promises: Promise<void>[] = [];
for (const [userId, devices] of Object.entries(devicesByUser)) { for (const [userId, devices] of devicesByUser) {
const userRes = otkResult[userId] || {}; const userRes = otkResult[userId] || {};
for (const deviceInfo of devices) { for (const deviceInfo of devices) {
const deviceId = deviceInfo.deviceId; const deviceId = deviceInfo.deviceId;
@ -336,7 +342,7 @@ export async function ensureOlmSessionsForDevices(
continue; continue;
} }
if (result[userId][deviceId].sessionId && !force) { if (result.get(userId)?.get(deviceId)?.sessionId && !force) {
// we already have a result for this device // we already have a result for this device
continue; continue;
} }
@ -351,24 +357,19 @@ export async function ensureOlmSessionsForDevices(
if (!oneTimeKey) { if (!oneTimeKey) {
log.warn(`No one-time keys (alg=${oneTimeKeyAlgorithm}) ` + `for device ${userId}:${deviceId}`); log.warn(`No one-time keys (alg=${oneTimeKeyAlgorithm}) ` + `for device ${userId}:${deviceId}`);
if (resolveSession[key]) { resolveSession.get(key)?.();
resolveSession[key]();
}
continue; continue;
} }
promises.push( promises.push(
_verifyKeyAndStartSession(olmDevice, oneTimeKey, userId, deviceInfo).then( _verifyKeyAndStartSession(olmDevice, oneTimeKey, userId, deviceInfo).then(
(sid) => { (sid) => {
if (resolveSession[key]) { resolveSession.get(key)?.(sid ?? undefined);
resolveSession[key](sid ?? undefined); const deviceInfo = result.get(userId)?.get(deviceId);
} if (deviceInfo) deviceInfo.sessionId = sid;
result[userId][deviceId].sessionId = sid;
}, },
(e) => { (e) => {
if (resolveSession[key]) { resolveSession.get(key)?.();
resolveSession[key]();
}
throw e; throw e;
}, },
), ),

View File

@ -21,6 +21,7 @@ import { IOlmDevice } from "../algorithms/megolm";
import { IRoomEncryption } from "../RoomList"; import { IRoomEncryption } from "../RoomList";
import { ICrossSigningKey } from "../../client"; import { ICrossSigningKey } from "../../client";
import { InboundGroupSessionData } from "../OlmDevice"; import { InboundGroupSessionData } from "../OlmDevice";
import { safeSet } from "../../utils";
/** /**
* Internal module. Partial localStorage backed storage for e2e. * Internal module. Partial localStorage backed storage for e2e.
@ -178,11 +179,11 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore {
if (userId in notifiedErrorDevices) { if (userId in notifiedErrorDevices) {
if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) { if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) {
ret.push(device); ret.push(device);
notifiedErrorDevices[userId][deviceInfo.deviceId] = true; safeSet(notifiedErrorDevices[userId], deviceInfo.deviceId, true);
} }
} else { } else {
ret.push(device); ret.push(device);
notifiedErrorDevices[userId] = { [deviceInfo.deviceId]: true }; safeSet(notifiedErrorDevices, userId, { [deviceInfo.deviceId]: true });
} }
} }

View File

@ -33,6 +33,7 @@ import { ICrossSigningKey } from "../../client";
import { IOlmDevice } from "../algorithms/megolm"; import { IOlmDevice } from "../algorithms/megolm";
import { IRoomEncryption } from "../RoomList"; import { IRoomEncryption } from "../RoomList";
import { InboundGroupSessionData } from "../OlmDevice"; import { InboundGroupSessionData } from "../OlmDevice";
import { safeSet } from "../../utils";
/** /**
* Internal module. in-memory storage for e2e. * Internal module. in-memory storage for e2e.
@ -375,11 +376,11 @@ export class MemoryCryptoStore implements CryptoStore {
if (userId in notifiedErrorDevices) { if (userId in notifiedErrorDevices) {
if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) { if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) {
ret.push(device); ret.push(device);
notifiedErrorDevices[userId][deviceInfo.deviceId] = true; safeSet(notifiedErrorDevices[userId], deviceInfo.deviceId, true);
} }
} else { } else {
ret.push(device); ret.push(device);
notifiedErrorDevices[userId] = { [deviceInfo.deviceId]: true }; safeSet(notifiedErrorDevices, userId, { [deviceInfo.deviceId]: true });
} }
} }

View File

@ -269,12 +269,12 @@ export class ToDeviceChannel implements IVerificationChannel {
private async sendToDevices(type: string, content: Record<string, any>, devices: string[]): Promise<void> { private async sendToDevices(type: string, content: Record<string, any>, devices: string[]): Promise<void> {
if (devices.length) { if (devices.length) {
const msgMap: Record<string, Record<string, any>> = {}; const deviceMessages: Map<string, Record<string, any>> = new Map();
for (const deviceId of devices) { for (const deviceId of devices) {
msgMap[deviceId] = content; deviceMessages.set(deviceId, content);
} }
await this.client.sendToDevice(type, { [this.userId]: msgMap }); await this.client.sendToDevice(type, new Map([[this.userId, deviceMessages]]));
} }
} }

View File

@ -29,15 +29,16 @@ import { IEvent, IContent, EventStatus } from "./models/event";
import { ISendEventResponse } from "./@types/requests"; import { ISendEventResponse } from "./@types/requests";
import { EventType } from "./@types/event"; import { EventType } from "./@types/event";
import { logger } from "./logger"; import { logger } from "./logger";
import { MatrixClient, ClientEvent, IMatrixClientCreateOpts, IStartClientOpts } from "./client"; import { MatrixClient, ClientEvent, IMatrixClientCreateOpts, IStartClientOpts, SendToDeviceContentMap } from "./client";
import { SyncApi, SyncState } from "./sync"; import { SyncApi, SyncState } from "./sync";
import { SlidingSyncSdk } from "./sliding-sync-sdk"; import { SlidingSyncSdk } from "./sliding-sync-sdk";
import { MatrixEvent } from "./models/event"; import { MatrixEvent } from "./models/event";
import { User } from "./models/user"; import { User } from "./models/user";
import { Room } from "./models/room"; import { Room } from "./models/room";
import { ToDeviceBatch } from "./models/ToDeviceMessage"; import { ToDeviceBatch, ToDevicePayload } from "./models/ToDeviceMessage";
import { DeviceInfo } from "./crypto/deviceinfo"; import { DeviceInfo } from "./crypto/deviceinfo";
import { IOlmDevice } from "./crypto/algorithms/megolm"; import { IOlmDevice } from "./crypto/algorithms/megolm";
import { MapWithDefault, recursiveMapToObject } from "./utils";
interface IStateEventRequest { interface IStateEventRequest {
eventType: string; eventType: string;
@ -234,35 +235,32 @@ export class RoomWidgetClient extends MatrixClient {
return await this.widgetApi.sendStateEvent(eventType, stateKey, content, roomId); return await this.widgetApi.sendStateEvent(eventType, stateKey, content, roomId);
} }
public async sendToDevice( public async sendToDevice(eventType: string, contentMap: SendToDeviceContentMap): Promise<{}> {
eventType: string, await this.widgetApi.sendToDevice(eventType, false, recursiveMapToObject(contentMap));
contentMap: { [userId: string]: { [deviceId: string]: Record<string, any> } },
): Promise<{}> {
await this.widgetApi.sendToDevice(eventType, false, contentMap);
return {}; return {};
} }
public async queueToDevice({ eventType, batch }: ToDeviceBatch): Promise<void> { public async queueToDevice({ eventType, batch }: ToDeviceBatch): Promise<void> {
const contentMap: { [userId: string]: { [deviceId: string]: object } } = {}; // map: user Iddevice Id → payload
const contentMap: MapWithDefault<string, Map<string, ToDevicePayload>> = new MapWithDefault(() => new Map());
for (const { userId, deviceId, payload } of batch) { for (const { userId, deviceId, payload } of batch) {
if (!contentMap[userId]) contentMap[userId] = {}; contentMap.getOrCreate(userId).set(deviceId, payload);
contentMap[userId][deviceId] = payload;
} }
await this.widgetApi.sendToDevice(eventType, false, contentMap); await this.widgetApi.sendToDevice(eventType, false, recursiveMapToObject(contentMap));
} }
public async encryptAndSendToDevices(userDeviceInfoArr: IOlmDevice<DeviceInfo>[], payload: object): Promise<void> { public async encryptAndSendToDevices(userDeviceInfoArr: IOlmDevice<DeviceInfo>[], payload: object): Promise<void> {
const contentMap: { [userId: string]: { [deviceId: string]: object } } = {}; // map: user Iddevice Id → payload
const contentMap: MapWithDefault<string, Map<string, object>> = new MapWithDefault(() => new Map());
for (const { for (const {
userId, userId,
deviceInfo: { deviceId }, deviceInfo: { deviceId },
} of userDeviceInfoArr) { } of userDeviceInfoArr) {
if (!contentMap[userId]) contentMap[userId] = {}; contentMap.getOrCreate(userId).set(deviceId, payload);
contentMap[userId][deviceId] = payload;
} }
await this.widgetApi.sendToDevice((payload as { type: string }).type, true, contentMap); await this.widgetApi.sendToDevice((payload as { type: string }).type, true, recursiveMapToObject(contentMap));
} }
// Overridden since we get TURN servers automatically over the widget API, // Overridden since we get TURN servers automatically over the widget API,

View File

@ -16,7 +16,6 @@ import {
MAIN_ROOM_TIMELINE, MAIN_ROOM_TIMELINE,
Receipt, Receipt,
ReceiptCache, ReceiptCache,
Receipts,
ReceiptType, ReceiptType,
WrappedReceipt, WrappedReceipt,
} from "../@types/read_receipts"; } from "../@types/read_receipts";
@ -25,6 +24,7 @@ import * as utils from "../utils";
import { MatrixEvent } from "./event"; import { MatrixEvent } from "./event";
import { EventType } from "../@types/event"; import { EventType } from "../@types/event";
import { EventTimelineSet } from "./event-timeline-set"; import { EventTimelineSet } from "./event-timeline-set";
import { MapWithDefault } from "../utils";
import { NotificationCountType } from "./room"; import { NotificationCountType } from "./room";
export function synthesizeReceipt(userId: string, event: MatrixEvent, receiptType: ReceiptType): MatrixEvent { export function synthesizeReceipt(userId: string, event: MatrixEvent, receiptType: ReceiptType): MatrixEvent {
@ -56,8 +56,11 @@ export abstract class ReadReceipt<
// the form of this structure. This is sub-optimal for the exposed APIs // the form of this structure. This is sub-optimal for the exposed APIs
// which pass in an event ID and get back some receipts, so we also store // which pass in an event ID and get back some receipts, so we also store
// a pre-cached list for this purpose. // a pre-cached list for this purpose.
private receipts: Receipts = {}; // { receipt_type: { user_id: Receipt } } // Map: receipt type user Id → receipt
private receiptCacheByEventId: ReceiptCache = {}; // { event_id: CachedReceipt[] } private receipts = new MapWithDefault<string, Map<string, [WrappedReceipt | null, WrappedReceipt | null]>>(
() => new Map(),
);
private receiptCacheByEventId: ReceiptCache = new Map();
public abstract getUnfilteredTimelineSet(): EventTimelineSet; public abstract getUnfilteredTimelineSet(): EventTimelineSet;
public abstract timeline: MatrixEvent[]; public abstract timeline: MatrixEvent[];
@ -74,7 +77,7 @@ export abstract class ReadReceipt<
ignoreSynthesized = false, ignoreSynthesized = false,
receiptType = ReceiptType.Read, receiptType = ReceiptType.Read,
): WrappedReceipt | null { ): WrappedReceipt | null {
const [realReceipt, syntheticReceipt] = this.receipts[receiptType]?.[userId] ?? []; const [realReceipt, syntheticReceipt] = this.receipts.get(receiptType)?.get(userId) ?? [null, null];
if (ignoreSynthesized) { if (ignoreSynthesized) {
return realReceipt; return realReceipt;
} }
@ -126,14 +129,13 @@ export abstract class ReadReceipt<
receipt: Receipt, receipt: Receipt,
synthetic: boolean, synthetic: boolean,
): void { ): void {
if (!this.receipts[receiptType]) { const receiptTypesMap = this.receipts.getOrCreate(receiptType);
this.receipts[receiptType] = {}; let pair = receiptTypesMap.get(userId);
}
if (!this.receipts[receiptType][userId]) {
this.receipts[receiptType][userId] = [null, null];
}
const pair = this.receipts[receiptType][userId]; if (!pair) {
pair = [null, null];
receiptTypesMap.set(userId, pair);
}
let existingReceipt = pair[ReceiptPairRealIndex]; let existingReceipt = pair[ReceiptPairRealIndex];
if (synthetic) { if (synthetic) {
@ -185,23 +187,26 @@ export abstract class ReadReceipt<
if (cachedReceipt === newCachedReceipt) return; if (cachedReceipt === newCachedReceipt) return;
// clean up any previous cache entry // clean up any previous cache entry
if (cachedReceipt && this.receiptCacheByEventId[cachedReceipt.eventId]) { if (cachedReceipt && this.receiptCacheByEventId.get(cachedReceipt.eventId)) {
const previousEventId = cachedReceipt.eventId; const previousEventId = cachedReceipt.eventId;
// Remove the receipt we're about to clobber out of existence from the cache // Remove the receipt we're about to clobber out of existence from the cache
this.receiptCacheByEventId[previousEventId] = this.receiptCacheByEventId[previousEventId].filter((r) => { this.receiptCacheByEventId.set(
return r.type !== receiptType || r.userId !== userId; previousEventId,
}); this.receiptCacheByEventId.get(previousEventId)!.filter((r) => {
return r.type !== receiptType || r.userId !== userId;
}),
);
if (this.receiptCacheByEventId[previousEventId].length < 1) { if (this.receiptCacheByEventId.get(previousEventId)!.length < 1) {
delete this.receiptCacheByEventId[previousEventId]; // clean up the cache keys this.receiptCacheByEventId.delete(previousEventId); // clean up the cache keys
} }
} }
// cache the new one // cache the new one
if (!this.receiptCacheByEventId[eventId]) { if (!this.receiptCacheByEventId.get(eventId)) {
this.receiptCacheByEventId[eventId] = []; this.receiptCacheByEventId.set(eventId, []);
} }
this.receiptCacheByEventId[eventId].push({ this.receiptCacheByEventId.get(eventId)!.push({
userId: userId, userId: userId,
type: receiptType as ReceiptType, type: receiptType as ReceiptType,
data: receipt, data: receipt,
@ -215,7 +220,7 @@ export abstract class ReadReceipt<
* an empty list. * an empty list.
*/ */
public getReceiptsForEvent(event: MatrixEvent): CachedReceipt[] { public getReceiptsForEvent(event: MatrixEvent): CachedReceipt[] {
return this.receiptCacheByEventId[event.getId()!] || []; return this.receiptCacheByEventId.get(event.getId()!) || [];
} }
public abstract addReceipt(event: MatrixEvent, synthetic: boolean): void; public abstract addReceipt(event: MatrixEvent, synthetic: boolean): void;

View File

@ -25,7 +25,7 @@ import {
import { Direction, EventTimeline } from "./event-timeline"; import { Direction, EventTimeline } from "./event-timeline";
import { getHttpUriForMxc } from "../content-repo"; import { getHttpUriForMxc } from "../content-repo";
import * as utils from "../utils"; import * as utils from "../utils";
import { normalize } from "../utils"; import { normalize, noUnsafeEventProps } from "../utils";
import { IEvent, IThreadBundledRelationship, MatrixEvent, MatrixEventEvent, MatrixEventHandlerMap } from "./event"; import { IEvent, IThreadBundledRelationship, MatrixEvent, MatrixEventEvent, MatrixEventHandlerMap } from "./event";
import { EventStatus } from "./event-status"; import { EventStatus } from "./event-status";
import { RoomMember } from "./room-member"; import { RoomMember } from "./room-member";
@ -311,7 +311,7 @@ export type RoomEventHandlerMap = {
export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> { export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
public readonly reEmitter: TypedReEmitter<RoomEmittedEvents, RoomEventHandlerMap>; public readonly reEmitter: TypedReEmitter<RoomEmittedEvents, RoomEventHandlerMap>;
private txnToEvent: Record<string, MatrixEvent> = {}; // Pending in-flight requests { string: MatrixEvent } private txnToEvent: Map<string, MatrixEvent> = new Map(); // Pending in-flight requests { string: MatrixEvent }
private notificationCounts: NotificationCount = {}; private notificationCounts: NotificationCount = {};
private readonly threadNotifications = new Map<string, NotificationCount>(); private readonly threadNotifications = new Map<string, NotificationCount>();
public readonly cachedThreadReadReceipts = new Map<string, CachedReceiptStructure[]>(); public readonly cachedThreadReadReceipts = new Map<string, CachedReceiptStructure[]>();
@ -356,7 +356,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
* accountData Dict of per-room account_data events; the keys are the * accountData Dict of per-room account_data events; the keys are the
* event type and the values are the events. * event type and the values are the events.
*/ */
public accountData: Record<string, MatrixEvent> = {}; // $eventType: $event public accountData: Map<string, MatrixEvent> = new Map(); // $eventType: $event
/** /**
* The room summary. * The room summary.
*/ */
@ -902,7 +902,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
rawMembersEvents = await this.loadMembersFromServer(); rawMembersEvents = await this.loadMembersFromServer();
logger.log(`LL: got ${rawMembersEvents.length} ` + `members from server for room ${this.roomId}`); logger.log(`LL: got ${rawMembersEvents.length} ` + `members from server for room ${this.roomId}`);
} }
const memberEvents = rawMembersEvents.map(this.client.getEventMapper()); const memberEvents = rawMembersEvents.filter(noUnsafeEventProps).map(this.client.getEventMapper());
return { memberEvents, fromServer }; return { memberEvents, fromServer };
} }
@ -2255,8 +2255,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
const txnId = event.getUnsigned().transaction_id; const txnId = event.getUnsigned().transaction_id;
if (!txnId && event.getSender() === this.myUserId) { if (!txnId && event.getSender() === this.myUserId) {
// check the txn map for a matching event ID // check the txn map for a matching event ID
for (const tid in this.txnToEvent) { for (const [tid, localEvent] of this.txnToEvent) {
const localEvent = this.txnToEvent[tid];
if (localEvent.getId() === event.getId()) { if (localEvent.getId() === event.getId()) {
logger.debug("processLiveEvent: found sent event without txn ID: ", tid, event.getId()); logger.debug("processLiveEvent: found sent event without txn ID: ", tid, event.getId());
// update the unsigned field so we can re-use the same codepaths // update the unsigned field so we can re-use the same codepaths
@ -2331,7 +2330,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
throw new Error("addPendingEvent called on an event with status " + event.status); throw new Error("addPendingEvent called on an event with status " + event.status);
} }
if (this.txnToEvent[txnId]) { if (this.txnToEvent.get(txnId)) {
throw new Error("addPendingEvent called on an event with known txnId " + txnId); throw new Error("addPendingEvent called on an event with known txnId " + txnId);
} }
@ -2340,7 +2339,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
// on the unfiltered timelineSet. // on the unfiltered timelineSet.
EventTimeline.setEventMetadata(event, this.getLiveTimeline().getState(EventTimeline.FORWARDS)!, false); EventTimeline.setEventMetadata(event, this.getLiveTimeline().getState(EventTimeline.FORWARDS)!, false);
this.txnToEvent[txnId] = event; this.txnToEvent.set(txnId, event);
if (this.pendingEventList) { if (this.pendingEventList) {
if (this.pendingEventList.some((e) => e.status === EventStatus.NOT_SENT)) { if (this.pendingEventList.some((e) => e.status === EventStatus.NOT_SENT)) {
logger.warn("Setting event as NOT_SENT due to messages in the same state"); logger.warn("Setting event as NOT_SENT due to messages in the same state");
@ -2432,8 +2431,8 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
this.relations.aggregateChildEvent(event); this.relations.aggregateChildEvent(event);
} }
public getEventForTxnId(txnId: string): MatrixEvent { public getEventForTxnId(txnId: string): MatrixEvent | undefined {
return this.txnToEvent[txnId]; return this.txnToEvent.get(txnId);
} }
/** /**
@ -2460,7 +2459,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
logger.debug(`Got remote echo for event ${oldEventId} -> ${newEventId} old status ${oldStatus}`); logger.debug(`Got remote echo for event ${oldEventId} -> ${newEventId} old status ${oldStatus}`);
// no longer pending // no longer pending
delete this.txnToEvent[remoteEvent.getUnsigned().transaction_id!]; this.txnToEvent.delete(remoteEvent.getUnsigned().transaction_id!);
// if it's in the pending list, remove it // if it's in the pending list, remove it
if (this.pendingEventList) { if (this.pendingEventList) {
@ -2673,7 +2672,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
this.processLiveEvent(event); this.processLiveEvent(event);
if (event.getUnsigned().transaction_id) { if (event.getUnsigned().transaction_id) {
const existingEvent = this.txnToEvent[event.getUnsigned().transaction_id!]; const existingEvent = this.txnToEvent.get(event.getUnsigned().transaction_id!);
if (existingEvent) { if (existingEvent) {
// remote echo of an event we sent earlier // remote echo of an event we sent earlier
this.handleRemoteEcho(event, existingEvent); this.handleRemoteEcho(event, existingEvent);
@ -2942,8 +2941,9 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
if (event.getType() === "m.tag") { if (event.getType() === "m.tag") {
this.addTags(event); this.addTags(event);
} }
const lastEvent = this.accountData[event.getType()]; const eventType = event.getType();
this.accountData[event.getType()] = event; const lastEvent = this.accountData.get(eventType);
this.accountData.set(eventType, event);
this.emit(RoomEvent.AccountData, event, this, lastEvent); this.emit(RoomEvent.AccountData, event, this, lastEvent);
} }
} }
@ -2954,7 +2954,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
* @returns the account_data event in question * @returns the account_data event in question
*/ */
public getAccountData(type: EventType | string): MatrixEvent | undefined { public getAccountData(type: EventType | string): MatrixEvent | undefined {
return this.accountData[type]; return this.accountData.get(type);
} }
/** /**

View File

@ -36,7 +36,7 @@ export interface ISavedSync {
* A store for most of the data js-sdk needs to store, apart from crypto data * A store for most of the data js-sdk needs to store, apart from crypto data
*/ */
export interface IStore { export interface IStore {
readonly accountData: Record<string, MatrixEvent>; // type : content readonly accountData: Map<string, MatrixEvent>; // type : content
// XXX: The indexeddb store exposes a non-standard emitter for the "degraded" event // XXX: The indexeddb store exposes a non-standard emitter for the "degraded" event
// for when it falls back to being a memory store due to errors. // for when it falls back to being a memory store due to errors.

View File

@ -31,6 +31,7 @@ import { ISyncResponse } from "../sync-accumulator";
import { IStateEventWithRoomId } from "../@types/search"; import { IStateEventWithRoomId } from "../@types/search";
import { IndexedToDeviceBatch, ToDeviceBatchWithTxnId } from "../models/ToDeviceMessage"; import { IndexedToDeviceBatch, ToDeviceBatchWithTxnId } from "../models/ToDeviceMessage";
import { IStoredClientOpts } from "../client"; import { IStoredClientOpts } from "../client";
import { MapWithDefault } from "../utils";
function isValidFilterId(filterId?: string | number | null): boolean { function isValidFilterId(filterId?: string | number | null): boolean {
const isValidStr = const isValidStr =
@ -54,10 +55,10 @@ export class MemoryStore implements IStore {
// userId: { // userId: {
// filterId: Filter // filterId: Filter
// } // }
private filters: Record<string, Record<string, Filter>> = {}; private filters: MapWithDefault<string, Map<string, Filter>> = new MapWithDefault(() => new Map());
public accountData: Record<string, MatrixEvent> = {}; // type : content public accountData: Map<string, MatrixEvent> = new Map(); // type: content
protected readonly localStorage?: Storage; protected readonly localStorage?: Storage;
private oobMembers: Record<string, IStateEventWithRoomId[]> = {}; // roomId: [member events] private oobMembers: Map<string, IStateEventWithRoomId[]> = new Map(); // roomId: [member events]
private pendingEvents: { [roomId: string]: Partial<IEvent>[] } = {}; private pendingEvents: { [roomId: string]: Partial<IEvent>[] } = {};
private clientOptions?: IStoredClientOpts; private clientOptions?: IStoredClientOpts;
private pendingToDeviceBatches: IndexedToDeviceBatch[] = []; private pendingToDeviceBatches: IndexedToDeviceBatch[] = [];
@ -220,10 +221,7 @@ export class MemoryStore implements IStore {
*/ */
public storeFilter(filter: Filter): void { public storeFilter(filter: Filter): void {
if (!filter?.userId || !filter?.filterId) return; if (!filter?.userId || !filter?.filterId) return;
if (!this.filters[filter.userId]) { this.filters.getOrCreate(filter.userId).set(filter.filterId, filter);
this.filters[filter.userId] = {};
}
this.filters[filter.userId][filter.filterId] = filter;
} }
/** /**
@ -231,10 +229,7 @@ export class MemoryStore implements IStore {
* @returns A filter or null. * @returns A filter or null.
*/ */
public getFilter(userId: string, filterId: string): Filter | null { public getFilter(userId: string, filterId: string): Filter | null {
if (!this.filters[userId] || !this.filters[userId][filterId]) { return this.filters.get(userId)?.get(filterId) || null;
return null;
}
return this.filters[userId][filterId];
} }
/** /**
@ -289,9 +284,9 @@ export class MemoryStore implements IStore {
// MSC3391: an event with content of {} should be interpreted as deleted // MSC3391: an event with content of {} should be interpreted as deleted
const isDeleted = !Object.keys(event.getContent()).length; const isDeleted = !Object.keys(event.getContent()).length;
if (isDeleted) { if (isDeleted) {
delete this.accountData[event.getType()]; this.accountData.delete(event.getType());
} else { } else {
this.accountData[event.getType()] = event; this.accountData.set(event.getType(), event);
} }
}); });
} }
@ -302,7 +297,7 @@ export class MemoryStore implements IStore {
* @returns the user account_data event of given type, if any * @returns the user account_data event of given type, if any
*/ */
public getAccountData(eventType: EventType | string): MatrixEvent | undefined { public getAccountData(eventType: EventType | string): MatrixEvent | undefined {
return this.accountData[eventType]; return this.accountData.get(eventType);
} }
/** /**
@ -368,14 +363,8 @@ export class MemoryStore implements IStore {
// userId: User // userId: User
}; };
this.syncToken = null; this.syncToken = null;
this.filters = { this.filters = new MapWithDefault(() => new Map());
// userId: { this.accountData = new Map(); // type : content
// filterId: Filter
// }
};
this.accountData = {
// type : content
};
return Promise.resolve(); return Promise.resolve();
} }
@ -386,7 +375,7 @@ export class MemoryStore implements IStore {
* @returns in case the members for this room haven't been stored yet * @returns in case the members for this room haven't been stored yet
*/ */
public getOutOfBandMembers(roomId: string): Promise<IStateEventWithRoomId[] | null> { public getOutOfBandMembers(roomId: string): Promise<IStateEventWithRoomId[] | null> {
return Promise.resolve(this.oobMembers[roomId] || null); return Promise.resolve(this.oobMembers.get(roomId) || null);
} }
/** /**
@ -397,12 +386,12 @@ export class MemoryStore implements IStore {
* @returns when all members have been stored * @returns when all members have been stored
*/ */
public setOutOfBandMembers(roomId: string, membershipEvents: IStateEventWithRoomId[]): Promise<void> { public setOutOfBandMembers(roomId: string, membershipEvents: IStateEventWithRoomId[]): Promise<void> {
this.oobMembers[roomId] = membershipEvents; this.oobMembers.set(roomId, membershipEvents);
return Promise.resolve(); return Promise.resolve();
} }
public clearOutOfBandMembers(roomId: string): Promise<void> { public clearOutOfBandMembers(roomId: string): Promise<void> {
this.oobMembers = {}; this.oobMembers.delete(roomId);
return Promise.resolve(); return Promise.resolve();
} }

View File

@ -34,7 +34,7 @@ import { IStoredClientOpts } from "../client";
* Construct a stub store. This does no-ops on most store methods. * Construct a stub store. This does no-ops on most store methods.
*/ */
export class StubStore implements IStore { export class StubStore implements IStore {
public readonly accountData = {}; // stub public readonly accountData = new Map(); // stub
private fromToken: string | null = null; private fromToken: string | null = null;
/** @returns whether or not the database was newly created in this session. */ /** @returns whether or not the database was newly created in this session. */

View File

@ -19,7 +19,7 @@ limitations under the License.
*/ */
import { logger } from "./logger"; import { logger } from "./logger";
import { deepCopy, isSupportedReceiptType } from "./utils"; import { deepCopy, isSupportedReceiptType, MapWithDefault, recursiveMapToObject } from "./utils";
import { IContent, IUnsigned } from "./models/event"; import { IContent, IUnsigned } from "./models/event";
import { IRoomSummary } from "./models/room-summary"; import { IRoomSummary } from "./models/room-summary";
import { EventType } from "./@types/event"; import { EventType } from "./@types/event";
@ -585,29 +585,31 @@ export class SyncAccumulator {
} as IContent, } as IContent,
}; };
const receiptEventContent: MapWithDefault<
string,
MapWithDefault<ReceiptType, Map<string, object>>
> = new MapWithDefault(() => new MapWithDefault(() => new Map()));
for (const [userId, receiptData] of Object.entries(roomData._readReceipts)) { for (const [userId, receiptData] of Object.entries(roomData._readReceipts)) {
if (!receiptEvent.content[receiptData.eventId]) { receiptEventContent
receiptEvent.content[receiptData.eventId] = {}; .getOrCreate(receiptData.eventId)
} .getOrCreate(receiptData.type)
if (!receiptEvent.content[receiptData.eventId][receiptData.type]) { .set(userId, receiptData.data);
receiptEvent.content[receiptData.eventId][receiptData.type] = {};
}
receiptEvent.content[receiptData.eventId][receiptData.type][userId] = receiptData.data;
} }
for (const threadReceipts of Object.values(roomData._threadReadReceipts)) { for (const threadReceipts of Object.values(roomData._threadReadReceipts)) {
for (const [userId, receiptData] of Object.entries(threadReceipts)) { for (const [userId, receiptData] of Object.entries(threadReceipts)) {
if (!receiptEvent.content[receiptData.eventId]) { receiptEventContent
receiptEvent.content[receiptData.eventId] = {}; .getOrCreate(receiptData.eventId)
} .getOrCreate(receiptData.type)
if (!receiptEvent.content[receiptData.eventId][receiptData.type]) { .set(userId, receiptData.data);
receiptEvent.content[receiptData.eventId][receiptData.type] = {};
}
receiptEvent.content[receiptData.eventId][receiptData.type][userId] = receiptData.data;
} }
} }
receiptEvent.content = recursiveMapToObject(receiptEventContent);
// add only if we have some receipt data // add only if we have some receipt data
if (Object.keys(receiptEvent.content).length > 0) { if (receiptEventContent.size > 0) {
roomJson.ephemeral.events.push(receiptEvent as IMinimalEvent); roomJson.ephemeral.events.push(receiptEvent as IMinimalEvent);
} }

View File

@ -29,7 +29,7 @@ import type { SyncCryptoCallbacks } from "./common-crypto/CryptoBackend";
import { User, UserEvent } from "./models/user"; import { User, UserEvent } from "./models/user";
import { NotificationCountType, Room, RoomEvent } from "./models/room"; import { NotificationCountType, Room, RoomEvent } from "./models/room";
import * as utils from "./utils"; import * as utils from "./utils";
import { IDeferred } from "./utils"; import { IDeferred, noUnsafeEventProps, unsafeProp } from "./utils";
import { Filter } from "./filter"; import { Filter } from "./filter";
import { EventTimeline } from "./models/event-timeline"; import { EventTimeline } from "./models/event-timeline";
import { logger } from "./logger"; import { logger } from "./logger";
@ -1133,22 +1133,24 @@ export class SyncApi {
// handle presence events (User objects) // handle presence events (User objects)
if (Array.isArray(data.presence?.events)) { if (Array.isArray(data.presence?.events)) {
data.presence!.events.map(client.getEventMapper()).forEach(function (presenceEvent) { data.presence!.events.filter(noUnsafeEventProps)
let user = client.store.getUser(presenceEvent.getSender()!); .map(client.getEventMapper())
if (user) { .forEach(function (presenceEvent) {
user.setPresenceEvent(presenceEvent); let user = client.store.getUser(presenceEvent.getSender()!);
} else { if (user) {
user = createNewUser(client, presenceEvent.getSender()!); user.setPresenceEvent(presenceEvent);
user.setPresenceEvent(presenceEvent); } else {
client.store.storeUser(user); user = createNewUser(client, presenceEvent.getSender()!);
} user.setPresenceEvent(presenceEvent);
client.emit(ClientEvent.Event, presenceEvent); client.store.storeUser(user);
}); }
client.emit(ClientEvent.Event, presenceEvent);
});
} }
// handle non-room account_data // handle non-room account_data
if (Array.isArray(data.account_data?.events)) { if (Array.isArray(data.account_data?.events)) {
const events = data.account_data.events.map(client.getEventMapper()); const events = data.account_data.events.filter(noUnsafeEventProps).map(client.getEventMapper());
const prevEventsMap = events.reduce<Record<string, MatrixEvent | undefined>>((m, c) => { const prevEventsMap = events.reduce<Record<string, MatrixEvent | undefined>>((m, c) => {
m[c.getType()!] = client.store.getAccountData(c.getType()); m[c.getType()!] = client.store.getAccountData(c.getType());
return m; return m;
@ -1171,7 +1173,7 @@ export class SyncApi {
// handle to-device events // handle to-device events
if (data.to_device && Array.isArray(data.to_device.events) && data.to_device.events.length > 0) { if (data.to_device && Array.isArray(data.to_device.events) && data.to_device.events.length > 0) {
let toDeviceMessages: IToDeviceEvent[] = data.to_device.events; let toDeviceMessages: IToDeviceEvent[] = data.to_device.events.filter(noUnsafeEventProps);
if (this.syncOpts.cryptoCallbacks) { if (this.syncOpts.cryptoCallbacks) {
toDeviceMessages = await this.syncOpts.cryptoCallbacks.preprocessToDeviceMessages(toDeviceMessages); toDeviceMessages = await this.syncOpts.cryptoCallbacks.preprocessToDeviceMessages(toDeviceMessages);
@ -1630,18 +1632,20 @@ export class SyncApi {
// to // to
// [{stuff+Room+isBrandNewRoom}, {stuff+Room+isBrandNewRoom}] // [{stuff+Room+isBrandNewRoom}, {stuff+Room+isBrandNewRoom}]
const client = this.client; const client = this.client;
return Object.keys(obj).map((roomId) => { return Object.keys(obj)
const arrObj = obj[roomId] as T & { room: Room; isBrandNewRoom: boolean }; .filter((k) => !unsafeProp(k))
let room = client.store.getRoom(roomId); .map((roomId) => {
let isBrandNewRoom = false; const arrObj = obj[roomId] as T & { room: Room; isBrandNewRoom: boolean };
if (!room) { let room = client.store.getRoom(roomId);
room = this.createRoom(roomId); let isBrandNewRoom = false;
isBrandNewRoom = true; if (!room) {
} room = this.createRoom(roomId);
arrObj.room = room; isBrandNewRoom = true;
arrObj.isBrandNewRoom = isBrandNewRoom; }
return arrObj; arrObj.room = room;
}); arrObj.isBrandNewRoom = isBrandNewRoom;
return arrObj;
});
} }
private mapSyncEventsFormat( private mapSyncEventsFormat(
@ -1654,7 +1658,7 @@ export class SyncApi {
} }
const mapper = this.client.getEventMapper({ decrypt }); const mapper = this.client.getEventMapper({ decrypt });
type TaggedEvent = (IStrippedState | IRoomEvent | IStateEvent | IMinimalEvent) & { room_id?: string }; type TaggedEvent = (IStrippedState | IRoomEvent | IStateEvent | IMinimalEvent) & { room_id?: string };
return (obj.events as TaggedEvent[]).map(function (e) { return (obj.events as TaggedEvent[]).filter(noUnsafeEventProps).map(function (e) {
if (room) { if (room) {
e.room_id = room.roomId; e.room_id = room.roomId;
} }

View File

@ -22,7 +22,7 @@ import unhomoglyph from "unhomoglyph";
import promiseRetry from "p-retry"; import promiseRetry from "p-retry";
import { Optional } from "matrix-events-sdk"; import { Optional } from "matrix-events-sdk";
import { MatrixEvent } from "./models/event"; import { IEvent, MatrixEvent } from "./models/event";
import { M_TIMESTAMP } from "./@types/location"; import { M_TIMESTAMP } from "./@types/location";
import { ReceiptType } from "./@types/read_receipts"; import { ReceiptType } from "./@types/read_receipts";
@ -703,3 +703,68 @@ export function mapsEqual<K, V>(x: Map<K, V>, y: Map<K, V>, eq = (v1: V, v2: V):
} }
return true; return true;
} }
function processMapToObjectValue(value: any): any {
if (value instanceof Map) {
// Value is a Map. Recursively map it to an object.
return recursiveMapToObject(value);
} else if (Array.isArray(value)) {
// Value is an Array. Recursively map the value (e.g. to cover Array of Arrays).
return value.map((v) => processMapToObjectValue(v));
} else {
return value;
}
}
/**
* Recursively converts Maps to plain objects.
* Also supports sub-lists of Maps.
*/
export function recursiveMapToObject(map: Map<any, any>): any {
const targetMap = new Map();
for (const [key, value] of map) {
targetMap.set(key, processMapToObjectValue(value));
}
return Object.fromEntries(targetMap.entries());
}
export function unsafeProp<K extends keyof any | undefined>(prop: K): boolean {
return prop === "__proto__" || prop === "prototype" || prop === "constructor";
}
export function safeSet<K extends keyof any>(obj: Record<any, any>, prop: K, value: any): void {
if (unsafeProp(prop)) {
throw new Error("Trying to modify prototype or constructor");
}
obj[prop] = value;
}
export function noUnsafeEventProps(event: Partial<IEvent>): boolean {
return !(
unsafeProp(event.room_id) ||
unsafeProp(event.sender) ||
unsafeProp(event.user_id) ||
unsafeProp(event.event_id)
);
}
export class MapWithDefault<K, V> extends Map<K, V> {
public constructor(private createDefault: () => V) {
super();
}
/**
* Returns the value if the key already exists.
* If not, it creates a new value under that key using the ctor callback and returns it.
*/
public getOrCreate(key: K): V {
if (!this.has(key)) {
this.set(key, this.createDefault());
}
return this.get(key)!;
}
}

View File

@ -609,7 +609,7 @@ export class MatrixCall extends TypedEventEmitter<CallEvent, CallEventHandlerMap
if (!userId) throw new Error("Couldn't find opponent user ID to init crypto"); if (!userId) throw new Error("Couldn't find opponent user ID to init crypto");
const deviceInfoMap = await this.client.crypto.deviceList.downloadKeys([userId], false); const deviceInfoMap = await this.client.crypto.deviceList.downloadKeys([userId], false);
this.opponentDeviceInfo = deviceInfoMap[userId][this.opponentDeviceId]; this.opponentDeviceInfo = deviceInfoMap.get(userId)?.get(this.opponentDeviceId);
if (this.opponentDeviceInfo === undefined) { if (this.opponentDeviceInfo === undefined) {
throw new GroupCallUnknownDeviceError(userId); throw new GroupCallUnknownDeviceError(userId);
} }
@ -2408,11 +2408,10 @@ export class MatrixCall extends TypedEventEmitter<CallEvent, CallEventHandlerMap
}, },
); );
} else { } else {
await this.client.sendToDevice(eventType, { await this.client.sendToDevice(
[userId]: { eventType,
[this.opponentDeviceId]: content, new Map<string, any>([[userId, new Map([[this.opponentDeviceId, content]])]]),
}, );
});
} }
} else { } else {
this.emit(CallEvent.SendVoipEvent, { this.emit(CallEvent.SendVoipEvent, {