You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Cache the upstream OAuth 2.0 provider metadata
This commit is contained in:
@ -18,7 +18,7 @@ use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use itertools::Itertools;
|
||||
use mas_config::AppConfig;
|
||||
use mas_handlers::{AppState, CookieManager, HttpClientFactory, MatrixHomeserver};
|
||||
use mas_handlers::{AppState, CookieManager, HttpClientFactory, MatrixHomeserver, MetadataCache};
|
||||
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||
use mas_matrix_synapse::SynapseConnection;
|
||||
use mas_router::UrlBuilder;
|
||||
@ -127,6 +127,9 @@ impl Options {
|
||||
// Maximum 50 outgoing HTTP requests at a time
|
||||
let http_client_factory = HttpClientFactory::new(50);
|
||||
|
||||
// The upstream OIDC metadata cache
|
||||
let metadata_cache = MetadataCache::new();
|
||||
|
||||
let conn = SynapseConnection::new(
|
||||
config.matrix.homeserver.clone(),
|
||||
config.matrix.endpoint.clone(),
|
||||
@ -147,6 +150,7 @@ impl Options {
|
||||
pool,
|
||||
templates,
|
||||
key_store,
|
||||
metadata_cache,
|
||||
cookie_manager,
|
||||
encrypter,
|
||||
url_builder,
|
||||
@ -158,6 +162,8 @@ impl Options {
|
||||
conn_acquisition_histogram: None,
|
||||
};
|
||||
s.init_metrics()?;
|
||||
// XXX: this might panic
|
||||
s.init_metadata_cache().await;
|
||||
s
|
||||
};
|
||||
|
||||
|
@ -35,7 +35,7 @@ use rand::SeedableRng;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{passwords::PasswordManager, MatrixHomeserver};
|
||||
use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@ -50,6 +50,7 @@ pub struct AppState {
|
||||
pub graphql_schema: mas_graphql::Schema,
|
||||
pub http_client_factory: HttpClientFactory,
|
||||
pub password_manager: PasswordManager,
|
||||
pub metadata_cache: MetadataCache,
|
||||
pub conn_acquisition_histogram: Option<Histogram<u64>>,
|
||||
}
|
||||
|
||||
@ -100,6 +101,37 @@ impl AppState {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Init the metadata cache.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the metadata cache could not be initialized.
|
||||
pub async fn init_metadata_cache(&self) {
|
||||
// XXX: this panics because the error is annoying to propagate
|
||||
let conn = self
|
||||
.pool
|
||||
.acquire()
|
||||
.await
|
||||
.expect("Failed to acquire a database connection");
|
||||
|
||||
let mut repo = PgRepository::from_conn(conn);
|
||||
|
||||
let http_service = self
|
||||
.http_client_factory
|
||||
.http_service()
|
||||
.await
|
||||
.expect("Failed to create the HTTP service");
|
||||
|
||||
self.metadata_cache
|
||||
.warm_up_and_run(
|
||||
http_service,
|
||||
std::time::Duration::from_secs(60 * 15),
|
||||
&mut repo,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to warm up the metadata cache");
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for PgPool {
|
||||
@ -168,6 +200,12 @@ impl FromRef<AppState> for CookieManager {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for MetadataCache {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.metadata_cache.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<AppState> for BoxClock {
|
||||
type Rejection = Infallible;
|
||||
|
@ -65,7 +65,7 @@ mod graphql;
|
||||
mod health;
|
||||
mod oauth2;
|
||||
pub mod passwords;
|
||||
mod upstream_oauth2;
|
||||
pub mod upstream_oauth2;
|
||||
mod views;
|
||||
|
||||
#[cfg(test)]
|
||||
@ -90,6 +90,7 @@ macro_rules! impl_from_error_for_route {
|
||||
pub use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
|
||||
|
||||
pub use self::{app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema};
|
||||
pub use crate::upstream_oauth2::cache::MetadataCache;
|
||||
|
||||
pub fn healthcheck_router<S, B>() -> Router<S, B>
|
||||
where
|
||||
@ -274,6 +275,7 @@ where
|
||||
Keystore: FromRef<S>,
|
||||
HttpClientFactory: FromRef<S>,
|
||||
PasswordManager: FromRef<S>,
|
||||
MetadataCache: FromRef<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
{
|
||||
|
@ -48,6 +48,7 @@ use url::Url;
|
||||
use crate::{
|
||||
app_state::RepositoryError,
|
||||
passwords::{Hasher, PasswordManager},
|
||||
upstream_oauth2::cache::MetadataCache,
|
||||
MatrixHomeserver,
|
||||
};
|
||||
|
||||
@ -67,6 +68,7 @@ pub(crate) struct TestState {
|
||||
pub templates: Templates,
|
||||
pub key_store: Keystore,
|
||||
pub cookie_manager: CookieManager,
|
||||
pub metadata_cache: MetadataCache,
|
||||
pub encrypter: Encrypter,
|
||||
pub url_builder: UrlBuilder,
|
||||
pub homeserver: MatrixHomeserver,
|
||||
@ -106,6 +108,8 @@ impl TestState {
|
||||
let cookie_manager =
|
||||
CookieManager::derive_from("https://example.com".parse()?, &[0x42; 32]);
|
||||
|
||||
let metadata_cache = MetadataCache::new();
|
||||
|
||||
let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?;
|
||||
|
||||
let homeserver = MatrixHomeserver::new("example.com".to_owned());
|
||||
@ -146,6 +150,7 @@ impl TestState {
|
||||
templates,
|
||||
key_store,
|
||||
cookie_manager,
|
||||
metadata_cache,
|
||||
encrypter,
|
||||
url_builder,
|
||||
homeserver,
|
||||
@ -334,6 +339,12 @@ impl FromRef<TestState> for CookieManager {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<TestState> for MetadataCache {
|
||||
fn from_ref(input: &TestState) -> Self {
|
||||
input.metadata_cache.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<TestState> for BoxClock {
|
||||
type Rejection = Infallible;
|
||||
|
@ -28,7 +28,10 @@ use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
|
||||
use crate::{
|
||||
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
|
||||
views::shared::OptionalPostAuthAction,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
@ -64,6 +67,7 @@ pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(metadata_cache): State<MetadataCache>,
|
||||
mut repo: BoxRepository,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
cookie_jar: CookieJar,
|
||||
@ -79,8 +83,7 @@ pub(crate) async fn get(
|
||||
let http_service = http_client_factory.http_service().await?;
|
||||
|
||||
// First, discover the provider
|
||||
let metadata =
|
||||
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
||||
|
||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||
|
||||
|
117
crates/handlers/src/upstream_oauth2/cache.rs
Normal file
117
crates/handlers/src/upstream_oauth2/cache.rs
Normal file
@ -0,0 +1,117 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use mas_http::HttpService;
|
||||
use mas_oidc_client::error::DiscoveryError;
|
||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
|
||||
use oauth2_types::oidc::VerifiedProviderMetadata;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// A simple OIDC metadata cache
|
||||
///
|
||||
/// It never evicts entries, does not cache failures and has no locking.
|
||||
/// It can also be refreshed in the background, and warmed up on startup.
|
||||
/// It is good enough for our use case.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MetadataCache {
|
||||
cache: Arc<RwLock<HashMap<String, VerifiedProviderMetadata>>>,
|
||||
}
|
||||
|
||||
impl MetadataCache {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Warm up the cache by fetching all the known providers from the database
|
||||
/// and inserting them into the cache.
|
||||
///
|
||||
/// This spawns a background task that will refresh the cache at the given
|
||||
/// interval.
|
||||
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
|
||||
pub async fn warm_up_and_run<R: RepositoryAccess>(
|
||||
&self,
|
||||
http_service: HttpService,
|
||||
interval: std::time::Duration,
|
||||
repository: &mut R,
|
||||
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
|
||||
let providers = repository.upstream_oauth_provider().all().await?;
|
||||
|
||||
for provider in providers {
|
||||
if let Err(e) = self.fetch(&http_service, &provider.issuer).await {
|
||||
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn a background task to refresh the cache regularly
|
||||
let cache = self.clone();
|
||||
Ok(tokio::spawn(async move {
|
||||
loop {
|
||||
// Re-fetch the known metadata at the given interval
|
||||
tokio::time::sleep(interval).await;
|
||||
cache.refresh_all(&http_service).await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
|
||||
async fn fetch(
|
||||
&self,
|
||||
http_service: &HttpService,
|
||||
issuer: &str,
|
||||
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
|
||||
let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
|
||||
|
||||
self.cache
|
||||
.write()
|
||||
.await
|
||||
.insert(issuer.to_owned(), metadata.clone());
|
||||
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
/// Get the metadata for the given issuer.
|
||||
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
|
||||
pub async fn get(
|
||||
&self,
|
||||
http_service: &HttpService,
|
||||
issuer: &str,
|
||||
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
|
||||
let cache = self.cache.read().await;
|
||||
if let Some(metadata) = cache.get(issuer) {
|
||||
return Ok(metadata.clone());
|
||||
}
|
||||
|
||||
let metadata = self.fetch(http_service, issuer).await?;
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
|
||||
async fn refresh_all(&self, http_service: &HttpService) {
|
||||
// Grab all the keys first to avoid locking the cache for too long
|
||||
let keys: Vec<String> = {
|
||||
let cache = self.cache.read().await;
|
||||
cache.keys().cloned().collect()
|
||||
};
|
||||
|
||||
for issuer in keys {
|
||||
if let Err(e) = self.fetch(http_service, &issuer).await {
|
||||
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -37,7 +37,7 @@ use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::{client_credentials_for_provider, UpstreamSessionsCookie};
|
||||
use crate::impl_from_error_for_route;
|
||||
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
@ -128,6 +128,7 @@ pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(metadata_cache): State<MetadataCache>,
|
||||
mut repo: BoxRepository,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(encrypter): State<Encrypter>,
|
||||
@ -185,10 +186,8 @@ pub(crate) async fn get(
|
||||
|
||||
let http_service = http_client_factory.http_service().await?;
|
||||
|
||||
// XXX: we shouldn't discover on-the-fly
|
||||
// Discover the provider
|
||||
let metadata =
|
||||
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
||||
|
||||
// Fetch the JWKS
|
||||
let jwks =
|
||||
|
@ -22,6 +22,7 @@ use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
pub(crate) mod authorize;
|
||||
pub(crate) mod cache;
|
||||
pub(crate) mod callback;
|
||||
mod cookie;
|
||||
pub(crate) mod link;
|
||||
|
Reference in New Issue
Block a user