diff --git a/src/Lifecycle.ts b/src/Lifecycle.ts index 2387314066..9f2fcaa197 100644 --- a/src/Lifecycle.ts +++ b/src/Lifecycle.ts @@ -65,6 +65,7 @@ import { OverwriteLoginPayload } from "./dispatcher/payloads/OverwriteLoginPaylo import { SdkContextClass } from "./contexts/SDKContext"; import { messageForLoginError } from "./utils/ErrorUtils"; import { completeOidcLogin } from "./utils/oidc/authorize"; +import { OidcClientStore } from "./stores/oidc/OidcClientStore"; import { getStoredOidcClientId, getStoredOidcIdTokenClaims, @@ -922,9 +923,28 @@ async function persistCredentials(credentials: IMatrixClientCreds): Promise { + if (oidcClientStore?.isUserAuthenticatedWithOidc) { + const accessToken = client.getAccessToken() ?? undefined; + const refreshToken = client.getRefreshToken() ?? undefined; + + await oidcClientStore.revokeTokens(accessToken, refreshToken); + } else { + await client.logout(true); + } +} + +/** + * Logs the current session out and transitions to the logged-out state + * @param oidcClientStore store instance from SDKContext + */ +export function logout(oidcClientStore?: OidcClientStore): void { const client = MatrixClientPeg.get(); if (!client) return; @@ -940,7 +960,8 @@ export function logout(): void { _isLoggingOut = true; PlatformPeg.get()?.destroyPickleKey(client.getSafeUserId(), client.getDeviceId() ?? ""); - client.logout(true).then(onLoggedOut, (err) => { + + doLogout(client, oidcClientStore).then(onLoggedOut, (err) => { // Just throwing an error here is going to be very unhelpful // if you're trying to log out because your server's down and // you want to log into a different server, so just forget the diff --git a/src/components/structures/MatrixChat.tsx b/src/components/structures/MatrixChat.tsx index dde4768e8e..3656ecfeac 100644 --- a/src/components/structures/MatrixChat.tsx +++ b/src/components/structures/MatrixChat.tsx @@ -650,7 +650,7 @@ export default class MatrixChat extends React.PureComponent { Promise.all([ ...[...CallStore.instance.activeCalls].map((call) => call.disconnect()), cleanUpBroadcasts(this.stores), - ]).finally(() => Lifecycle.logout()); + ]).finally(() => Lifecycle.logout(this.stores.oidcClientStore)); break; case "require_registration": startAnyRegistrationFlow(payload as any); diff --git a/src/components/structures/auth/SoftLogout.tsx b/src/components/structures/auth/SoftLogout.tsx index dede3b612b..250c6a6a1d 100644 --- a/src/components/structures/auth/SoftLogout.tsx +++ b/src/components/structures/auth/SoftLogout.tsx @@ -34,6 +34,7 @@ import AccessibleButton from "../../views/elements/AccessibleButton"; import Spinner from "../../views/elements/Spinner"; import AuthHeader from "../../views/auth/AuthHeader"; import AuthBody from "../../views/auth/AuthBody"; +import { SDKContext } from "../../../contexts/SDKContext"; enum LoginView { Loading, @@ -70,8 +71,13 @@ interface IState { } export default class SoftLogout extends React.Component { - public constructor(props: IProps) { - super(props); + public static contextType = SDKContext; + public context!: React.ContextType; + + public constructor(props: IProps, context: React.ContextType) { + super(props, context); + + this.context = context; this.state = { loginView: LoginView.Loading, @@ -98,7 +104,7 @@ export default class SoftLogout extends React.Component { if (!wipeData) return; logger.log("Clearing data from soft-logged-out session"); - Lifecycle.logout(); + Lifecycle.logout(this.context.oidcClientStore); }, }); }; diff --git a/src/stores/oidc/OidcClientStore.ts b/src/stores/oidc/OidcClientStore.ts index 8393bd5053..e4f452fbaf 100644 --- a/src/stores/oidc/OidcClientStore.ts +++ b/src/stores/oidc/OidcClientStore.ts @@ -50,6 +50,54 @@ export class OidcClientStore { return this._accountManagementEndpoint; } + /** + * Revokes provided access and refresh tokens with the configured OIDC provider + * @param accessToken + * @param refreshToken + * @returns Promise that resolves when tokens have been revoked + * @throws when OidcClient cannot be initialised, or revoking either token fails + */ + public async revokeTokens(accessToken?: string, refreshToken?: string): Promise { + const client = await this.getOidcClient(); + + if (!client) { + throw new Error("No OIDC client"); + } + + const results = await Promise.all([ + this.tryRevokeToken(client, accessToken, "access_token"), + this.tryRevokeToken(client, refreshToken, "refresh_token"), + ]); + + if (results.some((success) => !success)) { + throw new Error("Failed to revoke tokens"); + } + } + + /** + * Try to revoke a given token + * @param oidcClient + * @param token + * @param tokenType passed to revocation endpoint as token type hint + * @returns Promise that resolved with boolean whether the token revocation succeeded or not + */ + private async tryRevokeToken( + oidcClient: OidcClient, + token: string | undefined, + tokenType: "access_token" | "refresh_token", + ): Promise { + try { + if (!token) { + return false; + } + await oidcClient.revokeToken(token, tokenType); + return true; + } catch (error) { + logger.error(`Failed to revoke ${tokenType}`, error); + return false; + } + } + private async getOidcClient(): Promise { if (!this.oidcClient && !this.initialisingOidcClientPromise) { this.initialisingOidcClientPromise = this.initOidcClient(); @@ -59,18 +107,27 @@ export class OidcClientStore { return this.oidcClient; } + /** + * Tries to initialise an OidcClient using stored clientId and OIDC discovery. + * Assigns this.oidcClient and accountManagement endpoint. + * Logs errors and does not throw when oidc client cannot be initialised. + * @returns promise that resolves when initialising OidcClient succeeds or fails + */ private async initOidcClient(): Promise { - const wellKnown = this.matrixClient.getClientWellKnown(); - if (!wellKnown) { - logger.error("Cannot initialise OidcClientStore: client well known required."); + const wellKnown = await this.matrixClient.waitForClientWellKnown(); + if (!wellKnown && !this.authenticatedIssuer) { + logger.error("Cannot initialise OIDC client without issuer."); return; } + const delegatedAuthConfig = + (wellKnown && M_AUTHENTICATION.findIn(wellKnown)) ?? undefined; - const delegatedAuthConfig = M_AUTHENTICATION.findIn(wellKnown) ?? undefined; try { const clientId = getStoredOidcClientId(); const { account, metadata, signingKeys } = await discoverAndValidateAuthenticationConfig( - delegatedAuthConfig, + // if HS has valid delegated auth config in .well-known, use it + // otherwise fallback to the known issuer + delegatedAuthConfig ?? { issuer: this.authenticatedIssuer! }, ); // if no account endpoint is configured default to the issuer this._accountManagementEndpoint = account ?? metadata.issuer; diff --git a/test/Lifecycle-test.ts b/test/Lifecycle-test.ts index 802556f868..b984549e9d 100644 --- a/test/Lifecycle-test.ts +++ b/test/Lifecycle-test.ts @@ -19,15 +19,17 @@ import { logger } from "matrix-js-sdk/src/logger"; import * as MatrixJs from "matrix-js-sdk/src/matrix"; import { setCrypto } from "matrix-js-sdk/src/crypto/crypto"; import * as MatrixCryptoAes from "matrix-js-sdk/src/crypto/aes"; +import { MockedObject } from "jest-mock"; import fetchMock from "fetch-mock-jest"; import StorageEvictedDialog from "../src/components/views/dialogs/StorageEvictedDialog"; -import { restoreFromLocalStorage, setLoggedIn } from "../src/Lifecycle"; +import { logout, restoreFromLocalStorage, setLoggedIn } from "../src/Lifecycle"; import { MatrixClientPeg } from "../src/MatrixClientPeg"; import Modal from "../src/Modal"; import * as StorageManager from "../src/utils/StorageManager"; -import { getMockClientWithEventEmitter, mockPlatformPeg } from "./test-utils"; +import { flushPromises, getMockClientWithEventEmitter, mockClientMethodsUser, mockPlatformPeg } from "./test-utils"; import ToastStore from "../src/stores/ToastStore"; +import { OidcClientStore } from "../src/stores/oidc/OidcClientStore"; import { makeDelegatedAuthConfig } from "./test-utils/oidc"; import { persistOidcAuthenticatedSettings } from "../src/utils/oidc/persistOidcSettings"; @@ -40,24 +42,29 @@ describe("Lifecycle", () => { const realLocalStorage = global.localStorage; - const mockClient = getMockClientWithEventEmitter({ - stopClient: jest.fn(), - removeAllListeners: jest.fn(), - clearStores: jest.fn(), - getAccountData: jest.fn(), - getUserId: jest.fn(), - getDeviceId: jest.fn(), - isVersionSupported: jest.fn().mockResolvedValue(true), - getCrypto: jest.fn(), - getClientWellKnown: jest.fn(), - getThirdpartyProtocols: jest.fn(), - store: { - destroy: jest.fn(), - }, - getVersions: jest.fn().mockResolvedValue({ versions: ["v1.1"] }), - }); + let mockClient!: MockedObject; beforeEach(() => { + mockClient = getMockClientWithEventEmitter({ + ...mockClientMethodsUser(), + stopClient: jest.fn(), + removeAllListeners: jest.fn(), + clearStores: jest.fn(), + getAccountData: jest.fn(), + getDeviceId: jest.fn(), + isVersionSupported: jest.fn().mockResolvedValue(true), + getCrypto: jest.fn(), + getClientWellKnown: jest.fn(), + waitForClientWellKnown: jest.fn(), + getThirdpartyProtocols: jest.fn(), + store: { + destroy: jest.fn(), + }, + getVersions: jest.fn().mockResolvedValue({ versions: ["v1.1"] }), + logout: jest.fn().mockResolvedValue(undefined), + getAccessToken: jest.fn(), + getRefreshToken: jest.fn(), + }); // stub this jest.spyOn(MatrixClientPeg, "replaceUsingCreds").mockImplementation(() => {}); jest.spyOn(MatrixClientPeg, "start").mockResolvedValue(undefined); @@ -692,7 +699,7 @@ describe("Lifecycle", () => { beforeEach(() => { // mock oidc config for oidc client initialisation - mockClient.getClientWellKnown.mockReturnValue({ + mockClient.waitForClientWellKnown.mockResolvedValue({ "m.authentication": { issuer: issuer, }, @@ -776,4 +783,47 @@ describe("Lifecycle", () => { }); }); }); + + describe("logout()", () => { + let oidcClientStore!: OidcClientStore; + const accessToken = "test-access-token"; + const refreshToken = "test-refresh-token"; + + beforeEach(() => { + oidcClientStore = new OidcClientStore(mockClient); + // stub + jest.spyOn(oidcClientStore, "revokeTokens").mockResolvedValue(undefined); + + mockClient.getAccessToken.mockReturnValue(accessToken); + mockClient.getRefreshToken.mockReturnValue(refreshToken); + }); + + it("should call logout on the client when oidcClientStore is falsy", async () => { + logout(); + + await flushPromises(); + + expect(mockClient.logout).toHaveBeenCalledWith(true); + }); + + it("should call logout on the client when oidcClientStore.isUserAuthenticatedWithOidc is falsy", async () => { + jest.spyOn(oidcClientStore, "isUserAuthenticatedWithOidc", "get").mockReturnValue(false); + logout(oidcClientStore); + + await flushPromises(); + + expect(mockClient.logout).toHaveBeenCalledWith(true); + expect(oidcClientStore.revokeTokens).not.toHaveBeenCalled(); + }); + + it("should revoke tokens when user is authenticated with oidc", async () => { + jest.spyOn(oidcClientStore, "isUserAuthenticatedWithOidc", "get").mockReturnValue(true); + logout(oidcClientStore); + + await flushPromises(); + + expect(mockClient.logout).not.toHaveBeenCalled(); + expect(oidcClientStore.revokeTokens).toHaveBeenCalledWith(accessToken, refreshToken); + }); + }); }); diff --git a/test/stores/oidc/OidcClientStore-test.ts b/test/stores/oidc/OidcClientStore-test.ts index ea2dcac997..d25b0fd541 100644 --- a/test/stores/oidc/OidcClientStore-test.ts +++ b/test/stores/oidc/OidcClientStore-test.ts @@ -14,7 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ +import fetchMock from "fetch-mock-jest"; import { mocked } from "jest-mock"; +import { OidcClient } from "oidc-client-ts"; import { M_AUTHENTICATION } from "matrix-js-sdk/src/matrix"; import { logger } from "matrix-js-sdk/src/logger"; import { discoverAndValidateAuthenticationConfig } from "matrix-js-sdk/src/oidc/discovery"; @@ -38,7 +40,7 @@ describe("OidcClientStore", () => { }; const mockClient = getMockClientWithEventEmitter({ - getClientWellKnown: jest.fn().mockReturnValue({}), + waitForClientWellKnown: jest.fn().mockResolvedValue({}), }); beforeEach(() => { @@ -50,13 +52,15 @@ describe("OidcClientStore", () => { account, issuer: metadata.issuer, }); - mockClient.getClientWellKnown.mockReturnValue({ + mockClient.waitForClientWellKnown.mockResolvedValue({ [M_AUTHENTICATION.stable!]: { issuer: metadata.issuer, account, }, }); jest.spyOn(logger, "error").mockClear(); + + fetchMock.get(`${metadata.issuer}.well-known/openid-configuration`, metadata); }); describe("isUserAuthenticatedWithOidc()", () => { @@ -76,7 +80,7 @@ describe("OidcClientStore", () => { describe("initialising oidcClient", () => { it("should initialise oidc client from constructor", () => { - mockClient.getClientWellKnown.mockReturnValue(undefined); + mockClient.waitForClientWellKnown.mockResolvedValue(undefined as any); const store = new OidcClientStore(mockClient); // started initialising @@ -84,30 +88,33 @@ describe("OidcClientStore", () => { expect(store.initialisingOidcClientPromise).toBeTruthy(); }); - it("should log and return when no client well known is available", async () => { - mockClient.getClientWellKnown.mockReturnValue(undefined); + it("should fallback to stored issuer when no client well known is available", async () => { + mockClient.waitForClientWellKnown.mockResolvedValue(undefined as any); const store = new OidcClientStore(mockClient); - expect(logger.error).toHaveBeenCalledWith("Cannot initialise OidcClientStore: client well known required."); - // no oidc client + // successfully created oidc client // @ts-ignore private property - expect(await store.getOidcClient()).toEqual(undefined); + expect(await store.getOidcClient()).toBeTruthy(); }); it("should log and return when no clientId is found in storage", async () => { - jest.spyOn(sessionStorage.__proto__, "getItem").mockImplementation((key) => - key === "mx_oidc_token_issuer" ? metadata.issuer : null, + const sessionStorageWithoutClientId: Record = { + ...mockSessionStorage, + mx_oidc_client_id: null, + }; + jest.spyOn(sessionStorage.__proto__, "getItem").mockImplementation( + (key) => sessionStorageWithoutClientId[key as string] ?? null, ); const store = new OidcClientStore(mockClient); + // no oidc client + // @ts-ignore private property + expect(await store.getOidcClient()).toEqual(undefined); expect(logger.error).toHaveBeenCalledWith( "Failed to initialise OidcClientStore", new Error("Oidc client id not found in storage"), ); - // no oidc client - // @ts-ignore private property - expect(await store.getOidcClient()).toEqual(undefined); }); it("should log and return when discovery and validation fails", async () => { @@ -180,4 +187,77 @@ describe("OidcClientStore", () => { expect(discoverAndValidateAuthenticationConfig).toHaveBeenCalledTimes(1); }); }); + + describe("revokeTokens()", () => { + const accessToken = "test-access-token"; + const refreshToken = "test-refresh-token"; + + beforeEach(() => { + // spy and call through + jest.spyOn(OidcClient.prototype, "revokeToken").mockClear(); + + fetchMock.resetHistory(); + fetchMock.post( + metadata.revocation_endpoint, + { + status: 200, + }, + { sendAsJson: true }, + ); + }); + + it("should throw when oidcClient could not be initialised", async () => { + // make oidcClient initialisation fail + mockClient.waitForClientWellKnown.mockResolvedValue(undefined as any); + const sessionStorageWithoutIssuer: Record = { + ...mockSessionStorage, + mx_oidc_token_issuer: null, + }; + jest.spyOn(sessionStorage.__proto__, "getItem").mockImplementation( + (key) => sessionStorageWithoutIssuer[key as string] ?? null, + ); + + const store = new OidcClientStore(mockClient); + + await expect(() => store.revokeTokens(accessToken, refreshToken)).rejects.toThrow("No OIDC client"); + }); + + it("should revoke access and refresh tokens", async () => { + const store = new OidcClientStore(mockClient); + + await store.revokeTokens(accessToken, refreshToken); + + expect(fetchMock).toHaveFetchedTimes(2, metadata.revocation_endpoint); + expect(OidcClient.prototype.revokeToken).toHaveBeenCalledWith(accessToken, "access_token"); + expect(OidcClient.prototype.revokeToken).toHaveBeenCalledWith(refreshToken, "refresh_token"); + }); + + it("should still attempt to revoke refresh token when access token revocation fails", async () => { + // fail once, then succeed + fetchMock + .postOnce( + metadata.revocation_endpoint, + { + status: 404, + }, + { overwriteRoutes: true, sendAsJson: true }, + ) + .post( + metadata.revocation_endpoint, + { + status: 200, + }, + { sendAsJson: true }, + ); + + const store = new OidcClientStore(mockClient); + + await expect(() => store.revokeTokens(accessToken, refreshToken)).rejects.toThrow( + "Failed to revoke tokens", + ); + + expect(fetchMock).toHaveFetchedTimes(2, metadata.revocation_endpoint); + expect(OidcClient.prototype.revokeToken).toHaveBeenCalledWith(accessToken, "access_token"); + }); + }); });