1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Simplify OAuth2 error types

This commit is contained in:
Quentin Gliech
2022-04-07 10:08:10 +02:00
parent 54170cac6f
commit bbcd03fa73
2 changed files with 86 additions and 326 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C. // Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,356 +12,127 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use http::status::StatusCode;
use serde::ser::{Serialize, SerializeMap};
use url::Url;
#[derive(serde::Serialize)] #[derive(serde::Serialize)]
pub struct ClientError { pub struct ClientError {
pub error: &'static str, pub error: &'static str,
pub error_description: &'static str, pub error_description: &'static str,
} }
pub trait OAuth2Error: std::fmt::Debug + Send + Sync { impl ClientError {
/// A single ASCII error code. #[must_use]
/// pub const fn new(error: &'static str, error_description: &'static str) -> Self {
/// Maps to the required "error" field. Self {
fn error(&self) -> &'static str; error,
error_description,
/// Human-readable ASCII text providing additional information, used to
/// assist the client developer in understanding the error that
/// occurred.
///
/// Maps to the optional `error_description` field.
fn description(&self) -> Option<String> {
None
}
/// A URI identifying a human-readable web page with information about the
/// error, used to provide the client developer with additional
/// information about the error.
///
/// Maps to the optional `error_uri` field.
fn uri(&self) -> Option<Url> {
None
}
/// Wraps the error with an `ErrorResponse` to help serializing.
fn into_response(self) -> ErrorResponse
where
Self: Sized + 'static,
{
ErrorResponse(Box::new(self))
}
}
pub trait OAuth2ErrorCode: OAuth2Error + 'static {
/// The HTTP status code that must be returned by this error
fn status(&self) -> StatusCode;
}
impl OAuth2Error for &Box<dyn OAuth2ErrorCode> {
fn error(&self) -> &'static str {
self.as_ref().error()
}
fn description(&self) -> Option<String> {
self.as_ref().description()
}
fn uri(&self) -> Option<Url> {
self.as_ref().uri()
}
}
#[derive(Debug)]
pub struct ErrorResponse(Box<dyn OAuth2Error>);
impl From<Box<dyn OAuth2Error>> for ErrorResponse {
fn from(b: Box<dyn OAuth2Error>) -> Self {
Self(b)
}
}
impl OAuth2Error for ErrorResponse {
fn error(&self) -> &'static str {
self.0.error()
}
fn description(&self) -> Option<String> {
self.0.description()
}
fn uri(&self) -> Option<Url> {
self.0.uri()
}
}
impl Serialize for ErrorResponse {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let error = self.0.error();
let description = self.0.description();
let uri = self.0.uri();
// Count the number of fields to serialize
let len = {
let mut x = 1;
if description.is_some() {
x += 1;
}
if uri.is_some() {
x += 1;
}
x
};
let mut map = serializer.serialize_map(Some(len))?;
map.serialize_entry("error", error)?;
if let Some(ref description) = description {
map.serialize_entry("error_description", description)?;
}
if let Some(ref uri) = uri {
map.serialize_entry("error_uri", uri)?;
}
map.end()
}
}
macro_rules! oauth2_error_def {
($name:ident) => {
#[derive(Debug, Clone)]
pub struct $name;
};
}
macro_rules! oauth2_error_status {
($name:ident, $code:ident) => {
impl $crate::errors::OAuth2ErrorCode for $name {
fn status(&self) -> ::http::status::StatusCode {
::http::status::StatusCode::$code
} }
} }
};
}
macro_rules! oauth2_error_error {
($err:literal) => {
fn error(&self) -> &'static str {
$err
}
};
}
macro_rules! oauth2_error_const {
($const:ident, $err:literal, $description:expr) => {
pub const $const: ClientError = ClientError {
error: $err,
error_description: $description,
};
};
}
macro_rules! oauth2_error_description {
($description:expr) => {
fn description(&self) -> Option<String> {
Some(($description).to_string())
}
};
}
macro_rules! oauth2_error {
($name:ident, $const:ident, $err:literal => $description:expr) => {
oauth2_error_const!($const, $err, $description);
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
oauth2_error_description!(indoc::indoc! {$description});
}
};
($name:ident, $const:ident, $err:literal) => {
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
}
};
($name:ident, $const:ident, code: $code:ident, $err:literal => $description:expr) => {
oauth2_error!($name, $const, $err => $description);
oauth2_error_status!($name, $code);
};
($name:ident, $const:ident, code: $code:ident, $err:literal) => {
oauth2_error!($name, $const, $err);
oauth2_error_status!($name, $code);
};
} }
pub mod rfc6749 { pub mod rfc6749 {
use super::ClientError; use super::ClientError;
oauth2_error! { pub const INVALID_REQUEST: ClientError = ClientError::new(
InvalidRequest, "invalid_request",
INVALID_REQUEST, "The request is missing a required parameter, \
code: BAD_REQUEST, includes an invalid parameter value, \
"invalid_request" => includes a parameter more than once, \
"The request is missing a required parameter, includes an invalid parameter value, \ or is otherwise malformed.",
includes a parameter more than once, or is otherwise malformed." );
}
oauth2_error! { pub const INVALID_CLIENT: ClientError =
InvalidClient, ClientError::new("invalid_client", "Client authentication failed.");
INVALID_CLIENT,
code: BAD_REQUEST,
"invalid_client" =>
"Client authentication failed."
}
oauth2_error! { pub const INVALID_GRANT: ClientError = ClientError::new(
InvalidGrant, "invalid_grant",
INVALID_GRANT, "The provided access grant is invalid, expired, or revoked.",
code: BAD_REQUEST, );
"invalid_grant" =>
"The provided access grant is invalid, expired, or revoked."
}
oauth2_error! { pub const UNAUTHORIZED_CLIENT: ClientError = ClientError::new(
UnauthorizedClient, "unauthorized_client",
UNAUTHORIZED_CLIENT, "The client is not authorized to request an access token using this method.",
code: BAD_REQUEST, );
"unauthorized_client" =>
"The client is not authorized to request an access token using this method."
}
oauth2_error! { pub const UNSUPPORTED_GRANT_TYPE: ClientError = ClientError::new(
UnsupportedGrantType, "unsupported_grant_type",
UNSUPPORTED_GRANT_TYPE, "The authorization grant type is not supported by the authorization server.",
code: BAD_REQUEST, );
"unsupported_grant_type" =>
"The authorization grant type is not supported by the authorization server."
}
oauth2_error! { pub const ACCESS_DENIED: ClientError = ClientError::new(
AccessDenied, "access_denied",
ACCESS_DENIED, "The resource owner or authorization server denied the request.",
"access_denied" => );
"The resource owner or authorization server denied the request."
}
oauth2_error! { pub const UNSUPPORTED_RESPONSE_TYPE: ClientError = ClientError::new(
UnsupportedResponseType, "unsupported_response_type",
UNSUPPORTED_RESPONSE_TYPE, "The authorization server does not support obtaining an access token using this method.",
"unsupported_response_type" => );
"The authorization server does not support obtaining an access token using this method."
}
oauth2_error! { pub const INVALID_SCOPE: ClientError = ClientError::new(
InvalidScope, "invalid_scope",
INVALID_SCOPE, "The requested scope is invalid, unknown, or malformed.",
code: BAD_REQUEST, );
"invalid_scope" =>
"The requested scope is invalid, unknown, or malformed."
}
oauth2_error! { pub const SERVER_ERROR: ClientError = ClientError::new(
ServerError, "server_error",
SERVER_ERROR, "The authorization server encountered an unexpected condition \
code: INTERNAL_SERVER_ERROR, that prevented it from fulfilling the request.",
"server_error" => );
"The authorization server encountered an unexpected \
condition that prevented it from fulfilling the request."
}
oauth2_error! { pub const TEMPORARILY_UNAVAILABLE: ClientError = ClientError::new(
TemporarilyUnavailable, "temporarily_unavailable",
TEMPORARILY_UNAVAILABLE, "The authorization server is currently unable to handle the request \
"temporarily_unavailable" => due to a temporary overloading or maintenance of the server.",
"The authorization server is currently unable to handle \ );
the request due to a temporary overloading or maintenance \
of the server."
}
} }
pub mod oidc_core { pub mod oidc_core {
use super::ClientError; use super::ClientError;
oauth2_error! { pub const INTERACTION_REQUIRED: ClientError = ClientError::new(
InteractionRequired, "interaction_required",
INTERACTION_REQUIRED, "The Authorization Server requires End-User interaction of some form to proceed.",
"interaction_required" => );
"The Authorization Server requires End-User interaction of some form to proceed."
}
oauth2_error! { pub const LOGIN_REQUIRED: ClientError = ClientError::new(
LoginRequired, "login_required",
LOGIN_REQUIRED, "The Authorization Server requires End-User authentication.",
"login_required" => );
"The Authorization Server requires End-User authentication."
}
oauth2_error! { pub const ACCOUNT_SELECTION_REQUIRED: ClientError = ClientError::new(
AccountSelectionRequired, "account_selection_required",
ACCOUNT_SELECTION_REQUIRED, "The End-User is REQUIRED to select a session at the Authorization Server.",
"account_selection_required" );
}
oauth2_error! { pub const CONSENT_REQUIRED: ClientError = ClientError::new(
ConsentRequired, "consent_required",
CONSENT_REQUIRED, "The Authorization Server requires End-User consent.",
"consent_required" );
}
oauth2_error! { pub const INVALID_REQUEST_URI: ClientError = ClientError::new(
InvalidRequestUri, "invalid_request_uri",
INVALID_REQUEST_URI, "The request_uri in the Authorization Request returns an error or contains invalid data. ",
"invalid_request_uri" => );
"The request_uri in the Authorization Request returns an error or contains invalid data. "
}
oauth2_error! { pub const INVALID_REQUEST_OBJECT: ClientError = ClientError::new(
InvalidRequestObject, "invalid_request_object",
INVALID_REQUEST_OBJECT, "The request parameter contains an invalid Request Object.",
"invalid_request_object" => );
"The request parameter contains an invalid Request Object."
}
oauth2_error! { pub const REQUEST_NOT_SUPPORTED: ClientError = ClientError::new(
RequestNotSupported, "request_not_supported",
REQUEST_NOT_SUPPORTED, "The provider does not support use of the request parameter.",
"request_not_supported" => );
"The provider does not support use of the request parameter."
}
oauth2_error! { pub const REQUEST_URI_NOT_SUPPORTED: ClientError = ClientError::new(
RequestUriNotSupported, "request_uri_not_supported",
REQUEST_URI_NOT_SUPPORTED, "The provider does not support use of the request_uri parameter.",
"request_uri_not_supported" => );
"The provider does not support use of the request_uri parameter."
}
oauth2_error! { pub const REGISTRATION_NOT_SUPPORTED: ClientError = ClientError::new(
RegistrationNotSupported, "registration_not_supported",
REGISTRATION_NOT_SUPPORTED, "The provider does not support use of the registration parameter.",
"registration_not_supported" => );
"The provider does not support use of the registration parameter."
}
} }
pub use oidc_core::*; pub use oidc_core::*;
pub use rfc6749::*; pub use rfc6749::*;
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn serialize_error() {
let expected = json!({
"error": "invalid_grant",
"error_description": "The provided access grant is invalid, expired, or revoked."
});
let actual = serde_json::to_value(InvalidGrant.into_response()).unwrap();
assert_eq!(expected, actual);
}
}

View File

@ -19,7 +19,6 @@
use mas_data_model::{ use mas_data_model::{
errors::ErroredForm, AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail, errors::ErroredForm, AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail,
}; };
use oauth2_types::errors::OAuth2Error;
use serde::{ser::SerializeStruct, Serialize}; use serde::{ser::SerializeStruct, Serialize};
use url::Url; use url::Url;
@ -599,13 +598,3 @@ impl ErrorContext {
self self
} }
} }
impl From<Box<dyn OAuth2Error>> for ErrorContext {
fn from(err: Box<dyn OAuth2Error>) -> Self {
let mut ctx = ErrorContext::new().with_code(err.error());
if let Some(desc) = err.description() {
ctx = ctx.with_description(desc);
}
ctx
}
}