You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
handlers: add tests for client registration
This commit is contained in:
@ -37,13 +37,13 @@ pub(crate) enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("invalid redirect uri")]
|
||||
InvalidRedirectUri,
|
||||
#[error(transparent)]
|
||||
JsonExtract(#[from] axum::extract::rejection::JsonRejection),
|
||||
|
||||
#[error("invalid client metadata")]
|
||||
InvalidClientMetadata,
|
||||
InvalidClientMetadata(#[from] ClientMetadataVerificationError),
|
||||
|
||||
#[error("denied by the policy")]
|
||||
#[error("denied by the policy: {0:?}")]
|
||||
PolicyDenied(Vec<Violation>),
|
||||
}
|
||||
|
||||
@ -53,18 +53,6 @@ impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(mas_keystore::aead::Error);
|
||||
|
||||
impl From<ClientMetadataVerificationError> for RouteError {
|
||||
fn from(e: ClientMetadataVerificationError) -> Self {
|
||||
match e {
|
||||
ClientMetadataVerificationError::MissingRedirectUris
|
||||
| ClientMetadataVerificationError::RedirectUriWithFragment(_) => {
|
||||
Self::InvalidRedirectUri
|
||||
}
|
||||
_ => Self::InvalidClientMetadata,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
sentry::capture_error(&self);
|
||||
@ -74,17 +62,59 @@ impl IntoResponse for RouteError {
|
||||
Json(ClientError::from(ClientErrorCode::ServerError)),
|
||||
)
|
||||
.into_response(),
|
||||
Self::InvalidRedirectUri => (
|
||||
|
||||
// This error happens if we managed to parse the incomiong JSON but it can't be
|
||||
// deserialized to the expected type. In this case we return an
|
||||
// `invalid_client_metadata` error with the details of the error.
|
||||
Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidClientMetadata)
|
||||
.with_description(e.to_string()),
|
||||
),
|
||||
)
|
||||
.into_response(),
|
||||
|
||||
// For all other JSON errors we return a `invalid_request` error, since this is
|
||||
// probably due to a malformed request.
|
||||
Self::JsonExtract(_) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
|
||||
)
|
||||
.into_response(),
|
||||
|
||||
// This error comes from the `ClientMetadata::validate` method. We return an
|
||||
// `invalid_redirect_uri` error if the error is related to the redirect URIs, else we
|
||||
// return an `invalid_client_metadata` error.
|
||||
Self::InvalidClientMetadata(
|
||||
ClientMetadataVerificationError::MissingRedirectUris
|
||||
| ClientMetadataVerificationError::RedirectUriWithFragment(_),
|
||||
) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
|
||||
)
|
||||
.into_response(),
|
||||
Self::InvalidClientMetadata => (
|
||||
|
||||
Self::InvalidClientMetadata(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(ClientErrorCode::InvalidClientMetadata)),
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidClientMetadata)
|
||||
.with_description(e.to_string()),
|
||||
),
|
||||
)
|
||||
.into_response(),
|
||||
|
||||
// For policy violations, we return an `invalid_client_metadata` error with the details
|
||||
// of the violations in most cases. If a violation includes `redirect_uri` in the
|
||||
// message, we return an `invalid_redirect_uri` error instead.
|
||||
Self::PolicyDenied(violations) => {
|
||||
// TODO: detect them better
|
||||
let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
|
||||
ClientErrorCode::InvalidRedirectUri
|
||||
} else {
|
||||
ClientErrorCode::InvalidClientMetadata
|
||||
};
|
||||
|
||||
let collected = &violations
|
||||
.iter()
|
||||
.map(|v| v.msg.clone())
|
||||
@ -92,11 +122,8 @@ impl IntoResponse for RouteError {
|
||||
let joined = collected.join("; ");
|
||||
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::InvalidClientMetadata)
|
||||
.with_description(joined),
|
||||
),
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ClientError::from(code).with_description(joined)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
@ -111,8 +138,11 @@ pub(crate) async fn post(
|
||||
mut repo: BoxRepository,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
State(encrypter): State<Encrypter>,
|
||||
Json(body): Json<ClientMetadata>,
|
||||
body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
// Propagate any JSON extraction error
|
||||
let Json(body) = body?;
|
||||
|
||||
info!(?body, "Client registration");
|
||||
|
||||
// Validate the body
|
||||
@ -179,3 +209,112 @@ pub(crate) async fn post(
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hyper::{Request, StatusCode};
|
||||
use mas_router::SimpleRoute;
|
||||
use oauth2_types::{
|
||||
errors::{ClientError, ClientErrorCode},
|
||||
registration::ClientRegistrationResponse,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_registration_error(pool: PgPool) {
|
||||
init_tracing();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
|
||||
// Body is not a JSON
|
||||
let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
|
||||
.body("this is not a json".to_owned())
|
||||
.unwrap();
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::BAD_REQUEST);
|
||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||
assert_eq!(response.error, ClientErrorCode::InvalidRequest);
|
||||
|
||||
// Invalid client metadata
|
||||
let request =
|
||||
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||
"client_uri": "this is not a uri",
|
||||
}));
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::BAD_REQUEST);
|
||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
||||
|
||||
// Invalid redirect URI
|
||||
let request =
|
||||
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||
"application_type": "web",
|
||||
"contacts": ["hello@example.com"],
|
||||
"client_uri": "https://example.com/",
|
||||
"redirect_uris": ["http://this-is-insecure.com/"],
|
||||
}));
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::BAD_REQUEST);
|
||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||
assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
|
||||
|
||||
// Incoherent response types
|
||||
let request =
|
||||
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||
"contacts": ["hello@example.com"],
|
||||
"client_uri": "https://example.com/",
|
||||
"redirect_uris": ["https://example.com/"],
|
||||
"response_types": ["id_token"],
|
||||
"grant_types": ["authorization_code"],
|
||||
}));
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::BAD_REQUEST);
|
||||
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_registration(pool: PgPool) {
|
||||
init_tracing();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
|
||||
// A successful registration with no authentication should not return a client
|
||||
// secret
|
||||
let request =
|
||||
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||
"contacts": ["hello@example.com"],
|
||||
"client_uri": "https://example.com/",
|
||||
"redirect_uris": ["https://example.com/"],
|
||||
"response_types": ["code"],
|
||||
"grant_types": ["authorization_code"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
}));
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::CREATED);
|
||||
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
|
||||
assert!(response.client_secret.is_none());
|
||||
|
||||
// A successful registration with client_secret based authentication should
|
||||
// return a client secret
|
||||
let request =
|
||||
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||
"contacts": ["hello@example.com"],
|
||||
"client_uri": "https://example.com/",
|
||||
"redirect_uris": ["https://example.com/"],
|
||||
"response_types": ["code"],
|
||||
"grant_types": ["authorization_code"],
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}));
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::CREATED);
|
||||
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
|
||||
assert!(response.client_secret.is_some());
|
||||
}
|
||||
}
|
||||
|
@ -213,15 +213,11 @@ mod tests {
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState};
|
||||
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_revoke_access_token(pool: PgPool) {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_test_writer()
|
||||
.init();
|
||||
|
||||
init_tracing();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
|
||||
let request =
|
||||
|
@ -43,6 +43,13 @@ use crate::{
|
||||
MatrixHomeserver,
|
||||
};
|
||||
|
||||
pub(crate) fn init_tracing() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_test_writer()
|
||||
.try_init();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct TestState {
|
||||
pub pool: PgPool,
|
||||
|
Reference in New Issue
Block a user