You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Allow overriding usptream OAuth2 providers endpoints
Also have a way to disable OIDC discovery when all the endpoints are known.
This commit is contained in:
@ -48,8 +48,9 @@ pub use self::{
|
|||||||
upstream_oauth2::{
|
upstream_oauth2::{
|
||||||
UpsreamOAuthProviderSetEmailVerification, UpstreamOAuthAuthorizationSession,
|
UpsreamOAuthProviderSetEmailVerification, UpstreamOAuthAuthorizationSession,
|
||||||
UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider,
|
UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider,
|
||||||
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportAction,
|
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
|
||||||
UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderSubjectPreference,
|
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
|
||||||
|
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference,
|
||||||
},
|
},
|
||||||
users::{
|
users::{
|
||||||
Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail,
|
Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail,
|
||||||
|
@ -20,8 +20,10 @@ pub use self::{
|
|||||||
link::UpstreamOAuthLink,
|
link::UpstreamOAuthLink,
|
||||||
provider::{
|
provider::{
|
||||||
ClaimsImports as UpstreamOAuthProviderClaimsImports,
|
ClaimsImports as UpstreamOAuthProviderClaimsImports,
|
||||||
|
DiscoveryMode as UpstreamOAuthProviderDiscoveryMode,
|
||||||
ImportAction as UpstreamOAuthProviderImportAction,
|
ImportAction as UpstreamOAuthProviderImportAction,
|
||||||
ImportPreference as UpstreamOAuthProviderImportPreference,
|
ImportPreference as UpstreamOAuthProviderImportPreference,
|
||||||
|
PkceMode as UpstreamOAuthProviderPkceMode,
|
||||||
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
|
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
|
||||||
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
|
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
|
||||||
},
|
},
|
||||||
|
@ -17,11 +17,45 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod
|
|||||||
use oauth2_types::scope::Scope;
|
use oauth2_types::scope::Scope;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum DiscoveryMode {
|
||||||
|
/// Use OIDC discovery to fetch and verify the provider metadata
|
||||||
|
#[default]
|
||||||
|
Oidc,
|
||||||
|
|
||||||
|
/// Use OIDC discovery to fetch the provider metadata, but don't verify it
|
||||||
|
Insecure,
|
||||||
|
|
||||||
|
/// Don't fetch the provider metadata
|
||||||
|
Disabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum PkceMode {
|
||||||
|
/// Use PKCE if the provider supports it
|
||||||
|
#[default]
|
||||||
|
Auto,
|
||||||
|
|
||||||
|
/// Always use PKCE with the S256 method
|
||||||
|
S256,
|
||||||
|
|
||||||
|
/// Don't use PKCE
|
||||||
|
Disabled,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||||
pub struct UpstreamOAuthProvider {
|
pub struct UpstreamOAuthProvider {
|
||||||
pub id: Ulid,
|
pub id: Ulid,
|
||||||
pub issuer: String,
|
pub issuer: String,
|
||||||
|
pub discovery_mode: DiscoveryMode,
|
||||||
|
pub pkce_mode: PkceMode,
|
||||||
|
pub jwks_uri_override: Option<Url>,
|
||||||
|
pub authorization_endpoint_override: Option<Url>,
|
||||||
|
pub token_endpoint_override: Option<Url>,
|
||||||
pub scope: Scope,
|
pub scope: Scope,
|
||||||
pub client_id: String,
|
pub client_id: String,
|
||||||
pub encrypted_client_secret: Option<String>,
|
pub encrypted_client_secret: Option<String>,
|
||||||
|
@ -29,7 +29,7 @@ use mas_storage::{
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
use super::UpstreamSessionsCookie;
|
use super::{cache::LazyProviderInfos, UpstreamSessionsCookie};
|
||||||
use crate::{
|
use crate::{
|
||||||
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
|
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
|
||||||
views::shared::OptionalPostAuthAction,
|
views::shared::OptionalPostAuthAction,
|
||||||
@ -87,23 +87,28 @@ pub(crate) async fn get(
|
|||||||
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
|
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
|
||||||
|
|
||||||
// First, discover the provider
|
// First, discover the provider
|
||||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
// This is done lazyly according to provider.discovery_mode and the various
|
||||||
|
// endpoint overrides
|
||||||
|
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service);
|
||||||
|
lazy_metadata.maybe_discover().await?;
|
||||||
|
|
||||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||||
|
|
||||||
let mut data = AuthorizationRequestData::new(
|
let data = AuthorizationRequestData::new(
|
||||||
provider.client_id.clone(),
|
provider.client_id.clone(),
|
||||||
provider.scope.clone(),
|
provider.scope.clone(),
|
||||||
redirect_uri,
|
redirect_uri,
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(methods) = metadata.code_challenge_methods_supported.clone() {
|
let data = if let Some(methods) = lazy_metadata.pkce_methods().await? {
|
||||||
data = data.with_code_challenge_methods_supported(methods);
|
data.with_code_challenge_methods_supported(methods)
|
||||||
}
|
} else {
|
||||||
|
data
|
||||||
|
};
|
||||||
|
|
||||||
// Build an authorization request for it
|
// Build an authorization request for it
|
||||||
let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
|
let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
|
||||||
metadata.authorization_endpoint().clone(),
|
lazy_metadata.authorization_endpoint().await?.clone(),
|
||||||
data,
|
data,
|
||||||
&mut rng,
|
&mut rng,
|
||||||
)?;
|
)?;
|
||||||
|
@ -14,11 +14,128 @@
|
|||||||
|
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use mas_data_model::{
|
||||||
|
UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
|
||||||
|
};
|
||||||
use mas_http::HttpService;
|
use mas_http::HttpService;
|
||||||
|
use mas_iana::oauth::PkceCodeChallengeMethod;
|
||||||
use mas_oidc_client::error::DiscoveryError;
|
use mas_oidc_client::error::DiscoveryError;
|
||||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
|
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
|
||||||
use oauth2_types::oidc::VerifiedProviderMetadata;
|
use oauth2_types::oidc::VerifiedProviderMetadata;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
/// A high-level layer over metadata cache and provider configuration, which
|
||||||
|
/// resolves endpoint overrides and discovery modes.
|
||||||
|
pub struct LazyProviderInfos<'a> {
|
||||||
|
cache: &'a MetadataCache,
|
||||||
|
provider: &'a UpstreamOAuthProvider,
|
||||||
|
http_service: &'a HttpService,
|
||||||
|
loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> LazyProviderInfos<'a> {
|
||||||
|
pub fn new(
|
||||||
|
cache: &'a MetadataCache,
|
||||||
|
provider: &'a UpstreamOAuthProvider,
|
||||||
|
http_service: &'a HttpService,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
cache,
|
||||||
|
provider,
|
||||||
|
http_service,
|
||||||
|
loaded_metadata: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trigger the discovery process and return the metadata if discovery is
|
||||||
|
/// enabled.
|
||||||
|
pub async fn maybe_discover<'b>(
|
||||||
|
&'b mut self,
|
||||||
|
) -> Result<Option<&'b VerifiedProviderMetadata>, DiscoveryError> {
|
||||||
|
match self.load().await {
|
||||||
|
Ok(metadata) => Ok(Some(metadata)),
|
||||||
|
Err(DiscoveryError::Disabled) => Ok(None),
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load<'b>(&'b mut self) -> Result<&'b VerifiedProviderMetadata, DiscoveryError> {
|
||||||
|
if self.loaded_metadata.is_none() {
|
||||||
|
let verify = match self.provider.discovery_mode {
|
||||||
|
UpstreamOAuthProviderDiscoveryMode::Oidc => true,
|
||||||
|
UpstreamOAuthProviderDiscoveryMode::Insecure => false,
|
||||||
|
UpstreamOAuthProviderDiscoveryMode::Disabled => {
|
||||||
|
return Err(DiscoveryError::Disabled)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let metadata = self
|
||||||
|
.cache
|
||||||
|
.get(self.http_service, &self.provider.issuer, verify)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.loaded_metadata = Some(metadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.loaded_metadata.as_ref().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the JWKS URI for the provider.
|
||||||
|
///
|
||||||
|
/// Uses [`UpstreamOAuthProvider.jwks_uri_override`] if set, otherwise uses
|
||||||
|
/// the one from discovery.
|
||||||
|
pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
|
||||||
|
if let Some(jwks_uri) = &self.provider.jwks_uri_override {
|
||||||
|
return Ok(jwks_uri);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.load().await?.jwks_uri())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the authorization endpoint for the provider.
|
||||||
|
///
|
||||||
|
/// Uses [`UpstreamOAuthProvider.authorization_endpoint_override`] if set,
|
||||||
|
/// otherwise uses the one from discovery.
|
||||||
|
pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
|
||||||
|
if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
|
||||||
|
return Ok(authorization_endpoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.load().await?.authorization_endpoint())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the token endpoint for the provider.
|
||||||
|
///
|
||||||
|
/// Uses [`UpstreamOAuthProvider.token_endpoint_override`] if set, otherwise
|
||||||
|
/// uses the one from discovery.
|
||||||
|
pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
|
||||||
|
if let Some(token_endpoint) = &self.provider.token_endpoint_override {
|
||||||
|
return Ok(token_endpoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.load().await?.token_endpoint())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the PKCE methods supported by the provider.
|
||||||
|
///
|
||||||
|
/// If the mode is set to auto, it will use the ones from discovery,
|
||||||
|
/// defaulting to none if discovery is disabled.
|
||||||
|
pub async fn pkce_methods(
|
||||||
|
&mut self,
|
||||||
|
) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
|
||||||
|
let methods = match self.provider.pkce_mode {
|
||||||
|
UpstreamOAuthProviderPkceMode::Auto => self
|
||||||
|
.maybe_discover()
|
||||||
|
.await?
|
||||||
|
.and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
|
||||||
|
UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
|
||||||
|
UpstreamOAuthProviderPkceMode::Disabled => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(methods)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A simple OIDC metadata cache
|
/// A simple OIDC metadata cache
|
||||||
///
|
///
|
||||||
@ -28,7 +145,8 @@ use tokio::sync::RwLock;
|
|||||||
#[allow(clippy::module_name_repetitions)]
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct MetadataCache {
|
pub struct MetadataCache {
|
||||||
cache: Arc<RwLock<HashMap<String, VerifiedProviderMetadata>>>,
|
cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
|
||||||
|
insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MetadataCache {
|
impl MetadataCache {
|
||||||
@ -52,7 +170,13 @@ impl MetadataCache {
|
|||||||
let providers = repository.upstream_oauth_provider().all().await?;
|
let providers = repository.upstream_oauth_provider().all().await?;
|
||||||
|
|
||||||
for provider in providers {
|
for provider in providers {
|
||||||
if let Err(e) = self.fetch(&http_service, &provider.issuer).await {
|
let verify = match provider.discovery_mode {
|
||||||
|
UpstreamOAuthProviderDiscoveryMode::Oidc => true,
|
||||||
|
UpstreamOAuthProviderDiscoveryMode::Insecure => false,
|
||||||
|
UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e) = self.fetch(&http_service, &provider.issuer, verify).await {
|
||||||
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
|
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -73,15 +197,32 @@ impl MetadataCache {
|
|||||||
&self,
|
&self,
|
||||||
http_service: &HttpService,
|
http_service: &HttpService,
|
||||||
issuer: &str,
|
issuer: &str,
|
||||||
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
|
verify: bool,
|
||||||
let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
|
) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
|
||||||
|
if verify {
|
||||||
|
let metadata =
|
||||||
|
mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
|
||||||
|
let metadata = Arc::new(metadata);
|
||||||
|
|
||||||
self.cache
|
self.cache
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert(issuer.to_owned(), metadata.clone());
|
.insert(issuer.to_owned(), metadata.clone());
|
||||||
|
|
||||||
Ok(metadata)
|
Ok(metadata)
|
||||||
|
} else {
|
||||||
|
let metadata =
|
||||||
|
mas_oidc_client::requests::discovery::insecure_discover(http_service, issuer)
|
||||||
|
.await?;
|
||||||
|
let metadata = Arc::new(metadata);
|
||||||
|
|
||||||
|
self.insecure_cache
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(issuer.to_owned(), metadata.clone());
|
||||||
|
|
||||||
|
Ok(metadata)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the metadata for the given issuer.
|
/// Get the metadata for the given issuer.
|
||||||
@ -90,13 +231,21 @@ impl MetadataCache {
|
|||||||
&self,
|
&self,
|
||||||
http_service: &HttpService,
|
http_service: &HttpService,
|
||||||
issuer: &str,
|
issuer: &str,
|
||||||
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
|
verify: bool,
|
||||||
let cache = self.cache.read().await;
|
) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
|
||||||
if let Some(metadata) = cache.get(issuer) {
|
let cache = if verify {
|
||||||
return Ok(metadata.clone());
|
self.cache.read().await
|
||||||
}
|
} else {
|
||||||
|
self.insecure_cache.read().await
|
||||||
|
};
|
||||||
|
|
||||||
let metadata = self.fetch(http_service, issuer).await?;
|
if let Some(metadata) = cache.get(issuer) {
|
||||||
|
return Ok(Arc::clone(metadata));
|
||||||
|
}
|
||||||
|
// Drop the cache guard so that we don't deadlock when we try to fetch
|
||||||
|
drop(cache);
|
||||||
|
|
||||||
|
let metadata = self.fetch(http_service, issuer, verify).await?;
|
||||||
Ok(metadata)
|
Ok(metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,9 +258,369 @@ impl MetadataCache {
|
|||||||
};
|
};
|
||||||
|
|
||||||
for issuer in keys {
|
for issuer in keys {
|
||||||
if let Err(e) = self.fetch(http_service, &issuer).await {
|
if let Err(e) = self.fetch(http_service, &issuer, true).await {
|
||||||
|
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do the same for the insecure cache
|
||||||
|
let keys: Vec<String> = {
|
||||||
|
let cache = self.insecure_cache.read().await;
|
||||||
|
cache.keys().cloned().collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
for issuer in keys {
|
||||||
|
if let Err(e) = self.fetch(http_service, &issuer, false).await {
|
||||||
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
|
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::too_many_lines)]
|
||||||
|
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
|
use hyper::{body::Bytes, Request, Response, StatusCode};
|
||||||
|
use mas_data_model::UpstreamOAuthProviderClaimsImports;
|
||||||
|
use mas_http::BoxCloneSyncService;
|
||||||
|
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||||
|
use mas_storage::{clock::MockClock, Clock};
|
||||||
|
use oauth2_types::scope::{Scope, OPENID};
|
||||||
|
use tower::BoxError;
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
use crate::test_utils::init_tracing;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_metadata_cache() {
|
||||||
|
init_tracing();
|
||||||
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let closure_calls = Arc::clone(&calls);
|
||||||
|
let handler = move |req: Request<Bytes>| {
|
||||||
|
let calls = Arc::clone(&closure_calls);
|
||||||
|
async move {
|
||||||
|
calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
let body = match req.uri().authority().unwrap().as_str() {
|
||||||
|
"valid.example.com" => Bytes::from_static(
|
||||||
|
br#"{
|
||||||
|
"issuer": "https://valid.example.com/",
|
||||||
|
"authorization_endpoint": "https://valid.example.com/authorize",
|
||||||
|
"token_endpoint": "https://valid.example.com/token",
|
||||||
|
"jwks_uri": "https://valid.example.com/jwks",
|
||||||
|
"response_types_supported": [
|
||||||
|
"code"
|
||||||
|
],
|
||||||
|
"grant_types_supported": [
|
||||||
|
"authorization_code"
|
||||||
|
],
|
||||||
|
"subject_types_supported": [
|
||||||
|
"public"
|
||||||
|
],
|
||||||
|
"id_token_signing_alg_values_supported": [
|
||||||
|
"RS256"
|
||||||
|
],
|
||||||
|
"scopes_supported": [
|
||||||
|
"openid",
|
||||||
|
"profile",
|
||||||
|
"email"
|
||||||
|
]
|
||||||
|
}"#,
|
||||||
|
),
|
||||||
|
"insecure.example.com" => Bytes::from_static(
|
||||||
|
br#"{
|
||||||
|
"issuer": "http://insecure.example.com/",
|
||||||
|
"authorization_endpoint": "http://insecure.example.com/authorize",
|
||||||
|
"token_endpoint": "http://insecure.example.com/token",
|
||||||
|
"jwks_uri": "http://insecure.example.com/jwks",
|
||||||
|
"response_types_supported": [
|
||||||
|
"code"
|
||||||
|
],
|
||||||
|
"grant_types_supported": [
|
||||||
|
"authorization_code"
|
||||||
|
],
|
||||||
|
"subject_types_supported": [
|
||||||
|
"public"
|
||||||
|
],
|
||||||
|
"id_token_signing_alg_values_supported": [
|
||||||
|
"RS256"
|
||||||
|
],
|
||||||
|
"scopes_supported": [
|
||||||
|
"openid",
|
||||||
|
"profile",
|
||||||
|
"email"
|
||||||
|
]
|
||||||
|
}"#,
|
||||||
|
),
|
||||||
|
_ => Bytes::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut response = Response::new(body);
|
||||||
|
*response.status_mut() = StatusCode::OK;
|
||||||
|
Ok::<_, BoxError>(response)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let service = BoxCloneSyncService::new(tower::service_fn(handler));
|
||||||
|
let cache = MetadataCache::new();
|
||||||
|
|
||||||
|
// An inexistant issuer should fail
|
||||||
|
cache
|
||||||
|
.get(&service, "https://inexistant.example.com/", true)
|
||||||
|
.await
|
||||||
|
.unwrap_err();
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||||
|
|
||||||
|
// A valid issuer should succeed
|
||||||
|
cache
|
||||||
|
.get(&service, "https://valid.example.com/", true)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||||
|
|
||||||
|
// Calling again should not trigger a new fetch
|
||||||
|
cache
|
||||||
|
.get(&service, "https://valid.example.com/", true)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||||
|
|
||||||
|
// An insecure issuer should work with insecure discovery
|
||||||
|
cache
|
||||||
|
.get(&service, "http://insecure.example.com/", false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 3);
|
||||||
|
|
||||||
|
// Doing it again shpoild not trigger a new fetch
|
||||||
|
cache
|
||||||
|
.get(&service, "http://insecure.example.com/", false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 3);
|
||||||
|
|
||||||
|
// But it should fail with secure discovery
|
||||||
|
// Note that it still fetched because secure and insecure caches are distinct
|
||||||
|
cache
|
||||||
|
.get(&service, "http://insecure.example.com/", true)
|
||||||
|
.await
|
||||||
|
.unwrap_err();
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 4);
|
||||||
|
|
||||||
|
// Calling refresh should refresh all the known valid issuers
|
||||||
|
cache.refresh_all(&service).await;
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_lazy_provider_infos() {
|
||||||
|
init_tracing();
|
||||||
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let closure_calls = Arc::clone(&calls);
|
||||||
|
let handler = move |req: Request<Bytes>| {
|
||||||
|
let calls = Arc::clone(&closure_calls);
|
||||||
|
async move {
|
||||||
|
calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
let body = match req.uri().authority().unwrap().as_str() {
|
||||||
|
"valid.example.com" => Bytes::from_static(
|
||||||
|
br#"{
|
||||||
|
"issuer": "https://valid.example.com/",
|
||||||
|
"authorization_endpoint": "https://valid.example.com/authorize",
|
||||||
|
"token_endpoint": "https://valid.example.com/token",
|
||||||
|
"jwks_uri": "https://valid.example.com/jwks",
|
||||||
|
"response_types_supported": [
|
||||||
|
"code"
|
||||||
|
],
|
||||||
|
"grant_types_supported": [
|
||||||
|
"authorization_code"
|
||||||
|
],
|
||||||
|
"subject_types_supported": [
|
||||||
|
"public"
|
||||||
|
],
|
||||||
|
"id_token_signing_alg_values_supported": [
|
||||||
|
"RS256"
|
||||||
|
],
|
||||||
|
"scopes_supported": [
|
||||||
|
"openid",
|
||||||
|
"profile",
|
||||||
|
"email"
|
||||||
|
]
|
||||||
|
}"#,
|
||||||
|
),
|
||||||
|
"insecure.example.com" => Bytes::from_static(
|
||||||
|
br#"{
|
||||||
|
"issuer": "http://insecure.example.com/",
|
||||||
|
"authorization_endpoint": "http://insecure.example.com/authorize",
|
||||||
|
"token_endpoint": "http://insecure.example.com/token",
|
||||||
|
"jwks_uri": "http://insecure.example.com/jwks",
|
||||||
|
"response_types_supported": [
|
||||||
|
"code"
|
||||||
|
],
|
||||||
|
"grant_types_supported": [
|
||||||
|
"authorization_code"
|
||||||
|
],
|
||||||
|
"subject_types_supported": [
|
||||||
|
"public"
|
||||||
|
],
|
||||||
|
"id_token_signing_alg_values_supported": [
|
||||||
|
"RS256"
|
||||||
|
],
|
||||||
|
"scopes_supported": [
|
||||||
|
"openid",
|
||||||
|
"profile",
|
||||||
|
"email"
|
||||||
|
]
|
||||||
|
}"#,
|
||||||
|
),
|
||||||
|
_ => Bytes::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut response = Response::new(body);
|
||||||
|
*response.status_mut() = StatusCode::OK;
|
||||||
|
Ok::<_, BoxError>(response)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let clock = MockClock::default();
|
||||||
|
let service = BoxCloneSyncService::new(tower::service_fn(handler));
|
||||||
|
let provider = UpstreamOAuthProvider {
|
||||||
|
id: Ulid::nil(),
|
||||||
|
issuer: "https://valid.example.com/".to_owned(),
|
||||||
|
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||||
|
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
|
||||||
|
jwks_uri_override: None,
|
||||||
|
authorization_endpoint_override: None,
|
||||||
|
token_endpoint_override: None,
|
||||||
|
scope: Scope::from_iter([OPENID]),
|
||||||
|
client_id: "client_id".to_owned(),
|
||||||
|
encrypted_client_secret: None,
|
||||||
|
token_endpoint_signing_alg: None,
|
||||||
|
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
|
||||||
|
created_at: clock.now(),
|
||||||
|
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Without any override, it should just use discovery
|
||||||
|
{
|
||||||
|
let cache = MetadataCache::new();
|
||||||
|
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||||
|
lazy_metadata.maybe_discover().await.unwrap();
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||||
|
assert_eq!(
|
||||||
|
lazy_metadata
|
||||||
|
.authorization_endpoint()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.as_str(),
|
||||||
|
"https://valid.example.com/authorize"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test overriding endpoints
|
||||||
|
{
|
||||||
|
let provider = UpstreamOAuthProvider {
|
||||||
|
jwks_uri_override: Some("https://valid.example.com/jwks_override".parse().unwrap()),
|
||||||
|
authorization_endpoint_override: Some(
|
||||||
|
"https://valid.example.com/authorize_override"
|
||||||
|
.parse()
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
|
token_endpoint_override: Some(
|
||||||
|
"https://valid.example.com/token_override".parse().unwrap(),
|
||||||
|
),
|
||||||
|
..provider.clone()
|
||||||
|
};
|
||||||
|
let cache = MetadataCache::new();
|
||||||
|
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
|
||||||
|
assert_eq!(
|
||||||
|
lazy_metadata.jwks_uri().await.unwrap().as_str(),
|
||||||
|
"https://valid.example.com/jwks_override"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lazy_metadata
|
||||||
|
.authorization_endpoint()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.as_str(),
|
||||||
|
"https://valid.example.com/authorize_override"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lazy_metadata.token_endpoint().await.unwrap().as_str(),
|
||||||
|
"https://valid.example.com/token_override"
|
||||||
|
);
|
||||||
|
// This shouldn't trigger a new fetch as the endpoint is overriden
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insecure providers don't work with secure discovery
|
||||||
|
{
|
||||||
|
let provider = UpstreamOAuthProvider {
|
||||||
|
issuer: "http://insecure.example.com/".to_owned(),
|
||||||
|
..provider.clone()
|
||||||
|
};
|
||||||
|
let cache = MetadataCache::new();
|
||||||
|
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
|
||||||
|
lazy_metadata.authorization_endpoint().await.unwrap_err();
|
||||||
|
// This triggered a fetch, even though it failed
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insecure providers work with insecure discovery
|
||||||
|
{
|
||||||
|
let provider = UpstreamOAuthProvider {
|
||||||
|
issuer: "http://insecure.example.com/".to_owned(),
|
||||||
|
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
|
||||||
|
..provider.clone()
|
||||||
|
};
|
||||||
|
let cache = MetadataCache::new();
|
||||||
|
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
|
||||||
|
assert_eq!(
|
||||||
|
lazy_metadata
|
||||||
|
.authorization_endpoint()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.as_str(),
|
||||||
|
"http://insecure.example.com/authorize"
|
||||||
|
);
|
||||||
|
// This triggered a fetch
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getting endpoints when discovery is disabled only works for overriden ones
|
||||||
|
{
|
||||||
|
let provider = UpstreamOAuthProvider {
|
||||||
|
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
|
||||||
|
authorization_endpoint_override: Some(
|
||||||
|
Url::parse("https://valid.example.com/authorize_override").unwrap(),
|
||||||
|
),
|
||||||
|
token_endpoint_override: None,
|
||||||
|
..provider.clone()
|
||||||
|
};
|
||||||
|
let cache = MetadataCache::new();
|
||||||
|
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
|
||||||
|
// This should not fail, but also does nothing
|
||||||
|
assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
|
||||||
|
assert_eq!(
|
||||||
|
lazy_metadata
|
||||||
|
.authorization_endpoint()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.as_str(),
|
||||||
|
"https://valid.example.com/authorize_override"
|
||||||
|
);
|
||||||
|
assert!(matches!(
|
||||||
|
lazy_metadata.token_endpoint().await,
|
||||||
|
Err(DiscoveryError::Disabled),
|
||||||
|
));
|
||||||
|
// This did not trigger a fetch
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -37,7 +37,7 @@ use serde::Deserialize;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
use super::{client_credentials_for_provider, UpstreamSessionsCookie};
|
use super::{cache::LazyProviderInfos, client_credentials_for_provider, UpstreamSessionsCookie};
|
||||||
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
|
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -191,18 +191,17 @@ pub(crate) async fn get(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let http_service = http_client_factory.http_service("upstream_oauth2.callback");
|
let http_service = http_client_factory.http_service("upstream_oauth2.callback");
|
||||||
|
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service);
|
||||||
// Discover the provider
|
|
||||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
|
||||||
|
|
||||||
// Fetch the JWKS
|
// Fetch the JWKS
|
||||||
let jwks =
|
let jwks =
|
||||||
mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?;
|
mas_oidc_client::requests::jose::fetch_jwks(&http_service, lazy_metadata.jwks_uri().await?)
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Figure out the client credentials
|
// Figure out the client credentials
|
||||||
let client_credentials = client_credentials_for_provider(
|
let client_credentials = client_credentials_for_provider(
|
||||||
&provider,
|
&provider,
|
||||||
metadata.token_endpoint(),
|
lazy_metadata.token_endpoint().await?,
|
||||||
&keystore,
|
&keystore,
|
||||||
&encrypter,
|
&encrypter,
|
||||||
)?;
|
)?;
|
||||||
@ -229,7 +228,7 @@ pub(crate) async fn get(
|
|||||||
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
|
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
|
||||||
&http_service,
|
&http_service,
|
||||||
client_credentials,
|
client_credentials,
|
||||||
metadata.token_endpoint(),
|
lazy_metadata.token_endpoint().await?,
|
||||||
code,
|
code,
|
||||||
validation_data,
|
validation_data,
|
||||||
Some(id_token_verification_data),
|
Some(id_token_verification_data),
|
||||||
|
@ -93,6 +93,10 @@ pub enum DiscoveryError {
|
|||||||
/// An error occurred sending the request.
|
/// An error occurred sending the request.
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Service(BoxError),
|
Service(BoxError),
|
||||||
|
|
||||||
|
/// Discovery is disabled for this provider.
|
||||||
|
#[error("Discovery is disabled for this provider")]
|
||||||
|
Disabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> From<json_response::Error<S>> for DiscoveryError
|
impl<S> From<json_response::Error<S>> for DiscoveryError
|
||||||
|
@ -14,7 +14,10 @@
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
|
use mas_data_model::{
|
||||||
|
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
|
||||||
|
UpstreamOAuthProviderPkceMode,
|
||||||
|
};
|
||||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
|
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
|
||||||
@ -99,6 +102,13 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
|||||||
token_endpoint_signing_alg,
|
token_endpoint_signing_alg,
|
||||||
created_at: value.created_at,
|
created_at: value.created_at,
|
||||||
claims_imports: value.claims_imports.0,
|
claims_imports: value.claims_imports.0,
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
authorization_endpoint_override: None,
|
||||||
|
token_endpoint_override: None,
|
||||||
|
jwks_uri_override: None,
|
||||||
|
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
|
||||||
|
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -213,6 +223,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
token_endpoint_auth_method,
|
token_endpoint_auth_method,
|
||||||
created_at,
|
created_at,
|
||||||
claims_imports,
|
claims_imports,
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
authorization_endpoint_override: None,
|
||||||
|
token_endpoint_override: None,
|
||||||
|
jwks_uri_override: None,
|
||||||
|
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
|
||||||
|
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,6 +374,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
token_endpoint_auth_method,
|
token_endpoint_auth_method,
|
||||||
created_at,
|
created_at,
|
||||||
claims_imports,
|
claims_imports,
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
authorization_endpoint_override: None,
|
||||||
|
token_endpoint_override: None,
|
||||||
|
jwks_uri_override: None,
|
||||||
|
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
|
||||||
|
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user