diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index 35c74a41..9d8eff64 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -153,3 +153,27 @@ pub(crate) async fn get( Json(metadata) } + +#[cfg(test)] +mod tests { + use hyper::{Request, StatusCode}; + use oauth2_types::oidc::ProviderMetadata; + use sqlx::PgPool; + + use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_valid_discovery_metadata(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + let request = Request::get("/.well-known/openid-configuration").empty(); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + let metadata: ProviderMetadata = response.json(); + metadata + .validate(state.url_builder.oidc_issuer().as_str()) + .expect("Invalid metadata"); + } +} diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index e66f07f2..17373fc0 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -234,7 +234,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::BAD_REQUEST); - let response: ClientError = serde_json::from_str(response.body()).unwrap(); + let response: ClientError = response.json(); assert_eq!(response.error, ClientErrorCode::InvalidRequest); // Invalid client metadata @@ -245,7 +245,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::BAD_REQUEST); - let response: ClientError = serde_json::from_str(response.body()).unwrap(); + let response: ClientError = response.json(); assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata); // Invalid redirect URI @@ -259,7 +259,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::BAD_REQUEST); - let response: ClientError = serde_json::from_str(response.body()).unwrap(); + let response: ClientError = response.json(); assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri); // Incoherent response types @@ -274,7 +274,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::BAD_REQUEST); - let response: ClientError = serde_json::from_str(response.body()).unwrap(); + let response: ClientError = response.json(); assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata); } @@ -297,7 +297,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::CREATED); - let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap(); + let response: ClientRegistrationResponse = response.json(); assert!(response.client_secret.is_none()); // A successful registration with client_secret based authentication should @@ -314,7 +314,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::CREATED); - let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap(); + let response: ClientRegistrationResponse = response.json(); assert!(response.client_secret.is_some()); } } diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index dd945473..c80febc2 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -233,8 +233,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::CREATED); - let client_registration: ClientRegistrationResponse = - serde_json::from_str(response.body()).unwrap(); + let client_registration: ClientRegistrationResponse = response.json(); let client_id = client_registration.client_id; let client_secret = client_registration.client_secret.unwrap(); @@ -313,7 +312,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::OK); - let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap(); + let token: AccessTokenResponse = response.json(); // Check that the token is valid assert!(state.is_access_token_valid(&token.access_token).await); @@ -395,7 +394,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::OK); - let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap(); + let token: AccessTokenResponse = response.json(); // Use the refresh token to get a new access token. let request = @@ -410,7 +409,7 @@ mod tests { response.assert_status(StatusCode::OK); let old_token = token; - let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap(); + let token: AccessTokenResponse = response.json(); assert!(state.is_access_token_valid(&token.access_token).await); assert!(!state.is_access_token_valid(&old_token.access_token).await); diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 5736ab6c..2a2ef0b4 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -19,8 +19,8 @@ use axum::{ body::HttpBody, extract::{FromRef, FromRequestParts}, }; -use headers::{Authorization, ContentType, HeaderMapExt}; -use hyper::{Request, Response, StatusCode}; +use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue}; +use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::{MailTransport, Mailer}; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; @@ -31,7 +31,7 @@ use mas_storage_pg::PgRepository; use mas_templates::Templates; use rand::SeedableRng; use rand_chacha::ChaChaRng; -use serde::Serialize; +use serde::{de::DeserializeOwned, Serialize}; use sqlx::PgPool; use tokio::sync::Mutex; use tower::{Service, ServiceExt}; @@ -366,6 +366,22 @@ pub(crate) trait ResponseExt { /// /// Panics if the response has a different status code. fn assert_status(&self, status: StatusCode); + + /// Asserts that the response has the given header value. + /// + /// # Panics + /// + /// Panics if the response does not have the given header or if the header + /// value does not match. + fn assert_header_value(&self, header: HeaderName, value: &str); + + /// Get the response body as JSON. + /// + /// # Panics + /// + /// Panics if the response is missing the `Content-Type: application/json`, + /// or if the body is not valid JSON. + fn json(&self) -> T; } impl ResponseExt for Response { @@ -380,4 +396,26 @@ impl ResponseExt for Response { self.body() ); } + + #[track_caller] + fn assert_header_value(&self, header: HeaderName, value: &str) { + let actual_value = self + .headers() + .get(&header) + .unwrap_or_else(|| panic!("Missing header {header}")); + + assert_eq!( + actual_value, + value, + "Header mismatch: got {:?}, expected {:?}", + self.headers().get(header), + value + ); + } + + #[track_caller] + fn json(&self) -> T { + self.assert_header_value(CONTENT_TYPE, "application/json"); + serde_json::from_str(self.body()).expect("JSON deserialization failed") + } }