1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Cache the upstream OAuth 2.0 provider metadata

This commit is contained in:
Quentin Gliech
2023-08-28 16:18:49 +02:00
parent 17e28f56c1
commit 07ca145174
8 changed files with 187 additions and 10 deletions

View File

@ -18,7 +18,7 @@ use anyhow::Context;
use clap::Parser;
use itertools::Itertools;
use mas_config::AppConfig;
use mas_handlers::{AppState, CookieManager, HttpClientFactory, MatrixHomeserver};
use mas_handlers::{AppState, CookieManager, HttpClientFactory, MatrixHomeserver, MetadataCache};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
@ -127,6 +127,9 @@ impl Options {
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
// The upstream OIDC metadata cache
let metadata_cache = MetadataCache::new();
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
@ -147,6 +150,7 @@ impl Options {
pool,
templates,
key_store,
metadata_cache,
cookie_manager,
encrypter,
url_builder,
@ -158,6 +162,8 @@ impl Options {
conn_acquisition_histogram: None,
};
s.init_metrics()?;
// XXX: this might panic
s.init_metadata_cache().await;
s
};

View File

@ -35,7 +35,7 @@ use rand::SeedableRng;
use sqlx::PgPool;
use thiserror::Error;
use crate::{passwords::PasswordManager, MatrixHomeserver};
use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver};
#[derive(Clone)]
pub struct AppState {
@ -50,6 +50,7 @@ pub struct AppState {
pub graphql_schema: mas_graphql::Schema,
pub http_client_factory: HttpClientFactory,
pub password_manager: PasswordManager,
pub metadata_cache: MetadataCache,
pub conn_acquisition_histogram: Option<Histogram<u64>>,
}
@ -100,6 +101,37 @@ impl AppState {
Ok(())
}
/// Init the metadata cache.
///
/// # Panics
///
/// Panics if the metadata cache could not be initialized.
pub async fn init_metadata_cache(&self) {
// XXX: this panics because the error is annoying to propagate
let conn = self
.pool
.acquire()
.await
.expect("Failed to acquire a database connection");
let mut repo = PgRepository::from_conn(conn);
let http_service = self
.http_client_factory
.http_service()
.await
.expect("Failed to create the HTTP service");
self.metadata_cache
.warm_up_and_run(
http_service,
std::time::Duration::from_secs(60 * 15),
&mut repo,
)
.await
.expect("Failed to warm up the metadata cache");
}
}
impl FromRef<AppState> for PgPool {
@ -168,6 +200,12 @@ impl FromRef<AppState> for CookieManager {
}
}
impl FromRef<AppState> for MetadataCache {
fn from_ref(input: &AppState) -> Self {
input.metadata_cache.clone()
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxClock {
type Rejection = Infallible;

View File

@ -65,7 +65,7 @@ mod graphql;
mod health;
mod oauth2;
pub mod passwords;
mod upstream_oauth2;
pub mod upstream_oauth2;
mod views;
#[cfg(test)]
@ -90,6 +90,7 @@ macro_rules! impl_from_error_for_route {
pub use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
pub use self::{app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema};
pub use crate::upstream_oauth2::cache::MetadataCache;
pub fn healthcheck_router<S, B>() -> Router<S, B>
where
@ -274,6 +275,7 @@ where
Keystore: FromRef<S>,
HttpClientFactory: FromRef<S>,
PasswordManager: FromRef<S>,
MetadataCache: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
{

View File

@ -48,6 +48,7 @@ use url::Url;
use crate::{
app_state::RepositoryError,
passwords::{Hasher, PasswordManager},
upstream_oauth2::cache::MetadataCache,
MatrixHomeserver,
};
@ -67,6 +68,7 @@ pub(crate) struct TestState {
pub templates: Templates,
pub key_store: Keystore,
pub cookie_manager: CookieManager,
pub metadata_cache: MetadataCache,
pub encrypter: Encrypter,
pub url_builder: UrlBuilder,
pub homeserver: MatrixHomeserver,
@ -106,6 +108,8 @@ impl TestState {
let cookie_manager =
CookieManager::derive_from("https://example.com".parse()?, &[0x42; 32]);
let metadata_cache = MetadataCache::new();
let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?;
let homeserver = MatrixHomeserver::new("example.com".to_owned());
@ -146,6 +150,7 @@ impl TestState {
templates,
key_store,
cookie_manager,
metadata_cache,
encrypter,
url_builder,
homeserver,
@ -334,6 +339,12 @@ impl FromRef<TestState> for CookieManager {
}
}
impl FromRef<TestState> for MetadataCache {
fn from_ref(input: &TestState) -> Self {
input.metadata_cache.clone()
}
}
#[async_trait]
impl FromRequestParts<TestState> for BoxClock {
type Rejection = Infallible;

View File

@ -28,7 +28,10 @@ use thiserror::Error;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
use crate::{
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
views::shared::OptionalPostAuthAction,
};
#[derive(Debug, Error)]
pub(crate) enum RouteError {
@ -64,6 +67,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(metadata_cache): State<MetadataCache>,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
cookie_jar: CookieJar,
@ -79,8 +83,7 @@ pub(crate) async fn get(
let http_service = http_client_factory.http_service().await?;
// First, discover the provider
let metadata =
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);

View File

@ -0,0 +1,117 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, sync::Arc};
use mas_http::HttpService;
use mas_oidc_client::error::DiscoveryError;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
use oauth2_types::oidc::VerifiedProviderMetadata;
use tokio::sync::RwLock;
/// A simple OIDC metadata cache
///
/// It never evicts entries, does not cache failures and has no locking.
/// It can also be refreshed in the background, and warmed up on startup.
/// It is good enough for our use case.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Default)]
pub struct MetadataCache {
cache: Arc<RwLock<HashMap<String, VerifiedProviderMetadata>>>,
}
impl MetadataCache {
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Warm up the cache by fetching all the known providers from the database
/// and inserting them into the cache.
///
/// This spawns a background task that will refresh the cache at the given
/// interval.
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
pub async fn warm_up_and_run<R: RepositoryAccess>(
&self,
http_service: HttpService,
interval: std::time::Duration,
repository: &mut R,
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
let providers = repository.upstream_oauth_provider().all().await?;
for provider in providers {
if let Err(e) = self.fetch(&http_service, &provider.issuer).await {
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
}
}
// Spawn a background task to refresh the cache regularly
let cache = self.clone();
Ok(tokio::spawn(async move {
loop {
// Re-fetch the known metadata at the given interval
tokio::time::sleep(interval).await;
cache.refresh_all(&http_service).await;
}
}))
}
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
async fn fetch(
&self,
http_service: &HttpService,
issuer: &str,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
self.cache
.write()
.await
.insert(issuer.to_owned(), metadata.clone());
Ok(metadata)
}
/// Get the metadata for the given issuer.
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
pub async fn get(
&self,
http_service: &HttpService,
issuer: &str,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let cache = self.cache.read().await;
if let Some(metadata) = cache.get(issuer) {
return Ok(metadata.clone());
}
let metadata = self.fetch(http_service, issuer).await?;
Ok(metadata)
}
#[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
async fn refresh_all(&self, http_service: &HttpService) {
// Grab all the keys first to avoid locking the cache for too long
let keys: Vec<String> = {
let cache = self.cache.read().await;
cache.keys().cloned().collect()
};
for issuer in keys {
if let Err(e) = self.fetch(http_service, &issuer).await {
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
}
}
}
}

View File

@ -37,7 +37,7 @@ use thiserror::Error;
use ulid::Ulid;
use super::{client_credentials_for_provider, UpstreamSessionsCookie};
use crate::impl_from_error_for_route;
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
#[derive(Deserialize)]
pub struct QueryParams {
@ -128,6 +128,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(metadata_cache): State<MetadataCache>,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
State(encrypter): State<Encrypter>,
@ -185,10 +186,8 @@ pub(crate) async fn get(
let http_service = http_client_factory.http_service().await?;
// XXX: we shouldn't discover on-the-fly
// Discover the provider
let metadata =
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
// Fetch the JWKS
let jwks =

View File

@ -22,6 +22,7 @@ use thiserror::Error;
use url::Url;
pub(crate) mod authorize;
pub(crate) mod cache;
pub(crate) mod callback;
mod cookie;
pub(crate) mod link;