You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
handlers: add tests for client registration
This commit is contained in:
@ -37,13 +37,13 @@ pub(crate) enum RouteError {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
#[error("invalid redirect uri")]
|
#[error(transparent)]
|
||||||
InvalidRedirectUri,
|
JsonExtract(#[from] axum::extract::rejection::JsonRejection),
|
||||||
|
|
||||||
#[error("invalid client metadata")]
|
#[error("invalid client metadata")]
|
||||||
InvalidClientMetadata,
|
InvalidClientMetadata(#[from] ClientMetadataVerificationError),
|
||||||
|
|
||||||
#[error("denied by the policy")]
|
#[error("denied by the policy: {0:?}")]
|
||||||
PolicyDenied(Vec<Violation>),
|
PolicyDenied(Vec<Violation>),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,18 +53,6 @@ impl_from_error_for_route!(mas_policy::InstanciateError);
|
|||||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||||
impl_from_error_for_route!(mas_keystore::aead::Error);
|
impl_from_error_for_route!(mas_keystore::aead::Error);
|
||||||
|
|
||||||
impl From<ClientMetadataVerificationError> for RouteError {
|
|
||||||
fn from(e: ClientMetadataVerificationError) -> Self {
|
|
||||||
match e {
|
|
||||||
ClientMetadataVerificationError::MissingRedirectUris
|
|
||||||
| ClientMetadataVerificationError::RedirectUriWithFragment(_) => {
|
|
||||||
Self::InvalidRedirectUri
|
|
||||||
}
|
|
||||||
_ => Self::InvalidClientMetadata,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoResponse for RouteError {
|
impl IntoResponse for RouteError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
sentry::capture_error(&self);
|
sentry::capture_error(&self);
|
||||||
@ -74,17 +62,59 @@ impl IntoResponse for RouteError {
|
|||||||
Json(ClientError::from(ClientErrorCode::ServerError)),
|
Json(ClientError::from(ClientErrorCode::ServerError)),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
Self::InvalidRedirectUri => (
|
|
||||||
|
// This error happens if we managed to parse the incomiong JSON but it can't be
|
||||||
|
// deserialized to the expected type. In this case we return an
|
||||||
|
// `invalid_client_metadata` error with the details of the error.
|
||||||
|
Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(
|
||||||
|
ClientError::from(ClientErrorCode::InvalidClientMetadata)
|
||||||
|
.with_description(e.to_string()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
|
||||||
|
// For all other JSON errors we return a `invalid_request` error, since this is
|
||||||
|
// probably due to a malformed request.
|
||||||
|
Self::JsonExtract(_) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
|
||||||
|
// This error comes from the `ClientMetadata::validate` method. We return an
|
||||||
|
// `invalid_redirect_uri` error if the error is related to the redirect URIs, else we
|
||||||
|
// return an `invalid_client_metadata` error.
|
||||||
|
Self::InvalidClientMetadata(
|
||||||
|
ClientMetadataVerificationError::MissingRedirectUris
|
||||||
|
| ClientMetadataVerificationError::RedirectUriWithFragment(_),
|
||||||
|
) => (
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
|
Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
Self::InvalidClientMetadata => (
|
|
||||||
|
Self::InvalidClientMetadata(e) => (
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(ClientError::from(ClientErrorCode::InvalidClientMetadata)),
|
Json(
|
||||||
|
ClientError::from(ClientErrorCode::InvalidClientMetadata)
|
||||||
|
.with_description(e.to_string()),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
|
|
||||||
|
// For policy violations, we return an `invalid_client_metadata` error with the details
|
||||||
|
// of the violations in most cases. If a violation includes `redirect_uri` in the
|
||||||
|
// message, we return an `invalid_redirect_uri` error instead.
|
||||||
Self::PolicyDenied(violations) => {
|
Self::PolicyDenied(violations) => {
|
||||||
|
// TODO: detect them better
|
||||||
|
let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
|
||||||
|
ClientErrorCode::InvalidRedirectUri
|
||||||
|
} else {
|
||||||
|
ClientErrorCode::InvalidClientMetadata
|
||||||
|
};
|
||||||
|
|
||||||
let collected = &violations
|
let collected = &violations
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| v.msg.clone())
|
.map(|v| v.msg.clone())
|
||||||
@ -92,11 +122,8 @@ impl IntoResponse for RouteError {
|
|||||||
let joined = collected.join("; ");
|
let joined = collected.join("; ");
|
||||||
|
|
||||||
(
|
(
|
||||||
StatusCode::UNAUTHORIZED,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(
|
Json(ClientError::from(code).with_description(joined)),
|
||||||
ClientError::from(ClientErrorCode::InvalidClientMetadata)
|
|
||||||
.with_description(joined),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
@ -111,8 +138,11 @@ pub(crate) async fn post(
|
|||||||
mut repo: BoxRepository,
|
mut repo: BoxRepository,
|
||||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||||
State(encrypter): State<Encrypter>,
|
State(encrypter): State<Encrypter>,
|
||||||
Json(body): Json<ClientMetadata>,
|
body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
|
// Propagate any JSON extraction error
|
||||||
|
let Json(body) = body?;
|
||||||
|
|
||||||
info!(?body, "Client registration");
|
info!(?body, "Client registration");
|
||||||
|
|
||||||
// Validate the body
|
// Validate the body
|
||||||
@ -179,3 +209,112 @@ pub(crate) async fn post(
|
|||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(response)))
|
Ok((StatusCode::CREATED, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use hyper::{Request, StatusCode};
|
||||||
|
use mas_router::SimpleRoute;
|
||||||
|
use oauth2_types::{
|
||||||
|
errors::{ClientError, ClientErrorCode},
|
||||||
|
registration::ClientRegistrationResponse,
|
||||||
|
};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
|
||||||
|
|
||||||
|
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||||
|
async fn test_registration_error(pool: PgPool) {
|
||||||
|
init_tracing();
|
||||||
|
let state = TestState::from_pool(pool).await.unwrap();
|
||||||
|
|
||||||
|
// Body is not a JSON
|
||||||
|
let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
|
||||||
|
.body("this is not a json".to_owned())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
|
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||||
|
assert_eq!(response.error, ClientErrorCode::InvalidRequest);
|
||||||
|
|
||||||
|
// Invalid client metadata
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||||
|
"client_uri": "this is not a uri",
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
|
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||||
|
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
||||||
|
|
||||||
|
// Invalid redirect URI
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||||
|
"application_type": "web",
|
||||||
|
"contacts": ["hello@example.com"],
|
||||||
|
"client_uri": "https://example.com/",
|
||||||
|
"redirect_uris": ["http://this-is-insecure.com/"],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
|
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||||
|
assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
|
||||||
|
|
||||||
|
// Incoherent response types
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||||
|
"contacts": ["hello@example.com"],
|
||||||
|
"client_uri": "https://example.com/",
|
||||||
|
"redirect_uris": ["https://example.com/"],
|
||||||
|
"response_types": ["id_token"],
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
|
let response: ClientError = serde_json::from_str(response.body()).unwrap();
|
||||||
|
assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||||
|
async fn test_registration(pool: PgPool) {
|
||||||
|
init_tracing();
|
||||||
|
let state = TestState::from_pool(pool).await.unwrap();
|
||||||
|
|
||||||
|
// A successful registration with no authentication should not return a client
|
||||||
|
// secret
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||||
|
"contacts": ["hello@example.com"],
|
||||||
|
"client_uri": "https://example.com/",
|
||||||
|
"redirect_uris": ["https://example.com/"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::CREATED);
|
||||||
|
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
|
||||||
|
assert!(response.client_secret.is_none());
|
||||||
|
|
||||||
|
// A successful registration with client_secret based authentication should
|
||||||
|
// return a client secret
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||||
|
"contacts": ["hello@example.com"],
|
||||||
|
"client_uri": "https://example.com/",
|
||||||
|
"redirect_uris": ["https://example.com/"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"token_endpoint_auth_method": "client_secret_basic",
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::CREATED);
|
||||||
|
let response: ClientRegistrationResponse = serde_json::from_str(response.body()).unwrap();
|
||||||
|
assert!(response.client_secret.is_some());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -213,15 +213,11 @@ mod tests {
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState};
|
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
|
||||||
|
|
||||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||||
async fn test_revoke_access_token(pool: PgPool) {
|
async fn test_revoke_access_token(pool: PgPool) {
|
||||||
tracing_subscriber::fmt()
|
init_tracing();
|
||||||
.with_max_level(tracing::Level::INFO)
|
|
||||||
.with_test_writer()
|
|
||||||
.init();
|
|
||||||
|
|
||||||
let state = TestState::from_pool(pool).await.unwrap();
|
let state = TestState::from_pool(pool).await.unwrap();
|
||||||
|
|
||||||
let request =
|
let request =
|
||||||
|
@ -43,6 +43,13 @@ use crate::{
|
|||||||
MatrixHomeserver,
|
MatrixHomeserver,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub(crate) fn init_tracing() {
|
||||||
|
let _ = tracing_subscriber::fmt()
|
||||||
|
.with_max_level(tracing::Level::INFO)
|
||||||
|
.with_test_writer()
|
||||||
|
.try_init();
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct TestState {
|
pub(crate) struct TestState {
|
||||||
pub pool: PgPool,
|
pub pool: PgPool,
|
||||||
|
Reference in New Issue
Block a user