1
0
mirror of https://github.com/matrix-org/matrix-js-sdk.git synced 2025-07-31 15:24:23 +03:00

Fix token refresh racing with other requests and not using new token (#4798)

* Fix token refresh racing with other requests and not using new token

Signed-off-by: Michael Telatynski <7t3chguy@gmail.com>

* Iterate

Signed-off-by: Michael Telatynski <7t3chguy@gmail.com>

---------

Signed-off-by: Michael Telatynski <7t3chguy@gmail.com>
This commit is contained in:
Michael Telatynski
2025-04-14 10:11:55 +01:00
committed by GitHub
parent 1ba4412260
commit 480c8e86a4
2 changed files with 88 additions and 14 deletions

View File

@ -29,13 +29,18 @@ import {
Method, Method,
} from "../../../src"; } from "../../../src";
import { emitPromise } from "../../test-utils/test-utils"; import { emitPromise } from "../../test-utils/test-utils";
import { defer, type QueryDict } from "../../../src/utils"; import { defer, type QueryDict, sleep } from "../../../src/utils";
import { type Logger } from "../../../src/logger"; import { type Logger } from "../../../src/logger";
describe("FetchHttpApi", () => { describe("FetchHttpApi", () => {
const baseUrl = "http://baseUrl"; const baseUrl = "http://baseUrl";
const idBaseUrl = "http://idBaseUrl"; const idBaseUrl = "http://idBaseUrl";
const prefix = ClientPrefix.V3; const prefix = ClientPrefix.V3;
const tokenInactiveError = new MatrixError({ errcode: "M_UNKNOWN_TOKEN", error: "Token is not active" }, 401);
beforeEach(() => {
jest.useRealTimers();
});
it("should support aborting multiple times", () => { it("should support aborting multiple times", () => {
const fetchFn = jest.fn().mockResolvedValue({ ok: true }); const fetchFn = jest.fn().mockResolvedValue({ ok: true });
@ -492,8 +497,6 @@ describe("FetchHttpApi", () => {
}); });
it("should not make multiple concurrent refresh token requests", async () => { it("should not make multiple concurrent refresh token requests", async () => {
const tokenInactiveError = new MatrixError({ errcode: "M_UNKNOWN_TOKEN", error: "Token is not active" }, 401);
const deferredTokenRefresh = defer<{ accessToken: string; refreshToken: string }>(); const deferredTokenRefresh = defer<{ accessToken: string; refreshToken: string }>();
const fetchFn = jest.fn().mockResolvedValue({ const fetchFn = jest.fn().mockResolvedValue({
ok: false, ok: false,
@ -523,7 +526,7 @@ describe("FetchHttpApi", () => {
const prom1 = api.authedRequest(Method.Get, "/path1"); const prom1 = api.authedRequest(Method.Get, "/path1");
const prom2 = api.authedRequest(Method.Get, "/path2"); const prom2 = api.authedRequest(Method.Get, "/path2");
await jest.advanceTimersByTimeAsync(10); // wait for requests to fire await sleep(0); // wait for requests to fire
expect(fetchFn).toHaveBeenCalledTimes(2); expect(fetchFn).toHaveBeenCalledTimes(2);
fetchFn.mockResolvedValue({ fetchFn.mockResolvedValue({
ok: true, ok: true,
@ -547,4 +550,66 @@ describe("FetchHttpApi", () => {
expect(api.opts.accessToken).toBe("NEW_ACCESS_TOKEN"); expect(api.opts.accessToken).toBe("NEW_ACCESS_TOKEN");
expect(api.opts.refreshToken).toBe("NEW_REFRESH_TOKEN"); expect(api.opts.refreshToken).toBe("NEW_REFRESH_TOKEN");
}); });
it("should use newly refreshed token if request starts mid-refresh", async () => {
const deferredTokenRefresh = defer<{ accessToken: string; refreshToken: string }>();
const fetchFn = jest.fn().mockResolvedValue({
ok: false,
status: tokenInactiveError.httpStatus,
async text() {
return JSON.stringify(tokenInactiveError.data);
},
async json() {
return tokenInactiveError.data;
},
headers: {
get: jest.fn().mockReturnValue("application/json"),
},
});
const tokenRefreshFunction = jest.fn().mockReturnValue(deferredTokenRefresh.promise);
const api = new FetchHttpApi(new TypedEventEmitter<any, any>(), {
baseUrl,
prefix,
fetchFn,
doNotAttemptTokenRefresh: false,
tokenRefreshFunction,
accessToken: "ACCESS_TOKEN",
refreshToken: "REFRESH_TOKEN",
});
const prom1 = api.authedRequest(Method.Get, "/path1");
await sleep(0); // wait for request to fire
const prom2 = api.authedRequest(Method.Get, "/path2");
await sleep(0); // wait for request to fire
deferredTokenRefresh.resolve({ accessToken: "NEW_ACCESS_TOKEN", refreshToken: "NEW_REFRESH_TOKEN" });
fetchFn.mockResolvedValue({
ok: true,
status: 200,
async text() {
return "{}";
},
async json() {
return {};
},
headers: {
get: jest.fn().mockReturnValue("application/json"),
},
});
await prom1;
await prom2;
expect(fetchFn).toHaveBeenCalledTimes(3); // 2 original calls + 1 retry
expect(fetchFn.mock.calls[0][1]).toEqual(
expect.objectContaining({ headers: expect.objectContaining({ Authorization: "Bearer ACCESS_TOKEN" }) }),
);
expect(fetchFn.mock.calls[2][1]).toEqual(
expect.objectContaining({ headers: expect.objectContaining({ Authorization: "Bearer NEW_ACCESS_TOKEN" }) }),
);
expect(tokenRefreshFunction).toHaveBeenCalledTimes(1);
expect(api.opts.accessToken).toBe("NEW_ACCESS_TOKEN");
expect(api.opts.refreshToken).toBe("NEW_REFRESH_TOKEN");
});
}); });

View File

@ -158,25 +158,28 @@ export class FetchHttpApi<O extends IHttpOpts> {
// avoid mutating paramOpts so they can be used on retry // avoid mutating paramOpts so they can be used on retry
const opts = { ...paramOpts }; const opts = { ...paramOpts };
if (this.opts.accessToken) { // Await any ongoing token refresh before we build the headers/params
await this.tokenRefreshPromise;
// Take a copy of the access token so we have a record of the token we used for this request if it fails
const accessToken = this.opts.accessToken;
if (accessToken) {
if (this.opts.useAuthorizationHeader) { if (this.opts.useAuthorizationHeader) {
if (!opts.headers) { if (!opts.headers) {
opts.headers = {}; opts.headers = {};
} }
if (!opts.headers.Authorization) { if (!opts.headers.Authorization) {
opts.headers.Authorization = "Bearer " + this.opts.accessToken; opts.headers.Authorization = `Bearer ${accessToken}`;
} }
if (queryParams.access_token) { if (queryParams.access_token) {
delete queryParams.access_token; delete queryParams.access_token;
} }
} else if (!queryParams.access_token) { } else if (!queryParams.access_token) {
queryParams.access_token = this.opts.accessToken; queryParams.access_token = accessToken;
} }
} }
try { try {
// Await any ongoing token refresh
await this.tokenRefreshPromise;
const response = await this.request<T>(method, path, queryParams, body, opts); const response = await this.request<T>(method, path, queryParams, body, opts);
return response; return response;
} catch (error) { } catch (error) {
@ -185,15 +188,21 @@ export class FetchHttpApi<O extends IHttpOpts> {
} }
if (error.errcode === "M_UNKNOWN_TOKEN" && !opts.doNotAttemptTokenRefresh) { if (error.errcode === "M_UNKNOWN_TOKEN" && !opts.doNotAttemptTokenRefresh) {
const tokenRefreshPromise = this.tryRefreshToken(); // If the access token has changed since we started the request, but before we refreshed it,
this.tokenRefreshPromise = Promise.allSettled([tokenRefreshPromise]); // then it was refreshed due to another request failing, so retry before refreshing again.
const outcome = await tokenRefreshPromise; let outcome: TokenRefreshOutcome | null = null;
if (accessToken === this.opts.accessToken) {
const tokenRefreshPromise = this.tryRefreshToken();
this.tokenRefreshPromise = tokenRefreshPromise;
outcome = await tokenRefreshPromise;
}
if (outcome === TokenRefreshOutcome.Success) { if (outcome === TokenRefreshOutcome.Success || outcome === null) {
// if we got a new token retry the request // if we got a new token retry the request
return this.authedRequest(method, path, queryParams, body, { return this.authedRequest(method, path, queryParams, body, {
...paramOpts, ...paramOpts,
doNotAttemptTokenRefresh: true, // Only attempt token refresh once for each failed request
doNotAttemptTokenRefresh: outcome !== null,
}); });
} }
if (outcome === TokenRefreshOutcome.Failure) { if (outcome === TokenRefreshOutcome.Failure) {