From 1e9ce8d6d65e714cbd681bb7f66304a3819d033e Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 22 Feb 2023 14:11:00 +0100 Subject: [PATCH] handlers: add tests for client registration --- crates/handlers/src/oauth2/registration.rs | 189 ++++++++++++++++++--- crates/handlers/src/oauth2/revoke.rs | 8 +- crates/handlers/src/test_utils.rs | 7 + 3 files changed, 173 insertions(+), 31 deletions(-) diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index b8505b2e..e66f07f2 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -37,13 +37,13 @@ pub(crate) enum RouteError { #[error(transparent)] Internal(Box), - #[error("invalid redirect uri")] - InvalidRedirectUri, + #[error(transparent)] + JsonExtract(#[from] axum::extract::rejection::JsonRejection), #[error("invalid client metadata")] - InvalidClientMetadata, + InvalidClientMetadata(#[from] ClientMetadataVerificationError), - #[error("denied by the policy")] + #[error("denied by the policy: {0:?}")] PolicyDenied(Vec), } @@ -53,18 +53,6 @@ impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_keystore::aead::Error); -impl From for RouteError { - fn from(e: ClientMetadataVerificationError) -> Self { - match e { - ClientMetadataVerificationError::MissingRedirectUris - | ClientMetadataVerificationError::RedirectUriWithFragment(_) => { - Self::InvalidRedirectUri - } - _ => Self::InvalidClientMetadata, - } - } -} - impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { sentry::capture_error(&self); @@ -74,17 +62,59 @@ impl IntoResponse for RouteError { Json(ClientError::from(ClientErrorCode::ServerError)), ) .into_response(), - Self::InvalidRedirectUri => ( + + // This error happens if we managed to parse the incomiong JSON but it can't be + // deserialized to the expected type. In this case we return an + // `invalid_client_metadata` error with the details of the error. + Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => ( + StatusCode::BAD_REQUEST, + Json( + ClientError::from(ClientErrorCode::InvalidClientMetadata) + .with_description(e.to_string()), + ), + ) + .into_response(), + + // For all other JSON errors we return a `invalid_request` error, since this is + // probably due to a malformed request. + Self::JsonExtract(_) => ( + StatusCode::BAD_REQUEST, + Json(ClientError::from(ClientErrorCode::InvalidRequest)), + ) + .into_response(), + + // This error comes from the `ClientMetadata::validate` method. We return an + // `invalid_redirect_uri` error if the error is related to the redirect URIs, else we + // return an `invalid_client_metadata` error. + Self::InvalidClientMetadata( + ClientMetadataVerificationError::MissingRedirectUris + | ClientMetadataVerificationError::RedirectUriWithFragment(_), + ) => ( StatusCode::BAD_REQUEST, Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)), ) .into_response(), - Self::InvalidClientMetadata => ( + + Self::InvalidClientMetadata(e) => ( StatusCode::BAD_REQUEST, - Json(ClientError::from(ClientErrorCode::InvalidClientMetadata)), + Json( + ClientError::from(ClientErrorCode::InvalidClientMetadata) + .with_description(e.to_string()), + ), ) .into_response(), + + // For policy violations, we return an `invalid_client_metadata` error with the details + // of the violations in most cases. If a violation includes `redirect_uri` in the + // message, we return an `invalid_redirect_uri` error instead. Self::PolicyDenied(violations) => { + // TODO: detect them better + let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) { + ClientErrorCode::InvalidRedirectUri + } else { + ClientErrorCode::InvalidClientMetadata + }; + let collected = &violations .iter() .map(|v| v.msg.clone()) @@ -92,11 +122,8 @@ impl IntoResponse for RouteError { let joined = collected.join("; "); ( - StatusCode::UNAUTHORIZED, - Json( - ClientError::from(ClientErrorCode::InvalidClientMetadata) - .with_description(joined), - ), + StatusCode::BAD_REQUEST, + Json(ClientError::from(code).with_description(joined)), ) .into_response() } @@ -111,8 +138,11 @@ pub(crate) async fn post( mut repo: BoxRepository, State(policy_factory): State>, State(encrypter): State, - Json(body): Json, + body: Result, axum::extract::rejection::JsonRejection>, ) -> Result { + // Propagate any JSON extraction error + let Json(body) = body?; + info!(?body, "Client registration"); // Validate the body @@ -179,3 +209,112 @@ pub(crate) async fn post( Ok((StatusCode::CREATED, Json(response))) } + +#[cfg(test)] +mod tests { + use hyper::{Request, StatusCode}; + use mas_router::SimpleRoute; + use oauth2_types::{ + errors::{ClientError, ClientErrorCode}, + registration::ClientRegistrationResponse, + }; + use sqlx::PgPool; + + use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_registration_error(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Body is not a JSON + let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH) + .body("this is not a json".to_owned()) + .unwrap(); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let response: ClientError = serde_json::from_str(response.body()).unwrap(); + assert_eq!(response.error, ClientErrorCode::InvalidRequest); + + // Invalid client metadata + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "client_uri": "this is not a uri", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let response: ClientError = serde_json::from_str(response.body()).unwrap(); + assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata); + + // Invalid redirect URI + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "application_type": "web", + "contacts": ["hello@example.com"], + "client_uri": "https://example.com/", + "redirect_uris": ["http://this-is-insecure.com/"], + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let response: ClientError = serde_json::from_str(response.body()).unwrap(); + assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri); + + // Incoherent response types + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "contacts": ["hello@example.com"], + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/"], + "response_types": ["id_token"], + "grant_types": ["authorization_code"], + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let response: ClientError = serde_json::from_str(response.body()).unwrap(); + assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_registration(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + // A successful registration with no authentication should not return a client + // secret + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "contacts": ["hello@example.com"], + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/"], + "response_types": ["code"], + "grant_types": ["authorization_code"], + "token_endpoint_auth_method": "none", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap(); + assert!(response.client_secret.is_none()); + + // A successful registration with client_secret based authentication should + // return a client secret + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "contacts": ["hello@example.com"], + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/"], + "response_types": ["code"], + "grant_types": ["authorization_code"], + "token_endpoint_auth_method": "client_secret_basic", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap(); + assert!(response.client_secret.is_some()); + } +} diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index 3758dc1f..dd945473 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -213,15 +213,11 @@ mod tests { use sqlx::PgPool; use super::*; - use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState}; + use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}; #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_revoke_access_token(pool: PgPool) { - tracing_subscriber::fmt() - .with_max_level(tracing::Level::INFO) - .with_test_writer() - .init(); - + init_tracing(); let state = TestState::from_pool(pool).await.unwrap(); let request = diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 3bfbfe7b..5736ab6c 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -43,6 +43,13 @@ use crate::{ MatrixHomeserver, }; +pub(crate) fn init_tracing() { + let _ = tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_test_writer() + .try_init(); +} + #[derive(Clone)] pub(crate) struct TestState { pub pool: PgPool,