You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
handlers: add a test for OIDC discovery
This commit is contained in:
@ -153,3 +153,27 @@ pub(crate) async fn get(
|
|||||||
|
|
||||||
Json(metadata)
|
Json(metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use hyper::{Request, StatusCode};
|
||||||
|
use oauth2_types::oidc::ProviderMetadata;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
|
||||||
|
|
||||||
|
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||||
|
async fn test_valid_discovery_metadata(pool: PgPool) {
|
||||||
|
init_tracing();
|
||||||
|
let state = TestState::from_pool(pool).await.unwrap();
|
||||||
|
|
||||||
|
let request = Request::get("/.well-known/openid-configuration").empty();
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
|
let metadata: ProviderMetadata = response.json();
|
||||||
|
metadata
|
||||||
|
.validate(state.url_builder.oidc_issuer().as_str())
|
||||||
|
.expect("Invalid metadata");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -234,7 +234,7 @@ mod tests {
|
|||||||
|
|
||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::BAD_REQUEST);
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
let response: ClientError = response.json();
|
||||||
assert_eq!(response.error, ClientErrorCode::InvalidRequest);
|
assert_eq!(response.error, ClientErrorCode::InvalidRequest);
|
||||||
|
|
||||||
// Invalid client metadata
|
// Invalid client metadata
|
||||||
@ -245,7 +245,7 @@ mod tests {
|
|||||||
|
|
||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::BAD_REQUEST);
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
let response: ClientError = response.json();
|
||||||
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
||||||
|
|
||||||
// Invalid redirect URI
|
// Invalid redirect URI
|
||||||
@ -259,7 +259,7 @@ mod tests {
|
|||||||
|
|
||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::BAD_REQUEST);
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
let response: ClientError = response.json();
|
||||||
assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
|
assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
|
||||||
|
|
||||||
// Incoherent response types
|
// Incoherent response types
|
||||||
@ -274,7 +274,7 @@ mod tests {
|
|||||||
|
|
||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::BAD_REQUEST);
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
let response: ClientError = response.json();
|
||||||
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,7 +297,7 @@ mod tests {
|
|||||||
|
|
||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::CREATED);
|
response.assert_status(StatusCode::CREATED);
|
||||||
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
|
let response: ClientRegistrationResponse = response.json();
|
||||||
assert!(response.client_secret.is_none());
|
assert!(response.client_secret.is_none());
|
||||||
|
|
||||||
// A successful registration with client_secret based authentication should
|
// A successful registration with client_secret based authentication should
|
||||||
@ -314,7 +314,7 @@ mod tests {
|
|||||||
|
|
||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::CREATED);
|
response.assert_status(StatusCode::CREATED);
|
||||||
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
|
let response: ClientRegistrationResponse = response.json();
|
||||||
assert!(response.client_secret.is_some());
|
assert!(response.client_secret.is_some());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -233,8 +233,7 @@ mod tests {
|
|||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::CREATED);
|
response.assert_status(StatusCode::CREATED);
|
||||||
|
|
||||||
let client_registration: ClientRegistrationResponse =
|
let client_registration: ClientRegistrationResponse = response.json();
|
||||||
serde_json::from_str(response.body()).unwrap();
|
|
||||||
|
|
||||||
let client_id = client_registration.client_id;
|
let client_id = client_registration.client_id;
|
||||||
let client_secret = client_registration.client_secret.unwrap();
|
let client_secret = client_registration.client_secret.unwrap();
|
||||||
@ -313,7 +312,7 @@ mod tests {
|
|||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::OK);
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap();
|
let token: AccessTokenResponse = response.json();
|
||||||
|
|
||||||
// Check that the token is valid
|
// Check that the token is valid
|
||||||
assert!(state.is_access_token_valid(&token.access_token).await);
|
assert!(state.is_access_token_valid(&token.access_token).await);
|
||||||
@ -395,7 +394,7 @@ mod tests {
|
|||||||
let response = state.request(request).await;
|
let response = state.request(request).await;
|
||||||
response.assert_status(StatusCode::OK);
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap();
|
let token: AccessTokenResponse = response.json();
|
||||||
|
|
||||||
// Use the refresh token to get a new access token.
|
// Use the refresh token to get a new access token.
|
||||||
let request =
|
let request =
|
||||||
@ -410,7 +409,7 @@ mod tests {
|
|||||||
response.assert_status(StatusCode::OK);
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
let old_token = token;
|
let old_token = token;
|
||||||
let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap();
|
let token: AccessTokenResponse = response.json();
|
||||||
assert!(state.is_access_token_valid(&token.access_token).await);
|
assert!(state.is_access_token_valid(&token.access_token).await);
|
||||||
assert!(!state.is_access_token_valid(&old_token.access_token).await);
|
assert!(!state.is_access_token_valid(&old_token.access_token).await);
|
||||||
|
|
||||||
|
@ -19,8 +19,8 @@ use axum::{
|
|||||||
body::HttpBody,
|
body::HttpBody,
|
||||||
extract::{FromRef, FromRequestParts},
|
extract::{FromRef, FromRequestParts},
|
||||||
};
|
};
|
||||||
use headers::{Authorization, ContentType, HeaderMapExt};
|
use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue};
|
||||||
use hyper::{Request, Response, StatusCode};
|
use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode};
|
||||||
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||||
use mas_email::{MailTransport, Mailer};
|
use mas_email::{MailTransport, Mailer};
|
||||||
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
|
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
|
||||||
@ -31,7 +31,7 @@ use mas_storage_pg::PgRepository;
|
|||||||
use mas_templates::Templates;
|
use mas_templates::Templates;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use rand_chacha::ChaChaRng;
|
use rand_chacha::ChaChaRng;
|
||||||
use serde::Serialize;
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tower::{Service, ServiceExt};
|
use tower::{Service, ServiceExt};
|
||||||
@ -366,6 +366,22 @@ pub(crate) trait ResponseExt {
|
|||||||
///
|
///
|
||||||
/// Panics if the response has a different status code.
|
/// Panics if the response has a different status code.
|
||||||
fn assert_status(&self, status: StatusCode);
|
fn assert_status(&self, status: StatusCode);
|
||||||
|
|
||||||
|
/// Asserts that the response has the given header value.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the response does not have the given header or if the header
|
||||||
|
/// value does not match.
|
||||||
|
fn assert_header_value(&self, header: HeaderName, value: &str);
|
||||||
|
|
||||||
|
/// Get the response body as JSON.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the response is missing the `Content-Type: application/json`,
|
||||||
|
/// or if the body is not valid JSON.
|
||||||
|
fn json<T: DeserializeOwned>(&self) -> T;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseExt for Response<String> {
|
impl ResponseExt for Response<String> {
|
||||||
@ -380,4 +396,26 @@ impl ResponseExt for Response<String> {
|
|||||||
self.body()
|
self.body()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[track_caller]
|
||||||
|
fn assert_header_value(&self, header: HeaderName, value: &str) {
|
||||||
|
let actual_value = self
|
||||||
|
.headers()
|
||||||
|
.get(&header)
|
||||||
|
.unwrap_or_else(|| panic!("Missing header {header}"));
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
actual_value,
|
||||||
|
value,
|
||||||
|
"Header mismatch: got {:?}, expected {:?}",
|
||||||
|
self.headers().get(header),
|
||||||
|
value
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[track_caller]
|
||||||
|
fn json<T: DeserializeOwned>(&self) -> T {
|
||||||
|
self.assert_header_value(CONTENT_TYPE, "application/json");
|
||||||
|
serde_json::from_str(self.body()).expect("JSON deserialization failed")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user