From 0e21f00d17d6cee843d2f146ad3b2b640262f433 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Fri, 8 Jul 2022 22:11:54 +0100 Subject: [PATCH] Return reason for invalid_client_metadata in HTTP response (#298) --- crates/handlers/src/oauth2/registration.rs | 52 ++++++++++++++++++---- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 18bcef0d..68fcff91 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use axum::{response::IntoResponse, Extension, Json}; use hyper::StatusCode; use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}; -use mas_policy::PolicyFactory; +use mas_policy::{PolicyFactory, Violation}; use mas_storage::oauth2::client::insert_client; use oauth2_types::{ errors::{INVALID_CLIENT_METADATA, INVALID_REDIRECT_URI, SERVER_ERROR}, @@ -44,7 +44,7 @@ pub(crate) enum RouteError { InvalidClientMetadata, #[error("denied by the policy")] - PolicyDenied, + PolicyDenied(Vec), } impl From for RouteError { @@ -53,17 +53,53 @@ impl From for RouteError { } } +// TODO: there is probably a better way to do achieve this. ClientError only +// works for static strings +#[derive(serde::Serialize)] +struct PolicyError { + error: String, + error_description: String, +} + +impl PolicyError { + #[must_use] + pub const fn new(error: String, error_description: String) -> Self { + Self { + error, + error_description, + } + } +} + impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { Self::Internal(_) | Self::Anyhow(_) => { - (StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR)) + (StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR)).into_response() + } + Self::InvalidRedirectUri => { + (StatusCode::BAD_REQUEST, Json(INVALID_REDIRECT_URI)).into_response() + } + Self::InvalidClientMetadata => { + (StatusCode::BAD_REQUEST, Json(INVALID_CLIENT_METADATA)).into_response() + } + Self::PolicyDenied(violations) => { + let collected = &violations + .iter() + .map(|v| v.msg.clone()) + .collect::>(); + let joined = collected.join("; "); + + ( + StatusCode::UNAUTHORIZED, + Json(PolicyError::new( + "invalid_client_metadata".to_string(), + joined, + )), + ) + .into_response() } - Self::InvalidRedirectUri => (StatusCode::BAD_REQUEST, Json(INVALID_REDIRECT_URI)), - Self::InvalidClientMetadata => (StatusCode::BAD_REQUEST, Json(INVALID_CLIENT_METADATA)), - Self::PolicyDenied => (StatusCode::UNAUTHORIZED, Json(INVALID_CLIENT_METADATA)), } - .into_response() } } @@ -121,7 +157,7 @@ pub(crate) async fn post( let mut policy = policy_factory.instantiate().await?; let res = policy.evaluate_client_registration(&body).await?; if !res.valid() { - return Err(RouteError::PolicyDenied); + return Err(RouteError::PolicyDenied(res.violations)); } // Grab a txn