diff --git a/spec/unit/http-api/fetch.spec.ts b/spec/unit/http-api/fetch.spec.ts index cc5224777..ffca79866 100644 --- a/spec/unit/http-api/fetch.spec.ts +++ b/spec/unit/http-api/fetch.spec.ts @@ -29,13 +29,18 @@ import { Method, } from "../../../src"; import { emitPromise } from "../../test-utils/test-utils"; -import { defer, type QueryDict } from "../../../src/utils"; +import { defer, type QueryDict, sleep } from "../../../src/utils"; import { type Logger } from "../../../src/logger"; describe("FetchHttpApi", () => { const baseUrl = "http://baseUrl"; const idBaseUrl = "http://idBaseUrl"; const prefix = ClientPrefix.V3; + const tokenInactiveError = new MatrixError({ errcode: "M_UNKNOWN_TOKEN", error: "Token is not active" }, 401); + + beforeEach(() => { + jest.useRealTimers(); + }); it("should support aborting multiple times", () => { const fetchFn = jest.fn().mockResolvedValue({ ok: true }); @@ -492,8 +497,6 @@ describe("FetchHttpApi", () => { }); it("should not make multiple concurrent refresh token requests", async () => { - const tokenInactiveError = new MatrixError({ errcode: "M_UNKNOWN_TOKEN", error: "Token is not active" }, 401); - const deferredTokenRefresh = defer<{ accessToken: string; refreshToken: string }>(); const fetchFn = jest.fn().mockResolvedValue({ ok: false, @@ -523,7 +526,7 @@ describe("FetchHttpApi", () => { const prom1 = api.authedRequest(Method.Get, "/path1"); const prom2 = api.authedRequest(Method.Get, "/path2"); - await jest.advanceTimersByTimeAsync(10); // wait for requests to fire + await sleep(0); // wait for requests to fire expect(fetchFn).toHaveBeenCalledTimes(2); fetchFn.mockResolvedValue({ ok: true, @@ -547,4 +550,66 @@ describe("FetchHttpApi", () => { expect(api.opts.accessToken).toBe("NEW_ACCESS_TOKEN"); expect(api.opts.refreshToken).toBe("NEW_REFRESH_TOKEN"); }); + + it("should use newly refreshed token if request starts mid-refresh", async () => { + const deferredTokenRefresh = defer<{ accessToken: string; refreshToken: string }>(); + const fetchFn = jest.fn().mockResolvedValue({ + ok: false, + status: tokenInactiveError.httpStatus, + async text() { + return JSON.stringify(tokenInactiveError.data); + }, + async json() { + return tokenInactiveError.data; + }, + headers: { + get: jest.fn().mockReturnValue("application/json"), + }, + }); + const tokenRefreshFunction = jest.fn().mockReturnValue(deferredTokenRefresh.promise); + + const api = new FetchHttpApi(new TypedEventEmitter(), { + baseUrl, + prefix, + fetchFn, + doNotAttemptTokenRefresh: false, + tokenRefreshFunction, + accessToken: "ACCESS_TOKEN", + refreshToken: "REFRESH_TOKEN", + }); + + const prom1 = api.authedRequest(Method.Get, "/path1"); + await sleep(0); // wait for request to fire + + const prom2 = api.authedRequest(Method.Get, "/path2"); + await sleep(0); // wait for request to fire + + deferredTokenRefresh.resolve({ accessToken: "NEW_ACCESS_TOKEN", refreshToken: "NEW_REFRESH_TOKEN" }); + fetchFn.mockResolvedValue({ + ok: true, + status: 200, + async text() { + return "{}"; + }, + async json() { + return {}; + }, + headers: { + get: jest.fn().mockReturnValue("application/json"), + }, + }); + + await prom1; + await prom2; + expect(fetchFn).toHaveBeenCalledTimes(3); // 2 original calls + 1 retry + expect(fetchFn.mock.calls[0][1]).toEqual( + expect.objectContaining({ headers: expect.objectContaining({ Authorization: "Bearer ACCESS_TOKEN" }) }), + ); + expect(fetchFn.mock.calls[2][1]).toEqual( + expect.objectContaining({ headers: expect.objectContaining({ Authorization: "Bearer NEW_ACCESS_TOKEN" }) }), + ); + expect(tokenRefreshFunction).toHaveBeenCalledTimes(1); + expect(api.opts.accessToken).toBe("NEW_ACCESS_TOKEN"); + expect(api.opts.refreshToken).toBe("NEW_REFRESH_TOKEN"); + }); }); diff --git a/src/http-api/fetch.ts b/src/http-api/fetch.ts index 9a54b7360..502b2bdb3 100644 --- a/src/http-api/fetch.ts +++ b/src/http-api/fetch.ts @@ -158,25 +158,28 @@ export class FetchHttpApi { // avoid mutating paramOpts so they can be used on retry const opts = { ...paramOpts }; - if (this.opts.accessToken) { + // Await any ongoing token refresh before we build the headers/params + await this.tokenRefreshPromise; + + // Take a copy of the access token so we have a record of the token we used for this request if it fails + const accessToken = this.opts.accessToken; + if (accessToken) { if (this.opts.useAuthorizationHeader) { if (!opts.headers) { opts.headers = {}; } if (!opts.headers.Authorization) { - opts.headers.Authorization = "Bearer " + this.opts.accessToken; + opts.headers.Authorization = `Bearer ${accessToken}`; } if (queryParams.access_token) { delete queryParams.access_token; } } else if (!queryParams.access_token) { - queryParams.access_token = this.opts.accessToken; + queryParams.access_token = accessToken; } } try { - // Await any ongoing token refresh - await this.tokenRefreshPromise; const response = await this.request(method, path, queryParams, body, opts); return response; } catch (error) { @@ -185,15 +188,21 @@ export class FetchHttpApi { } if (error.errcode === "M_UNKNOWN_TOKEN" && !opts.doNotAttemptTokenRefresh) { - const tokenRefreshPromise = this.tryRefreshToken(); - this.tokenRefreshPromise = Promise.allSettled([tokenRefreshPromise]); - const outcome = await tokenRefreshPromise; + // If the access token has changed since we started the request, but before we refreshed it, + // then it was refreshed due to another request failing, so retry before refreshing again. + let outcome: TokenRefreshOutcome | null = null; + if (accessToken === this.opts.accessToken) { + const tokenRefreshPromise = this.tryRefreshToken(); + this.tokenRefreshPromise = tokenRefreshPromise; + outcome = await tokenRefreshPromise; + } - if (outcome === TokenRefreshOutcome.Success) { + if (outcome === TokenRefreshOutcome.Success || outcome === null) { // if we got a new token retry the request return this.authedRequest(method, path, queryParams, body, { ...paramOpts, - doNotAttemptTokenRefresh: true, + // Only attempt token refresh once for each failed request + doNotAttemptTokenRefresh: outcome !== null, }); } if (outcome === TokenRefreshOutcome.Failure) {