1
0
mirror of https://github.com/matrix-org/matrix-js-sdk.git synced 2025-07-31 15:24:23 +03:00

Merge branch 'master' into develop

This commit is contained in:
RiotRobot
2023-03-28 14:15:09 +01:00
42 changed files with 718 additions and 559 deletions

View File

@ -1,3 +1,9 @@
Changes in [24.0.0](https://github.com/matrix-org/matrix-js-sdk/releases/tag/v24.0.0) (2023-03-28)
==================================================================================================
## 🐛 Bug Fixes
* Changes for matrix-js-sdk v24.0.0
Changes in [23.5.0](https://github.com/matrix-org/matrix-js-sdk/releases/tag/v23.5.0) (2023-03-15)
==================================================================================================

View File

@ -1,6 +1,6 @@
{
"name": "matrix-js-sdk",
"version": "23.5.0",
"version": "24.0.0",
"description": "Matrix Client-Server SDK for Javascript",
"engines": {
"node": ">=16.0.0"

View File

@ -560,7 +560,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({});
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
}
@ -620,7 +620,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({});
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
}
@ -688,7 +688,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({});
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
}
@ -1071,8 +1071,8 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
throw new Error("sendTextMessage succeeded on an unknown device");
} catch (e) {
expect((e as any).name).toEqual("UnknownDeviceError");
expect(Object.keys((e as any).devices)).toEqual([aliceClient.getUserId()!]);
expect(Object.keys((e as any)?.devices[aliceClient.getUserId()!])).toEqual(["DEVICE_ID"]);
expect([...(e as any).devices.keys()]).toEqual([aliceClient.getUserId()!]);
expect((e as any).devices.get(aliceClient.getUserId()!).has("DEVICE_ID"));
}
// mark the device as known, and resend.
@ -1140,7 +1140,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({});
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
}
@ -1296,7 +1296,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) {
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({});
aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz";
}
@ -1363,7 +1363,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
// if we're using the old crypto impl, stub out some methods in the device manager.
// TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic.
if (aliceClient.crypto) {
aliceClient.crypto!.deviceList.downloadKeys = () => Promise.resolve({});
aliceClient.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
aliceClient.crypto!.deviceList.getDeviceByIdentityKey = () => device;
aliceClient.crypto!.deviceList.getUserByIdentityKey = () => beccaTestClient.client.getUserId()!;
}

View File

@ -603,14 +603,14 @@ describe("MatrixClient", function () {
});
const prom = client!.downloadKeys(["boris", "chaz"]).then(function (res) {
assertObjectContains(res.boris.dev1, {
assertObjectContains(res.get("boris")!.get("dev1")!, {
verified: 0, // DeviceVerification.UNVERIFIED
keys: { "ed25519:dev1": ed25519key },
algorithms: ["1"],
unsigned: { abc: "def" },
});
assertObjectContains(res.chaz.dev2, {
assertObjectContains(res.get("chaz")!.get("dev2")!, {
verified: 0, // DeviceVerification.UNVERIFIED
keys: { "ed25519:dev2": ed25519key },
algorithms: ["2"],

View File

@ -472,7 +472,7 @@ describe("MatrixClient crypto", () => {
aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} });
await aliTestClient.start();
await bobTestClient.start();
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({});
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
await firstSync(aliTestClient);
await aliEnablesEncryption();
await aliSendsFirstMessage();
@ -483,7 +483,7 @@ describe("MatrixClient crypto", () => {
aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} });
await aliTestClient.start();
await bobTestClient.start();
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({});
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
await firstSync(aliTestClient);
await aliEnablesEncryption();
await aliSendsFirstMessage();
@ -545,7 +545,7 @@ describe("MatrixClient crypto", () => {
aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} });
await aliTestClient.start();
await bobTestClient.start();
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({});
bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
await firstSync(aliTestClient);
await aliEnablesEncryption();
await aliSendsFirstMessage();

View File

@ -30,6 +30,7 @@ import {
RoomState,
RoomStateEvent,
RoomStateEventHandlerMap,
SendToDeviceContentMap,
} from "../../src";
import { TypedEventEmitter } from "../../src/models/typed-event-emitter";
import { ReEmitter } from "../../src/ReEmitter";
@ -443,11 +444,7 @@ export class MockCallMatrixClient extends TypedEventEmitter<EmittedEvents, Emitt
>();
public sendToDevice = jest.fn<
Promise<{}>,
[
eventType: string,
contentMap: { [userId: string]: { [deviceId: string]: Record<string, any> } },
txnId?: string,
]
[eventType: string, contentMap: SendToDeviceContentMap, txnId?: string]
>();
public isInitialSyncComplete(): boolean {

View File

@ -405,7 +405,7 @@ describe("Crypto", function () {
// the first message can't be decrypted yet, but the second one
// can
let ksEvent = await keyshareEventForEvent(aliceClient, events[1], 1);
bobClient.crypto!.deviceList.downloadKeys = () => Promise.resolve({});
bobClient.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
bobClient.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com";
await bobDecryptor.onRoomKeyEvent(ksEvent);
await decryptEventsPromise;
@ -1039,7 +1039,7 @@ describe("Crypto", function () {
beforeEach(async () => {
ensureOlmSessionsForDevices = jest.spyOn(olmlib, "ensureOlmSessionsForDevices");
ensureOlmSessionsForDevices.mockResolvedValue({});
ensureOlmSessionsForDevices.mockResolvedValue(new Map());
encryptMessageForDevice = jest.spyOn(olmlib, "encryptMessageForDevice");
encryptMessageForDevice.mockImplementation(async (...[result, , , , , , payload]) => {
result.plaintext = { type: 0, body: JSON.stringify(payload) };

View File

@ -34,6 +34,7 @@ import { ClientEvent, MatrixClient, RoomMember } from "../../../../src";
import { DeviceInfo, IDevice } from "../../../../src/crypto/deviceinfo";
import { DeviceTrustLevel } from "../../../../src/crypto/CrossSigning";
import { MegolmEncryption as MegolmEncryptionClass } from "../../../../src/crypto/algorithms/megolm";
import { recursiveMapToObject } from "../../../../src/utils";
import { sleep } from "../../../../src/utils";
const MegolmDecryption = algorithms.DECRYPTION_CLASSES.get("m.megolm.v1.aes-sha2")!;
@ -183,14 +184,22 @@ describe("MegolmDecryption", function () {
const deviceInfo = {} as DeviceInfo;
mockCrypto.getStoredDevice.mockReturnValue(deviceInfo);
mockOlmLib.ensureOlmSessionsForDevices.mockResolvedValue({
"@alice:foo": {
alidevice: {
mockOlmLib.ensureOlmSessionsForDevices.mockResolvedValue(
new Map([
[
"@alice:foo",
new Map([
[
"alidevice",
{
sessionId: "alisession",
device: new DeviceInfo("alidevice"),
},
},
});
],
]),
],
]),
);
const awaitEncryptForDevice = new Promise<void>((res, rej) => {
mockOlmLib.encryptMessageForDevice.mockImplementation(() => {
@ -357,11 +366,7 @@ describe("MegolmDecryption", function () {
} as unknown as DeviceInfo;
mockCrypto.downloadKeys.mockReturnValue(
Promise.resolve({
"@alice:home.server": {
aliceDevice: aliceDeviceInfo,
},
}),
Promise.resolve(new Map([["@alice:home.server", new Map([["aliceDevice", aliceDeviceInfo]])]])),
);
mockCrypto.checkDeviceTrust.mockReturnValue({
@ -523,23 +528,32 @@ describe("MegolmDecryption", function () {
let megolm: MegolmEncryptionClass;
let room: jest.Mocked<Room>;
const deviceMap: DeviceInfoMap = {
"user-a": {
"device-a": new DeviceInfo("device-a"),
"device-b": new DeviceInfo("device-b"),
"device-c": new DeviceInfo("device-c"),
},
"user-b": {
"device-d": new DeviceInfo("device-d"),
"device-e": new DeviceInfo("device-e"),
"device-f": new DeviceInfo("device-f"),
},
"user-c": {
"device-g": new DeviceInfo("device-g"),
"device-h": new DeviceInfo("device-h"),
"device-i": new DeviceInfo("device-i"),
},
};
const deviceMap: DeviceInfoMap = new Map([
[
"user-a",
new Map([
["device-a", new DeviceInfo("device-a")],
["device-b", new DeviceInfo("device-b")],
["device-c", new DeviceInfo("device-c")],
]),
],
[
"user-b",
new Map([
["device-d", new DeviceInfo("device-d")],
["device-e", new DeviceInfo("device-e")],
["device-f", new DeviceInfo("device-f")],
]),
],
[
"user-c",
new Map([
["device-g", new DeviceInfo("device-g")],
["device-h", new DeviceInfo("device-h")],
["device-i", new DeviceInfo("device-i")],
]),
],
]);
beforeEach(() => {
room = testUtils.mock(Room, "Room") as jest.Mocked<Room>;
@ -572,8 +586,8 @@ describe("MegolmDecryption", function () {
//@ts-ignore private member access, gross
await megolm.encryptionPreparation?.promise;
for (const userId in deviceMap) {
for (const deviceId in deviceMap[userId]) {
for (const [userId, devices] of deviceMap) {
for (const deviceId of devices.keys()) {
expect(mockCrypto.checkDeviceTrust).toHaveBeenCalledWith(userId, deviceId);
}
}
@ -658,20 +672,20 @@ describe("MegolmDecryption", function () {
expect(aliceClient.sendToDevice).toHaveBeenCalled();
const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0];
expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/);
delete contentMap["@bob:example.com"].bobdevice1.session_id;
delete contentMap["@bob:example.com"].bobdevice1["org.matrix.msgid"];
delete contentMap["@bob:example.com"].bobdevice2.session_id;
delete contentMap["@bob:example.com"].bobdevice2["org.matrix.msgid"];
expect(contentMap).toStrictEqual({
"@bob:example.com": {
bobdevice1: {
delete contentMap.get("@bob:example.com")?.get("bobdevice1")?.["session_id"];
delete contentMap.get("@bob:example.com")?.get("bobdevice1")?.["org.matrix.msgid"];
delete contentMap.get("@bob:example.com")?.get("bobdevice2")?.["session_id"];
delete contentMap.get("@bob:example.com")?.get("bobdevice2")?.["org.matrix.msgid"];
expect(recursiveMapToObject(contentMap)).toStrictEqual({
["@bob:example.com"]: {
["bobdevice1"]: {
algorithm: "m.megolm.v1.aes-sha2",
room_id: roomId,
code: "m.unverified",
reason: "The sender has disabled encrypting to unverified devices.",
sender_key: aliceDevice.deviceCurve25519Key,
},
bobdevice2: {
["bobdevice2"]: {
algorithm: "m.megolm.v1.aes-sha2",
room_id: roomId,
code: "m.blacklisted",
@ -839,10 +853,10 @@ describe("MegolmDecryption", function () {
expect(aliceClient.sendToDevice).toHaveBeenCalled();
const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0];
expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/);
delete contentMap["@bob:example.com"]["bobdevice"]["org.matrix.msgid"];
expect(contentMap).toStrictEqual({
"@bob:example.com": {
bobdevice: {
delete contentMap.get("@bob:example.com")?.get("bobdevice")?.["org.matrix.msgid"];
expect(recursiveMapToObject(contentMap)).toStrictEqual({
["@bob:example.com"]: {
["bobdevice"]: {
algorithm: "m.megolm.v1.aes-sha2",
code: "m.no_olm",
reason: "Unable to establish a secure channel.",

View File

@ -146,8 +146,10 @@ describe("OlmDevice", function () {
});
},
} as unknown as MockedObject<MatrixClient>;
const devicesByUser = {
"@bob:example.com": [
const devicesByUser = new Map([
[
"@bob:example.com",
[
DeviceInfo.fromStorage(
{
keys: {
@ -157,7 +159,8 @@ describe("OlmDevice", function () {
"ABCDEFG",
),
],
};
],
]);
// start two tasks that try to ensure that there's an olm session
const promises = Promise.all([
@ -218,12 +221,8 @@ describe("OlmDevice", function () {
// There's no required ordering of devices per user, so here we
// create two different orderings so that each task reserves a
// device the other task needs before continuing.
const devicesByUserAB = {
"@bob:example.com": [deviceBobA, deviceBobB],
};
const devicesByUserBA = {
"@bob:example.com": [deviceBobB, deviceBobA],
};
const devicesByUserAB = new Map([["@bob:example.com", [deviceBobA, deviceBobB]]]);
const devicesByUserBA = new Map([["@bob:example.com", [deviceBobB, deviceBobA]]]);
const task1 = alwaysSucceed(olmlib.ensureOlmSessionsForDevices(aliceOlmDevice, baseApis, devicesByUserAB));

View File

@ -45,7 +45,7 @@ async function makeTestClient(
await client.initCrypto();
// No need to download keys for these tests
jest.spyOn(client.crypto!, "downloadKeys").mockResolvedValue({});
jest.spyOn(client.crypto!, "downloadKeys").mockResolvedValue(new Map());
return client;
}
@ -274,7 +274,7 @@ describe("Secrets", function () {
Object.values(otks)[0],
);
osborne2.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({});
osborne2.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map());
osborne2.client.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com";
const request = await secretStorage.request("foo", ["VAX"]);

View File

@ -121,12 +121,12 @@ describe("SAS verification", function () {
alice.client.crypto!.deviceList.storeDevicesForUser("@bob:example.com", BOB_DEVICES);
alice.client.downloadKeys = () => {
return Promise.resolve({});
return Promise.resolve(new Map());
};
bob.client.crypto!.deviceList.storeDevicesForUser("@alice:example.com", ALICE_DEVICES);
bob.client.downloadKeys = () => {
return Promise.resolve({});
return Promise.resolve(new Map());
};
aliceSasEvent = null;
@ -176,6 +176,7 @@ describe("SAS verification", function () {
}
});
});
afterEach(async () => {
await Promise.all([alice.stop(), bob.stop()]);
@ -186,10 +187,14 @@ describe("SAS verification", function () {
let macMethod;
let keyAgreement;
const origSendToDevice = bob.client.sendToDevice.bind(bob.client);
bob.client.sendToDevice = function (type, map) {
bob.client.sendToDevice = async (type, map) => {
if (type === "m.key.verification.accept") {
macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code;
keyAgreement = map[alice.client.getUserId()!][alice.client.deviceId!].key_agreement_protocol;
macMethod = map
.get(alice.client.getUserId()!)
?.get(alice.client.deviceId!)?.message_authentication_code;
keyAgreement = map
.get(alice.client.getUserId()!)
?.get(alice.client.deviceId!)?.key_agreement_protocol;
}
return origSendToDevice(type, map);
};
@ -237,7 +242,7 @@ describe("SAS verification", function () {
// has, since it is the same object. If this does not
// happen, the verification will fail due to a hash
// commitment mismatch.
map[bob.client.getUserId()!][bob.client.deviceId!].message_authentication_codes = [
map.get(bob.client.getUserId()!)!.get(bob.client.deviceId!)!.message_authentication_codes = [
"hkdf-hmac-sha256",
];
}
@ -246,7 +251,9 @@ describe("SAS verification", function () {
const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client);
bob.client.sendToDevice = (type, map) => {
if (type === "m.key.verification.accept") {
macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code;
macMethod = map
.get(alice.client.getUserId()!)!
.get(alice.client.deviceId!)!.message_authentication_code;
}
return bobOrigSendToDevice(type, map);
};
@ -291,14 +298,18 @@ describe("SAS verification", function () {
// has, since it is the same object. If this does not
// happen, the verification will fail due to a hash
// commitment mismatch.
map[bob.client.getUserId()!][bob.client.deviceId!].message_authentication_codes = ["hmac-sha256"];
map.get(bob.client.getUserId()!)!.get(bob.client.deviceId!)!.message_authentication_codes = [
"hmac-sha256",
];
}
return aliceOrigSendToDevice(type, map);
};
const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client);
bob.client.sendToDevice = (type, map) => {
if (type === "m.key.verification.accept") {
macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code;
macMethod = map
.get(alice.client.getUserId()!)!
.get(alice.client.deviceId!)!.message_authentication_code;
}
return bobOrigSendToDevice(type, map);
};
@ -454,7 +465,7 @@ describe("SAS verification", function () {
);
};
alice.client.downloadKeys = () => {
return Promise.resolve({});
return Promise.resolve(new Map());
};
bob.client.crypto!.setDeviceVerification = jest.fn();
@ -472,7 +483,7 @@ describe("SAS verification", function () {
return "bob+base64+ed25519+key";
};
bob.client.downloadKeys = () => {
return Promise.resolve({});
return Promise.resolve(new Map());
};
aliceSasEvent = null;

View File

@ -20,7 +20,7 @@ import { IContent, MatrixEvent } from "../../../../src/models/event";
import { IRoomTimelineData } from "../../../../src/models/event-timeline-set";
import { Room, RoomEvent } from "../../../../src/models/room";
import { logger } from "../../../../src/logger";
import { MatrixClient, ClientEvent, ICreateClientOpts } from "../../../../src/client";
import { MatrixClient, ClientEvent, ICreateClientOpts, SendToDeviceContentMap } from "../../../../src/client";
interface UserInfo {
userId: string;
@ -36,16 +36,16 @@ export async function makeTestClients(
const clientMap: Record<string, Record<string, MatrixClient>> = {};
const makeSendToDevice =
(matrixClient: MatrixClient): MatrixClient["sendToDevice"] =>
async (type, map) => {
async (type: string, contentMap: SendToDeviceContentMap) => {
// logger.log(this.getUserId(), "sends", type, map);
for (const [userId, devMap] of Object.entries(map)) {
for (const [userId, deviceMessages] of contentMap) {
if (userId in clientMap) {
for (const [deviceId, msg] of Object.entries(devMap)) {
for (const [deviceId, message] of deviceMessages) {
if (deviceId in clientMap[userId]) {
const event = new MatrixEvent({
sender: matrixClient.getUserId()!,
type: type,
content: msg,
content: message,
});
const client = clientMap[userId][deviceId];
const decryptionPromise = event.isEncrypted()

View File

@ -25,6 +25,7 @@ import { IContent, MatrixEvent } from "../../../../src/models/event";
import { MatrixClient } from "../../../../src/client";
import { IVerificationChannel } from "../../../../src/crypto/verification/request/Channel";
import { VerificationBase } from "../../../../src/crypto/verification/Base";
import { MapWithDefault } from "../../../../src/utils";
type MockClient = MatrixClient & {
popEvents: () => MatrixEvent[];
@ -33,7 +34,9 @@ type MockClient = MatrixClient & {
function makeMockClient(userId: string, deviceId: string): MockClient {
let counter = 1;
let events: MatrixEvent[] = [];
const deviceEvents: Record<string, Record<string, MatrixEvent[]>> = {};
const deviceEvents: MapWithDefault<string, MapWithDefault<string, MatrixEvent[]>> = new MapWithDefault(
() => new MapWithDefault(() => []),
);
return {
getUserId() {
return userId;
@ -58,15 +61,11 @@ function makeMockClient(userId: string, deviceId: string): MockClient {
return Promise.resolve({ event_id: eventId });
},
sendToDevice(type: string, msgMap: Record<string, Record<string, IContent>>) {
for (const userId of Object.keys(msgMap)) {
const deviceMap = msgMap[userId];
for (const deviceId of Object.keys(deviceMap)) {
const content = deviceMap[deviceId];
sendToDevice(type: string, msgMap: Map<string, Map<string, IContent>>) {
for (const [userId, deviceMessages] of msgMap) {
for (const [deviceId, content] of deviceMessages) {
const event = new MatrixEvent({ content, type });
deviceEvents[userId] = deviceEvents[userId] || {};
deviceEvents[userId][deviceId] = deviceEvents[userId][deviceId] || [];
deviceEvents[userId][deviceId].push(event);
deviceEvents.getOrCreate(userId).getOrCreate(deviceId).push(event);
}
}
return Promise.resolve({});
@ -79,14 +78,9 @@ function makeMockClient(userId: string, deviceId: string): MockClient {
return e;
},
// @ts-ignore special testing fn
popDeviceEvents(userId: string, deviceId: string): MatrixEvent[] {
const forDevice = deviceEvents[userId];
const events = forDevice && forDevice[deviceId];
const result = events || [];
if (events) {
delete forDevice[deviceId];
}
const result = deviceEvents.get(userId)?.get(deviceId) || [];
deviceEvents?.get(userId)?.delete(deviceId);
return result;
},
} as unknown as MockClient;

View File

@ -204,9 +204,14 @@ describe("RoomWidgetClient", () => {
});
describe("to-device messages", () => {
const unencryptedContentMap = {
"@alice:example.org": { "*": { hello: "alice!" } },
"@bob:example.org": { bobDesktop: { hello: "bob!" } },
const unencryptedContentMap = new Map([
["@alice:example.org", new Map([["*", { hello: "alice!" }]])],
["@bob:example.org", new Map([["bobDesktop", { hello: "bob!" }]])],
]);
const expectedRequestData = {
["@alice:example.org"]: { ["*"]: { hello: "alice!" } },
["@bob:example.org"]: { ["bobDesktop"]: { hello: "bob!" } },
};
it("sends unencrypted (sendToDevice)", async () => {
@ -214,7 +219,7 @@ describe("RoomWidgetClient", () => {
expect(widgetApi.requestCapabilityToSendToDevice).toHaveBeenCalledWith("org.example.foo");
await client.sendToDevice("org.example.foo", unencryptedContentMap);
expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, unencryptedContentMap);
expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, expectedRequestData);
});
it("sends unencrypted (queueToDevice)", async () => {
@ -229,7 +234,7 @@ describe("RoomWidgetClient", () => {
],
};
await client.queueToDevice(batch);
expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, unencryptedContentMap);
expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, expectedRequestData);
});
it("sends encrypted (encryptAndSendToDevices)", async () => {

View File

@ -59,7 +59,7 @@ describe("MemoryStore", () => {
await store.deleteAllData();
// empty object
expect(store.accountData).toEqual({});
expect(store.accountData).toEqual(new Map());
});
});
});

View File

@ -24,9 +24,12 @@ import {
lexicographicCompare,
nextString,
prevString,
recursiveMapToObject,
simpleRetryOperation,
stringToBase,
sortEventsByLatestContentTimestamp,
safeSet,
MapWithDefault,
} from "../../src/utils";
import { logger } from "../../src/logger";
import { mkMessage } from "../test-utils/test-utils";
@ -606,6 +609,105 @@ describe("utils", function () {
});
});
describe("recursiveMapToObject", () => {
it.each([
// empty map
{
map: new Map(),
expected: {},
},
// one level map
{
map: new Map<any, any>([
["key1", "value 1"],
["key2", 23],
["key3", undefined],
["key4", null],
["key5", [1, 2, 3]],
]),
expected: { key1: "value 1", key2: 23, key3: undefined, key4: null, key5: [1, 2, 3] },
},
// two level map
{
map: new Map<any, any>([
[
"key1",
new Map<any, any>([
["key1_1", "value 1"],
["key1_2", "value 1.2"],
]),
],
["key2", "value 2"],
]),
expected: { key1: { key1_1: "value 1", key1_2: "value 1.2" }, key2: "value 2" },
},
// multi level map
{
map: new Map<any, any>([
["key1", new Map<any, any>([["key1_1", new Map<any, any>([["key1_1_1", "value 1.1.1"]])]])],
]),
expected: { key1: { key1_1: { key1_1_1: "value 1.1.1" } } },
},
// list of maps
{
map: new Map<any, any>([
[
"key1",
[new Map<any, any>([["key1_1", "value 1.1"]]), new Map<any, any>([["key1_2", "value 1.2"]])],
],
]),
expected: { key1: [{ key1_1: "value 1.1" }, { key1_2: "value 1.2" }] },
},
// map → array → array → map
{
map: new Map<any, any>([["key1", [[new Map<any, any>([["key2", "value 2"]])]]]]),
expected: {
key1: [
[
{
key2: "value 2",
},
],
],
},
},
])("%# should convert the value", ({ map, expected }) => {
expect(recursiveMapToObject(map)).toStrictEqual(expected);
});
});
describe("safeSet", () => {
it("should set a value", () => {
const obj = {};
safeSet(obj, "testProp", "test value");
expect(obj).toEqual({ testProp: "test value" });
});
it.each(["__proto__", "prototype", "constructor"])("should raise an error when setting »%s«", (prop) => {
expect(() => {
safeSet({}, prop, "teset value");
}).toThrow("Trying to modify prototype or constructor");
});
});
describe("MapWithDefault", () => {
it("getOrCreate should create the value if it does not exist", () => {
const newValue = {};
const map = new MapWithDefault(() => newValue);
// undefined before getOrCreate
expect(map.get("test")).toBeUndefined();
expect(map.getOrCreate("test")).toBe(newValue);
// default value after getOrCreate
expect(map.get("test")).toBe(newValue);
// test that it always returns the same value
expect(map.getOrCreate("test")).toBe(newValue);
});
});
describe("sleep", () => {
it("resolves", async () => {
await utils.sleep(0);

View File

@ -688,15 +688,15 @@ describe("Group Call", function () {
expect(client1.sendToDevice.mock.calls[0][0]).toBe("m.call.invite");
const toDeviceCallContent = client1.sendToDevice.mock.calls[0][1];
expect(Object.keys(toDeviceCallContent).length).toBe(1);
expect(Object.keys(toDeviceCallContent)[0]).toBe(FAKE_USER_ID_2);
expect(toDeviceCallContent.size).toBe(1);
expect(toDeviceCallContent.has(FAKE_USER_ID_2)).toBe(true);
const toDeviceBobDevices = toDeviceCallContent[FAKE_USER_ID_2];
expect(Object.keys(toDeviceBobDevices).length).toBe(1);
expect(Object.keys(toDeviceBobDevices)[0]).toBe(FAKE_DEVICE_ID_2);
const toDeviceBobDevices = toDeviceCallContent.get(FAKE_USER_ID_2);
expect(toDeviceBobDevices?.size).toBe(1);
expect(toDeviceBobDevices?.has(FAKE_DEVICE_ID_2)).toBe(true);
const bobDeviceMessage = toDeviceBobDevices[FAKE_DEVICE_ID_2];
expect(bobDeviceMessage.conf_id).toBe(FAKE_CONF_ID);
const bobDeviceMessage = toDeviceBobDevices?.get(FAKE_DEVICE_ID_2);
expect(bobDeviceMessage?.conf_id).toBe(FAKE_CONF_ID);
} finally {
await Promise.all([groupCall1.leave(), groupCall2.leave()]);
}

View File

@ -38,7 +38,7 @@ export interface CachedReceipt {
data: Receipt;
}
export type ReceiptCache = { [eventId: string]: CachedReceipt[] };
export type ReceiptCache = Map<string, CachedReceipt[]>;
export interface ReceiptContent {
[eventId: string]: {
@ -49,11 +49,8 @@ export interface ReceiptContent {
}
// We will only hold a synthetic receipt if we do not have a real receipt or the synthetic is newer.
export type Receipts = {
[receiptType: string]: {
[userId: string]: [WrappedReceipt | null, WrappedReceipt | null]; // Pair<real receipt, synthetic receipt> (both nullable)
};
};
// map: receipt type → user Id → receipt
export type Receipts = Map<string, Map<string, [real: WrappedReceipt | null, synthetic: WrappedReceipt | null]>>;
export type CachedReceiptStructure = {
eventId: string;

View File

@ -21,6 +21,7 @@ import { MatrixError } from "./http-api";
import { IndexedToDeviceBatch, ToDeviceBatch, ToDeviceBatchWithTxnId, ToDevicePayload } from "./models/ToDeviceMessage";
import { MatrixScheduler } from "./scheduler";
import { SyncState } from "./sync";
import { MapWithDefault } from "./utils";
const MAX_BATCH_SIZE = 20;
@ -122,12 +123,9 @@ export class ToDeviceMessageQueue {
* Attempts to send a batch of to-device messages.
*/
private async sendBatch(batch: IndexedToDeviceBatch): Promise<void> {
const contentMap: Record<string, Record<string, ToDevicePayload>> = {};
const contentMap: MapWithDefault<string, Map<string, ToDevicePayload>> = new MapWithDefault(() => new Map());
for (const item of batch.batch) {
if (!contentMap[item.userId]) {
contentMap[item.userId] = {};
}
contentMap[item.userId][item.deviceId] = item.payload;
contentMap.getOrCreate(item.userId).set(item.deviceId, item.payload);
}
logger.info(

View File

@ -37,7 +37,7 @@ import { Filter, IFilterDefinition, IRoomEventFilter } from "./filter";
import { CallEventHandlerEvent, CallEventHandler, CallEventHandlerEventHandlerMap } from "./webrtc/callEventHandler";
import { GroupCallEventHandlerEvent, GroupCallEventHandlerEventHandlerMap } from "./webrtc/groupCallEventHandler";
import * as utils from "./utils";
import { replaceParam, QueryDict, sleep } from "./utils";
import { replaceParam, QueryDict, sleep, noUnsafeEventProps } from "./utils";
import { Direction, EventTimeline } from "./models/event-timeline";
import { IActionsObject, PushProcessor } from "./pushprocessor";
import { AutoDiscovery, AutoDiscoveryAction } from "./autodiscovery";
@ -79,7 +79,7 @@ import {
VerificationMethod,
IRoomKeyRequestBody,
} from "./crypto";
import { DeviceInfo, IDevice } from "./crypto/deviceinfo";
import { DeviceInfo } from "./crypto/deviceinfo";
import { decodeRecoveryKey } from "./crypto/recoverykey";
import { keyFromAuthData } from "./crypto/key_passphrase";
import { User, UserEvent, UserEventHandlerMap } from "./models/user";
@ -207,6 +207,7 @@ import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature";
import { CryptoBackend } from "./common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants";
import { CryptoApi } from "./crypto-api";
import { DeviceInfoMap } from "./crypto/DeviceList";
export type Store = IStore;
@ -511,6 +512,8 @@ enum CrossSigningKeyType {
export type CrossSigningKeys = Record<CrossSigningKeyType, ICrossSigningKey>;
export type SendToDeviceContentMap = Map<string, Map<string, Record<string, any>>>;
export interface ISignedKey {
keys: Record<string, string>;
signatures: ISignatures;
@ -2268,7 +2271,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
*
* @returns A promise which resolves to a map userId-\>deviceId-\>{@link DeviceInfo}
*/
public downloadKeys(userIds: string[], forceDownload?: boolean): Promise<Record<string, Record<string, IDevice>>> {
public downloadKeys(userIds: string[], forceDownload?: boolean): Promise<DeviceInfoMap> {
if (!this.crypto) {
return Promise.reject(new Error("End-to-end encryption disabled"));
}
@ -3818,9 +3821,9 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
}
const deviceInfos = await this.crypto.downloadKeys(userIds);
const devicesByUser: Record<string, DeviceInfo[]> = {};
for (const [userId, devices] of Object.entries(deviceInfos)) {
devicesByUser[userId] = Object.values(devices);
const devicesByUser: Map<string, DeviceInfo[]> = new Map();
for (const [userId, devices] of deviceInfos) {
devicesByUser.set(userId, Array.from(devices.values()));
}
// XXX: Private member access
@ -6035,6 +6038,8 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
const token = res.next_token;
const matrixEvents: MatrixEvent[] = [];
res.notifications = res.notifications.filter(noUnsafeEventProps);
for (let i = 0; i < res.notifications.length; i++) {
const notification = res.notifications[i];
const event = this.getEventMapper()(notification.event);
@ -6081,11 +6086,11 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
.then((res) => {
if (res.state) {
const roomState = eventTimeline.getState(dir)!;
const stateEvents = res.state.map(this.getEventMapper());
const stateEvents = res.state.filter(noUnsafeEventProps).map(this.getEventMapper());
roomState.setUnknownStateEvents(stateEvents);
}
const token = res.end;
const matrixEvents = res.chunk.map(this.getEventMapper());
const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(this.getEventMapper());
const timelineSet = eventTimeline.getTimelineSet();
timelineSet.addEventsToTimeline(matrixEvents, backwards, eventTimeline, token);
@ -6117,7 +6122,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
})
.then(async (res) => {
const mapper = this.getEventMapper();
const matrixEvents = res.chunk.map(mapper);
const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(mapper);
// Process latest events first
for (const event of matrixEvents.slice().reverse()) {
@ -6165,11 +6170,11 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
.then((res) => {
if (res.state) {
const roomState = eventTimeline.getState(dir)!;
const stateEvents = res.state.map(this.getEventMapper());
const stateEvents = res.state.filter(noUnsafeEventProps).map(this.getEventMapper());
roomState.setUnknownStateEvents(stateEvents);
}
const token = res.end;
const matrixEvents = res.chunk.map(this.getEventMapper());
const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(this.getEventMapper());
const timelineSet = eventTimeline.getTimelineSet();
const [timelineEvents] = room.partitionThreadedEvents(matrixEvents);
@ -9187,24 +9192,22 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
* supplied.
* @returns Promise which resolves: to an empty object `{}`
*/
public sendToDevice(
eventType: string,
contentMap: { [userId: string]: { [deviceId: string]: Record<string, any> } },
txnId?: string,
): Promise<{}> {
public sendToDevice(eventType: string, contentMap: SendToDeviceContentMap, txnId?: string): Promise<{}> {
const path = utils.encodeUri("/sendToDevice/$eventType/$txnId", {
$eventType: eventType,
$txnId: txnId ? txnId : this.makeTxnId(),
});
const body = {
messages: contentMap,
messages: utils.recursiveMapToObject(contentMap),
};
const targets = Object.keys(contentMap).reduce<Record<string, string[]>>((obj, key) => {
obj[key] = Object.keys(contentMap[key]);
return obj;
}, {});
const targets = new Map<string, string[]>();
for (const [userId, deviceMessages] of contentMap) {
targets.set(userId, Array.from(deviceMessages.keys()));
}
logger.log(`PUT ${path}`, targets);
return this.http.authedRequest(Method.Put, path, undefined, body);

View File

@ -58,7 +58,8 @@ export enum TrackingStatus {
UpToDate,
}
export type DeviceInfoMap = Record<string, Record<string, DeviceInfo>>;
// user-Id → device-Id → DeviceInfo
export type DeviceInfoMap = Map<string, Map<string, DeviceInfo>>;
type EmittedEvents = CryptoEvent.WillUpdateDevices | CryptoEvent.DevicesUpdated | CryptoEvent.UserCrossSigningUpdated;
@ -301,13 +302,13 @@ export class DeviceList extends TypedEventEmitter<EmittedEvents, CryptoEventHand
* @returns userId-\>deviceId-\>{@link DeviceInfo}.
*/
private getDevicesFromStore(userIds: string[]): DeviceInfoMap {
const stored: DeviceInfoMap = {};
userIds.forEach((u) => {
stored[u] = {};
const devices = this.getStoredDevicesForUser(u) || [];
devices.forEach(function (dev) {
stored[u][dev.deviceId] = dev;
const stored: DeviceInfoMap = new Map();
userIds.forEach((userId) => {
const deviceMap = new Map();
this.getStoredDevicesForUser(userId)?.forEach(function (device) {
deviceMap.set(device.deviceId, device);
});
stored.set(userId, deviceMap);
});
return stored;
}

View File

@ -61,7 +61,7 @@ export class EncryptionSetupBuilder {
* @param accountData - pre-existing account data, will only be read, not written.
* @param delegateCryptoCallbacks - crypto callbacks to delegate to if the key isn't in cache yet
*/
public constructor(accountData: Record<string, MatrixEvent>, delegateCryptoCallbacks?: ICryptoCallbacks) {
public constructor(accountData: Map<string, MatrixEvent>, delegateCryptoCallbacks?: ICryptoCallbacks) {
this.accountDataClientAdapter = new AccountDataClientAdapter(accountData);
this.crossSigningCallbacks = new CrossSigningCallbacks();
this.ssssCryptoCallbacks = new SSSSCryptoCallbacks(delegateCryptoCallbacks);
@ -246,7 +246,7 @@ class AccountDataClientAdapter
/**
* @param existingValues - existing account data
*/
public constructor(private readonly existingValues: Record<string, MatrixEvent>) {
public constructor(private readonly existingValues: Map<string, MatrixEvent>) {
super();
}
@ -265,7 +265,7 @@ class AccountDataClientAdapter
if (modifiedValue) {
return modifiedValue;
}
const existingValue = this.existingValues[type];
const existingValue = this.existingValues.get(type);
if (existingValue) {
return existingValue.getContent();
}

View File

@ -21,6 +21,7 @@ import { MatrixClient } from "../client";
import { IRoomKeyRequestBody, IRoomKeyRequestRecipient } from "./index";
import { CryptoStore, OutgoingRoomKeyRequest } from "./store/base";
import { EventType, ToDeviceMessageId } from "../@types/event";
import { MapWithDefault } from "../utils";
/**
* Internal module. Management of outgoing room key requests.
@ -460,15 +461,13 @@ export class OutgoingRoomKeyRequestManager {
recipients: IRoomKeyRequestRecipient[],
txnId?: string,
): Promise<{}> {
const contentMap: Record<string, Record<string, Record<string, any>>> = {};
const contentMap = new MapWithDefault<string, Map<string, Record<string, any>>>(() => new Map());
for (const recip of recipients) {
if (!contentMap[recip.userId]) {
contentMap[recip.userId] = {};
}
contentMap[recip.userId][recip.deviceId] = {
const userDeviceMap = contentMap.getOrCreate(recip.userId);
userDeviceMap.set(recip.deviceId, {
...message,
[ToDeviceMessageId]: uuidv4(),
};
});
}
return this.baseApis.sendToDevice(EventType.RoomKeyRequest, contentMap, txnId);

View File

@ -367,13 +367,11 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
requesting_device_id: this.baseApis.deviceId,
request_id: requestId,
};
const toDevice: Record<string, typeof cancelData> = {};
const toDevice: Map<string, typeof cancelData> = new Map();
for (const device of devices) {
toDevice[device] = cancelData;
toDevice.set(device, cancelData);
}
this.baseApis.sendToDevice("m.secret.request", {
[this.baseApis.getUserId()!]: toDevice,
});
this.baseApis.sendToDevice("m.secret.request", new Map([[this.baseApis.getUserId()!, toDevice]]));
// and reject the promise so that anyone waiting on it will be
// notified
@ -388,14 +386,12 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
request_id: requestId,
[ToDeviceMessageId]: uuidv4(),
};
const toDevice: Record<string, typeof requestData> = {};
const toDevice: Map<string, typeof requestData> = new Map();
for (const device of devices) {
toDevice[device] = requestData;
toDevice.set(device, requestData);
}
logger.info(`Request secret ${name} from ${devices}, id ${requestId}`);
this.baseApis.sendToDevice("m.secret.request", {
[this.baseApis.getUserId()!]: toDevice,
});
this.baseApis.sendToDevice("m.secret.request", new Map([[this.baseApis.getUserId()!, toDevice]]));
return {
requestId,
@ -469,9 +465,11 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
ciphertext: {},
[ToDeviceMessageId]: uuidv4(),
};
await olmlib.ensureOlmSessionsForDevices(this.baseApis.crypto!.olmDevice, this.baseApis, {
[sender]: [this.baseApis.getStoredDevice(sender, deviceId)!],
});
await olmlib.ensureOlmSessionsForDevices(
this.baseApis.crypto!.olmDevice,
this.baseApis,
new Map([[sender, [this.baseApis.getStoredDevice(sender, deviceId)!]]]),
);
await olmlib.encryptMessageForDevice(
encryptedContent.ciphertext,
this.baseApis.getUserId()!,
@ -481,11 +479,7 @@ export class SecretStorage<B extends MatrixClient | undefined = MatrixClient> {
this.baseApis.getStoredDevice(sender, deviceId)!,
payload,
);
const contentMap = {
[sender]: {
[deviceId]: encryptedContent,
},
};
const contentMap = new Map([[sender, new Map([[deviceId, encryptedContent]])]]);
logger.info(`Sending ${content.name} secret for ${deviceId}`);
this.baseApis.sendToDevice("m.room.encrypted", contentMap);

View File

@ -26,6 +26,7 @@ import { IContent, MatrixEvent, RoomMember } from "../../matrix";
import { Crypto, IEncryptedContent, IEventDecryptionResult, IncomingRoomKeyRequest } from "..";
import { DeviceInfo } from "../deviceinfo";
import { IRoomEncryption } from "../RoomList";
import { DeviceInfoMap } from "../DeviceList";
/**
* Map of registered encryption algorithm classes. A map from string to {@link EncryptionAlgorithm} class
@ -195,7 +196,7 @@ export abstract class DecryptionAlgorithm {
}
public onRoomKeyWithheldEvent?(event: MatrixEvent): Promise<void>;
public sendSharedHistoryInboundSessions?(devicesByUser: Record<string, DeviceInfo[]>): Promise<void>;
public sendSharedHistoryInboundSessions?(devicesByUser: Map<string, DeviceInfo[]>): Promise<void>;
}
/**
@ -241,11 +242,7 @@ export class UnknownDeviceError extends Error {
* @param msg - message describing the problem
* @param devices - set of unknown devices per user we're warning about
*/
public constructor(
msg: string,
public readonly devices: Record<string, Record<string, object>>,
public event?: MatrixEvent,
) {
public constructor(msg: string, public readonly devices: DeviceInfoMap, public event?: MatrixEvent) {
super(msg);
this.name = "UnknownDeviceError";
this.devices = devices;

View File

@ -43,7 +43,7 @@ import { IMegolmEncryptedContent, IncomingRoomKeyRequest, IEncryptedContent } fr
import { RoomKeyRequestState } from "../OutgoingRoomKeyRequestManager";
import { OlmGroupSessionExtraData } from "../../@types/crypto";
import { MatrixError } from "../../http-api";
import { immediate } from "../../utils";
import { immediate, MapWithDefault } from "../../utils";
// determine whether the key can be shared with invitees
export function isRoomSharedHistory(room: Room): boolean {
@ -63,17 +63,27 @@ interface IBlockedDevice {
deviceInfo: DeviceInfo;
}
interface IBlockedMap {
[userId: string]: {
[deviceId: string]: IBlockedDevice;
};
}
// map user Id → device Id → IBlockedDevice
type BlockedMap = Map<string, Map<string, IBlockedDevice>>;
export interface IOlmDevice<T = DeviceInfo> {
userId: string;
deviceInfo: T;
}
/**
* Tests whether an encrypted content has a ciphertext.
* Ciphertext can be a string or object depending on the content type {@link IEncryptedContent}.
*
* @param content - Encrypted content
* @returns true: has ciphertext, else false
*/
const hasCiphertext = (content: IEncryptedContent): boolean => {
return typeof content.ciphertext === "string"
? !!content.ciphertext.length
: !!Object.keys(content.ciphertext).length;
};
/** The result of parsing the an `m.room_key` or `m.forwarded_room_key` to-device event */
interface RoomKey {
/**
@ -147,8 +157,8 @@ class OutboundSessionInfo {
/** when the session was created (ms since the epoch) */
public creationTime: number;
/** devices with which we have shared the session key `userId -> {deviceId -> SharedWithData}` */
public sharedWithDevices: Record<string, Record<string, SharedWithData>> = {};
public blockedDevicesNotified: Record<string, Record<string, boolean>> = {};
public sharedWithDevices: MapWithDefault<string, Map<string, SharedWithData>> = new MapWithDefault(() => new Map());
public blockedDevicesNotified: MapWithDefault<string, Map<string, boolean>> = new MapWithDefault(() => new Map());
/**
* @param sharedHistory - whether the session can be freely shared with
@ -173,17 +183,11 @@ class OutboundSessionInfo {
}
public markSharedWithDevice(userId: string, deviceId: string, deviceKey: string, chainIndex: number): void {
if (!this.sharedWithDevices[userId]) {
this.sharedWithDevices[userId] = {};
}
this.sharedWithDevices[userId][deviceId] = { deviceKey, messageIndex: chainIndex };
this.sharedWithDevices.getOrCreate(userId).set(deviceId, { deviceKey, messageIndex: chainIndex });
}
public markNotifiedBlockedDevice(userId: string, deviceId: string): void {
if (!this.blockedDevicesNotified[userId]) {
this.blockedDevicesNotified[userId] = {};
}
this.blockedDevicesNotified[userId][deviceId] = true;
this.blockedDevicesNotified.getOrCreate(userId).set(deviceId, true);
}
/**
@ -196,23 +200,15 @@ class OutboundSessionInfo {
* @returns true if we have shared the session with devices which aren't
* in devicesInRoom.
*/
public sharedWithTooManyDevices(devicesInRoom: Record<string, Record<string, object>>): boolean {
for (const userId in this.sharedWithDevices) {
if (!this.sharedWithDevices.hasOwnProperty(userId)) {
continue;
}
if (!devicesInRoom.hasOwnProperty(userId)) {
public sharedWithTooManyDevices(devicesInRoom: DeviceInfoMap): boolean {
for (const [userId, devices] of this.sharedWithDevices) {
if (!devicesInRoom.has(userId)) {
logger.log("Starting new megolm session because we shared with " + userId);
return true;
}
for (const deviceId in this.sharedWithDevices[userId]) {
if (!this.sharedWithDevices[userId].hasOwnProperty(deviceId)) {
continue;
}
if (!devicesInRoom[userId].hasOwnProperty(deviceId)) {
for (const [deviceId] of devices) {
if (!devicesInRoom.get(userId)?.get(deviceId)) {
logger.log("Starting new megolm session because we shared with " + userId + ":" + deviceId);
return true;
}
@ -292,7 +288,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
private async ensureOutboundSession(
room: Room,
devicesInRoom: DeviceInfoMap,
blocked: IBlockedMap,
blocked: BlockedMap,
singleOlmCreationPhase = false,
): Promise<OutboundSessionInfo> {
// takes the previous OutboundSessionInfo, and considers whether to create
@ -360,21 +356,21 @@ export class MegolmEncryption extends EncryptionAlgorithm {
devicesInRoom: DeviceInfoMap,
sharedHistory: boolean,
singleOlmCreationPhase: boolean,
blocked: IBlockedMap,
blocked: BlockedMap,
session: OutboundSessionInfo,
): Promise<void> {
// now check if we need to share with any devices
const shareMap: Record<string, DeviceInfo[]> = {};
for (const [userId, userDevices] of Object.entries(devicesInRoom)) {
for (const [deviceId, deviceInfo] of Object.entries(userDevices)) {
for (const [userId, userDevices] of devicesInRoom) {
for (const [deviceId, deviceInfo] of userDevices) {
const key = deviceInfo.getIdentityKey();
if (key == this.olmDevice.deviceCurve25519Key) {
// don't bother sending to ourself
continue;
}
if (!session.sharedWithDevices[userId] || session.sharedWithDevices[userId][deviceId] === undefined) {
if (!session.sharedWithDevices.get(userId)?.get(deviceId)) {
shareMap[userId] = shareMap[userId] || [];
shareMap[userId].push(deviceInfo);
}
@ -402,9 +398,9 @@ export class MegolmEncryption extends EncryptionAlgorithm {
await Promise.all([
(async (): Promise<void> => {
// share keys with devices that we already have a session for
const olmSessionList = Object.entries(olmSessions)
const olmSessionList = Array.from(olmSessions.entries())
.map(([userId, sessionsByUser]) =>
Object.entries(sessionsByUser).map(
Array.from(sessionsByUser.entries()).map(
([deviceId, session]) => `${userId}/${deviceId}: ${session.sessionId}`,
),
)
@ -414,7 +410,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
this.prefixedLogger.debug("Shared keys with existing Olm sessions");
})(),
(async (): Promise<void> => {
const deviceList = Object.entries(devicesWithoutSession)
const deviceList = Array.from(devicesWithoutSession.entries())
.map(([userId, devicesByUser]) => devicesByUser.map((device) => `${userId}/${device.deviceId}`))
.flat(1);
this.prefixedLogger.debug(
@ -450,7 +446,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
// do this in the background and don't block anything else while we
// do this. We only need to retry users from servers that didn't
// respond the first time.
const retryDevices: Record<string, DeviceInfo[]> = {};
const retryDevices: MapWithDefault<string, DeviceInfo[]> = new MapWithDefault(() => []);
const failedServerMap = new Set();
for (const server of failedServers) {
failedServerMap.add(server);
@ -459,8 +455,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
for (const { userId, deviceInfo } of errorDevices) {
const userHS = userId.slice(userId.indexOf(":") + 1);
if (failedServerMap.has(userHS)) {
retryDevices[userId] = retryDevices[userId] || [];
retryDevices[userId].push(deviceInfo);
retryDevices.getOrCreate(userId).push(deviceInfo);
} else {
// if we aren't going to retry, then handle it
// as a failed device
@ -468,7 +463,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
}
}
const retryDeviceList = Object.entries(retryDevices)
const retryDeviceList = Array.from(retryDevices.entries())
.map(([userId, devicesByUser]) =>
devicesByUser.map((device) => `${userId}/${device.deviceId}`),
)
@ -493,25 +488,25 @@ export class MegolmEncryption extends EncryptionAlgorithm {
})(),
(async (): Promise<void> => {
this.prefixedLogger.debug(
`There are ${Object.entries(blocked).length} blocked devices:`,
Object.entries(blocked)
`There are ${blocked.size} blocked devices:`,
Array.from(blocked.entries())
.map(([userId, blockedByUser]) =>
Object.entries(blockedByUser).map(([deviceId, _deviceInfo]) => `${userId}/${deviceId}`),
Array.from(blockedByUser.entries()).map(
([deviceId, _deviceInfo]) => `${userId}/${deviceId}`,
),
)
.flat(1),
);
// also, notify newly blocked devices that they're blocked
const blockedMap: Record<string, Record<string, { device: IBlockedDevice }>> = {};
const blockedMap: MapWithDefault<string, Map<string, { device: IBlockedDevice }>> = new MapWithDefault(
() => new Map(),
);
let blockedCount = 0;
for (const [userId, userBlockedDevices] of Object.entries(blocked)) {
for (const [deviceId, device] of Object.entries(userBlockedDevices)) {
if (
!session.blockedDevicesNotified[userId] ||
session.blockedDevicesNotified[userId][deviceId] === undefined
) {
blockedMap[userId] = blockedMap[userId] || {};
blockedMap[userId][deviceId] = { device };
for (const [userId, userBlockedDevices] of blocked) {
for (const [deviceId, device] of userBlockedDevices) {
if (session.blockedDevicesNotified.get(userId)?.get(deviceId) === undefined) {
blockedMap.getOrCreate(userId).set(deviceId, { device });
blockedCount++;
}
}
@ -520,7 +515,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
if (blockedCount) {
this.prefixedLogger.debug(
`Notifying ${blockedCount} newly blocked devices:`,
Object.entries(blockedMap)
Array.from(blockedMap.entries())
.map(([userId, blockedByUser]) =>
Object.entries(blockedByUser).map(([deviceId, _deviceInfo]) => `${userId}/${deviceId}`),
)
@ -566,7 +561,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
*
* @internal
*
* @param devicemap - the devices that have olm sessions, as returned by
* @param deviceMap - the devices that have olm sessions, as returned by
* olmlib.ensureOlmSessionsForDevices.
* @param devicesByUser - a map of user IDs to array of deviceInfo
* @param noOlmDevices - an array to fill with devices that don't have
@ -576,23 +571,23 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* noOlmDevices is specified, then noOlmDevices will be returned.
*/
private getDevicesWithoutSessions(
devicemap: Record<string, Record<string, IOlmSessionResult>>,
devicesByUser: Record<string, DeviceInfo[]>,
deviceMap: Map<string, Map<string, IOlmSessionResult>>,
devicesByUser: Map<string, DeviceInfo[]>,
noOlmDevices: IOlmDevice[] = [],
): IOlmDevice[] {
for (const [userId, devicesToShareWith] of Object.entries(devicesByUser)) {
const sessionResults = devicemap[userId];
for (const [userId, devicesToShareWith] of devicesByUser) {
const sessionResults = deviceMap.get(userId);
for (const deviceInfo of devicesToShareWith) {
const deviceId = deviceInfo.deviceId;
const sessionResult = sessionResults[deviceId];
if (!sessionResult.sessionId) {
const sessionResult = sessionResults?.get(deviceId);
if (!sessionResult?.sessionId) {
// no session with this device, probably because there
// were no one-time keys.
noOlmDevices.push({ userId, deviceInfo });
delete sessionResults[deviceId];
sessionResults?.delete(deviceId);
// ensureOlmSessionsForUsers has already done the logging,
// so just skip it.
@ -615,7 +610,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* @returns the blocked devices, split into chunks
*/
private splitDevices<T extends DeviceInfo | IBlockedDevice>(
devicesByUser: Record<string, Record<string, { device: T }>>,
devicesByUser: Map<string, Map<string, { device: T }>>,
): IOlmDevice<T>[][] {
const maxDevicesPerRequest = 20;
@ -623,8 +618,8 @@ export class MegolmEncryption extends EncryptionAlgorithm {
let currentSlice: IOlmDevice<T>[] = [];
const mapSlices = [currentSlice];
for (const [userId, userDevices] of Object.entries(devicesByUser)) {
for (const deviceInfo of Object.values(userDevices)) {
for (const [userId, userDevices] of devicesByUser) {
for (const deviceInfo of userDevices.values()) {
currentSlice.push({
userId: userId,
deviceInfo: deviceInfo.device,
@ -702,7 +697,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
userDeviceMap: IOlmDevice<IBlockedDevice>[],
payload: IPayload,
): Promise<void> {
const contentMap: Record<string, Record<string, IPayload>> = {};
const contentMap: MapWithDefault<string, Map<string, IPayload>> = new MapWithDefault(() => new Map());
for (const val of userDeviceMap) {
const userId = val.userId;
@ -722,17 +717,14 @@ export class MegolmEncryption extends EncryptionAlgorithm {
delete message.session_id;
}
if (!contentMap[userId]) {
contentMap[userId] = {};
}
contentMap[userId][deviceId] = message;
contentMap.getOrCreate(userId).set(deviceId, message);
}
await this.baseApis.sendToDevice("m.room_key.withheld", contentMap);
// record the fact that we notified these blocked devices
for (const userId of Object.keys(contentMap)) {
for (const deviceId of Object.keys(contentMap[userId])) {
for (const [userId, userDeviceMap] of contentMap) {
for (const deviceId of userDeviceMap.keys()) {
session.markNotifiedBlockedDevice(userId, deviceId);
}
}
@ -760,11 +752,11 @@ export class MegolmEncryption extends EncryptionAlgorithm {
}
// The chain index of the key we previously sent this device
if (obSessionInfo.sharedWithDevices[userId] === undefined) {
if (!obSessionInfo.sharedWithDevices.has(userId)) {
this.prefixedLogger.debug(`megolm session ${senderKey}|${sessionId} never shared with user ${userId}`);
return;
}
const sessionSharedData = obSessionInfo.sharedWithDevices[userId][device.deviceId];
const sessionSharedData = obSessionInfo.sharedWithDevices.get(userId)?.get(device.deviceId);
if (sessionSharedData === undefined) {
this.prefixedLogger.debug(
`megolm session ${senderKey}|${sessionId} never shared with device ${userId}:${device.deviceId}`,
@ -796,9 +788,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
return;
}
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, {
[userId]: [device],
});
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[userId, [device]]]));
const payload = {
type: "m.forwarded_room_key",
@ -831,11 +821,10 @@ export class MegolmEncryption extends EncryptionAlgorithm {
payload,
);
await this.baseApis.sendToDevice("m.room.encrypted", {
[userId]: {
[device.deviceId]: encryptedContent,
},
});
await this.baseApis.sendToDevice(
"m.room.encrypted",
new Map([[userId, new Map([[device.deviceId, encryptedContent]])]]),
);
this.prefixedLogger.debug(
`Re-shared key for megolm session ${senderKey}|${sessionId} with ${userId}:${device.deviceId}`,
);
@ -865,7 +854,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
session: OutboundSessionInfo,
key: IOutboundGroupSessionKey,
payload: IPayload,
devicesByUser: Record<string, DeviceInfo[]>,
devicesByUser: Map<string, DeviceInfo[]>,
errorDevices: IOlmDevice[],
otkTimeout: number,
failedServers?: string[],
@ -887,9 +876,9 @@ export class MegolmEncryption extends EncryptionAlgorithm {
session: OutboundSessionInfo,
key: IOutboundGroupSessionKey,
payload: IPayload,
devicemap: Record<string, Record<string, IOlmSessionResult>>,
deviceMap: Map<string, Map<string, IOlmSessionResult>>,
): Promise<void> {
const userDeviceMaps = this.splitDevices(devicemap);
const userDeviceMaps = this.splitDevices(deviceMap);
for (let i = 0; i < userDeviceMaps.length; i++) {
const taskDetail = `megolm keys for ${session.sessionId} (slice ${i + 1}/${userDeviceMaps.length})`;
@ -934,19 +923,20 @@ export class MegolmEncryption extends EncryptionAlgorithm {
this.prefixedLogger.debug(
`Need to notify ${unnotifiedFailedDevices.length} failed devices which haven't been notified before`,
);
const blockedMap: Record<string, Record<string, { device: IBlockedDevice }>> = {};
const blockedMap: MapWithDefault<string, Map<string, { device: IBlockedDevice }>> = new MapWithDefault(
() => new Map(),
);
for (const { userId, deviceInfo } of unnotifiedFailedDevices) {
blockedMap[userId] = blockedMap[userId] || {};
// we use a similar format to what
// olmlib.ensureOlmSessionsForDevices returns, so that
// we can use the same function to split
blockedMap[userId][deviceInfo.deviceId] = {
blockedMap.getOrCreate(userId).set(deviceInfo.deviceId, {
device: {
code: "m.no_olm",
reason: WITHHELD_MESSAGES["m.no_olm"],
deviceInfo,
},
};
});
}
// send the notifications
@ -964,7 +954,7 @@ export class MegolmEncryption extends EncryptionAlgorithm {
*/
private async notifyBlockedDevices(
session: OutboundSessionInfo,
devicesByUser: Record<string, Record<string, { device: IBlockedDevice }>>,
devicesByUser: Map<string, Map<string, { device: IBlockedDevice }>>,
): Promise<void> {
const payload: IPayload = {
room_id: this.roomId,
@ -1154,21 +1144,17 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* devices we should shared the session with.
*/
private checkForUnknownDevices(devicesInRoom: DeviceInfoMap): void {
const unknownDevices: Record<string, Record<string, DeviceInfo>> = {};
const unknownDevices: MapWithDefault<string, Map<string, DeviceInfo>> = new MapWithDefault(() => new Map());
Object.keys(devicesInRoom).forEach((userId) => {
Object.keys(devicesInRoom[userId]).forEach((deviceId) => {
const device = devicesInRoom[userId][deviceId];
for (const [userId, userDevices] of devicesInRoom) {
for (const [deviceId, device] of userDevices) {
if (device.isUnverified() && !device.isKnown()) {
if (!unknownDevices[userId]) {
unknownDevices[userId] = {};
unknownDevices.getOrCreate(userId).set(deviceId, device);
}
}
unknownDevices[userId][deviceId] = device;
}
});
});
if (Object.keys(unknownDevices).length) {
if (unknownDevices.size) {
// it'd be kind to pass unknownDevices up to the user in this error
throw new UnknownDeviceError(
"This room contains unknown devices which have not been verified. " +
@ -1186,15 +1172,15 @@ export class MegolmEncryption extends EncryptionAlgorithm {
* devices we should shared the session with.
*/
private removeUnknownDevices(devicesInRoom: DeviceInfoMap): void {
for (const [userId, userDevices] of Object.entries(devicesInRoom)) {
for (const [deviceId, device] of Object.entries(userDevices)) {
for (const [userId, userDevices] of devicesInRoom) {
for (const [deviceId, device] of userDevices) {
if (device.isUnverified() && !device.isKnown()) {
delete userDevices[deviceId];
userDevices.delete(deviceId);
}
}
if (Object.keys(userDevices).length === 0) {
delete devicesInRoom[userId];
if (userDevices.size === 0) {
devicesInRoom.delete(userId);
}
}
}
@ -1219,17 +1205,17 @@ export class MegolmEncryption extends EncryptionAlgorithm {
private async getDevicesInRoom(
room: Room,
forceDistributeToUnverified?: boolean,
): Promise<[DeviceInfoMap, IBlockedMap]>;
): Promise<[DeviceInfoMap, BlockedMap]>;
private async getDevicesInRoom(
room: Room,
forceDistributeToUnverified?: boolean,
isCancelled?: () => boolean,
): Promise<null | [DeviceInfoMap, IBlockedMap]>;
): Promise<null | [DeviceInfoMap, BlockedMap]>;
private async getDevicesInRoom(
room: Room,
forceDistributeToUnverified = false,
isCancelled?: () => boolean,
): Promise<null | [DeviceInfoMap, IBlockedMap]> {
): Promise<null | [DeviceInfoMap, BlockedMap]> {
const members = await room.getEncryptionTargetMembers();
this.prefixedLogger.debug(
`Encrypting for users (shouldEncryptForInvitedMembers: ${room.shouldEncryptForInvitedMembers()}):`,
@ -1254,24 +1240,15 @@ export class MegolmEncryption extends EncryptionAlgorithm {
// using all the device_lists changes and left fields.
// See https://github.com/vector-im/element-web/issues/2305 for details.
const devices = await this.crypto.downloadKeys(roomMembers, false);
const blocked: IBlockedMap = {};
if (isCancelled?.() === true) {
return null;
}
const blocked = new MapWithDefault<string, Map<string, IBlockedDevice>>(() => new Map());
// remove any blocked devices
for (const userId in devices) {
if (!devices.hasOwnProperty(userId)) {
continue;
}
const userDevices = devices[userId];
for (const deviceId in userDevices) {
if (!userDevices.hasOwnProperty(deviceId)) {
continue;
}
for (const [userId, userDevices] of devices) {
for (const [deviceId, userDevice] of userDevices) {
// Yield prior to checking each device so that we don't block
// updating/rendering for too long.
// See https://github.com/vector-im/element-web/issues/21612
@ -1280,19 +1257,17 @@ export class MegolmEncryption extends EncryptionAlgorithm {
const deviceTrust = this.crypto.checkDeviceTrust(userId, deviceId);
if (
userDevices[deviceId].isBlocked() ||
userDevice.isBlocked() ||
(!deviceTrust.isVerified() && isBlacklisting && !forceDistributeToUnverified)
) {
if (!blocked[userId]) {
blocked[userId] = {};
}
const isBlocked = userDevices[deviceId].isBlocked();
blocked[userId][deviceId] = {
const blockedDevices = blocked.getOrCreate(userId);
const isBlocked = userDevice.isBlocked();
blockedDevices.set(deviceId, {
code: isBlocked ? "m.blacklisted" : "m.unverified",
reason: WITHHELD_MESSAGES[isBlocked ? "m.blacklisted" : "m.unverified"],
deviceInfo: userDevices[deviceId],
};
delete userDevices[deviceId];
deviceInfo: userDevice,
});
userDevices.delete(deviceId);
}
}
}
@ -1923,7 +1898,7 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// XXX: switch this to use encryptAndSendToDevices() rather than duplicating it?
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { [sender]: [device] }, false);
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[sender, [device]]]), false);
const encryptedContent: IEncryptedContent = {
algorithm: olmlib.OLM_ALGORITHM,
sender_key: this.olmDevice.deviceCurve25519Key!,
@ -1942,11 +1917,10 @@ export class MegolmDecryption extends DecryptionAlgorithm {
await this.olmDevice.recordSessionProblem(senderKey, "no_olm", true);
await this.baseApis.sendToDevice("m.room.encrypted", {
[sender]: {
[device.deviceId]: encryptedContent,
},
});
await this.baseApis.sendToDevice(
"m.room.encrypted",
new Map([[sender, new Map([[device.deviceId, encryptedContent]])]]),
);
}
public hasKeysForKeyRequest(keyRequest: IncomingRoomKeyRequest): Promise<boolean> {
@ -1969,12 +1943,10 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// XXX: switch this to use encryptAndSendToDevices()?
this.olmlib
.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, {
[userId]: [deviceInfo],
})
.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[userId, [deviceInfo]]]))
.then((devicemap) => {
const olmSessionResult = devicemap[userId][deviceId];
if (!olmSessionResult.sessionId) {
const olmSessionResult = devicemap.get(userId)?.get(deviceId);
if (!olmSessionResult?.sessionId) {
// no session with this device, probably because there
// were no one-time keys.
//
@ -2015,14 +1987,11 @@ export class MegolmDecryption extends DecryptionAlgorithm {
payload!,
)
.then(() => {
const contentMap = {
[userId]: {
[deviceId]: encryptedContent,
},
};
// TODO: retries
return this.baseApis.sendToDevice("m.room.encrypted", contentMap);
return this.baseApis.sendToDevice(
"m.room.encrypted",
new Map([[userId, new Map([[deviceId, encryptedContent]])]]),
);
});
});
}
@ -2162,12 +2131,12 @@ export class MegolmDecryption extends DecryptionAlgorithm {
return !this.pendingEvents.has(senderKey);
}
public async sendSharedHistoryInboundSessions(devicesByUser: Record<string, DeviceInfo[]>): Promise<void> {
public async sendSharedHistoryInboundSessions(devicesByUser: Map<string, DeviceInfo[]>): Promise<void> {
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser);
const sharedHistorySessions = await this.olmDevice.getSharedHistoryInboundGroupSessions(this.roomId);
this.prefixedLogger.log(
`Sharing history in with users ${Object.keys(devicesByUser)}`,
`Sharing history in with users ${Array.from(devicesByUser.keys())}`,
sharedHistorySessions.map(([senderKey, sessionId]) => `${senderKey}|${sessionId}`),
);
for (const [senderKey, sessionId] of sharedHistorySessions) {
@ -2175,9 +2144,10 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// FIXME: use encryptAndSendToDevices() rather than duplicating it here.
const promises: Promise<unknown>[] = [];
const contentMap: Record<string, Record<string, IEncryptedContent>> = {};
for (const [userId, devices] of Object.entries(devicesByUser)) {
contentMap[userId] = {};
const contentMap: Map<string, Map<string, IEncryptedContent>> = new Map();
for (const [userId, devices] of devicesByUser) {
const deviceMessages = new Map();
contentMap.set(userId, deviceMessages);
for (const deviceInfo of devices) {
const encryptedContent: IEncryptedContent = {
algorithm: olmlib.OLM_ALGORITHM,
@ -2185,7 +2155,7 @@ export class MegolmDecryption extends DecryptionAlgorithm {
ciphertext: {},
[ToDeviceMessageId]: uuidv4(),
};
contentMap[userId][deviceInfo.deviceId] = encryptedContent;
deviceMessages.set(deviceInfo.deviceId, encryptedContent);
promises.push(
olmlib.encryptMessageForDevice(
encryptedContent.ciphertext,
@ -2205,22 +2175,22 @@ export class MegolmDecryption extends DecryptionAlgorithm {
// in which case it will have just not added anything to the ciphertext object.
// There's no point sending messages to devices if we couldn't encrypt to them,
// since that's effectively a blank message.
for (const userId of Object.keys(contentMap)) {
for (const deviceId of Object.keys(contentMap[userId])) {
if (Object.keys(contentMap[userId][deviceId].ciphertext).length === 0) {
for (const [userId, deviceMessages] of contentMap) {
for (const [deviceId, content] of deviceMessages) {
if (!hasCiphertext(content)) {
this.prefixedLogger.log("No ciphertext for device " + userId + ":" + deviceId + ": pruning");
delete contentMap[userId][deviceId];
deviceMessages.delete(deviceId);
}
}
// No devices left for that user? Strip that too.
if (Object.keys(contentMap[userId]).length === 0) {
if (deviceMessages.size === 0) {
this.prefixedLogger.log("Pruned all devices for user " + userId);
delete contentMap[userId];
contentMap.delete(userId);
}
}
// Is there anything left?
if (Object.keys(contentMap).length === 0) {
if (contentMap.size === 0) {
this.prefixedLogger.log("No users left to send to: aborting");
return;
}

View File

@ -25,7 +25,7 @@ import { MEGOLM_ALGORITHM, verifySignature } from "./olmlib";
import { DeviceInfo } from "./deviceinfo";
import { DeviceTrustLevel } from "./CrossSigning";
import { keyFromPassphrase } from "./key_passphrase";
import { sleep } from "../utils";
import { safeSet, sleep } from "../utils";
import { IndexedDBCryptoStore } from "./store/indexeddb-crypto-store";
import { encodeRecoveryKey } from "./recoverykey";
import { calculateKeyCheck, decryptAES, encryptAES, IEncryptedPayload } from "./aes";
@ -498,9 +498,7 @@ export class BackupManager {
const rooms: IKeyBackup["rooms"] = {};
for (const session of sessions) {
const roomId = session.sessionData!.room_id;
if (rooms[roomId] === undefined) {
rooms[roomId] = { sessions: {} };
}
safeSet(rooms, roomId, rooms[roomId] || { sessions: {} });
const sessionData = this.baseApis.crypto!.olmDevice.exportInboundGroupSession(
session.senderKey,
@ -517,12 +515,12 @@ export class BackupManager {
undefined;
const verified = this.baseApis.crypto!.checkDeviceInfoTrust(userId!, device).isVerified();
rooms[roomId]["sessions"][session.sessionId] = {
safeSet(rooms[roomId]["sessions"], session.sessionId, {
first_message_index: sessionData.first_known_index,
forwarded_count: forwardedCount,
is_verified: verified,
session_data: await this.algorithm!.encryptSession(sessionData),
};
});
}
await this.baseApis.sendKeyBackup(undefined, undefined, this.backupInfo!.version, { rooms });

View File

@ -90,6 +90,7 @@ import { ISignatures } from "../@types/signed";
import { IMessage } from "./algorithms/olm";
import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend";
import { RoomState, RoomStateEvent } from "../models/room-state";
import { MapWithDefault, recursiveMapToObject } from "../utils";
const DeviceVerification = DeviceInfo.DeviceVerification;
@ -399,7 +400,10 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
// deviceId: 1234567890000,
// },
// }
private lastNewSessionForced: Record<string, Record<string, number>> = {};
// Map: user Id → device Id → timestamp
private lastNewSessionForced: MapWithDefault<string, MapWithDefault<string, number>> = new MapWithDefault(
() => new MapWithDefault(() => 0),
);
// This flag will be unset whilst the client processes a sync response
// so that we don't start requesting keys until we've actually finished
@ -2690,11 +2694,13 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
public ensureOlmSessionsForUsers(
users: string[],
force?: boolean,
): Promise<Record<string, Record<string, olmlib.IOlmSessionResult>>> {
const devicesByUser: Record<string, DeviceInfo[]> = {};
): Promise<Map<string, Map<string, olmlib.IOlmSessionResult>>> {
// map user Id → DeviceInfo[]
const devicesByUser: Map<string, DeviceInfo[]> = new Map();
for (const userId of users) {
devicesByUser[userId] = [];
const userDevices: DeviceInfo[] = [];
devicesByUser.set(userId, userDevices);
const devices = this.getStoredDevicesForUser(userId) || [];
for (const deviceInfo of devices) {
@ -2708,7 +2714,7 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
continue;
}
devicesByUser[userId].push(deviceInfo);
userDevices.push(deviceInfo);
}
}
@ -3146,7 +3152,11 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
payload: encryptedContent,
});
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { [userId]: [deviceInfo] });
await olmlib.ensureOlmSessionsForDevices(
this.olmDevice,
this.baseApis,
new Map([[userId, [deviceInfo]]]),
);
await olmlib.encryptMessageForDevice(
encryptedContent.ciphertext,
this.userId,
@ -3459,8 +3469,8 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
// check when we last forced a new session with this device: if we've already done so
// recently, don't do it again.
this.lastNewSessionForced[sender] = this.lastNewSessionForced[sender] || {};
const lastNewSessionForced = this.lastNewSessionForced[sender][deviceKey] || 0;
const lastNewSessionDevices = this.lastNewSessionForced.getOrCreate(sender);
const lastNewSessionForced = lastNewSessionDevices.getOrCreate(deviceKey);
if (lastNewSessionForced + MIN_FORCE_SESSION_INTERVAL_MS > Date.now()) {
logger.debug(
"New session already forced with device " +
@ -3493,11 +3503,10 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
return;
}
}
const devicesByUser: Record<string, DeviceInfo[]> = {};
devicesByUser[sender] = [device];
const devicesByUser = new Map([[sender, [device]]]);
await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser, true);
this.lastNewSessionForced[sender][deviceKey] = Date.now();
lastNewSessionDevices.set(deviceKey, Date.now());
// Now send a blank message on that session so the other side knows about it.
// (The keyshare request is sent in the clear so that won't do)
@ -3524,11 +3533,10 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
await this.olmDevice.recordSessionProblem(deviceKey, "wedged", true);
retryDecryption();
await this.baseApis.sendToDevice("m.room.encrypted", {
[sender]: {
[device.deviceId]: encryptedContent,
},
});
await this.baseApis.sendToDevice(
"m.room.encrypted",
new Map([[sender, new Map([[device.deviceId, encryptedContent]])]]),
);
// Most of the time this probably won't be necessary since we'll have queued up a key request when
// we failed to decrypt the message and will be waiting a bit for the key to arrive before sending
@ -3835,15 +3843,16 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
* @param obj - Object to which we will add a 'signatures' property
*/
public async signObject<T extends ISignableObject & object>(obj: T): Promise<void> {
const sigs = obj.signatures || {};
const sigs = new Map(Object.entries(obj.signatures || {}));
const unsigned = obj.unsigned;
delete obj.signatures;
delete obj.unsigned;
sigs[this.userId] = sigs[this.userId] || {};
sigs[this.userId]["ed25519:" + this.deviceId] = await this.olmDevice.sign(anotherjson.stringify(obj));
obj.signatures = sigs;
const userSignatures = sigs.get(this.userId) || {};
sigs.set(this.userId, userSignatures);
userSignatures["ed25519:" + this.deviceId] = await this.olmDevice.sign(anotherjson.stringify(obj));
obj.signatures = recursiveMapToObject(sigs);
if (unsigned !== undefined) obj.unsigned = unsigned;
}
}

View File

@ -30,6 +30,7 @@ import { ISignatures } from "../@types/signed";
import { MatrixEvent } from "../models/event";
import { EventType } from "../@types/event";
import { IMessage } from "./algorithms/olm";
import { MapWithDefault } from "../utils";
enum Algorithm {
Olm = "m.olm.v1.curve25519-aes-sha2",
@ -154,9 +155,11 @@ export async function getExistingOlmSessions(
olmDevice: OlmDevice,
baseApis: MatrixClient,
devicesByUser: Record<string, DeviceInfo[]>,
): Promise<[Record<string, DeviceInfo[]>, Record<string, Record<string, IExistingOlmSession>>]> {
const devicesWithoutSession: { [userId: string]: DeviceInfo[] } = {};
const sessions: { [userId: string]: { [deviceId: string]: IExistingOlmSession } } = {};
): Promise<[Map<string, DeviceInfo[]>, Map<string, Map<string, IExistingOlmSession>>]> {
// map user Id → DeviceInfo[]
const devicesWithoutSession: MapWithDefault<string, DeviceInfo[]> = new MapWithDefault(() => []);
// map user Id → device Id → IExistingOlmSession
const sessions: MapWithDefault<string, Map<string, IExistingOlmSession>> = new MapWithDefault(() => new Map());
const promises: Promise<void>[] = [];
@ -168,14 +171,12 @@ export async function getExistingOlmSessions(
(async (): Promise<void> => {
const sessionId = await olmDevice.getSessionIdForDevice(key, true);
if (sessionId === null) {
devicesWithoutSession[userId] = devicesWithoutSession[userId] || [];
devicesWithoutSession[userId].push(deviceInfo);
devicesWithoutSession.getOrCreate(userId).push(deviceInfo);
} else {
sessions[userId] = sessions[userId] || {};
sessions[userId][deviceId] = {
sessions.getOrCreate(userId).set(deviceId, {
device: deviceInfo,
sessionId: sessionId,
};
});
}
})(),
);
@ -210,24 +211,26 @@ export async function getExistingOlmSessions(
export async function ensureOlmSessionsForDevices(
olmDevice: OlmDevice,
baseApis: MatrixClient,
devicesByUser: Record<string, DeviceInfo[]>,
devicesByUser: Map<string, DeviceInfo[]>,
force = false,
otkTimeout?: number,
failedServers?: string[],
log = logger,
): Promise<Record<string, Record<string, IOlmSessionResult>>> {
): Promise<Map<string, Map<string, IOlmSessionResult>>> {
const devicesWithoutSession: [string, string][] = [
// [userId, deviceId], ...
];
const result: { [userId: string]: { [deviceId: string]: IExistingOlmSession } } = {};
const resolveSession: Record<string, (sessionId?: string) => void> = {};
// map user Iddevice Id → IExistingOlmSession
const result: Map<string, Map<string, IExistingOlmSession>> = new Map();
// map device key → resolve session fn
const resolveSession: Map<string, (sessionId?: string) => void> = new Map();
// Mark all sessions this task intends to update as in progress. It is
// important to do this for all devices this task cares about in a single
// synchronous operation, as otherwise it is possible to have deadlocks
// where multiple tasks wait indefinitely on another task to update some set
// of common devices.
for (const [, devices] of Object.entries(devicesByUser)) {
for (const devices of devicesByUser.values()) {
for (const deviceInfo of devices) {
const key = deviceInfo.getIdentityKey();
@ -242,17 +245,19 @@ export async function ensureOlmSessionsForDevices(
// conditions. If we find that we already have a session, then
// we'll resolve
olmDevice.sessionsInProgress[key] = new Promise((resolve) => {
resolveSession[key] = (v: any): void => {
resolveSession.set(key, (v: any): void => {
delete olmDevice.sessionsInProgress[key];
resolve(v);
};
});
});
}
}
}
for (const [userId, devices] of Object.entries(devicesByUser)) {
result[userId] = {};
for (const [userId, devices] of devicesByUser) {
const resultDevices = new Map();
result.set(userId, resultDevices);
for (const deviceInfo of devices) {
const deviceId = deviceInfo.deviceId;
const key = deviceInfo.getIdentityKey();
@ -268,20 +273,21 @@ export async function ensureOlmSessionsForDevices(
log.info("Attempted to start session with ourself! Ignoring");
// We must fill in the section in the return value though, as callers
// expect it to be there.
result[userId][deviceId] = {
resultDevices.set(deviceId, {
device: deviceInfo,
sessionId: null,
};
});
continue;
}
const forWhom = `for ${key} (${userId}:${deviceId})`;
const sessionId = await olmDevice.getSessionIdForDevice(key, !!resolveSession[key], log);
if (sessionId !== null && resolveSession[key]) {
const sessionId = await olmDevice.getSessionIdForDevice(key, !!resolveSession.get(key), log);
const resolveSessionFn = resolveSession.get(key);
if (sessionId !== null && resolveSessionFn) {
// we found a session, but we had marked the session as
// in-progress, so resolve it now, which will unmark it and
// unblock anything that was waiting
resolveSession[key]();
resolveSessionFn();
}
if (sessionId === null || force) {
if (force) {
@ -291,10 +297,10 @@ export async function ensureOlmSessionsForDevices(
}
devicesWithoutSession.push([userId, deviceId]);
}
result[userId][deviceId] = {
resultDevices.set(deviceId, {
device: deviceInfo,
sessionId: sessionId,
};
});
}
}
@ -310,7 +316,7 @@ export async function ensureOlmSessionsForDevices(
res = await baseApis.claimOneTimeKeys(devicesWithoutSession, oneTimeKeyAlgorithm, otkTimeout);
log.debug(`Claimed ${taskDetail}`);
} catch (e) {
for (const resolver of Object.values(resolveSession)) {
for (const resolver of resolveSession.values()) {
resolver();
}
log.log(`Failed to claim ${taskDetail}`, e, devicesWithoutSession);
@ -323,7 +329,7 @@ export async function ensureOlmSessionsForDevices(
const otkResult = res.one_time_keys || ({} as IClaimOTKsResult["one_time_keys"]);
const promises: Promise<void>[] = [];
for (const [userId, devices] of Object.entries(devicesByUser)) {
for (const [userId, devices] of devicesByUser) {
const userRes = otkResult[userId] || {};
for (const deviceInfo of devices) {
const deviceId = deviceInfo.deviceId;
@ -336,7 +342,7 @@ export async function ensureOlmSessionsForDevices(
continue;
}
if (result[userId][deviceId].sessionId && !force) {
if (result.get(userId)?.get(deviceId)?.sessionId && !force) {
// we already have a result for this device
continue;
}
@ -351,24 +357,19 @@ export async function ensureOlmSessionsForDevices(
if (!oneTimeKey) {
log.warn(`No one-time keys (alg=${oneTimeKeyAlgorithm}) ` + `for device ${userId}:${deviceId}`);
if (resolveSession[key]) {
resolveSession[key]();
}
resolveSession.get(key)?.();
continue;
}
promises.push(
_verifyKeyAndStartSession(olmDevice, oneTimeKey, userId, deviceInfo).then(
(sid) => {
if (resolveSession[key]) {
resolveSession[key](sid ?? undefined);
}
result[userId][deviceId].sessionId = sid;
resolveSession.get(key)?.(sid ?? undefined);
const deviceInfo = result.get(userId)?.get(deviceId);
if (deviceInfo) deviceInfo.sessionId = sid;
},
(e) => {
if (resolveSession[key]) {
resolveSession[key]();
}
resolveSession.get(key)?.();
throw e;
},
),

View File

@ -21,6 +21,7 @@ import { IOlmDevice } from "../algorithms/megolm";
import { IRoomEncryption } from "../RoomList";
import { ICrossSigningKey } from "../../client";
import { InboundGroupSessionData } from "../OlmDevice";
import { safeSet } from "../../utils";
/**
* Internal module. Partial localStorage backed storage for e2e.
@ -178,11 +179,11 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore {
if (userId in notifiedErrorDevices) {
if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) {
ret.push(device);
notifiedErrorDevices[userId][deviceInfo.deviceId] = true;
safeSet(notifiedErrorDevices[userId], deviceInfo.deviceId, true);
}
} else {
ret.push(device);
notifiedErrorDevices[userId] = { [deviceInfo.deviceId]: true };
safeSet(notifiedErrorDevices, userId, { [deviceInfo.deviceId]: true });
}
}

View File

@ -33,6 +33,7 @@ import { ICrossSigningKey } from "../../client";
import { IOlmDevice } from "../algorithms/megolm";
import { IRoomEncryption } from "../RoomList";
import { InboundGroupSessionData } from "../OlmDevice";
import { safeSet } from "../../utils";
/**
* Internal module. in-memory storage for e2e.
@ -375,11 +376,11 @@ export class MemoryCryptoStore implements CryptoStore {
if (userId in notifiedErrorDevices) {
if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) {
ret.push(device);
notifiedErrorDevices[userId][deviceInfo.deviceId] = true;
safeSet(notifiedErrorDevices[userId], deviceInfo.deviceId, true);
}
} else {
ret.push(device);
notifiedErrorDevices[userId] = { [deviceInfo.deviceId]: true };
safeSet(notifiedErrorDevices, userId, { [deviceInfo.deviceId]: true });
}
}

View File

@ -269,12 +269,12 @@ export class ToDeviceChannel implements IVerificationChannel {
private async sendToDevices(type: string, content: Record<string, any>, devices: string[]): Promise<void> {
if (devices.length) {
const msgMap: Record<string, Record<string, any>> = {};
const deviceMessages: Map<string, Record<string, any>> = new Map();
for (const deviceId of devices) {
msgMap[deviceId] = content;
deviceMessages.set(deviceId, content);
}
await this.client.sendToDevice(type, { [this.userId]: msgMap });
await this.client.sendToDevice(type, new Map([[this.userId, deviceMessages]]));
}
}

View File

@ -29,15 +29,16 @@ import { IEvent, IContent, EventStatus } from "./models/event";
import { ISendEventResponse } from "./@types/requests";
import { EventType } from "./@types/event";
import { logger } from "./logger";
import { MatrixClient, ClientEvent, IMatrixClientCreateOpts, IStartClientOpts } from "./client";
import { MatrixClient, ClientEvent, IMatrixClientCreateOpts, IStartClientOpts, SendToDeviceContentMap } from "./client";
import { SyncApi, SyncState } from "./sync";
import { SlidingSyncSdk } from "./sliding-sync-sdk";
import { MatrixEvent } from "./models/event";
import { User } from "./models/user";
import { Room } from "./models/room";
import { ToDeviceBatch } from "./models/ToDeviceMessage";
import { ToDeviceBatch, ToDevicePayload } from "./models/ToDeviceMessage";
import { DeviceInfo } from "./crypto/deviceinfo";
import { IOlmDevice } from "./crypto/algorithms/megolm";
import { MapWithDefault, recursiveMapToObject } from "./utils";
interface IStateEventRequest {
eventType: string;
@ -234,35 +235,32 @@ export class RoomWidgetClient extends MatrixClient {
return await this.widgetApi.sendStateEvent(eventType, stateKey, content, roomId);
}
public async sendToDevice(
eventType: string,
contentMap: { [userId: string]: { [deviceId: string]: Record<string, any> } },
): Promise<{}> {
await this.widgetApi.sendToDevice(eventType, false, contentMap);
public async sendToDevice(eventType: string, contentMap: SendToDeviceContentMap): Promise<{}> {
await this.widgetApi.sendToDevice(eventType, false, recursiveMapToObject(contentMap));
return {};
}
public async queueToDevice({ eventType, batch }: ToDeviceBatch): Promise<void> {
const contentMap: { [userId: string]: { [deviceId: string]: object } } = {};
// map: user Iddevice Id → payload
const contentMap: MapWithDefault<string, Map<string, ToDevicePayload>> = new MapWithDefault(() => new Map());
for (const { userId, deviceId, payload } of batch) {
if (!contentMap[userId]) contentMap[userId] = {};
contentMap[userId][deviceId] = payload;
contentMap.getOrCreate(userId).set(deviceId, payload);
}
await this.widgetApi.sendToDevice(eventType, false, contentMap);
await this.widgetApi.sendToDevice(eventType, false, recursiveMapToObject(contentMap));
}
public async encryptAndSendToDevices(userDeviceInfoArr: IOlmDevice<DeviceInfo>[], payload: object): Promise<void> {
const contentMap: { [userId: string]: { [deviceId: string]: object } } = {};
// map: user Iddevice Id → payload
const contentMap: MapWithDefault<string, Map<string, object>> = new MapWithDefault(() => new Map());
for (const {
userId,
deviceInfo: { deviceId },
} of userDeviceInfoArr) {
if (!contentMap[userId]) contentMap[userId] = {};
contentMap[userId][deviceId] = payload;
contentMap.getOrCreate(userId).set(deviceId, payload);
}
await this.widgetApi.sendToDevice((payload as { type: string }).type, true, contentMap);
await this.widgetApi.sendToDevice((payload as { type: string }).type, true, recursiveMapToObject(contentMap));
}
// Overridden since we get TURN servers automatically over the widget API,

View File

@ -16,7 +16,6 @@ import {
MAIN_ROOM_TIMELINE,
Receipt,
ReceiptCache,
Receipts,
ReceiptType,
WrappedReceipt,
} from "../@types/read_receipts";
@ -25,6 +24,7 @@ import * as utils from "../utils";
import { MatrixEvent } from "./event";
import { EventType } from "../@types/event";
import { EventTimelineSet } from "./event-timeline-set";
import { MapWithDefault } from "../utils";
import { NotificationCountType } from "./room";
export function synthesizeReceipt(userId: string, event: MatrixEvent, receiptType: ReceiptType): MatrixEvent {
@ -56,8 +56,11 @@ export abstract class ReadReceipt<
// the form of this structure. This is sub-optimal for the exposed APIs
// which pass in an event ID and get back some receipts, so we also store
// a pre-cached list for this purpose.
private receipts: Receipts = {}; // { receipt_type: { user_id: Receipt } }
private receiptCacheByEventId: ReceiptCache = {}; // { event_id: CachedReceipt[] }
// Map: receipt type user Id → receipt
private receipts = new MapWithDefault<string, Map<string, [WrappedReceipt | null, WrappedReceipt | null]>>(
() => new Map(),
);
private receiptCacheByEventId: ReceiptCache = new Map();
public abstract getUnfilteredTimelineSet(): EventTimelineSet;
public abstract timeline: MatrixEvent[];
@ -74,7 +77,7 @@ export abstract class ReadReceipt<
ignoreSynthesized = false,
receiptType = ReceiptType.Read,
): WrappedReceipt | null {
const [realReceipt, syntheticReceipt] = this.receipts[receiptType]?.[userId] ?? [];
const [realReceipt, syntheticReceipt] = this.receipts.get(receiptType)?.get(userId) ?? [null, null];
if (ignoreSynthesized) {
return realReceipt;
}
@ -126,14 +129,13 @@ export abstract class ReadReceipt<
receipt: Receipt,
synthetic: boolean,
): void {
if (!this.receipts[receiptType]) {
this.receipts[receiptType] = {};
}
if (!this.receipts[receiptType][userId]) {
this.receipts[receiptType][userId] = [null, null];
}
const receiptTypesMap = this.receipts.getOrCreate(receiptType);
let pair = receiptTypesMap.get(userId);
const pair = this.receipts[receiptType][userId];
if (!pair) {
pair = [null, null];
receiptTypesMap.set(userId, pair);
}
let existingReceipt = pair[ReceiptPairRealIndex];
if (synthetic) {
@ -185,23 +187,26 @@ export abstract class ReadReceipt<
if (cachedReceipt === newCachedReceipt) return;
// clean up any previous cache entry
if (cachedReceipt && this.receiptCacheByEventId[cachedReceipt.eventId]) {
if (cachedReceipt && this.receiptCacheByEventId.get(cachedReceipt.eventId)) {
const previousEventId = cachedReceipt.eventId;
// Remove the receipt we're about to clobber out of existence from the cache
this.receiptCacheByEventId[previousEventId] = this.receiptCacheByEventId[previousEventId].filter((r) => {
this.receiptCacheByEventId.set(
previousEventId,
this.receiptCacheByEventId.get(previousEventId)!.filter((r) => {
return r.type !== receiptType || r.userId !== userId;
});
}),
);
if (this.receiptCacheByEventId[previousEventId].length < 1) {
delete this.receiptCacheByEventId[previousEventId]; // clean up the cache keys
if (this.receiptCacheByEventId.get(previousEventId)!.length < 1) {
this.receiptCacheByEventId.delete(previousEventId); // clean up the cache keys
}
}
// cache the new one
if (!this.receiptCacheByEventId[eventId]) {
this.receiptCacheByEventId[eventId] = [];
if (!this.receiptCacheByEventId.get(eventId)) {
this.receiptCacheByEventId.set(eventId, []);
}
this.receiptCacheByEventId[eventId].push({
this.receiptCacheByEventId.get(eventId)!.push({
userId: userId,
type: receiptType as ReceiptType,
data: receipt,
@ -215,7 +220,7 @@ export abstract class ReadReceipt<
* an empty list.
*/
public getReceiptsForEvent(event: MatrixEvent): CachedReceipt[] {
return this.receiptCacheByEventId[event.getId()!] || [];
return this.receiptCacheByEventId.get(event.getId()!) || [];
}
public abstract addReceipt(event: MatrixEvent, synthetic: boolean): void;

View File

@ -25,7 +25,7 @@ import {
import { Direction, EventTimeline } from "./event-timeline";
import { getHttpUriForMxc } from "../content-repo";
import * as utils from "../utils";
import { normalize } from "../utils";
import { normalize, noUnsafeEventProps } from "../utils";
import { IEvent, IThreadBundledRelationship, MatrixEvent, MatrixEventEvent, MatrixEventHandlerMap } from "./event";
import { EventStatus } from "./event-status";
import { RoomMember } from "./room-member";
@ -311,7 +311,7 @@ export type RoomEventHandlerMap = {
export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
public readonly reEmitter: TypedReEmitter<RoomEmittedEvents, RoomEventHandlerMap>;
private txnToEvent: Record<string, MatrixEvent> = {}; // Pending in-flight requests { string: MatrixEvent }
private txnToEvent: Map<string, MatrixEvent> = new Map(); // Pending in-flight requests { string: MatrixEvent }
private notificationCounts: NotificationCount = {};
private readonly threadNotifications = new Map<string, NotificationCount>();
public readonly cachedThreadReadReceipts = new Map<string, CachedReceiptStructure[]>();
@ -356,7 +356,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
* accountData Dict of per-room account_data events; the keys are the
* event type and the values are the events.
*/
public accountData: Record<string, MatrixEvent> = {}; // $eventType: $event
public accountData: Map<string, MatrixEvent> = new Map(); // $eventType: $event
/**
* The room summary.
*/
@ -902,7 +902,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
rawMembersEvents = await this.loadMembersFromServer();
logger.log(`LL: got ${rawMembersEvents.length} ` + `members from server for room ${this.roomId}`);
}
const memberEvents = rawMembersEvents.map(this.client.getEventMapper());
const memberEvents = rawMembersEvents.filter(noUnsafeEventProps).map(this.client.getEventMapper());
return { memberEvents, fromServer };
}
@ -2255,8 +2255,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
const txnId = event.getUnsigned().transaction_id;
if (!txnId && event.getSender() === this.myUserId) {
// check the txn map for a matching event ID
for (const tid in this.txnToEvent) {
const localEvent = this.txnToEvent[tid];
for (const [tid, localEvent] of this.txnToEvent) {
if (localEvent.getId() === event.getId()) {
logger.debug("processLiveEvent: found sent event without txn ID: ", tid, event.getId());
// update the unsigned field so we can re-use the same codepaths
@ -2331,7 +2330,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
throw new Error("addPendingEvent called on an event with status " + event.status);
}
if (this.txnToEvent[txnId]) {
if (this.txnToEvent.get(txnId)) {
throw new Error("addPendingEvent called on an event with known txnId " + txnId);
}
@ -2340,7 +2339,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
// on the unfiltered timelineSet.
EventTimeline.setEventMetadata(event, this.getLiveTimeline().getState(EventTimeline.FORWARDS)!, false);
this.txnToEvent[txnId] = event;
this.txnToEvent.set(txnId, event);
if (this.pendingEventList) {
if (this.pendingEventList.some((e) => e.status === EventStatus.NOT_SENT)) {
logger.warn("Setting event as NOT_SENT due to messages in the same state");
@ -2432,8 +2431,8 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
this.relations.aggregateChildEvent(event);
}
public getEventForTxnId(txnId: string): MatrixEvent {
return this.txnToEvent[txnId];
public getEventForTxnId(txnId: string): MatrixEvent | undefined {
return this.txnToEvent.get(txnId);
}
/**
@ -2460,7 +2459,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
logger.debug(`Got remote echo for event ${oldEventId} -> ${newEventId} old status ${oldStatus}`);
// no longer pending
delete this.txnToEvent[remoteEvent.getUnsigned().transaction_id!];
this.txnToEvent.delete(remoteEvent.getUnsigned().transaction_id!);
// if it's in the pending list, remove it
if (this.pendingEventList) {
@ -2673,7 +2672,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
this.processLiveEvent(event);
if (event.getUnsigned().transaction_id) {
const existingEvent = this.txnToEvent[event.getUnsigned().transaction_id!];
const existingEvent = this.txnToEvent.get(event.getUnsigned().transaction_id!);
if (existingEvent) {
// remote echo of an event we sent earlier
this.handleRemoteEcho(event, existingEvent);
@ -2942,8 +2941,9 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
if (event.getType() === "m.tag") {
this.addTags(event);
}
const lastEvent = this.accountData[event.getType()];
this.accountData[event.getType()] = event;
const eventType = event.getType();
const lastEvent = this.accountData.get(eventType);
this.accountData.set(eventType, event);
this.emit(RoomEvent.AccountData, event, this, lastEvent);
}
}
@ -2954,7 +2954,7 @@ export class Room extends ReadReceipt<RoomEmittedEvents, RoomEventHandlerMap> {
* @returns the account_data event in question
*/
public getAccountData(type: EventType | string): MatrixEvent | undefined {
return this.accountData[type];
return this.accountData.get(type);
}
/**

View File

@ -36,7 +36,7 @@ export interface ISavedSync {
* A store for most of the data js-sdk needs to store, apart from crypto data
*/
export interface IStore {
readonly accountData: Record<string, MatrixEvent>; // type : content
readonly accountData: Map<string, MatrixEvent>; // type : content
// XXX: The indexeddb store exposes a non-standard emitter for the "degraded" event
// for when it falls back to being a memory store due to errors.

View File

@ -31,6 +31,7 @@ import { ISyncResponse } from "../sync-accumulator";
import { IStateEventWithRoomId } from "../@types/search";
import { IndexedToDeviceBatch, ToDeviceBatchWithTxnId } from "../models/ToDeviceMessage";
import { IStoredClientOpts } from "../client";
import { MapWithDefault } from "../utils";
function isValidFilterId(filterId?: string | number | null): boolean {
const isValidStr =
@ -54,10 +55,10 @@ export class MemoryStore implements IStore {
// userId: {
// filterId: Filter
// }
private filters: Record<string, Record<string, Filter>> = {};
public accountData: Record<string, MatrixEvent> = {}; // type : content
private filters: MapWithDefault<string, Map<string, Filter>> = new MapWithDefault(() => new Map());
public accountData: Map<string, MatrixEvent> = new Map(); // type: content
protected readonly localStorage?: Storage;
private oobMembers: Record<string, IStateEventWithRoomId[]> = {}; // roomId: [member events]
private oobMembers: Map<string, IStateEventWithRoomId[]> = new Map(); // roomId: [member events]
private pendingEvents: { [roomId: string]: Partial<IEvent>[] } = {};
private clientOptions?: IStoredClientOpts;
private pendingToDeviceBatches: IndexedToDeviceBatch[] = [];
@ -220,10 +221,7 @@ export class MemoryStore implements IStore {
*/
public storeFilter(filter: Filter): void {
if (!filter?.userId || !filter?.filterId) return;
if (!this.filters[filter.userId]) {
this.filters[filter.userId] = {};
}
this.filters[filter.userId][filter.filterId] = filter;
this.filters.getOrCreate(filter.userId).set(filter.filterId, filter);
}
/**
@ -231,10 +229,7 @@ export class MemoryStore implements IStore {
* @returns A filter or null.
*/
public getFilter(userId: string, filterId: string): Filter | null {
if (!this.filters[userId] || !this.filters[userId][filterId]) {
return null;
}
return this.filters[userId][filterId];
return this.filters.get(userId)?.get(filterId) || null;
}
/**
@ -289,9 +284,9 @@ export class MemoryStore implements IStore {
// MSC3391: an event with content of {} should be interpreted as deleted
const isDeleted = !Object.keys(event.getContent()).length;
if (isDeleted) {
delete this.accountData[event.getType()];
this.accountData.delete(event.getType());
} else {
this.accountData[event.getType()] = event;
this.accountData.set(event.getType(), event);
}
});
}
@ -302,7 +297,7 @@ export class MemoryStore implements IStore {
* @returns the user account_data event of given type, if any
*/
public getAccountData(eventType: EventType | string): MatrixEvent | undefined {
return this.accountData[eventType];
return this.accountData.get(eventType);
}
/**
@ -368,14 +363,8 @@ export class MemoryStore implements IStore {
// userId: User
};
this.syncToken = null;
this.filters = {
// userId: {
// filterId: Filter
// }
};
this.accountData = {
// type : content
};
this.filters = new MapWithDefault(() => new Map());
this.accountData = new Map(); // type : content
return Promise.resolve();
}
@ -386,7 +375,7 @@ export class MemoryStore implements IStore {
* @returns in case the members for this room haven't been stored yet
*/
public getOutOfBandMembers(roomId: string): Promise<IStateEventWithRoomId[] | null> {
return Promise.resolve(this.oobMembers[roomId] || null);
return Promise.resolve(this.oobMembers.get(roomId) || null);
}
/**
@ -397,12 +386,12 @@ export class MemoryStore implements IStore {
* @returns when all members have been stored
*/
public setOutOfBandMembers(roomId: string, membershipEvents: IStateEventWithRoomId[]): Promise<void> {
this.oobMembers[roomId] = membershipEvents;
this.oobMembers.set(roomId, membershipEvents);
return Promise.resolve();
}
public clearOutOfBandMembers(roomId: string): Promise<void> {
this.oobMembers = {};
this.oobMembers.delete(roomId);
return Promise.resolve();
}

View File

@ -34,7 +34,7 @@ import { IStoredClientOpts } from "../client";
* Construct a stub store. This does no-ops on most store methods.
*/
export class StubStore implements IStore {
public readonly accountData = {}; // stub
public readonly accountData = new Map(); // stub
private fromToken: string | null = null;
/** @returns whether or not the database was newly created in this session. */

View File

@ -19,7 +19,7 @@ limitations under the License.
*/
import { logger } from "./logger";
import { deepCopy, isSupportedReceiptType } from "./utils";
import { deepCopy, isSupportedReceiptType, MapWithDefault, recursiveMapToObject } from "./utils";
import { IContent, IUnsigned } from "./models/event";
import { IRoomSummary } from "./models/room-summary";
import { EventType } from "./@types/event";
@ -585,29 +585,31 @@ export class SyncAccumulator {
} as IContent,
};
const receiptEventContent: MapWithDefault<
string,
MapWithDefault<ReceiptType, Map<string, object>>
> = new MapWithDefault(() => new MapWithDefault(() => new Map()));
for (const [userId, receiptData] of Object.entries(roomData._readReceipts)) {
if (!receiptEvent.content[receiptData.eventId]) {
receiptEvent.content[receiptData.eventId] = {};
}
if (!receiptEvent.content[receiptData.eventId][receiptData.type]) {
receiptEvent.content[receiptData.eventId][receiptData.type] = {};
}
receiptEvent.content[receiptData.eventId][receiptData.type][userId] = receiptData.data;
receiptEventContent
.getOrCreate(receiptData.eventId)
.getOrCreate(receiptData.type)
.set(userId, receiptData.data);
}
for (const threadReceipts of Object.values(roomData._threadReadReceipts)) {
for (const [userId, receiptData] of Object.entries(threadReceipts)) {
if (!receiptEvent.content[receiptData.eventId]) {
receiptEvent.content[receiptData.eventId] = {};
}
if (!receiptEvent.content[receiptData.eventId][receiptData.type]) {
receiptEvent.content[receiptData.eventId][receiptData.type] = {};
}
receiptEvent.content[receiptData.eventId][receiptData.type][userId] = receiptData.data;
receiptEventContent
.getOrCreate(receiptData.eventId)
.getOrCreate(receiptData.type)
.set(userId, receiptData.data);
}
}
receiptEvent.content = recursiveMapToObject(receiptEventContent);
// add only if we have some receipt data
if (Object.keys(receiptEvent.content).length > 0) {
if (receiptEventContent.size > 0) {
roomJson.ephemeral.events.push(receiptEvent as IMinimalEvent);
}

View File

@ -29,7 +29,7 @@ import type { SyncCryptoCallbacks } from "./common-crypto/CryptoBackend";
import { User, UserEvent } from "./models/user";
import { NotificationCountType, Room, RoomEvent } from "./models/room";
import * as utils from "./utils";
import { IDeferred } from "./utils";
import { IDeferred, noUnsafeEventProps, unsafeProp } from "./utils";
import { Filter } from "./filter";
import { EventTimeline } from "./models/event-timeline";
import { logger } from "./logger";
@ -1133,7 +1133,9 @@ export class SyncApi {
// handle presence events (User objects)
if (Array.isArray(data.presence?.events)) {
data.presence!.events.map(client.getEventMapper()).forEach(function (presenceEvent) {
data.presence!.events.filter(noUnsafeEventProps)
.map(client.getEventMapper())
.forEach(function (presenceEvent) {
let user = client.store.getUser(presenceEvent.getSender()!);
if (user) {
user.setPresenceEvent(presenceEvent);
@ -1148,7 +1150,7 @@ export class SyncApi {
// handle non-room account_data
if (Array.isArray(data.account_data?.events)) {
const events = data.account_data.events.map(client.getEventMapper());
const events = data.account_data.events.filter(noUnsafeEventProps).map(client.getEventMapper());
const prevEventsMap = events.reduce<Record<string, MatrixEvent | undefined>>((m, c) => {
m[c.getType()!] = client.store.getAccountData(c.getType());
return m;
@ -1171,7 +1173,7 @@ export class SyncApi {
// handle to-device events
if (data.to_device && Array.isArray(data.to_device.events) && data.to_device.events.length > 0) {
let toDeviceMessages: IToDeviceEvent[] = data.to_device.events;
let toDeviceMessages: IToDeviceEvent[] = data.to_device.events.filter(noUnsafeEventProps);
if (this.syncOpts.cryptoCallbacks) {
toDeviceMessages = await this.syncOpts.cryptoCallbacks.preprocessToDeviceMessages(toDeviceMessages);
@ -1630,7 +1632,9 @@ export class SyncApi {
// to
// [{stuff+Room+isBrandNewRoom}, {stuff+Room+isBrandNewRoom}]
const client = this.client;
return Object.keys(obj).map((roomId) => {
return Object.keys(obj)
.filter((k) => !unsafeProp(k))
.map((roomId) => {
const arrObj = obj[roomId] as T & { room: Room; isBrandNewRoom: boolean };
let room = client.store.getRoom(roomId);
let isBrandNewRoom = false;
@ -1654,7 +1658,7 @@ export class SyncApi {
}
const mapper = this.client.getEventMapper({ decrypt });
type TaggedEvent = (IStrippedState | IRoomEvent | IStateEvent | IMinimalEvent) & { room_id?: string };
return (obj.events as TaggedEvent[]).map(function (e) {
return (obj.events as TaggedEvent[]).filter(noUnsafeEventProps).map(function (e) {
if (room) {
e.room_id = room.roomId;
}

View File

@ -22,7 +22,7 @@ import unhomoglyph from "unhomoglyph";
import promiseRetry from "p-retry";
import { Optional } from "matrix-events-sdk";
import { MatrixEvent } from "./models/event";
import { IEvent, MatrixEvent } from "./models/event";
import { M_TIMESTAMP } from "./@types/location";
import { ReceiptType } from "./@types/read_receipts";
@ -703,3 +703,68 @@ export function mapsEqual<K, V>(x: Map<K, V>, y: Map<K, V>, eq = (v1: V, v2: V):
}
return true;
}
function processMapToObjectValue(value: any): any {
if (value instanceof Map) {
// Value is a Map. Recursively map it to an object.
return recursiveMapToObject(value);
} else if (Array.isArray(value)) {
// Value is an Array. Recursively map the value (e.g. to cover Array of Arrays).
return value.map((v) => processMapToObjectValue(v));
} else {
return value;
}
}
/**
* Recursively converts Maps to plain objects.
* Also supports sub-lists of Maps.
*/
export function recursiveMapToObject(map: Map<any, any>): any {
const targetMap = new Map();
for (const [key, value] of map) {
targetMap.set(key, processMapToObjectValue(value));
}
return Object.fromEntries(targetMap.entries());
}
export function unsafeProp<K extends keyof any | undefined>(prop: K): boolean {
return prop === "__proto__" || prop === "prototype" || prop === "constructor";
}
export function safeSet<K extends keyof any>(obj: Record<any, any>, prop: K, value: any): void {
if (unsafeProp(prop)) {
throw new Error("Trying to modify prototype or constructor");
}
obj[prop] = value;
}
export function noUnsafeEventProps(event: Partial<IEvent>): boolean {
return !(
unsafeProp(event.room_id) ||
unsafeProp(event.sender) ||
unsafeProp(event.user_id) ||
unsafeProp(event.event_id)
);
}
export class MapWithDefault<K, V> extends Map<K, V> {
public constructor(private createDefault: () => V) {
super();
}
/**
* Returns the value if the key already exists.
* If not, it creates a new value under that key using the ctor callback and returns it.
*/
public getOrCreate(key: K): V {
if (!this.has(key)) {
this.set(key, this.createDefault());
}
return this.get(key)!;
}
}

View File

@ -609,7 +609,7 @@ export class MatrixCall extends TypedEventEmitter<CallEvent, CallEventHandlerMap
if (!userId) throw new Error("Couldn't find opponent user ID to init crypto");
const deviceInfoMap = await this.client.crypto.deviceList.downloadKeys([userId], false);
this.opponentDeviceInfo = deviceInfoMap[userId][this.opponentDeviceId];
this.opponentDeviceInfo = deviceInfoMap.get(userId)?.get(this.opponentDeviceId);
if (this.opponentDeviceInfo === undefined) {
throw new GroupCallUnknownDeviceError(userId);
}
@ -2408,11 +2408,10 @@ export class MatrixCall extends TypedEventEmitter<CallEvent, CallEventHandlerMap
},
);
} else {
await this.client.sendToDevice(eventType, {
[userId]: {
[this.opponentDeviceId]: content,
},
});
await this.client.sendToDevice(
eventType,
new Map<string, any>([[userId, new Map([[this.opponentDeviceId, content]])]]),
);
}
} else {
this.emit(CallEvent.SendVoipEvent, {