diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index fa9b7179..f293aa10 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -18,7 +18,7 @@ use anyhow::Context; use clap::Parser; use itertools::Itertools; use mas_config::AppConfig; -use mas_handlers::{AppState, CookieManager, HttpClientFactory, MatrixHomeserver}; +use mas_handlers::{AppState, CookieManager, HttpClientFactory, MatrixHomeserver, MetadataCache}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; @@ -127,6 +127,9 @@ impl Options { // Maximum 50 outgoing HTTP requests at a time let http_client_factory = HttpClientFactory::new(50); + // The upstream OIDC metadata cache + let metadata_cache = MetadataCache::new(); + let conn = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), @@ -147,6 +150,7 @@ impl Options { pool, templates, key_store, + metadata_cache, cookie_manager, encrypter, url_builder, @@ -158,6 +162,8 @@ impl Options { conn_acquisition_histogram: None, }; s.init_metrics()?; + // XXX: this might panic + s.init_metadata_cache().await; s }; diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index d9a8e96b..772c4bf5 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -35,7 +35,7 @@ use rand::SeedableRng; use sqlx::PgPool; use thiserror::Error; -use crate::{passwords::PasswordManager, MatrixHomeserver}; +use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver}; #[derive(Clone)] pub struct AppState { @@ -50,6 +50,7 @@ pub struct AppState { pub graphql_schema: mas_graphql::Schema, pub http_client_factory: HttpClientFactory, pub password_manager: PasswordManager, + pub metadata_cache: MetadataCache, pub conn_acquisition_histogram: Option>, } @@ -100,6 +101,37 @@ impl AppState { Ok(()) } + + /// Init the metadata cache. + /// + /// # Panics + /// + /// Panics if the metadata cache could not be initialized. + pub async fn init_metadata_cache(&self) { + // XXX: this panics because the error is annoying to propagate + let conn = self + .pool + .acquire() + .await + .expect("Failed to acquire a database connection"); + + let mut repo = PgRepository::from_conn(conn); + + let http_service = self + .http_client_factory + .http_service() + .await + .expect("Failed to create the HTTP service"); + + self.metadata_cache + .warm_up_and_run( + http_service, + std::time::Duration::from_secs(60 * 15), + &mut repo, + ) + .await + .expect("Failed to warm up the metadata cache"); + } } impl FromRef for PgPool { @@ -168,6 +200,12 @@ impl FromRef for CookieManager { } } +impl FromRef for MetadataCache { + fn from_ref(input: &AppState) -> Self { + input.metadata_cache.clone() + } +} + #[async_trait] impl FromRequestParts for BoxClock { type Rejection = Infallible; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index db409d6e..dda27c57 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -65,7 +65,7 @@ mod graphql; mod health; mod oauth2; pub mod passwords; -mod upstream_oauth2; +pub mod upstream_oauth2; mod views; #[cfg(test)] @@ -90,6 +90,7 @@ macro_rules! impl_from_error_for_route { pub use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; pub use self::{app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema}; +pub use crate::upstream_oauth2::cache::MetadataCache; pub fn healthcheck_router() -> Router where @@ -274,6 +275,7 @@ where Keystore: FromRef, HttpClientFactory: FromRef, PasswordManager: FromRef, + MetadataCache: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, { diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 7ae54316..c06a8ba5 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -48,6 +48,7 @@ use url::Url; use crate::{ app_state::RepositoryError, passwords::{Hasher, PasswordManager}, + upstream_oauth2::cache::MetadataCache, MatrixHomeserver, }; @@ -67,6 +68,7 @@ pub(crate) struct TestState { pub templates: Templates, pub key_store: Keystore, pub cookie_manager: CookieManager, + pub metadata_cache: MetadataCache, pub encrypter: Encrypter, pub url_builder: UrlBuilder, pub homeserver: MatrixHomeserver, @@ -106,6 +108,8 @@ impl TestState { let cookie_manager = CookieManager::derive_from("https://example.com".parse()?, &[0x42; 32]); + let metadata_cache = MetadataCache::new(); + let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; let homeserver = MatrixHomeserver::new("example.com".to_owned()); @@ -146,6 +150,7 @@ impl TestState { templates, key_store, cookie_manager, + metadata_cache, encrypter, url_builder, homeserver, @@ -334,6 +339,12 @@ impl FromRef for CookieManager { } } +impl FromRef for MetadataCache { + fn from_ref(input: &TestState) -> Self { + input.metadata_cache.clone() + } +} + #[async_trait] impl FromRequestParts for BoxClock { type Rejection = Infallible; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 83588a5d..09ca2635 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -28,7 +28,10 @@ use thiserror::Error; use ulid::Ulid; use super::UpstreamSessionsCookie; -use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction}; +use crate::{ + impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, + views::shared::OptionalPostAuthAction, +}; #[derive(Debug, Error)] pub(crate) enum RouteError { @@ -64,6 +67,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, + State(metadata_cache): State, mut repo: BoxRepository, State(url_builder): State, cookie_jar: CookieJar, @@ -79,8 +83,7 @@ pub(crate) async fn get( let http_service = http_client_factory.http_service().await?; // First, discover the provider - let metadata = - mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?; + let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; let redirect_uri = url_builder.upstream_oauth_callback(provider.id); diff --git a/crates/handlers/src/upstream_oauth2/cache.rs b/crates/handlers/src/upstream_oauth2/cache.rs new file mode 100644 index 00000000..cdb40fa8 --- /dev/null +++ b/crates/handlers/src/upstream_oauth2/cache.rs @@ -0,0 +1,117 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashMap, sync::Arc}; + +use mas_http::HttpService; +use mas_oidc_client::error::DiscoveryError; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess}; +use oauth2_types::oidc::VerifiedProviderMetadata; +use tokio::sync::RwLock; + +/// A simple OIDC metadata cache +/// +/// It never evicts entries, does not cache failures and has no locking. +/// It can also be refreshed in the background, and warmed up on startup. +/// It is good enough for our use case. +#[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone, Default)] +pub struct MetadataCache { + cache: Arc>>, +} + +impl MetadataCache { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Warm up the cache by fetching all the known providers from the database + /// and inserting them into the cache. + /// + /// This spawns a background task that will refresh the cache at the given + /// interval. + #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)] + pub async fn warm_up_and_run( + &self, + http_service: HttpService, + interval: std::time::Duration, + repository: &mut R, + ) -> Result, R::Error> { + let providers = repository.upstream_oauth_provider().all().await?; + + for provider in providers { + if let Err(e) = self.fetch(&http_service, &provider.issuer).await { + tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata"); + } + } + + // Spawn a background task to refresh the cache regularly + let cache = self.clone(); + Ok(tokio::spawn(async move { + loop { + // Re-fetch the known metadata at the given interval + tokio::time::sleep(interval).await; + cache.refresh_all(&http_service).await; + } + })) + } + + #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)] + async fn fetch( + &self, + http_service: &HttpService, + issuer: &str, + ) -> Result { + let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?; + + self.cache + .write() + .await + .insert(issuer.to_owned(), metadata.clone()); + + Ok(metadata) + } + + /// Get the metadata for the given issuer. + #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)] + pub async fn get( + &self, + http_service: &HttpService, + issuer: &str, + ) -> Result { + let cache = self.cache.read().await; + if let Some(metadata) = cache.get(issuer) { + return Ok(metadata.clone()); + } + + let metadata = self.fetch(http_service, issuer).await?; + Ok(metadata) + } + + #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)] + async fn refresh_all(&self, http_service: &HttpService) { + // Grab all the keys first to avoid locking the cache for too long + let keys: Vec = { + let cache = self.cache.read().await; + cache.keys().cloned().collect() + }; + + for issuer in keys { + if let Err(e) = self.fetch(http_service, &issuer).await { + tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata"); + } + } + } +} diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 4e1131fd..5a32f949 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -37,7 +37,7 @@ use thiserror::Error; use ulid::Ulid; use super::{client_credentials_for_provider, UpstreamSessionsCookie}; -use crate::impl_from_error_for_route; +use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache}; #[derive(Deserialize)] pub struct QueryParams { @@ -128,6 +128,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, + State(metadata_cache): State, mut repo: BoxRepository, State(url_builder): State, State(encrypter): State, @@ -185,10 +186,8 @@ pub(crate) async fn get( let http_service = http_client_factory.http_service().await?; - // XXX: we shouldn't discover on-the-fly // Discover the provider - let metadata = - mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?; + let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; // Fetch the JWKS let jwks = diff --git a/crates/handlers/src/upstream_oauth2/mod.rs b/crates/handlers/src/upstream_oauth2/mod.rs index bc1d4c22..e9974f71 100644 --- a/crates/handlers/src/upstream_oauth2/mod.rs +++ b/crates/handlers/src/upstream_oauth2/mod.rs @@ -22,6 +22,7 @@ use thiserror::Error; use url::Url; pub(crate) mod authorize; +pub(crate) mod cache; pub(crate) mod callback; mod cookie; pub(crate) mod link;