You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Cache the upstream OAuth 2.0 provider metadata
This commit is contained in:
@@ -28,7 +28,10 @@ use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
|
||||
use crate::{
|
||||
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
|
||||
views::shared::OptionalPostAuthAction,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
@@ -64,6 +67,7 @@ pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(metadata_cache): State<MetadataCache>,
|
||||
mut repo: BoxRepository,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
cookie_jar: CookieJar,
|
||||
@@ -79,8 +83,7 @@ pub(crate) async fn get(
|
||||
let http_service = http_client_factory.http_service().await?;
|
||||
|
||||
// First, discover the provider
|
||||
let metadata =
|
||||
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
||||
|
||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||
|
||||
|
||||
117
crates/handlers/src/upstream_oauth2/cache.rs
Normal file
117
crates/handlers/src/upstream_oauth2/cache.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use mas_http::HttpService;
|
||||
use mas_oidc_client::error::DiscoveryError;
|
||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
|
||||
use oauth2_types::oidc::VerifiedProviderMetadata;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// A simple OIDC metadata cache
|
||||
///
|
||||
/// It never evicts entries, does not cache failures and has no locking.
|
||||
/// It can also be refreshed in the background, and warmed up on startup.
|
||||
/// It is good enough for our use case.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MetadataCache {
|
||||
cache: Arc<RwLock<HashMap<String, VerifiedProviderMetadata>>>,
|
||||
}
|
||||
|
||||
impl MetadataCache {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Warm up the cache by fetching all the known providers from the database
|
||||
/// and inserting them into the cache.
|
||||
///
|
||||
/// This spawns a background task that will refresh the cache at the given
|
||||
/// interval.
|
||||
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
|
||||
pub async fn warm_up_and_run<R: RepositoryAccess>(
|
||||
&self,
|
||||
http_service: HttpService,
|
||||
interval: std::time::Duration,
|
||||
repository: &mut R,
|
||||
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
|
||||
let providers = repository.upstream_oauth_provider().all().await?;
|
||||
|
||||
for provider in providers {
|
||||
if let Err(e) = self.fetch(&http_service, &provider.issuer).await {
|
||||
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn a background task to refresh the cache regularly
|
||||
let cache = self.clone();
|
||||
Ok(tokio::spawn(async move {
|
||||
loop {
|
||||
// Re-fetch the known metadata at the given interval
|
||||
tokio::time::sleep(interval).await;
|
||||
cache.refresh_all(&http_service).await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
|
||||
async fn fetch(
|
||||
&self,
|
||||
http_service: &HttpService,
|
||||
issuer: &str,
|
||||
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
|
||||
let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
|
||||
|
||||
self.cache
|
||||
.write()
|
||||
.await
|
||||
.insert(issuer.to_owned(), metadata.clone());
|
||||
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
/// Get the metadata for the given issuer.
|
||||
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
|
||||
pub async fn get(
|
||||
&self,
|
||||
http_service: &HttpService,
|
||||
issuer: &str,
|
||||
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
|
||||
let cache = self.cache.read().await;
|
||||
if let Some(metadata) = cache.get(issuer) {
|
||||
return Ok(metadata.clone());
|
||||
}
|
||||
|
||||
let metadata = self.fetch(http_service, issuer).await?;
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
|
||||
async fn refresh_all(&self, http_service: &HttpService) {
|
||||
// Grab all the keys first to avoid locking the cache for too long
|
||||
let keys: Vec<String> = {
|
||||
let cache = self.cache.read().await;
|
||||
cache.keys().cloned().collect()
|
||||
};
|
||||
|
||||
for issuer in keys {
|
||||
if let Err(e) = self.fetch(http_service, &issuer).await {
|
||||
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::{client_credentials_for_provider, UpstreamSessionsCookie};
|
||||
use crate::impl_from_error_for_route;
|
||||
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
@@ -128,6 +128,7 @@ pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
State(http_client_factory): State<HttpClientFactory>,
|
||||
State(metadata_cache): State<MetadataCache>,
|
||||
mut repo: BoxRepository,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(encrypter): State<Encrypter>,
|
||||
@@ -185,10 +186,8 @@ pub(crate) async fn get(
|
||||
|
||||
let http_service = http_client_factory.http_service().await?;
|
||||
|
||||
// XXX: we shouldn't discover on-the-fly
|
||||
// Discover the provider
|
||||
let metadata =
|
||||
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
||||
|
||||
// Fetch the JWKS
|
||||
let jwks =
|
||||
|
||||
@@ -22,6 +22,7 @@ use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
pub(crate) mod authorize;
|
||||
pub(crate) mod cache;
|
||||
pub(crate) mod callback;
|
||||
mod cookie;
|
||||
pub(crate) mod link;
|
||||
|
||||
Reference in New Issue
Block a user