You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
policy: define custom errors and ditch anyhow
This commit is contained in:
@ -14,7 +14,6 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Response},
|
||||
@ -44,10 +43,6 @@ pub enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
// TODO: remove this one: needed because mas_policy returns errors from anyhow
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("authorization grant was not found")]
|
||||
NotFound,
|
||||
|
||||
@ -67,9 +62,6 @@ impl IntoResponse for RouteError {
|
||||
"authorization grant not in a pending state",
|
||||
)
|
||||
.into_response(),
|
||||
RouteError::Anyhow(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
}
|
||||
RouteError::Internal(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
}
|
||||
@ -79,6 +71,9 @@ impl IntoResponse for RouteError {
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
|
||||
impl_from_error_for_route!(super::callback::CallbackDestinationError);
|
||||
|
||||
@ -126,7 +121,6 @@ pub(crate) async fn get(
|
||||
}
|
||||
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
|
||||
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),
|
||||
Err(GrantCompletionError::Anyhow(e)) => Err(RouteError::Anyhow(e)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,9 +129,6 @@ pub enum GrantCompletionError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("authorization grant is not in a pending state")]
|
||||
NotPending,
|
||||
|
||||
@ -154,6 +145,9 @@ pub enum GrantCompletionError {
|
||||
impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
|
||||
|
||||
pub(crate) async fn complete(
|
||||
grant: AuthorizationGrant,
|
||||
@ -214,7 +208,9 @@ pub(crate) async fn complete(
|
||||
// Did they request an ID token?
|
||||
if grant.response_type_id_token {
|
||||
// TODO
|
||||
return Err(anyhow!("id tokens are not implemented yet").into());
|
||||
return Err(GrantCompletionError::Internal(
|
||||
"ID tokens are not implemented yet".into(),
|
||||
));
|
||||
}
|
||||
|
||||
txn.commit().await?;
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use axum::{
|
||||
extract::{Form, State},
|
||||
response::{IntoResponse, Response},
|
||||
@ -52,13 +51,12 @@ pub enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
// TODO: remove this one
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("could not find client")]
|
||||
ClientNotFound,
|
||||
|
||||
#[error("invalid response mode")]
|
||||
InvalidResponseMode,
|
||||
|
||||
#[error("invalid parameters")]
|
||||
IntoCallbackDestination(#[from] self::callback::IntoCallbackDestinationError),
|
||||
|
||||
@ -73,12 +71,12 @@ impl IntoResponse for RouteError {
|
||||
RouteError::Internal(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
}
|
||||
RouteError::Anyhow(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
}
|
||||
RouteError::ClientNotFound => {
|
||||
(StatusCode::BAD_REQUEST, "could not find client").into_response()
|
||||
}
|
||||
RouteError::InvalidResponseMode => {
|
||||
(StatusCode::BAD_REQUEST, "invalid response mode").into_response()
|
||||
}
|
||||
RouteError::IntoCallbackDestination(e) => {
|
||||
(StatusCode::BAD_REQUEST, e.to_string()).into_response()
|
||||
}
|
||||
@ -94,6 +92,9 @@ impl IntoResponse for RouteError {
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(self::callback::CallbackDestinationError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct Params {
|
||||
@ -110,7 +111,7 @@ pub(crate) struct Params {
|
||||
fn resolve_response_mode(
|
||||
response_type: &ResponseType,
|
||||
suggested_response_mode: Option<ResponseMode>,
|
||||
) -> anyhow::Result<ResponseMode> {
|
||||
) -> Result<ResponseMode, RouteError> {
|
||||
use ResponseMode as M;
|
||||
|
||||
// If the response type includes either "token" or "id_token", the default
|
||||
@ -119,7 +120,7 @@ fn resolve_response_mode(
|
||||
if response_type.has_token() || response_type.has_id_token() {
|
||||
match suggested_response_mode {
|
||||
None => Ok(M::Fragment),
|
||||
Some(M::Query) => Err(anyhow!("invalid response mode")),
|
||||
Some(M::Query) => Err(RouteError::InvalidResponseMode),
|
||||
Some(mode) => Ok(mode),
|
||||
}
|
||||
} else {
|
||||
@ -166,10 +167,7 @@ pub(crate) async fn get(
|
||||
let templates = templates.clone();
|
||||
let callback_destination = callback_destination.clone();
|
||||
async move {
|
||||
let maybe_session = session_info
|
||||
.load_session(&mut txn)
|
||||
.await
|
||||
.context("failed to load browser session")?;
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
|
||||
|
||||
// Check if the request/request_uri/registration params are used. If so, reply
|
||||
@ -356,13 +354,12 @@ pub(crate) async fn get(
|
||||
.go(&templates, ClientError::from(ClientErrorCode::AccessDenied))
|
||||
.await?
|
||||
}
|
||||
Err(GrantCompletionError::Anyhow(a)) => return Err(RouteError::Anyhow(a)),
|
||||
Err(GrantCompletionError::Internal(e)) => {
|
||||
return Err(RouteError::Internal(e))
|
||||
}
|
||||
Err(GrantCompletionError::NotPending) => {
|
||||
Err(e @ GrantCompletionError::NotPending) => {
|
||||
// This should never happen
|
||||
return Err(anyhow!("authorization grant is not pending").into());
|
||||
return Err(RouteError::Internal(Box::new(e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -387,13 +384,12 @@ pub(crate) async fn get(
|
||||
.go()
|
||||
.into_response()
|
||||
}
|
||||
Err(GrantCompletionError::Anyhow(a)) => return Err(RouteError::Anyhow(a)),
|
||||
Err(GrantCompletionError::Internal(e)) => {
|
||||
return Err(RouteError::Internal(e))
|
||||
}
|
||||
Err(GrantCompletionError::NotPending) => {
|
||||
Err(e @ GrantCompletionError::NotPending) => {
|
||||
// This should never happen
|
||||
return Err(anyhow!("authorization grant is not pending").into());
|
||||
return Err(RouteError::Internal(Box::new(e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -44,10 +44,6 @@ pub enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
// TODO: remove this one, needed because of mas_policy
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Csrf(#[from] mas_axum_utils::csrf::CsrfError),
|
||||
|
||||
@ -64,6 +60,9 @@ pub enum RouteError {
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
|
@ -39,10 +39,6 @@ pub(crate) enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
// TODO: remove this, needed because of mas_policy
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("invalid redirect uri")]
|
||||
InvalidRedirectUri,
|
||||
|
||||
@ -54,6 +50,10 @@ pub(crate) enum RouteError {
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(mas_keystore::aead::Error);
|
||||
|
||||
impl From<ClientMetadataVerificationError> for RouteError {
|
||||
fn from(e: ClientMetadataVerificationError) -> Self {
|
||||
@ -70,7 +70,7 @@ impl From<ClientMetadataVerificationError> for RouteError {
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::Internal(_) | Self::Anyhow(_) => (
|
||||
Self::Internal(_) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ClientError::from(ClientErrorCode::ServerError)),
|
||||
)
|
||||
|
@ -12,10 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use std::string::FromUtf8Error;
|
||||
|
||||
use mas_data_model::UpstreamOAuthProvider;
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_keystore::{DecryptError, Encrypter, Keystore};
|
||||
use mas_oidc_client::types::client_credentials::{ClientCredentials, JwtSigningMethod};
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
@ -28,14 +29,21 @@ pub(crate) mod link;
|
||||
use self::cookie::UpstreamSessions as UpstreamSessionsCookie;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
enum ProviderCredentialsError {
|
||||
#[error("Provider doesn't have a client secret")]
|
||||
MissingClientSecret,
|
||||
|
||||
#[error("Could not decrypt client secret")]
|
||||
DecryptClientSecret {
|
||||
#[from]
|
||||
inner: DecryptError,
|
||||
},
|
||||
|
||||
#[error("Client secret is invalid")]
|
||||
InvalidClientSecret {
|
||||
#[source]
|
||||
inner: anyhow::Error,
|
||||
#[from]
|
||||
inner: FromUtf8Error,
|
||||
},
|
||||
}
|
||||
|
||||
@ -52,13 +60,9 @@ fn client_credentials_for_provider(
|
||||
.encrypted_client_secret
|
||||
.as_deref()
|
||||
.map(|encrypted_client_secret| {
|
||||
encrypter
|
||||
.decrypt_string(encrypted_client_secret)
|
||||
.and_then(|client_secret| {
|
||||
String::from_utf8(client_secret)
|
||||
.context("Client secret contains non-UTF8 bytes")
|
||||
})
|
||||
.map_err(|inner| ProviderCredentialsError::InvalidClientSecret { inner })
|
||||
let decrypted = encrypter.decrypt_string(encrypted_client_secret)?;
|
||||
let decrypted = String::from_utf8(decrypted)?;
|
||||
Ok::<_, ProviderCredentialsError>(decrypted)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
|
@ -19,6 +19,7 @@ use base64ct::{Base64, Encoding};
|
||||
use chacha20poly1305::{ChaCha20Poly1305, KeyInit};
|
||||
use cookie::Key;
|
||||
use generic_array::GenericArray;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Helps encrypting and decrypting data
|
||||
#[derive(Clone)]
|
||||
@ -33,6 +34,14 @@ impl From<Encrypter> for Key {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Decryption error")]
|
||||
pub enum DecryptError {
|
||||
Aead(#[from] aead::Error),
|
||||
Base64(#[from] base64ct::Error),
|
||||
Shape,
|
||||
}
|
||||
|
||||
impl Encrypter {
|
||||
/// Creates an [`Encrypter`] out of an encryption key
|
||||
#[must_use]
|
||||
@ -50,7 +59,7 @@ impl Encrypter {
|
||||
/// # Errors
|
||||
///
|
||||
/// Will return `Err` when the payload failed to encrypt
|
||||
pub fn encrypt(&self, nonce: &[u8; 12], decrypted: &[u8]) -> anyhow::Result<Vec<u8>> {
|
||||
pub fn encrypt(&self, nonce: &[u8; 12], decrypted: &[u8]) -> Result<Vec<u8>, aead::Error> {
|
||||
let nonce = GenericArray::from_slice(&nonce[..]);
|
||||
let encrypted = self.aead.encrypt(nonce, decrypted)?;
|
||||
Ok(encrypted)
|
||||
@ -61,7 +70,7 @@ impl Encrypter {
|
||||
/// # Errors
|
||||
///
|
||||
/// Will return `Err` when the payload failed to decrypt
|
||||
pub fn decrypt(&self, nonce: &[u8; 12], encrypted: &[u8]) -> anyhow::Result<Vec<u8>> {
|
||||
pub fn decrypt(&self, nonce: &[u8; 12], encrypted: &[u8]) -> Result<Vec<u8>, aead::Error> {
|
||||
let nonce = GenericArray::from_slice(&nonce[..]);
|
||||
let encrypted = self.aead.decrypt(nonce, encrypted)?;
|
||||
Ok(encrypted)
|
||||
@ -72,7 +81,7 @@ impl Encrypter {
|
||||
/// # Errors
|
||||
///
|
||||
/// Will return `Err` when the payload failed to encrypt
|
||||
pub fn encryt_to_string(&self, decrypted: &[u8]) -> anyhow::Result<String> {
|
||||
pub fn encryt_to_string(&self, decrypted: &[u8]) -> Result<String, aead::Error> {
|
||||
let nonce = rand::random();
|
||||
let encrypted = self.encrypt(&nonce, decrypted)?;
|
||||
let encrypted = [&nonce[..], &encrypted].concat();
|
||||
@ -85,17 +94,16 @@ impl Encrypter {
|
||||
/// # Errors
|
||||
///
|
||||
/// Will return `Err` when the payload failed to decrypt
|
||||
pub fn decrypt_string(&self, encrypted: &str) -> anyhow::Result<Vec<u8>> {
|
||||
pub fn decrypt_string(&self, encrypted: &str) -> Result<Vec<u8>, DecryptError> {
|
||||
let encrypted = Base64::decode_vec(encrypted)?;
|
||||
|
||||
let nonce: &[u8; 12] = encrypted
|
||||
.get(0..12)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
|
||||
.try_into()?;
|
||||
.ok_or(DecryptError::Shape)?
|
||||
.try_into()
|
||||
.map_err(|_| DecryptError::Shape)?;
|
||||
|
||||
let payload = encrypted
|
||||
.get(12..)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
|
||||
let payload = encrypted.get(12..).ok_or(DecryptError::Shape)?;
|
||||
|
||||
let decrypted_client_secret = self.decrypt(nonce, payload)?;
|
||||
|
||||
|
@ -43,7 +43,9 @@ use thiserror::Error;
|
||||
|
||||
mod encrypter;
|
||||
|
||||
pub use self::encrypter::Encrypter;
|
||||
pub use aead;
|
||||
|
||||
pub use self::encrypter::{DecryptError, Encrypter};
|
||||
|
||||
/// Error type used when a key could not be loaded
|
||||
#[derive(Debug, Error)]
|
||||
|
@ -17,7 +17,6 @@
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::missing_errors_doc)]
|
||||
|
||||
use anyhow::bail;
|
||||
use mas_data_model::{AuthorizationGrant, User};
|
||||
use oauth2_types::registration::VerifiedClientMetadata;
|
||||
use opa_wasm::Runtime;
|
||||
@ -41,13 +40,25 @@ pub enum LoadError {
|
||||
Compilation(#[source] anyhow::Error),
|
||||
|
||||
#[error("failed to instantiate a test instance")]
|
||||
Instantiate(#[source] anyhow::Error),
|
||||
Instantiate(#[source] InstanciateError),
|
||||
|
||||
#[cfg(feature = "cache")]
|
||||
#[error("could not load wasmtime cache configuration")]
|
||||
CacheSetup(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InstanciateError {
|
||||
#[error("failed to create WASM runtime")]
|
||||
Runtime(#[source] anyhow::Error),
|
||||
|
||||
#[error("missing entrypoint {entrypoint}")]
|
||||
MissingEntrypoint { entrypoint: String },
|
||||
|
||||
#[error("failed to load policy data")]
|
||||
LoadData(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
pub struct PolicyFactory {
|
||||
engine: Engine,
|
||||
module: Module,
|
||||
@ -58,7 +69,7 @@ pub struct PolicyFactory {
|
||||
}
|
||||
|
||||
impl PolicyFactory {
|
||||
#[tracing::instrument(skip(source), err(Display))]
|
||||
#[tracing::instrument(skip(source), err)]
|
||||
pub async fn load(
|
||||
mut source: impl AsyncRead + std::marker::Unpin,
|
||||
data: serde_json::Value,
|
||||
@ -107,9 +118,11 @@ impl PolicyFactory {
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), err)]
|
||||
pub async fn instantiate(&self) -> Result<Policy, anyhow::Error> {
|
||||
pub async fn instantiate(&self) -> Result<Policy, InstanciateError> {
|
||||
let mut store = Store::new(&self.engine, ());
|
||||
let runtime = Runtime::new(&mut store, &self.module).await?;
|
||||
let runtime = Runtime::new(&mut store, &self.module)
|
||||
.await
|
||||
.map_err(InstanciateError::Runtime)?;
|
||||
|
||||
// Check that we have the required entrypoints
|
||||
let entrypoints = runtime.entrypoints();
|
||||
@ -120,11 +133,16 @@ impl PolicyFactory {
|
||||
self.authorization_grant_endpoint.as_str(),
|
||||
] {
|
||||
if !entrypoints.contains(e) {
|
||||
bail!("missing entrypoint {e}")
|
||||
return Err(InstanciateError::MissingEntrypoint {
|
||||
entrypoint: e.to_owned(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let instance = runtime.with_data(&mut store, &self.data).await?;
|
||||
let instance = runtime
|
||||
.with_data(&mut store, &self.data)
|
||||
.await
|
||||
.map_err(InstanciateError::LoadData)?;
|
||||
|
||||
Ok(Policy {
|
||||
store,
|
||||
@ -163,6 +181,13 @@ pub struct Policy {
|
||||
authorization_grant_endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("failed to evaluate policy")]
|
||||
pub enum EvaluationError {
|
||||
Serialization(#[from] serde_json::Error),
|
||||
Evaluation(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
#[tracing::instrument(skip(self, password))]
|
||||
pub async fn evaluate_register(
|
||||
@ -170,7 +195,7 @@ impl Policy {
|
||||
username: &str,
|
||||
password: &str,
|
||||
email: &str,
|
||||
) -> Result<EvaluationResult, anyhow::Error> {
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = serde_json::json!({
|
||||
"user": {
|
||||
"username": username,
|
||||
@ -191,7 +216,7 @@ impl Policy {
|
||||
pub async fn evaluate_client_registration(
|
||||
&mut self,
|
||||
client_metadata: &VerifiedClientMetadata,
|
||||
) -> Result<EvaluationResult, anyhow::Error> {
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let client_metadata = serde_json::to_value(client_metadata)?;
|
||||
let input = serde_json::json!({
|
||||
"client_metadata": client_metadata,
|
||||
@ -214,7 +239,7 @@ impl Policy {
|
||||
&mut self,
|
||||
authorization_grant: &AuthorizationGrant,
|
||||
user: &User,
|
||||
) -> Result<EvaluationResult, anyhow::Error> {
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let authorization_grant = serde_json::to_value(authorization_grant)?;
|
||||
let user = serde_json::to_value(user)?;
|
||||
let input = serde_json::json!({
|
||||
|
Reference in New Issue
Block a user