diff --git a/src/crypto/verification/Base.js b/src/crypto/verification/Base.js index 1bcba5ad6..9213c3b49 100644 --- a/src/crypto/verification/Base.js +++ b/src/crypto/verification/Base.js @@ -27,6 +27,13 @@ import {newTimeoutError} from "./Error"; const timeoutException = new Error("Verification timed out"); +export class SwitchStartEventError extends Error { + constructor(startEvent) { + super(); + this.startEvent = startEvent; + } +} + export class VerificationBase extends EventEmitter { /** * Base class for verification methods. @@ -69,6 +76,19 @@ export class VerificationBase extends EventEmitter { this._transactionTimeoutTimer = null; } + get initiatedByMe() { + // if there is no start event yet, + // we probably want to send it, + // which happens if we initiate + if (!this.startEvent) { + return true; + } + const sender = this.startEvent.getSender(); + const content = this.startEvent.getContent(); + return sender === this._baseApis.getUserId() && + content.from_device === this._baseApis.getDeviceId(); + } + _resetTimer() { logger.info("Refreshing/starting the verification transaction timeout timer"); if (this._transactionTimeoutTimer !== null) { @@ -104,6 +124,22 @@ export class VerificationBase extends EventEmitter { }); } + canSwitchStartEvent() { + return false; + } + + switchStartEvent(event) { + if (this.canSwitchStartEvent(event)) { + if (this._rejectEvent) { + const reject = this._rejectEvent; + this._rejectEvent = undefined; + reject(new SwitchStartEventError(event)); + } else { + this.startEvent = event; + } + } + } + handleEvent(e) { if (this._done) { return; diff --git a/src/crypto/verification/SAS.js b/src/crypto/verification/SAS.js index f840b1abf..367ee1463 100644 --- a/src/crypto/verification/SAS.js +++ b/src/crypto/verification/SAS.js @@ -19,7 +19,7 @@ limitations under the License. * @module crypto/verification/SAS */ -import {VerificationBase as Base} from "./Base"; +import {VerificationBase as Base, SwitchStartEventError} from "./Base"; import anotherjson from 'another-json'; import { errorFactory, @@ -29,6 +29,8 @@ import { newUserCancelledError, } from './Error'; +const START_TYPE = "m.key.verification.start"; + const EVENTS = [ "m.key.verification.accept", "m.key.verification.key", @@ -201,16 +203,37 @@ export class SAS extends Base { // make sure user's keys are downloaded await this._baseApis.downloadKeys([this.userId]); - if (this.startEvent) { - return await this._doRespondVerification(); - } else { - return await this._doSendVerification(); - } + let retry = false; + do { + try { + if (this.initiatedByMe) { + return await this._doSendVerification(); + } else { + return await this._doRespondVerification(); + } + } catch (err) { + if (err instanceof SwitchStartEventError) { + // this changes what initiatedByMe returns + this.startEvent = err.startEvent; + retry = true; + } else { + throw err; + } + } + } while (retry); } - async _doSendVerification() { - const type = "m.key.verification.start"; - const initialMessage = this._channel.completeContent(type, { + canSwitchStartEvent(event) { + if (event.getType() !== START_TYPE) { + return false; + } + const content = event.getContent(); + return content && content.method === SAS.NAME && + this._waitingForAccept; + } + + async _sendStart() { + const startContent = this._channel.completeContent(START_TYPE, { method: SAS.NAME, from_device: this._baseApis.deviceId, key_agreement_protocols: KEY_AGREEMENT_LIST, @@ -219,11 +242,33 @@ export class SAS extends Base { // FIXME: allow app to specify what SAS methods can be used short_authentication_string: SAS_LIST, }); - // add the transaction id to the message beforehand because - // it needs to be included in the commitment hash later on - this._channel.sendCompleted(type, initialMessage); + await this._channel.sendCompleted(START_TYPE, startContent); + return startContent; + } - let e = await this._waitForEvent("m.key.verification.accept"); + async _doSendVerification() { + this._waitingForAccept = true; + let startContent; + if (this.startEvent) { + startContent = this._channel.completedContentFromEvent(this.startEvent); + } else { + startContent = await this._sendStart(); + } + + // we might have switched to a different start event, + // but was we didn't call _waitForEvent there was no + // call that could throw yet. So check manually that + // we're still on the initiator side + if (!this.initiatedByMe) { + throw new SwitchStartEventError(this.startEvent); + } + + let e; + try { + e = await this._waitForEvent("m.key.verification.accept"); + } finally { + this._waitingForAccept = false; + } let content = e.getContent(); const sasMethods = intersection(content.short_authentication_string, SAS_SET); @@ -248,7 +293,7 @@ export class SAS extends Base { e = await this._waitForEvent("m.key.verification.key"); // FIXME: make sure event is properly formed content = e.getContent(); - const commitmentStr = content.key + anotherjson.stringify(initialMessage); + const commitmentStr = content.key + anotherjson.stringify(startContent); // TODO: use selected hash function (when we support multiple) if (olmutil.sha256(commitmentStr) !== hashCommitment) { throw newMismatchedCommitmentError();