From 4c651c15ea28657de55db85e3843dcb7de0928e0 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 11 Nov 2019 20:01:11 +0000 Subject: [PATCH] Convert secrets events to callbacks too --- spec/unit/crypto/secrets.spec.js | 40 ++++--- src/client.js | 14 +++ src/crypto/Secrets.js | 197 +++++++++++++++---------------- src/crypto/index.js | 4 +- 4 files changed, 138 insertions(+), 117 deletions(-) diff --git a/spec/unit/crypto/secrets.spec.js b/spec/unit/crypto/secrets.spec.js index 4024fc52f..be23c6d61 100644 --- a/spec/unit/crypto/secrets.spec.js +++ b/spec/unit/crypto/secrets.spec.js @@ -45,15 +45,25 @@ describe("Secrets", function() { }); it("should store and retrieve a secret", async function() { - const alice = await makeTestClient( - {userId: "@alice:example.com", deviceId: "Osborne2"}, - ); - const secretStorage = alice._crypto._secretStorage; - const decryption = new global.Olm.PkDecryption(); const pubkey = decryption.generate_key(); const privkey = decryption.get_private_key(); + const getKey = expect.createSpy().andCall(e => { + expect(Object.keys(e.keys)).toEqual(["abc"]); + return ['abc', privkey]; + }); + + const alice = await makeTestClient( + {userId: "@alice:example.com", deviceId: "Osborne2"}, + { + cryptoCallbacks: { + getSecretStorageKey: getKey, + }, + }, + ); + const secretStorage = alice._crypto._secretStorage; + alice.setAccountData = async function(eventType, contents, callback) { alice.store.storeAccountDataEvents([ new MatrixEvent({ @@ -81,13 +91,6 @@ describe("Secrets", function() { await secretStorage.store("foo", "bar", ["abc"]); expect(secretStorage.isStored("foo")).toBe(true); - - const getKey = expect.createSpy().andCall(function(e) { - expect(Object.keys(e.keys)).toEqual(["abc"]); - e.done("abc", privkey); - }); - alice.once("crypto.secrets.getKey", getKey); - expect(await secretStorage.get("foo")).toBe("bar"); expect(getKey).toHaveBeenCalled(); @@ -99,6 +102,14 @@ describe("Secrets", function() { {userId: "@alice:example.com", deviceId: "Osborne2"}, {userId: "@alice:example.com", deviceId: "VAX"}, ], + { + cryptoCallbacks: { + onSecretRequested: e => { + expect(e.name).toBe("foo"); + return "bar"; + }, + }, + }, ); const vaxDevice = vax.client._crypto._olmDevice; @@ -128,11 +139,6 @@ describe("Secrets", function() { }, }); - vax.client.once("crypto.secrets.request", function(e) { - expect(e.name).toBe("foo"); - e.send("bar"); - }); - await osborne2Device.generateOneTimeKeys(1); const otks = (await osborne2Device.getOneTimeKeys()).curve25519; await osborne2Device.markKeysAsPublished(); diff --git a/src/client.js b/src/client.js index 54d4f038a..8b23a9a75 100644 --- a/src/client.js +++ b/src/client.js @@ -210,6 +210,20 @@ function keyFromRecoverySession(session, decryptionKey) { * Should return a promise which resolves with an array of the user IDs who * should be cross-signed. * + * @param {function} [opts.cryptoCallbacks.getSecretStorageKey] + * Optional. Function called when an encryption key for secret storage + * is required. One or more keys will be described in the keys object. + * The callback function should return with an array of: + * [, ] or null if it cannot provide + * any of the keys. + * Args: + * {object} keys Information about the keys: + * { + * : { + * pubkey: {UInt8Array} + * } + * } + * * @param {function} [opts.cryptoCallbacks.onSecretRequested] * Optional. Function called when a request for a secret is received from another * device. diff --git a/src/crypto/Secrets.js b/src/crypto/Secrets.js index aebf869bf..738e9b024 100644 --- a/src/crypto/Secrets.js +++ b/src/crypto/Secrets.js @@ -26,11 +26,12 @@ import { encodeRecoveryKey } from './recoverykey'; * @module crypto/Secrets */ export default class SecretStorage extends EventEmitter { - constructor(baseApis) { + constructor(baseApis, cryptoCallbacks) { super(); this._baseApis = baseApis; this._requests = {}; this._incomingRequests = {}; + this._cryptoCallbacks = cryptoCallbacks; } /** @@ -159,7 +160,7 @@ export default class SecretStorage extends EventEmitter { const secretContent = secretInfo.getContent(); if (!secretContent.encrypted) { - return; + throw new Error("Content is not encrypted!"); } // get possible keys to decrypt @@ -182,63 +183,13 @@ export default class SecretStorage extends EventEmitter { } } - // fetch private key from app - let decryption; let keyName; - let cleanUp; - let error; - do { - [keyName, decryption, cleanUp] = await new Promise((resolve, reject) => { - this._baseApis.emit("crypto.secrets.getKey", { - keys, - error, - done: function(keyName, key) { - // FIXME: interpret key? - if (!keys[keyName]) { - error = "Unknown key (your app is broken)"; - resolve([]); - } - switch (keys[keyName].algorithm) { - case "m.secret_storage.v1.curve25519-aes-sha2": - { - const decryption = new global.Olm.PkDecryption(); - try { - const pubkey = decryption.init_with_private_key(key); - if (pubkey !== keys[keyName].pubkey) { - error = "Key does not match"; - resolve([]); - return; - } - } catch (e) { - decryption.free(); - error = "Invalid key"; - resolve([]); - return; - } - resolve([ - keyName, - decryption, - decryption.free.bind(decryption), - ]); - break; - } - default: - error = "The universe is broken"; - resolve([]); - } - }, - cancel: function(e) { - reject(e || new Error("Cancelled")); - }, - }); - }); - if (error) { - logger.error("Error getting private key:", error); - } - } while (!keyName); - - // decrypt secret + let decryption; try { + // fetch private key from app + [keyName, decryption] = await this._getSecretStorageKey(keys); + + // decrypt secret const encInfo = secretContent.encrypted[keyName]; switch (keys[keyName].algorithm) { case "m.secret_storage.v1.curve25519-aes-sha2": @@ -247,7 +198,7 @@ export default class SecretStorage extends EventEmitter { ); } } finally { - cleanUp(); + if (decryption) decryption.free(); } } @@ -358,7 +309,7 @@ export default class SecretStorage extends EventEmitter { }; } - _onRequestReceived(event) { + async _onRequestReceived(event) { const sender = event.getSender(); const content = event.getContent(); if (sender !== this._baseApis.getUserId() @@ -389,52 +340,55 @@ export default class SecretStorage extends EventEmitter { // check if we have the secret logger.info("received request for secret (" + sender + ", " + deviceId + ", " + content.request_id + ")"); - this._baseApis.emit("crypto.secrets.request", { + if (!this._cryptoCallbacks.onSecretRequested) { + return; + } + const secret = await this._cryptoCallbacks.onSecretRequested({ user_id: sender, device_id: deviceId, request_id: content.request_id, name: content.name, device_trust: this._baseApis.checkDeviceTrust(sender, deviceId), - send: async (secret) => { - const payload = { - type: "m.secret.send", - content: { - request_id: content.request_id, - secret: secret, - }, - }; - const encryptedContent = { - algorithm: olmlib.OLM_ALGORITHM, - sender_key: this._baseApis._crypto._olmDevice.deviceCurve25519Key, - ciphertext: {}, - }; - await olmlib.ensureOlmSessionsForDevices( - this._baseApis._crypto._olmDevice, - this._baseApis, - { - [sender]: [ - await this._baseApis.getStoredDevice(sender, deviceId), - ], - }, - ); - await olmlib.encryptMessageForDevice( - encryptedContent.ciphertext, - this._baseApis.getUserId(), - this._baseApis.deviceId, - this._baseApis._crypto._olmDevice, - sender, - this._baseApis._crypto.getStoredDevice(sender, deviceId), - payload, - ); - const contentMap = { - [sender]: { - [deviceId]: encryptedContent, - }, - }; - - this._baseApis.sendToDevice("m.room.encrypted", contentMap); - }, }); + if (secret) { + const payload = { + type: "m.secret.send", + content: { + request_id: content.request_id, + secret: secret, + }, + }; + const encryptedContent = { + algorithm: olmlib.OLM_ALGORITHM, + sender_key: this._baseApis._crypto._olmDevice.deviceCurve25519Key, + ciphertext: {}, + }; + await olmlib.ensureOlmSessionsForDevices( + this._baseApis._crypto._olmDevice, + this._baseApis, + { + [sender]: [ + await this._baseApis.getStoredDevice(sender, deviceId), + ], + }, + ); + await olmlib.encryptMessageForDevice( + encryptedContent.ciphertext, + this._baseApis.getUserId(), + this._baseApis.deviceId, + this._baseApis._crypto._olmDevice, + sender, + this._baseApis._crypto.getStoredDevice(sender, deviceId), + payload, + ); + const contentMap = { + [sender]: { + [deviceId]: encryptedContent, + }, + }; + + this._baseApis.sendToDevice("m.room.encrypted", contentMap); + } } } @@ -468,4 +422,49 @@ export default class SecretStorage extends EventEmitter { requestControl.resolve(content.secret); } } + + async _getSecretStorageKey(keys) { + if (!this._cryptoCallbacks.getSecretStorageKey) { + throw new Error("No getSecretStorageKey callback supplied"); + } + + const returned = await Promise.resolve( + this._cryptoCallbacks.getSecretStorageKey({keys}), + ); + + if (!returned) { + throw new Error("getSecretStorageKey callback returned falsey"); + } + if (returned.length < 2) { + throw new Error("getSecretStorageKey callback returned invalid data"); + } + + const [keyName, privateKey] = returned; + if (!keys[keyName]) { + throw new Error("App returned unknown key from getSecretStorageKey!"); + } + + switch (keys[keyName].algorithm) { + case "m.secret_storage.v1.curve25519-aes-sha2": + { + const decryption = new global.Olm.PkDecryption(); + let pubkey; + try { + pubkey = decryption.init_with_private_key(privateKey); + } catch (e) { + decryption.free(); + throw new Error("getSecretStorageKey callback returned invalid key"); + } + if (pubkey !== keys[keyName].pubkey) { + decryption.free(); + throw new Error( + "getSecretStorageKey callback returned incorrect key", + ); + } + return [keyName, decryption]; + } + default: + throw new Error("Unknown key type: " + keys[keyName].algorithm); + } + } } diff --git a/src/crypto/index.js b/src/crypto/index.js index bbe7c8be5..651f9df4d 100644 --- a/src/crypto/index.js +++ b/src/crypto/index.js @@ -207,7 +207,9 @@ export default function Crypto(baseApis, sessionStore, userId, deviceId, userId, this._baseApis._cryptoCallbacks, ); - this._secretStorage = new SecretStorage(baseApis); + this._secretStorage = new SecretStorage( + baseApis, this._baseApis._cryptoCallbacks, + ); } utils.inherits(Crypto, EventEmitter);