1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

handlers: add tests for client registration

This commit is contained in:
Quentin Gliech
2023-02-22 14:11:00 +01:00
parent 304ec10d1b
commit 1e9ce8d6d6
3 changed files with 173 additions and 31 deletions

View File

@ -37,13 +37,13 @@ pub(crate) enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync>),
#[error("invalid redirect uri")]
InvalidRedirectUri,
#[error(transparent)]
JsonExtract(#[from] axum::extract::rejection::JsonRejection),
#[error("invalid client metadata")]
InvalidClientMetadata,
InvalidClientMetadata(#[from] ClientMetadataVerificationError),
#[error("denied by the policy")]
#[error("denied by the policy: {0:?}")]
PolicyDenied(Vec<Violation>),
}
@ -53,18 +53,6 @@ impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_keystore::aead::Error);
impl From<ClientMetadataVerificationError> for RouteError {
fn from(e: ClientMetadataVerificationError) -> Self {
match e {
ClientMetadataVerificationError::MissingRedirectUris
| ClientMetadataVerificationError::RedirectUriWithFragment(_) => {
Self::InvalidRedirectUri
}
_ => Self::InvalidClientMetadata,
}
}
}
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
sentry::capture_error(&self);
@ -74,17 +62,59 @@ impl IntoResponse for RouteError {
Json(ClientError::from(ClientErrorCode::ServerError)),
)
.into_response(),
Self::InvalidRedirectUri => (
// This error happens if we managed to parse the incomiong JSON but it can't be
// deserialized to the expected type. In this case we return an
// `invalid_client_metadata` error with the details of the error.
Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
StatusCode::BAD_REQUEST,
Json(
ClientError::from(ClientErrorCode::InvalidClientMetadata)
.with_description(e.to_string()),
),
)
.into_response(),
// For all other JSON errors we return a `invalid_request` error, since this is
// probably due to a malformed request.
Self::JsonExtract(_) => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
)
.into_response(),
// This error comes from the `ClientMetadata::validate` method. We return an
// `invalid_redirect_uri` error if the error is related to the redirect URIs, else we
// return an `invalid_client_metadata` error.
Self::InvalidClientMetadata(
ClientMetadataVerificationError::MissingRedirectUris
| ClientMetadataVerificationError::RedirectUriWithFragment(_),
) => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
)
.into_response(),
Self::InvalidClientMetadata => (
Self::InvalidClientMetadata(e) => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidClientMetadata)),
Json(
ClientError::from(ClientErrorCode::InvalidClientMetadata)
.with_description(e.to_string()),
),
)
.into_response(),
// For policy violations, we return an `invalid_client_metadata` error with the details
// of the violations in most cases. If a violation includes `redirect_uri` in the
// message, we return an `invalid_redirect_uri` error instead.
Self::PolicyDenied(violations) => {
// TODO: detect them better
let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
ClientErrorCode::InvalidRedirectUri
} else {
ClientErrorCode::InvalidClientMetadata
};
let collected = &violations
.iter()
.map(|v| v.msg.clone())
@ -92,11 +122,8 @@ impl IntoResponse for RouteError {
let joined = collected.join("; ");
(
StatusCode::UNAUTHORIZED,
Json(
ClientError::from(ClientErrorCode::InvalidClientMetadata)
.with_description(joined),
),
StatusCode::BAD_REQUEST,
Json(ClientError::from(code).with_description(joined)),
)
.into_response()
}
@ -111,8 +138,11 @@ pub(crate) async fn post(
mut repo: BoxRepository,
State(policy_factory): State<Arc<PolicyFactory>>,
State(encrypter): State<Encrypter>,
Json(body): Json<ClientMetadata>,
body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
) -> Result<impl IntoResponse, RouteError> {
// Propagate any JSON extraction error
let Json(body) = body?;
info!(?body, "Client registration");
// Validate the body
@ -179,3 +209,112 @@ pub(crate) async fn post(
Ok((StatusCode::CREATED, Json(response)))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use mas_router::SimpleRoute;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::ClientRegistrationResponse,
};
use sqlx::PgPool;
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_registration_error(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
// Body is not a JSON
let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
.body("this is not a json".to_owned())
.unwrap();
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = serde_json::from_str(response.body()).unwrap();
assert_eq!(response.error, ClientErrorCode::InvalidRequest);
// Invalid client metadata
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "this is not a uri",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = serde_json::from_str(response.body()).unwrap();
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
// Invalid redirect URI
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"application_type": "web",
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["http://this-is-insecure.com/"],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = serde_json::from_str(response.body()).unwrap();
assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
// Incoherent response types
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["id_token"],
"grant_types": ["authorization_code"],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let response: ClientError = serde_json::from_str(response.body()).unwrap();
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_registration(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
// A successful registration with no authentication should not return a client
// secret
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "none",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
assert!(response.client_secret.is_none());
// A successful registration with client_secret based authentication should
// return a client secret
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"contacts": ["hello@example.com"],
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/"],
"response_types": ["code"],
"grant_types": ["authorization_code"],
"token_endpoint_auth_method": "client_secret_basic",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
assert!(response.client_secret.is_some());
}
}

View File

@ -213,15 +213,11 @@ mod tests {
use sqlx::PgPool;
use super::*;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState};
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_revoke_access_token(pool: PgPool) {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_test_writer()
.init();
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
let request =

View File

@ -43,6 +43,13 @@ use crate::{
MatrixHomeserver,
};
pub(crate) fn init_tracing() {
let _ = tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_test_writer()
.try_init();
}
#[derive(Clone)]
pub(crate) struct TestState {
pub pool: PgPool,