diff --git a/spec/unit/crypto/algorithms/megolm.spec.js b/spec/unit/crypto/algorithms/megolm.spec.js index 7f86f86ff..79257e6d7 100644 --- a/spec/unit/crypto/algorithms/megolm.spec.js +++ b/spec/unit/crypto/algorithms/megolm.spec.js @@ -129,12 +129,16 @@ describe("MegolmDecryption", function() { const deviceInfo = {}; mockCrypto.getStoredDevice.andReturn(deviceInfo); - const awaitEnsureSessions = new Promise((res, rej) => { - mockOlmLib.ensureOlmSessionsForDevices.andCall(() => { + mockOlmLib.ensureOlmSessionsForDevices.andReturn( + Promise.resolve({'@alice:foo': {'alidevice': { + sessionId: 'alisession', + }}}), + ); + + const awaitEncryptForDevice = new Promise((res, rej) => { + mockOlmLib.encryptMessageForDevice.andCall(() => { res(); - return Promise.resolve({'@alice:foo': {'alidevice': { - sessionId: 'alisession', - }}}); + return Promise.resolve(); }); }); @@ -144,7 +148,7 @@ describe("MegolmDecryption", function() { megolmDecryption.shareKeysWithDevice(keyRequest); // it's asynchronous, so we have to wait a bit - return awaitEnsureSessions; + return awaitEncryptForDevice; }).then(() => { // check that it called encryptMessageForDevice with // appropriate args. diff --git a/src/crypto/OlmDevice.js b/src/crypto/OlmDevice.js index f53ada53b..184810822 100644 --- a/src/crypto/OlmDevice.js +++ b/src/crypto/OlmDevice.js @@ -849,9 +849,9 @@ OlmDevice.prototype.decryptGroupMessage = async function( * @param {string} senderKey base64-encoded curve25519 key of the sender * @param {sring} sessionId session identifier * - * @returns {boolean} true if we have the keys to this session + * @returns {Promise} true if we have the keys to this session */ -OlmDevice.prototype.hasInboundSessionKeys = function(roomId, senderKey, sessionId) { +OlmDevice.prototype.hasInboundSessionKeys = async function(roomId, senderKey, sessionId) { const s = this._sessionStore.getEndToEndInboundGroupSession( senderKey, sessionId, ); @@ -880,14 +880,16 @@ OlmDevice.prototype.hasInboundSessionKeys = function(roomId, senderKey, sessionI * @param {string} senderKey base64-encoded curve25519 key of the sender * @param {string} sessionId session identifier * - * @returns {{chain_index: number, key: string, + * @returns {Promise<{chain_index: number, key: string, * forwarding_curve25519_key_chain: Array, * sender_claimed_ed25519_key: string - * }} + * }>} * details of the session key. The key is a base64-encoded megolm key in * export format. */ -OlmDevice.prototype.getInboundGroupSessionKey = function(roomId, senderKey, sessionId) { +OlmDevice.prototype.getInboundGroupSessionKey = async function( + roomId, senderKey, sessionId, +) { function getKey(session, sessionData) { const messageIndex = session.first_known_index(); diff --git a/src/crypto/algorithms/megolm.js b/src/crypto/algorithms/megolm.js index 16cc2e3b0..7c9a6a1f4 100644 --- a/src/crypto/algorithms/megolm.js +++ b/src/crypto/algorithms/megolm.js @@ -725,7 +725,7 @@ MegolmDecryption.prototype.onRoomKeyEvent = function(event) { /** * @inheritdoc */ -MegolmDecryption.prototype.hasKeysForKeyRequest = async function(keyRequest) { +MegolmDecryption.prototype.hasKeysForKeyRequest = function(keyRequest) { const body = keyRequest.requestBody; return this._olmDevice.hasInboundSessionKeys( @@ -766,10 +766,10 @@ MegolmDecryption.prototype.shareKeysWithDevice = function(keyRequest) { + userId + ":" + deviceId, ); - const payload = this._buildKeyForwardingMessage( + return this._buildKeyForwardingMessage( body.room_id, body.sender_key, body.session_id, ); - + }).then((payload) => { const encryptedContent = { algorithm: olmlib.OLM_ALGORITHM, sender_key: this._olmDevice.deviceCurve25519Key, @@ -797,10 +797,10 @@ MegolmDecryption.prototype.shareKeysWithDevice = function(keyRequest) { }).done(); }; -MegolmDecryption.prototype._buildKeyForwardingMessage = function( +MegolmDecryption.prototype._buildKeyForwardingMessage = async function( roomId, senderKey, sessionId, ) { - const key = this._olmDevice.getInboundGroupSessionKey( + const key = await this._olmDevice.getInboundGroupSessionKey( roomId, senderKey, sessionId, );