1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

policy: define custom errors and ditch anyhow

This commit is contained in:
Quentin Gliech
2022-12-08 14:07:53 +01:00
parent 68890b7291
commit 13b1ac7c83
8 changed files with 103 additions and 73 deletions

View File

@ -14,7 +14,6 @@
use std::sync::Arc;
use anyhow::anyhow;
use axum::{
extract::{Path, State},
response::{IntoResponse, Response},
@ -44,10 +43,6 @@ pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
// TODO: remove this one: needed because mas_policy returns errors from anyhow
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
#[error("authorization grant was not found")]
NotFound,
@ -67,9 +62,6 @@ impl IntoResponse for RouteError {
"authorization grant not in a pending state",
)
.into_response(),
RouteError::Anyhow(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
RouteError::Internal(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
@ -79,6 +71,9 @@ impl IntoResponse for RouteError {
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(super::callback::CallbackDestinationError);
@ -126,7 +121,6 @@ pub(crate) async fn get(
}
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),
Err(GrantCompletionError::Anyhow(e)) => Err(RouteError::Anyhow(e)),
}
}
@ -135,9 +129,6 @@ pub enum GrantCompletionError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
#[error("authorization grant is not in a pending state")]
NotPending,
@ -154,6 +145,9 @@ pub enum GrantCompletionError {
impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
pub(crate) async fn complete(
grant: AuthorizationGrant,
@ -214,7 +208,9 @@ pub(crate) async fn complete(
// Did they request an ID token?
if grant.response_type_id_token {
// TODO
return Err(anyhow!("id tokens are not implemented yet").into());
return Err(GrantCompletionError::Internal(
"ID tokens are not implemented yet".into(),
));
}
txn.commit().await?;

View File

@ -14,7 +14,6 @@
use std::sync::Arc;
use anyhow::{anyhow, Context};
use axum::{
extract::{Form, State},
response::{IntoResponse, Response},
@ -52,13 +51,12 @@ pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
// TODO: remove this one
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
#[error("could not find client")]
ClientNotFound,
#[error("invalid response mode")]
InvalidResponseMode,
#[error("invalid parameters")]
IntoCallbackDestination(#[from] self::callback::IntoCallbackDestinationError),
@ -73,12 +71,12 @@ impl IntoResponse for RouteError {
RouteError::Internal(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
RouteError::Anyhow(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
RouteError::ClientNotFound => {
(StatusCode::BAD_REQUEST, "could not find client").into_response()
}
RouteError::InvalidResponseMode => {
(StatusCode::BAD_REQUEST, "invalid response mode").into_response()
}
RouteError::IntoCallbackDestination(e) => {
(StatusCode::BAD_REQUEST, e.to_string()).into_response()
}
@ -94,6 +92,9 @@ impl IntoResponse for RouteError {
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
#[derive(Deserialize)]
pub(crate) struct Params {
@ -110,7 +111,7 @@ pub(crate) struct Params {
fn resolve_response_mode(
response_type: &ResponseType,
suggested_response_mode: Option<ResponseMode>,
) -> anyhow::Result<ResponseMode> {
) -> Result<ResponseMode, RouteError> {
use ResponseMode as M;
// If the response type includes either "token" or "id_token", the default
@ -119,7 +120,7 @@ fn resolve_response_mode(
if response_type.has_token() || response_type.has_id_token() {
match suggested_response_mode {
None => Ok(M::Fragment),
Some(M::Query) => Err(anyhow!("invalid response mode")),
Some(M::Query) => Err(RouteError::InvalidResponseMode),
Some(mode) => Ok(mode),
}
} else {
@ -166,10 +167,7 @@ pub(crate) async fn get(
let templates = templates.clone();
let callback_destination = callback_destination.clone();
async move {
let maybe_session = session_info
.load_session(&mut txn)
.await
.context("failed to load browser session")?;
let maybe_session = session_info.load_session(&mut txn).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
// Check if the request/request_uri/registration params are used. If so, reply
@ -356,13 +354,12 @@ pub(crate) async fn get(
.go(&templates, ClientError::from(ClientErrorCode::AccessDenied))
.await?
}
Err(GrantCompletionError::Anyhow(a)) => return Err(RouteError::Anyhow(a)),
Err(GrantCompletionError::Internal(e)) => {
return Err(RouteError::Internal(e))
}
Err(GrantCompletionError::NotPending) => {
Err(e @ GrantCompletionError::NotPending) => {
// This should never happen
return Err(anyhow!("authorization grant is not pending").into());
return Err(RouteError::Internal(Box::new(e)));
}
}
}
@ -387,13 +384,12 @@ pub(crate) async fn get(
.go()
.into_response()
}
Err(GrantCompletionError::Anyhow(a)) => return Err(RouteError::Anyhow(a)),
Err(GrantCompletionError::Internal(e)) => {
return Err(RouteError::Internal(e))
}
Err(GrantCompletionError::NotPending) => {
Err(e @ GrantCompletionError::NotPending) => {
// This should never happen
return Err(anyhow!("authorization grant is not pending").into());
return Err(RouteError::Internal(Box::new(e)));
}
}
}

View File

@ -44,10 +44,6 @@ pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync>),
// TODO: remove this one, needed because of mas_policy
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
#[error(transparent)]
Csrf(#[from] mas_axum_utils::csrf::CsrfError),
@ -64,6 +60,9 @@ pub enum RouteError {
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {

View File

@ -39,10 +39,6 @@ pub(crate) enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync>),
// TODO: remove this, needed because of mas_policy
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
#[error("invalid redirect uri")]
InvalidRedirectUri,
@ -54,6 +50,10 @@ pub(crate) enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_keystore::aead::Error);
impl From<ClientMetadataVerificationError> for RouteError {
fn from(e: ClientMetadataVerificationError) -> Self {
@ -70,7 +70,7 @@ impl From<ClientMetadataVerificationError> for RouteError {
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
Self::Internal(_) | Self::Anyhow(_) => (
Self::Internal(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
)

View File

@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use std::string::FromUtf8Error;
use mas_data_model::UpstreamOAuthProvider;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_keystore::{Encrypter, Keystore};
use mas_keystore::{DecryptError, Encrypter, Keystore};
use mas_oidc_client::types::client_credentials::{ClientCredentials, JwtSigningMethod};
use thiserror::Error;
use url::Url;
@ -28,14 +29,21 @@ pub(crate) mod link;
use self::cookie::UpstreamSessions as UpstreamSessionsCookie;
#[derive(Debug, Error)]
#[allow(clippy::enum_variant_names)]
enum ProviderCredentialsError {
#[error("Provider doesn't have a client secret")]
MissingClientSecret,
#[error("Could not decrypt client secret")]
DecryptClientSecret {
#[from]
inner: DecryptError,
},
#[error("Client secret is invalid")]
InvalidClientSecret {
#[source]
inner: anyhow::Error,
#[from]
inner: FromUtf8Error,
},
}
@ -52,13 +60,9 @@ fn client_credentials_for_provider(
.encrypted_client_secret
.as_deref()
.map(|encrypted_client_secret| {
encrypter
.decrypt_string(encrypted_client_secret)
.and_then(|client_secret| {
String::from_utf8(client_secret)
.context("Client secret contains non-UTF8 bytes")
})
.map_err(|inner| ProviderCredentialsError::InvalidClientSecret { inner })
let decrypted = encrypter.decrypt_string(encrypted_client_secret)?;
let decrypted = String::from_utf8(decrypted)?;
Ok::<_, ProviderCredentialsError>(decrypted)
})
.transpose()?;

View File

@ -19,6 +19,7 @@ use base64ct::{Base64, Encoding};
use chacha20poly1305::{ChaCha20Poly1305, KeyInit};
use cookie::Key;
use generic_array::GenericArray;
use thiserror::Error;
/// Helps encrypting and decrypting data
#[derive(Clone)]
@ -33,6 +34,14 @@ impl From<Encrypter> for Key {
}
}
#[derive(Debug, Error)]
#[error("Decryption error")]
pub enum DecryptError {
Aead(#[from] aead::Error),
Base64(#[from] base64ct::Error),
Shape,
}
impl Encrypter {
/// Creates an [`Encrypter`] out of an encryption key
#[must_use]
@ -50,7 +59,7 @@ impl Encrypter {
/// # Errors
///
/// Will return `Err` when the payload failed to encrypt
pub fn encrypt(&self, nonce: &[u8; 12], decrypted: &[u8]) -> anyhow::Result<Vec<u8>> {
pub fn encrypt(&self, nonce: &[u8; 12], decrypted: &[u8]) -> Result<Vec<u8>, aead::Error> {
let nonce = GenericArray::from_slice(&nonce[..]);
let encrypted = self.aead.encrypt(nonce, decrypted)?;
Ok(encrypted)
@ -61,7 +70,7 @@ impl Encrypter {
/// # Errors
///
/// Will return `Err` when the payload failed to decrypt
pub fn decrypt(&self, nonce: &[u8; 12], encrypted: &[u8]) -> anyhow::Result<Vec<u8>> {
pub fn decrypt(&self, nonce: &[u8; 12], encrypted: &[u8]) -> Result<Vec<u8>, aead::Error> {
let nonce = GenericArray::from_slice(&nonce[..]);
let encrypted = self.aead.decrypt(nonce, encrypted)?;
Ok(encrypted)
@ -72,7 +81,7 @@ impl Encrypter {
/// # Errors
///
/// Will return `Err` when the payload failed to encrypt
pub fn encryt_to_string(&self, decrypted: &[u8]) -> anyhow::Result<String> {
pub fn encryt_to_string(&self, decrypted: &[u8]) -> Result<String, aead::Error> {
let nonce = rand::random();
let encrypted = self.encrypt(&nonce, decrypted)?;
let encrypted = [&nonce[..], &encrypted].concat();
@ -85,17 +94,16 @@ impl Encrypter {
/// # Errors
///
/// Will return `Err` when the payload failed to decrypt
pub fn decrypt_string(&self, encrypted: &str) -> anyhow::Result<Vec<u8>> {
pub fn decrypt_string(&self, encrypted: &str) -> Result<Vec<u8>, DecryptError> {
let encrypted = Base64::decode_vec(encrypted)?;
let nonce: &[u8; 12] = encrypted
.get(0..12)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
.try_into()?;
.ok_or(DecryptError::Shape)?
.try_into()
.map_err(|_| DecryptError::Shape)?;
let payload = encrypted
.get(12..)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
let payload = encrypted.get(12..).ok_or(DecryptError::Shape)?;
let decrypted_client_secret = self.decrypt(nonce, payload)?;

View File

@ -43,7 +43,9 @@ use thiserror::Error;
mod encrypter;
pub use self::encrypter::Encrypter;
pub use aead;
pub use self::encrypter::{DecryptError, Encrypter};
/// Error type used when a key could not be loaded
#[derive(Debug, Error)]

View File

@ -17,7 +17,6 @@
#![warn(clippy::pedantic)]
#![allow(clippy::missing_errors_doc)]
use anyhow::bail;
use mas_data_model::{AuthorizationGrant, User};
use oauth2_types::registration::VerifiedClientMetadata;
use opa_wasm::Runtime;
@ -41,13 +40,25 @@ pub enum LoadError {
Compilation(#[source] anyhow::Error),
#[error("failed to instantiate a test instance")]
Instantiate(#[source] anyhow::Error),
Instantiate(#[source] InstanciateError),
#[cfg(feature = "cache")]
#[error("could not load wasmtime cache configuration")]
CacheSetup(#[source] anyhow::Error),
}
#[derive(Debug, Error)]
pub enum InstanciateError {
#[error("failed to create WASM runtime")]
Runtime(#[source] anyhow::Error),
#[error("missing entrypoint {entrypoint}")]
MissingEntrypoint { entrypoint: String },
#[error("failed to load policy data")]
LoadData(#[source] anyhow::Error),
}
pub struct PolicyFactory {
engine: Engine,
module: Module,
@ -58,7 +69,7 @@ pub struct PolicyFactory {
}
impl PolicyFactory {
#[tracing::instrument(skip(source), err(Display))]
#[tracing::instrument(skip(source), err)]
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: serde_json::Value,
@ -107,9 +118,11 @@ impl PolicyFactory {
}
#[tracing::instrument(skip(self), err)]
pub async fn instantiate(&self) -> Result<Policy, anyhow::Error> {
pub async fn instantiate(&self) -> Result<Policy, InstanciateError> {
let mut store = Store::new(&self.engine, ());
let runtime = Runtime::new(&mut store, &self.module).await?;
let runtime = Runtime::new(&mut store, &self.module)
.await
.map_err(InstanciateError::Runtime)?;
// Check that we have the required entrypoints
let entrypoints = runtime.entrypoints();
@ -120,11 +133,16 @@ impl PolicyFactory {
self.authorization_grant_endpoint.as_str(),
] {
if !entrypoints.contains(e) {
bail!("missing entrypoint {e}")
return Err(InstanciateError::MissingEntrypoint {
entrypoint: e.to_owned(),
});
}
}
let instance = runtime.with_data(&mut store, &self.data).await?;
let instance = runtime
.with_data(&mut store, &self.data)
.await
.map_err(InstanciateError::LoadData)?;
Ok(Policy {
store,
@ -163,6 +181,13 @@ pub struct Policy {
authorization_grant_endpoint: String,
}
#[derive(Debug, Error)]
#[error("failed to evaluate policy")]
pub enum EvaluationError {
Serialization(#[from] serde_json::Error),
Evaluation(#[from] anyhow::Error),
}
impl Policy {
#[tracing::instrument(skip(self, password))]
pub async fn evaluate_register(
@ -170,7 +195,7 @@ impl Policy {
username: &str,
password: &str,
email: &str,
) -> Result<EvaluationResult, anyhow::Error> {
) -> Result<EvaluationResult, EvaluationError> {
let input = serde_json::json!({
"user": {
"username": username,
@ -191,7 +216,7 @@ impl Policy {
pub async fn evaluate_client_registration(
&mut self,
client_metadata: &VerifiedClientMetadata,
) -> Result<EvaluationResult, anyhow::Error> {
) -> Result<EvaluationResult, EvaluationError> {
let client_metadata = serde_json::to_value(client_metadata)?;
let input = serde_json::json!({
"client_metadata": client_metadata,
@ -214,7 +239,7 @@ impl Policy {
&mut self,
authorization_grant: &AuthorizationGrant,
user: &User,
) -> Result<EvaluationResult, anyhow::Error> {
) -> Result<EvaluationResult, EvaluationError> {
let authorization_grant = serde_json::to_value(authorization_grant)?;
let user = serde_json::to_value(user)?;
let input = serde_json::json!({