diff --git a/spec/unit/http-api/fetch.spec.ts b/spec/unit/http-api/fetch.spec.ts index ffca79866..ad30d8672 100644 --- a/spec/unit/http-api/fetch.spec.ts +++ b/spec/unit/http-api/fetch.spec.ts @@ -356,7 +356,9 @@ describe("FetchHttpApi", () => { accessToken, refreshToken, }); - const result = await api.authedRequest(Method.Post, "/account/password"); + const result = await api.authedRequest(Method.Post, "/account/password", undefined, undefined, { + headers: {}, + }); expect(result).toEqual(okayResponse); expect(tokenRefreshFunction).toHaveBeenCalledWith(refreshToken); @@ -372,6 +374,7 @@ describe("FetchHttpApi", () => { const tokenRefreshFunction = jest.fn().mockResolvedValue({ accessToken: newAccessToken, refreshToken: newRefreshToken, + expiry: new Date(Date.now() + 1000), }); // fetch doesn't like our new or old tokens diff --git a/spec/unit/oidc/tokenRefresher.spec.ts b/spec/unit/oidc/tokenRefresher.spec.ts index f2230b9ce..48c944161 100644 --- a/spec/unit/oidc/tokenRefresher.spec.ts +++ b/spec/unit/oidc/tokenRefresher.spec.ts @@ -130,10 +130,12 @@ describe("OidcTokenRefresher", () => { method: "POST", }); - expect(result).toEqual({ - accessToken: "new-access-token", - refreshToken: "new-refresh-token", - }); + expect(result).toEqual( + expect.objectContaining({ + accessToken: "new-access-token", + refreshToken: "new-refresh-token", + }), + ); }); it("should persist the new tokens", async () => { @@ -144,10 +146,12 @@ describe("OidcTokenRefresher", () => { await refresher.doRefreshAccessToken("refresh-token"); - expect(refresher.persistTokens).toHaveBeenCalledWith({ - accessToken: "new-access-token", - refreshToken: "new-refresh-token", - }); + expect(refresher.persistTokens).toHaveBeenCalledWith( + expect.objectContaining({ + accessToken: "new-access-token", + refreshToken: "new-refresh-token", + }), + ); }); it("should only have one inflight refresh request at once", async () => { @@ -189,10 +193,12 @@ describe("OidcTokenRefresher", () => { // only one call to token endpoint expect(fetchMock).toHaveFetchedTimes(1, config.token_endpoint); - expect(result1).toEqual({ - accessToken: "first-new-access-token", - refreshToken: "first-new-refresh-token", - }); + expect(result1).toEqual( + expect.objectContaining({ + accessToken: "first-new-access-token", + refreshToken: "first-new-refresh-token", + }), + ); // same response expect(result1).toEqual(result2); @@ -200,10 +206,12 @@ describe("OidcTokenRefresher", () => { const third = await refresher.doRefreshAccessToken("first-new-refresh-token"); // called token endpoint, got new tokens - expect(third).toEqual({ - accessToken: "second-new-access-token", - refreshToken: "second-new-refresh-token", - }); + expect(third).toEqual( + expect.objectContaining({ + accessToken: "second-new-access-token", + refreshToken: "second-new-refresh-token", + }), + ); }); it("should log and rethrow when token refresh fails", async () => { @@ -261,10 +269,12 @@ describe("OidcTokenRefresher", () => { const result = await refresher.doRefreshAccessToken("first-new-refresh-token"); // called token endpoint, got new tokens - expect(result).toEqual({ - accessToken: "second-new-access-token", - refreshToken: "second-new-refresh-token", - }); + expect(result).toEqual( + expect.objectContaining({ + accessToken: "second-new-access-token", + refreshToken: "second-new-refresh-token", + }), + ); }); it("should throw TokenRefreshLogoutError when expired", async () => { diff --git a/src/http-api/fetch.ts b/src/http-api/fetch.ts index 502b2bdb3..6dc5c78de 100644 --- a/src/http-api/fetch.ts +++ b/src/http-api/fetch.ts @@ -18,10 +18,10 @@ limitations under the License. * This is an internal module. See {@link MatrixHttpApi} for the public class. */ -import { checkObjectHasKeys, encodeParams } from "../utils.ts"; +import { checkObjectHasKeys, deepCopy, encodeParams } from "../utils.ts"; import { type TypedEventEmitter } from "../models/typed-event-emitter.ts"; import { Method } from "./method.ts"; -import { ConnectionError, MatrixError, TokenRefreshError, TokenRefreshLogoutError } from "./errors.ts"; +import { ConnectionError, MatrixError, TokenRefreshError } from "./errors.ts"; import { HttpApiEvent, type HttpApiEventHandlerMap, @@ -31,7 +31,7 @@ import { } from "./interface.ts"; import { anySignal, parseErrorResponse, timeoutSignal } from "./utils.ts"; import { type QueryDict } from "../utils.ts"; -import { singleAsyncExecution } from "../utils/decorators.ts"; +import { TokenRefresher, TokenRefreshOutcome } from "./refresh.ts"; interface TypedResponse extends Response { json(): Promise; @@ -43,14 +43,9 @@ export type ResponseType = O extends { json: false } ? T : TypedResponse; -const enum TokenRefreshOutcome { - Success = "success", - Failure = "failure", - Logout = "logout", -} - export class FetchHttpApi { private abortController = new AbortController(); + private readonly tokenRefresher: TokenRefresher; public constructor( private eventEmitter: TypedEventEmitter, @@ -59,6 +54,8 @@ export class FetchHttpApi { checkObjectHasKeys(opts, ["baseUrl", "prefix"]); opts.onlyData = !!opts.onlyData; opts.useAuthorizationHeader = opts.useAuthorizationHeader ?? true; + + this.tokenRefresher = new TokenRefresher(opts); } public abort(): void { @@ -113,12 +110,6 @@ export class FetchHttpApi { return this.requestOtherUrl(method, fullUri, body, opts); } - /** - * Promise used to block authenticated requests during a token refresh to avoid repeated expected errors. - * @private - */ - private tokenRefreshPromise?: Promise; - /** * Perform an authorised request to the homeserver. * @param method - The HTTP method e.g. "GET". @@ -146,36 +137,45 @@ export class FetchHttpApi { * @returns Rejects with an error if a problem occurred. * This includes network problems and Matrix-specific error JSON. */ - public async authedRequest( + public authedRequest( method: Method, path: string, - queryParams?: QueryDict, + queryParams: QueryDict = {}, body?: Body, - paramOpts: IRequestOpts & { doNotAttemptTokenRefresh?: boolean } = {}, + paramOpts: IRequestOpts = {}, ): Promise> { - if (!queryParams) queryParams = {}; + return this.doAuthedRequest(1, method, path, queryParams, body, paramOpts); + } + // Wrapper around public method authedRequest to allow for tracking retry attempt counts + private async doAuthedRequest( + attempt: number, + method: Method, + path: string, + queryParams: QueryDict, + body?: Body, + paramOpts: IRequestOpts = {}, + ): Promise> { // avoid mutating paramOpts so they can be used on retry - const opts = { ...paramOpts }; + const opts = deepCopy(paramOpts); + // we have to manually copy the abortSignal over as it is not a plain object + opts.abortSignal = paramOpts.abortSignal; - // Await any ongoing token refresh before we build the headers/params - await this.tokenRefreshPromise; - - // Take a copy of the access token so we have a record of the token we used for this request if it fails - const accessToken = this.opts.accessToken; - if (accessToken) { + // Take a snapshot of the current token state before we start the request so we can reference it if we error + const requestSnapshot = await this.tokenRefresher.prepareForRequest(); + if (requestSnapshot.accessToken) { if (this.opts.useAuthorizationHeader) { if (!opts.headers) { opts.headers = {}; } if (!opts.headers.Authorization) { - opts.headers.Authorization = `Bearer ${accessToken}`; + opts.headers.Authorization = `Bearer ${requestSnapshot.accessToken}`; } if (queryParams.access_token) { delete queryParams.access_token; } } else if (!queryParams.access_token) { - queryParams.access_token = accessToken; + queryParams.access_token = requestSnapshot.accessToken; } } @@ -187,33 +187,19 @@ export class FetchHttpApi { throw error; } - if (error.errcode === "M_UNKNOWN_TOKEN" && !opts.doNotAttemptTokenRefresh) { - // If the access token has changed since we started the request, but before we refreshed it, - // then it was refreshed due to another request failing, so retry before refreshing again. - let outcome: TokenRefreshOutcome | null = null; - if (accessToken === this.opts.accessToken) { - const tokenRefreshPromise = this.tryRefreshToken(); - this.tokenRefreshPromise = tokenRefreshPromise; - outcome = await tokenRefreshPromise; - } - - if (outcome === TokenRefreshOutcome.Success || outcome === null) { + if (error.errcode === "M_UNKNOWN_TOKEN") { + const outcome = await this.tokenRefresher.handleUnknownToken(requestSnapshot, attempt); + if (outcome === TokenRefreshOutcome.Success) { // if we got a new token retry the request - return this.authedRequest(method, path, queryParams, body, { - ...paramOpts, - // Only attempt token refresh once for each failed request - doNotAttemptTokenRefresh: outcome !== null, - }); + return this.doAuthedRequest(attempt + 1, method, path, queryParams, body, paramOpts); } if (outcome === TokenRefreshOutcome.Failure) { throw new TokenRefreshError(error); } - // Fall through to SessionLoggedOut handler below - } - // otherwise continue with error handling - if (error.errcode == "M_UNKNOWN_TOKEN" && !opts?.inhibitLogoutEmit) { - this.eventEmitter.emit(HttpApiEvent.SessionLoggedOut, error); + if (!opts?.inhibitLogoutEmit) { + this.eventEmitter.emit(HttpApiEvent.SessionLoggedOut, error); + } } else if (error.errcode == "M_CONSENT_NOT_GIVEN") { this.eventEmitter.emit(HttpApiEvent.NoConsent, error.message, error.data.consent_uri); } @@ -222,33 +208,6 @@ export class FetchHttpApi { } } - /** - * Attempt to refresh access tokens. - * On success, sets new access and refresh tokens in opts. - * @returns Promise that resolves to a boolean - true when token was refreshed successfully - */ - @singleAsyncExecution - private async tryRefreshToken(): Promise { - if (!this.opts.refreshToken || !this.opts.tokenRefreshFunction) { - return TokenRefreshOutcome.Logout; - } - - try { - const { accessToken, refreshToken } = await this.opts.tokenRefreshFunction(this.opts.refreshToken); - this.opts.accessToken = accessToken; - this.opts.refreshToken = refreshToken; - // successfully got new tokens - return TokenRefreshOutcome.Success; - } catch (error) { - this.opts.logger?.warn("Failed to refresh token", error); - // If we get a TokenError or MatrixError, we should log out, otherwise assume transient - if (error instanceof TokenRefreshLogoutError || error instanceof MatrixError) { - return TokenRefreshOutcome.Logout; - } - return TokenRefreshOutcome.Failure; - } - } - /** * Perform a request to the homeserver without any credentials. * @param method - The HTTP method e.g. "GET". diff --git a/src/http-api/interface.ts b/src/http-api/interface.ts index bfa604c04..d5a01deb2 100644 --- a/src/http-api/interface.ts +++ b/src/http-api/interface.ts @@ -24,9 +24,20 @@ export type Body = Record | BodyInit; * Unencrypted access and (optional) refresh token */ export type AccessTokens = { + /** + * The new access token to use for authenticated requests + */ accessToken: string; + /** + * The new refresh token to use for refreshing tokens, optional + */ refreshToken?: string; + /** + * Approximate date when the access token will expire, optional + */ + expiry?: Date; }; + /** * @experimental * Function that performs token refresh using the given refreshToken. diff --git a/src/http-api/refresh.ts b/src/http-api/refresh.ts new file mode 100644 index 000000000..4bde53c1b --- /dev/null +++ b/src/http-api/refresh.ts @@ -0,0 +1,165 @@ +/* +Copyright 2025 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { MatrixError, TokenRefreshLogoutError } from "./errors.ts"; +import { type IHttpOpts } from "./interface.ts"; +import { sleep } from "../utils.ts"; + +/** + * This is an internal module. See {@link MatrixHttpApi} for the public class. + */ + +export const enum TokenRefreshOutcome { + Success = "success", + Failure = "failure", + Logout = "logout", +} + +interface Snapshot { + accessToken: string; + refreshToken?: string; + expiry?: Date; +} + +// If the token expires in less than this time amount of time, we will eagerly refresh it before making the intended request. +const REFRESH_IF_TOKEN_EXPIRES_WITHIN_MS = 500; +// If we get an unknown token error and the token expires in less than this time amount of time, we will refresh it before making the intended request. +// Otherwise, we will error as the token should not have expired yet and we need to avoid retrying indefinitely. +const REFRESH_ON_ERROR_IF_TOKEN_EXPIRES_WITHIN_MS = 60 * 1000; + +type Opts = Pick; + +/** + * This class is responsible for managing the access token and refresh token for authenticated requests. + * It will automatically refresh the access token when it is about to expire, and will handle unknown token errors. + */ +export class TokenRefresher { + public constructor(private readonly opts: Opts) {} + + /** + * Promise used to block authenticated requests during a token refresh to avoid repeated expected errors. + * @private + */ + private tokenRefreshPromise?: Promise; + + private latestTokenRefreshExpiry?: Date; + + /** + * This function is called before every request to ensure that the access token is valid. + * @returns a snapshot containing the access token and other properties which must be passed to the handleUnknownToken + * handler if an M_UNKNOWN_TOKEN error is encountered. + */ + public async prepareForRequest(): Promise { + // Ensure our token is refreshed before we build the headers/params + await this.refreshIfNeeded(); + + return { + accessToken: this.opts.accessToken!, + refreshToken: this.opts.refreshToken, + expiry: this.latestTokenRefreshExpiry, + }; + } + + private async refreshIfNeeded(): Promise { + if (this.tokenRefreshPromise) { + return this.tokenRefreshPromise; + } + // If we don't know the token expiry, we can't eagerly refresh + if (!this.latestTokenRefreshExpiry) return; + + const expiresIn = this.latestTokenRefreshExpiry.getTime() - Date.now(); + if (expiresIn <= REFRESH_IF_TOKEN_EXPIRES_WITHIN_MS) { + await this._handleUnknownToken(); + } + } + + /** + * This function is called when an M_UNKNOWN_TOKEN error is encountered. + * It will attempt to refresh the access token if it is unknown, and will return a TokenRefreshOutcome. + * @param snapshot - the snapshot returned by prepareForRequest + * @param attempt - the number of attempts made for this request so far + * @returns a TokenRefreshOutcome indicating the result of the refresh attempt + */ + public async handleUnknownToken(snapshot: Snapshot, attempt: number): Promise { + return this._handleUnknownToken(snapshot, attempt); + } + + /* eslint-disable @typescript-eslint/naming-convention */ + private async _handleUnknownToken(): Promise; + private async _handleUnknownToken(snapshot: Snapshot, attempt: number): Promise; + private async _handleUnknownToken(snapshot?: Snapshot, attempt?: number): Promise { + if (snapshot?.expiry) { + // If our token is unknown, but it should not have expired yet, then we should not refresh + const expiresIn = snapshot.expiry.getTime() - Date.now(); + if (expiresIn <= REFRESH_ON_ERROR_IF_TOKEN_EXPIRES_WITHIN_MS) { + return TokenRefreshOutcome.Logout; + } + } + + if (!snapshot || snapshot?.accessToken === this.opts.accessToken) { + // If we have a snapshot, but the access token is the same as the current one then a refresh + // did not happen behind us but one may be ongoing anyway + this.tokenRefreshPromise ??= this.doTokenRefresh(attempt); + + try { + return await this.tokenRefreshPromise; + } finally { + this.tokenRefreshPromise = undefined; + } + } + + // We may end up here if the token was refreshed in the background due to another request + return TokenRefreshOutcome.Success; + } + + /** + * Attempt to refresh access tokens. + * On success, sets new access and refresh tokens in opts. + * @returns Promise that resolves to a boolean - true when token was refreshed successfully + */ + private async doTokenRefresh(attempt?: number): Promise { + if (!this.opts.refreshToken || !this.opts.tokenRefreshFunction) { + this.opts.logger?.error("Unable to refresh token - no refresh token or refresh function"); + return TokenRefreshOutcome.Logout; + } + + if (attempt && attempt > 1) { + // Exponential backoff to ensure we don't trash the server, up to 2^5 seconds + await sleep(1000 * Math.min(32, 2 ** attempt)); + } + + try { + this.opts.logger?.debug("Attempting to refresh token"); + const { accessToken, refreshToken, expiry } = await this.opts.tokenRefreshFunction(this.opts.refreshToken); + this.opts.accessToken = accessToken; + this.opts.refreshToken = refreshToken; + this.latestTokenRefreshExpiry = expiry; + this.opts.logger?.debug("... token refresh complete, new token expiry:", expiry); + + // successfully got new tokens + return TokenRefreshOutcome.Success; + } catch (error) { + // If we get a TokenError or MatrixError, we should log out, otherwise assume transient + if (error instanceof TokenRefreshLogoutError || error instanceof MatrixError) { + this.opts.logger?.error("Failed to refresh token", error); + return TokenRefreshOutcome.Logout; + } + + this.opts.logger?.warn("Failed to refresh token", error); + return TokenRefreshOutcome.Failure; + } + } +} diff --git a/src/oidc/tokenRefresher.ts b/src/oidc/tokenRefresher.ts index 8bae99ce0..34c75e375 100644 --- a/src/oidc/tokenRefresher.ts +++ b/src/oidc/tokenRefresher.ts @@ -139,6 +139,7 @@ export class OidcTokenRefresher { profile: this.idTokenClaims, }; + const requestStart = Date.now(); const response = await this.oidcClient.useRefreshToken({ state: refreshTokenState, timeoutInSeconds: 300, @@ -147,7 +148,9 @@ export class OidcTokenRefresher { const tokens = { accessToken: response.access_token, refreshToken: response.refresh_token, - }; + // We use the request start time to calculate the expiry time as we don't know when the server received our request + expiry: response.expires_in ? new Date(requestStart + response.expires_in * 1000) : undefined, + } satisfies AccessTokens; await this.persistTokens(tokens); diff --git a/src/utils/decorators.ts b/src/utils/decorators.ts deleted file mode 100644 index f95391b8b..000000000 --- a/src/utils/decorators.ts +++ /dev/null @@ -1,39 +0,0 @@ -/* -Copyright 2025 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -/** - * Method decorator to ensure that only one instance of the method is running at a time, - * and any concurrent calls will return the same promise as the original call. - * After execution is complete a new call will be able to run the method again. - */ -export function singleAsyncExecution( - target: (this: This, ...args: Args) => Promise, -): (this: This, ...args: Args) => Promise { - let promise: Promise | undefined; - - async function replacementMethod(this: This, ...args: Args): Promise { - if (promise) return promise; - try { - promise = target.call(this, ...args); - await promise; - return promise; - } finally { - promise = undefined; - } - } - - return replacementMethod; -}