diff --git a/spec/unit/matrixrtc/MatrixRTCSession.spec.ts b/spec/unit/matrixrtc/MatrixRTCSession.spec.ts index 1d9feea13..16f55386f 100644 --- a/spec/unit/matrixrtc/MatrixRTCSession.spec.ts +++ b/spec/unit/matrixrtc/MatrixRTCSession.spec.ts @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -import { EventTimeline, EventType, MatrixClient, MatrixError, MatrixEvent, Room } from "../../../src"; +import { encodeBase64, EventTimeline, EventType, MatrixClient, MatrixError, MatrixEvent, Room } from "../../../src"; import { KnownMembership } from "../../../src/@types/membership"; import { CallMembershipData, @@ -1145,6 +1145,7 @@ describe("MatrixRTCSession", () => { ], }), getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(Date.now()), } as unknown as MatrixEvent); const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!; @@ -1168,6 +1169,7 @@ describe("MatrixRTCSession", () => { ], }), getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(Date.now()), } as unknown as MatrixEvent); const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!; @@ -1179,6 +1181,130 @@ describe("MatrixRTCSession", () => { expect(bobKeys[4]).toEqual(Buffer.from("this is the key", "utf-8")); }); + it("collects keys by merging", () => { + const mockRoom = makeMockRoom([membershipTemplate]); + sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom); + sess.onCallEncryption({ + getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"), + getContent: jest.fn().mockReturnValue({ + device_id: "bobsphone", + call_id: "", + keys: [ + { + index: 0, + key: "dGhpcyBpcyB0aGUga2V5", + }, + ], + }), + getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(Date.now()), + } as unknown as MatrixEvent); + + let bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!; + expect(bobKeys).toHaveLength(1); + expect(bobKeys[0]).toEqual(Buffer.from("this is the key", "utf-8")); + + sess.onCallEncryption({ + getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"), + getContent: jest.fn().mockReturnValue({ + device_id: "bobsphone", + call_id: "", + keys: [ + { + index: 4, + key: "dGhpcyBpcyB0aGUga2V5", + }, + ], + }), + getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(Date.now()), + } as unknown as MatrixEvent); + + bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!; + expect(bobKeys).toHaveLength(5); + expect(bobKeys[4]).toEqual(Buffer.from("this is the key", "utf-8")); + }); + + it("ignores older keys at same index", () => { + const mockRoom = makeMockRoom([membershipTemplate]); + sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom); + sess.onCallEncryption({ + getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"), + getContent: jest.fn().mockReturnValue({ + device_id: "bobsphone", + call_id: "", + keys: [ + { + index: 0, + key: encodeBase64(Buffer.from("newer key", "utf-8")), + }, + ], + }), + getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(2000), + } as unknown as MatrixEvent); + + sess.onCallEncryption({ + getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"), + getContent: jest.fn().mockReturnValue({ + device_id: "bobsphone", + call_id: "", + keys: [ + { + index: 0, + key: encodeBase64(Buffer.from("older key", "utf-8")), + }, + ], + }), + getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(1000), // earlier timestamp than the newer key + } as unknown as MatrixEvent); + + const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!; + expect(bobKeys).toHaveLength(1); + expect(bobKeys[0]).toEqual(Buffer.from("newer key", "utf-8")); + }); + + it("key timestamps are treated as monotonic", () => { + const mockRoom = makeMockRoom([membershipTemplate]); + sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom); + sess.onCallEncryption({ + getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"), + getContent: jest.fn().mockReturnValue({ + device_id: "bobsphone", + call_id: "", + keys: [ + { + index: 0, + key: encodeBase64(Buffer.from("first key", "utf-8")), + }, + ], + }), + getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(1000), + } as unknown as MatrixEvent); + + sess.onCallEncryption({ + getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"), + getContent: jest.fn().mockReturnValue({ + device_id: "bobsphone", + call_id: "", + keys: [ + { + index: 0, + key: encodeBase64(Buffer.from("second key", "utf-8")), + }, + ], + }), + getSender: jest.fn().mockReturnValue("@bob:example.org"), + getTs: jest.fn().mockReturnValue(1000), // same timestamp as the first key + } as unknown as MatrixEvent); + + const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!; + expect(bobKeys).toHaveLength(1); + expect(bobKeys[0]).toEqual(Buffer.from("second key", "utf-8")); + }); + it("ignores keys event for the local participant", () => { const mockRoom = makeMockRoom([membershipTemplate]); sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom); @@ -1195,6 +1321,7 @@ describe("MatrixRTCSession", () => { ], }), getSender: jest.fn().mockReturnValue(client.getUserId()), + getTs: jest.fn().mockReturnValue(Date.now()), } as unknown as MatrixEvent); const myKeys = sess.getKeysForParticipant(client.getUserId()!, client.getDeviceId()!)!; diff --git a/src/matrixrtc/MatrixRTCSession.ts b/src/matrixrtc/MatrixRTCSession.ts index ed3d7a44f..c5b73a3fd 100644 --- a/src/matrixrtc/MatrixRTCSession.ts +++ b/src/matrixrtc/MatrixRTCSession.ts @@ -56,9 +56,9 @@ const USE_KEY_DELAY = 5000; const getParticipantId = (userId: string, deviceId: string): string => `${userId}:${deviceId}`; const getParticipantIdFromMembership = (m: CallMembership): string => getParticipantId(m.sender!, m.deviceId); -function keysEqual(a: Uint8Array, b: Uint8Array): boolean { +function keysEqual(a: Uint8Array | undefined, b: Uint8Array | undefined): boolean { if (a === b) return true; - return a && b && a.length === b.length && a.every((x, i) => x === b[i]); + return !!a && !!b && a.length === b.length && a.every((x, i) => x === b[i]); } export enum MatrixRTCSessionEvent { @@ -134,8 +134,8 @@ export class MatrixRTCSession extends TypedEventEmitter array of keys - private encryptionKeys = new Map>(); + // userId:deviceId => array of (key, timestamp) + private encryptionKeys = new Map>(); private lastEncryptionKeyUpdateRequest?: number; // We use this to store the last membership fingerprints we saw, so we can proactively re-send encryption keys @@ -378,8 +378,15 @@ export class MatrixRTCSession extends TypedEventEmitter | undefined { - return this.encryptionKeys.get(getParticipantId(userId, deviceId)); + return this.encryptionKeys.get(getParticipantId(userId, deviceId))?.map((entry) => entry.key); } /** @@ -387,7 +394,10 @@ export class MatrixRTCSession extends TypedEventEmitter]> { - return this.encryptionKeys.entries(); + // the returned array doesn't contain the timestamps + return Array.from(this.encryptionKeys.entries()) + .map(([participantId, keys]): [string, Uint8Array[]] => [participantId, keys.map((k) => k.key)]) + .values(); } private getNewEncryptionKeyIndex(): number { @@ -402,12 +412,14 @@ export class MatrixRTCSession extends TypedEventEmitter timestamp) { + logger.info( + `Ignoring new key at index ${encryptionKeyIndex} for ${participantId} as it is older than existing known key`, + ); + return; + } + + if (keysEqual(existingKeyAtIndex.key, keyBin)) { + existingKeyAtIndex.timestamp = timestamp; + return; + } + } + + participantKeys[encryptionKeyIndex] = { + key: keyBin, + timestamp, + }; - encryptionKeys[encryptionKeyIndex] = keyBin; - this.encryptionKeys.set(participantId, encryptionKeys); if (delayBeforeUse) { const useKeyTimeout = setTimeout(() => { this.setNewKeyTimeouts.delete(useKeyTimeout); @@ -455,7 +488,7 @@ export class MatrixRTCSession extends TypedEventEmitter { const userId = event.getSender(); const content = event.getContent(); @@ -635,7 +675,7 @@ export class MatrixRTCSession extends TypedEventEmitter