1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

Cache the upstream OAuth 2.0 provider metadata

This commit is contained in:
Quentin Gliech
2023-08-28 16:18:49 +02:00
parent 17e28f56c1
commit 07ca145174
8 changed files with 187 additions and 10 deletions

View File

@@ -28,7 +28,10 @@ use thiserror::Error;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
use crate::{
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
views::shared::OptionalPostAuthAction,
};
#[derive(Debug, Error)]
pub(crate) enum RouteError {
@@ -64,6 +67,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(metadata_cache): State<MetadataCache>,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
cookie_jar: CookieJar,
@@ -79,8 +83,7 @@ pub(crate) async fn get(
let http_service = http_client_factory.http_service().await?;
// First, discover the provider
let metadata =
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);

View File

@@ -0,0 +1,117 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, sync::Arc};
use mas_http::HttpService;
use mas_oidc_client::error::DiscoveryError;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
use oauth2_types::oidc::VerifiedProviderMetadata;
use tokio::sync::RwLock;
/// A simple OIDC metadata cache
///
/// It never evicts entries, does not cache failures and has no locking.
/// It can also be refreshed in the background, and warmed up on startup.
/// It is good enough for our use case.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Default)]
pub struct MetadataCache {
cache: Arc<RwLock<HashMap<String, VerifiedProviderMetadata>>>,
}
impl MetadataCache {
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Warm up the cache by fetching all the known providers from the database
/// and inserting them into the cache.
///
/// This spawns a background task that will refresh the cache at the given
/// interval.
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
pub async fn warm_up_and_run<R: RepositoryAccess>(
&self,
http_service: HttpService,
interval: std::time::Duration,
repository: &mut R,
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
let providers = repository.upstream_oauth_provider().all().await?;
for provider in providers {
if let Err(e) = self.fetch(&http_service, &provider.issuer).await {
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
}
}
// Spawn a background task to refresh the cache regularly
let cache = self.clone();
Ok(tokio::spawn(async move {
loop {
// Re-fetch the known metadata at the given interval
tokio::time::sleep(interval).await;
cache.refresh_all(&http_service).await;
}
}))
}
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
async fn fetch(
&self,
http_service: &HttpService,
issuer: &str,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
self.cache
.write()
.await
.insert(issuer.to_owned(), metadata.clone());
Ok(metadata)
}
/// Get the metadata for the given issuer.
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
pub async fn get(
&self,
http_service: &HttpService,
issuer: &str,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let cache = self.cache.read().await;
if let Some(metadata) = cache.get(issuer) {
return Ok(metadata.clone());
}
let metadata = self.fetch(http_service, issuer).await?;
Ok(metadata)
}
#[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
async fn refresh_all(&self, http_service: &HttpService) {
// Grab all the keys first to avoid locking the cache for too long
let keys: Vec<String> = {
let cache = self.cache.read().await;
cache.keys().cloned().collect()
};
for issuer in keys {
if let Err(e) = self.fetch(http_service, &issuer).await {
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
}
}
}
}

View File

@@ -37,7 +37,7 @@ use thiserror::Error;
use ulid::Ulid;
use super::{client_credentials_for_provider, UpstreamSessionsCookie};
use crate::impl_from_error_for_route;
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
#[derive(Deserialize)]
pub struct QueryParams {
@@ -128,6 +128,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
State(metadata_cache): State<MetadataCache>,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
State(encrypter): State<Encrypter>,
@@ -185,10 +186,8 @@ pub(crate) async fn get(
let http_service = http_client_factory.http_service().await?;
// XXX: we shouldn't discover on-the-fly
// Discover the provider
let metadata =
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
// Fetch the JWKS
let jwks =

View File

@@ -22,6 +22,7 @@ use thiserror::Error;
use url::Url;
pub(crate) mod authorize;
pub(crate) mod cache;
pub(crate) mod callback;
mod cookie;
pub(crate) mod link;