diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index fe9163f5..9c00b8f8 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -55,6 +55,19 @@ mod oauth2; mod upstream_oauth2; mod views; +/// Implement `From` for `RouteError`, for "internal server error" kind of +/// errors. +#[macro_export] +macro_rules! impl_from_error_for_route { + ($error:ty) => { + impl From<$error> for self::RouteError { + fn from(e: $error) -> Self { + Self::InternalError(Box::new(e)) + } + } + }; +} + pub use mas_axum_utils::http_client_factory::HttpClientFactory; pub use self::{app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema}; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index c6b94049..e0d58a05 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -19,26 +19,21 @@ use axum::{ use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use hyper::StatusCode; use mas_axum_utils::http_client_factory::HttpClientFactory; -use mas_http::ClientInitError; use mas_keystore::Encrypter; -use mas_oidc_client::{ - error::{AuthorizationError, DiscoveryError}, - requests::authorization_code::AuthorizationRequestData, -}; +use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{upstream_oauth2::lookup_provider, LookupResultExt}; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; +use crate::impl_from_error_for_route; + #[derive(Debug, Error)] pub(crate) enum RouteError { #[error("Provider not found")] ProviderNotFound, - #[error(transparent)] - Authorization(#[from] AuthorizationError), - #[error(transparent)] InternalError(Box), @@ -46,37 +41,16 @@ pub(crate) enum RouteError { Anyhow(#[from] anyhow::Error), } -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: DiscoveryError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: mas_storage::upstream_oauth2::ProviderLookupError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: ClientInitError) -> Self { - Self::InternalError(Box::new(e)) - } -} +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_http::ClientInitError); +impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); +impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); +impl_from_error_for_route!(mas_storage::upstream_oauth2::ProviderLookupError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(), - Self::Authorization(e) => { - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() - } Self::InternalError(e) => { (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 467956ba..43d90a26 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -19,17 +19,15 @@ use axum::{ use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; use mas_axum_utils::http_client_factory::HttpClientFactory; -use mas_http::ClientInitError; use mas_jose::claims::ClaimError; use mas_keystore::{Encrypter, Keystore}; -use mas_oidc_client::{ - error::{DiscoveryError, JwksError, TokenAuthorizationCodeError}, - requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData}, +use mas_oidc_client::requests::{ + authorization_code::AuthorizationValidationData, jose::JwtVerificationData, }; use mas_router::{Route, UrlBuilder}; use mas_storage::{ upstream_oauth2::{add_link, complete_session, lookup_link_by_subject, lookup_session}, - GenericLookupError, LookupResultExt, + LookupResultExt, }; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; @@ -37,7 +35,8 @@ use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; -use super::{client_credentials_for_provider, ProviderCredentialsError}; +use super::client_credentials_for_provider; +use crate::impl_from_error_for_route; #[derive(Deserialize)] pub struct QueryParams { @@ -100,53 +99,14 @@ pub(crate) enum RouteError { Anyhow(#[from] anyhow::Error), } -impl From for RouteError { - fn from(e: GenericLookupError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: DiscoveryError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: JwksError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: TokenAuthorizationCodeError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: mas_storage::upstream_oauth2::SessionLookupError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: ClientInitError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: ProviderCredentialsError) -> Self { - Self::InternalError(Box::new(e)) - } -} +impl_from_error_for_route!(mas_storage::GenericLookupError); +impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError); +impl_from_error_for_route!(mas_http::ClientInitError); +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); +impl_from_error_for_route!(mas_oidc_client::error::JwksError); +impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError); +impl_from_error_for_route!(super::ProviderCredentialsError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 0b7d3e0d..20a58a04 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -20,7 +20,7 @@ use axum::{ use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; use mas_axum_utils::{ - csrf::{CsrfError, CsrfExt, ProtectedForm}, + csrf::{CsrfExt, ProtectedForm}, SessionInfoExt, }; use mas_keystore::Encrypter; @@ -31,18 +31,17 @@ use mas_storage::{ }, user::{ authenticate_session_with_upstream, lookup_user, register_passwordless_user, start_session, - ActiveSessionLookupError, UserLookupError, }, - GenericLookupError, LookupResultExt, -}; -use mas_templates::{ - EmptyContext, TemplateContext, TemplateError, Templates, UpstreamExistingLinkContext, + LookupResultExt, }; +use mas_templates::{EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext}; use serde::Deserialize; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; +use crate::impl_from_error_for_route; + #[derive(Debug, Error)] pub(crate) enum RouteError { /// Couldn't find the link specified in the URL @@ -73,41 +72,12 @@ pub(crate) enum RouteError { Anyhow(#[from] anyhow::Error), } -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: TemplateError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: ActiveSessionLookupError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: CsrfError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: UserLookupError) -> Self { - Self::InternalError(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: GenericLookupError) -> Self { - Self::InternalError(Box::new(e)) - } -} +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_templates::TemplateError); +impl_from_error_for_route!(mas_storage::GenericLookupError); +impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError); +impl_from_error_for_route!(mas_storage::user::UserLookupError); +impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response {