1
0
mirror of https://github.com/matrix-org/matrix-js-sdk.git synced 2026-01-03 23:22:30 +03:00

only create one session at a time per device

This commit is contained in:
Hubert Chathi
2019-03-12 16:04:26 -04:00
parent 4570fcaa8a
commit 79ca235e7c
3 changed files with 149 additions and 19 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright 2018 New Vector Ltd
Copyright 2018,2019 New Vector Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,6 +23,8 @@ import MockStorageApi from '../../../MockStorageApi';
import testUtils from '../../../test-utils';
import OlmDevice from '../../../../lib/crypto/OlmDevice';
import olmlib from '../../../../lib/crypto/olmlib';
import DeviceInfo from '../../../../lib/crypto/deviceinfo';
function makeOlmDevice() {
const mockStorage = new MockStorageApi();
@@ -82,5 +84,61 @@ describe("OlmDecryption", function() {
"The olm or proteus is an aquatic salamander in the family Proteidae",
);
});
it("creates only one session at a time", async function() {
// if we call ensureOlmSessionsForDevices multiple times, it should
// only try to create one session at a time, even if the server is
// slow
let count = 0;
const baseApis = {
claimOneTimeKeys: () => {
// simulate a very slow server (.5 seconds to respond)
count++;
return new Promise((resolve, reject) => {
setTimeout(reject, 500);
});
},
};
const devicesByUser = {
"@bob:example.com": [
DeviceInfo.fromStorage({
keys: {
"curve25519:ABCDEFG": "akey",
},
}, "ABCDEFG"),
],
};
function alwaysSucceed(promise) {
// swallow any exception thrown by a promise, so that
// Promise.all doesn't abort
return promise.catch(() => {});
}
// start two tasks that try to ensure that there's an olm session
const promises = Promise.all([
alwaysSucceed(olmlib.ensureOlmSessionsForDevices(
aliceOlmDevice, baseApis, devicesByUser,
)),
alwaysSucceed(olmlib.ensureOlmSessionsForDevices(
aliceOlmDevice, baseApis, devicesByUser,
)),
]);
await new Promise((resolve) => {
setTimeout(resolve, 200);
});
// after .2s, both tasks should have started, but one should be
// waiting on the other before trying to create a session, so
// claimOneTimeKeys should have only been called once
expect(count).toBe(1);
await promises;
// after waiting for both tasks to complete, the first task should
// have failed, so the second task should have tried to create a
// new session and will have called claimOneTimeKeys
expect(count).toBe(2);
});
});
});

View File

@@ -1,6 +1,6 @@
/*
Copyright 2016 OpenMarket Ltd
Copyright 2017 New Vector Ltd
Copyright 2017, 2019 New Vector Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -102,6 +102,10 @@ function OlmDevice(sessionStore, cryptoStore) {
// Keys are strings of form "<senderKey>|<session_id>|<message_index>"
// Values are objects of the form "{id: <event id>, timestamp: <ts>}"
this._inboundGroupSessionMessageIndexes = {};
// Keep track of sessions that we're starting, so that we don't start
// multiple sessions for the same device at the same time.
this._sessionsInProgress = {};
}
/**
@@ -553,6 +557,15 @@ OlmDevice.prototype.createInboundSession = async function(
* @return {Promise<string[]>} a list of known session ids for the device
*/
OlmDevice.prototype.getSessionIdsForDevice = async function(theirDeviceIdentityKey) {
if (this._sessionsInProgress[theirDeviceIdentityKey]) {
console.log("waiting for session to be created");
try {
await this._sessionsInProgress[theirDeviceIdentityKey];
} catch (e) {
// if the session failed to be created, just fall through and
// return an empty result
}
}
let sessionIds;
await this._cryptoStore.doTxn(
'readonly', [IndexedDBCryptoStore.STORE_SESSIONS],
@@ -573,10 +586,18 @@ OlmDevice.prototype.getSessionIdsForDevice = async function(theirDeviceIdentityK
*
* @param {string} theirDeviceIdentityKey Curve25519 identity key for the
* remote device
* @param {boolean} nowait Don't wait for an in-progress session to complete.
* This should only be set to true of the calling function is the function
* that marked the session as being in-progress.
* @return {Promise<?string>} session id, or null if no established session
*/
OlmDevice.prototype.getSessionIdForDevice = async function(theirDeviceIdentityKey) {
const sessionInfos = await this.getSessionInfoForDevice(theirDeviceIdentityKey);
OlmDevice.prototype.getSessionIdForDevice = async function(
theirDeviceIdentityKey, nowait,
) {
const sessionInfos = await this.getSessionInfoForDevice(
theirDeviceIdentityKey, nowait,
);
if (sessionInfos.length === 0) {
return null;
}
@@ -611,9 +632,21 @@ OlmDevice.prototype.getSessionIdForDevice = async function(theirDeviceIdentityKe
* message and is therefore past the pre-key stage), and 'sessionId'.
*
* @param {string} deviceIdentityKey Curve25519 identity key for the device
* @param {boolean} nowait Don't wait for an in-progress session to complete.
* This should only be set to true of the calling function is the function
* that marked the session as being in-progress.
* @return {Array.<{sessionId: string, hasReceivedMessage: Boolean}>}
*/
OlmDevice.prototype.getSessionInfoForDevice = async function(deviceIdentityKey) {
OlmDevice.prototype.getSessionInfoForDevice = async function(deviceIdentityKey, nowait) {
if (this._sessionsInProgress[deviceIdentityKey] && !nowait) {
logger.log("waiting for session to be created");
try {
await this._sessionsInProgress[deviceIdentityKey];
} catch (e) {
// if the session failed to be created, then just fall through and
// return an empty result
}
}
const info = [];
await this._cryptoStore.doTxn(

View File

@@ -1,5 +1,6 @@
/*
Copyright 2016 OpenMarket Ltd
Copyright 2019 New Vector Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -137,6 +138,7 @@ module.exports.ensureOlmSessionsForDevices = async function(
// [userId, deviceId], ...
];
const result = {};
const resolveSession = {};
for (const userId in devicesByUser) {
if (!devicesByUser.hasOwnProperty(userId)) {
@@ -148,7 +150,36 @@ module.exports.ensureOlmSessionsForDevices = async function(
const deviceInfo = devices[j];
const deviceId = deviceInfo.deviceId;
const key = deviceInfo.getIdentityKey();
const sessionId = await olmDevice.getSessionIdForDevice(key);
if (!olmDevice._sessionsInProgress[key]) {
// pre-emptively mark the session as in-progress to avoid race
// conditions. If we find that we already have a session, then
// we'll resolve
olmDevice._sessionsInProgress[key] = new Promise(
(resolve, reject) => {
resolveSession[key] = {
resolve: (...args) => {
delete olmDevice._sessionsInProgress[key];
resolve(...args);
},
reject: (...args) => {
delete olmDevice._sessionsInProgress[key];
reject(...args);
},
};
},
);
}
const sessionId = await olmDevice.getSessionIdForDevice(
key, resolveSession[key],
);
if (sessionId !== null && resolveSession[key]) {
// we found a session, but we had marked the session as
// in-progress, so unmark it and unblock anything that was
// waiting
delete olmDevice._sessionsInProgress[key];
resolveSession[key].resolve();
delete resolveSession[key];
}
if (sessionId === null || force) {
devicesWithoutSession.push([userId, deviceId]);
}
@@ -163,16 +194,19 @@ module.exports.ensureOlmSessionsForDevices = async function(
return result;
}
// TODO: this has a race condition - if we try to send another message
// while we are claiming a key, we will end up claiming two and setting up
// two sessions.
//
// That should eventually resolve itself, but it's poor form.
const oneTimeKeyAlgorithm = "signed_curve25519";
const res = await baseApis.claimOneTimeKeys(
devicesWithoutSession, oneTimeKeyAlgorithm,
);
let res;
try {
res = await baseApis.claimOneTimeKeys(
devicesWithoutSession, oneTimeKeyAlgorithm,
);
} catch (e) {
for (const resolver of Object.values(resolveSession)) {
resolver.reject(e);
}
logger.log("failed to claim one-time keys", e, devicesWithoutSession);
throw e;
}
const otk_res = res.one_time_keys || {};
const promises = [];
@@ -185,6 +219,7 @@ module.exports.ensureOlmSessionsForDevices = async function(
for (let j = 0; j < devices.length; j++) {
const deviceInfo = devices[j];
const deviceId = deviceInfo.deviceId;
const key = deviceInfo.getIdentityKey();
if (result[userId][deviceId].sessionId && !force) {
// we already have a result for this device
continue;
@@ -199,10 +234,10 @@ module.exports.ensureOlmSessionsForDevices = async function(
}
if (!oneTimeKey) {
logger.warn(
"No one-time keys (alg=" + oneTimeKeyAlgorithm +
") for device " + userId + ":" + deviceId,
);
const msg = "No one-time keys (alg=" + oneTimeKeyAlgorithm +
") for device " + userId + ":" + deviceId;
logger.warn(msg);
resolveSession[key].reject(new Error(msg));
continue;
}
@@ -210,7 +245,11 @@ module.exports.ensureOlmSessionsForDevices = async function(
_verifyKeyAndStartSession(
olmDevice, oneTimeKey, userId, deviceInfo,
).then((sid) => {
resolveSession[key].resolve(sid);
result[userId][deviceId].sessionId = sid;
}, (e) => {
resolveSession[key].reject(e);
throw e;
}),
);
}