diff --git a/crates/core/src/errors.rs b/crates/core/src/errors.rs index 49dc15ae..d7cb5a9c 100644 --- a/crates/core/src/errors.rs +++ b/crates/core/src/errors.rs @@ -22,6 +22,10 @@ pub struct WrappedError(anyhow::Error); impl warp::reject::Reject for WrappedError {} +pub fn wrapped_error>(e: T) -> impl Reject { + WrappedError(e.into()) +} + pub trait WrapError { fn wrap_error(self) -> Result; } diff --git a/crates/core/src/filters/authenticate.rs b/crates/core/src/filters/authenticate.rs index 504f994c..57f314c3 100644 --- a/crates/core/src/filters/authenticate.rs +++ b/crates/core/src/filters/authenticate.rs @@ -14,16 +14,60 @@ use chrono::Utc; use headers::{authorization::Bearer, Authorization}; +use hyper::StatusCode; use sqlx::{pool::PoolConnection, PgPool, Postgres}; -use warp::{Filter, Rejection}; - -use super::{database::with_connection, headers::with_typed_header}; -use crate::{ - errors::WrapError, - storage::oauth2::access_token::{lookup_access_token, OAuth2AccessTokenLookup}, - tokens, +use thiserror::Error; +use warp::{ + reject::{MissingHeader, Reject}, + reply::{with_header, with_status}, + Filter, Rejection, Reply, }; +use super::{ + database::with_connection, + headers::{with_typed_header, InvalidTypedHeader}, +}; +use crate::{ + errors::wrapped_error, + storage::oauth2::access_token::{ + lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup, + }, + tokens::{self, TokenFormatError, TokenType}, +}; + +#[derive(Debug, Error)] +enum AuthenticationError { + #[error("invalid token format")] + TokenFormat(#[from] TokenFormatError), + + #[error("invalid token type {0:?}, expected an access token")] + WrongTokenType(TokenType), + + #[error("unknown token")] + TokenNotFound(#[source] AccessTokenLookupError), + + #[error("token is not active")] + TokenInactive, + + #[error("token expired")] + TokenExpired, + + #[error("missing authorization header")] + MissingAuthorizationHeader, + + #[error("invalid authorization header")] + InvalidAuthorizationHeader, +} + +impl Reject for AuthenticationError {} + +/// Authenticate a request using an access token as a bearer authorization +/// +/// # Rejections +/// +/// This can reject with either a [`AuthenticationError`] or with a generic +/// wrapped sqlx error. +#[must_use] pub fn with_authentication( pool: &PgPool, ) -> impl Filter + Clone + Send + Sync + 'static @@ -31,6 +75,8 @@ pub fn with_authentication( with_connection(pool) .and(with_typed_header()) .and_then(authenticate) + .recover(recover) + .unify() } async fn authenticate( @@ -38,18 +84,58 @@ async fn authenticate( auth: Authorization, ) -> Result { let token = auth.0.token(); - let token_type = tokens::check(token).wrap_error()?; + let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?; + if token_type != tokens::TokenType::AccessToken { - return Err(anyhow::anyhow!("wrong token type")).wrap_error(); + return Err(AuthenticationError::WrongTokenType(token_type).into()); } - let token = lookup_access_token(&mut conn, token).await.wrap_error()?; - let exp = token.exp(); + let token = lookup_access_token(&mut conn, token).await.map_err(|e| { + if e.not_found() { + // This error happens if the token was not found and should be recovered + warp::reject::custom(AuthenticationError::TokenNotFound(e)) + } else { + // This is a generic database error that we want to propagate + warp::reject::custom(wrapped_error(e)) + } + })?; - // Check it is active and did not expire - if !token.active || exp < Utc::now() { - return Err(anyhow::anyhow!("token expired")).wrap_error(); + if !token.active { + return Err(AuthenticationError::TokenInactive.into()); + } + + if token.exp() < Utc::now() { + return Err(AuthenticationError::TokenExpired.into()); } Ok(token) } + +/// Transform the rejections from the [`with_typed_header`] filter +async fn recover(rejection: Rejection) -> Result { + if rejection.find::().is_some() { + return Err(warp::reject::custom( + AuthenticationError::MissingAuthorizationHeader, + )); + } + + if rejection.find::().is_some() { + return Err(warp::reject::custom( + AuthenticationError::InvalidAuthorizationHeader, + )); + } + + Err(rejection) +} + +pub async fn recover_unauthorized(rejection: Rejection) -> Result { + if rejection.find::().is_some() { + // TODO: have the issuer/realm here + let reply = "invalid token"; + let reply = with_status(reply, StatusCode::UNAUTHORIZED); + let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#); + return Ok(reply); + } + + Err(rejection) +} diff --git a/crates/core/src/filters/headers.rs b/crates/core/src/filters/headers.rs index 473d1402..d79c6061 100644 --- a/crates/core/src/filters/headers.rs +++ b/crates/core/src/filters/headers.rs @@ -13,10 +13,10 @@ // limitations under the License. // use headers::{Header, HeaderMapExt, HeaderValue}; -use warp::{Filter, Rejection, Reply}; - -use crate::errors::WrapError; +use thiserror::Error; +use warp::{reject::Reject, Filter, Rejection, Reply}; +/// Add a typed header to a reply pub fn typed_header(header: H, reply: R) -> WithTypedHeader { WithTypedHeader { reply, header } } @@ -38,6 +38,18 @@ where } } +#[derive(Debug, Error)] +#[error("could not decode header {1}")] +pub struct InvalidTypedHeader(#[source] headers::Error, &'static str); + +impl Reject for InvalidTypedHeader {} + +/// Extract a typed header from the request +/// +/// # Rejections +/// +/// This can reject with either a [`warp::reject::MissingHeader`] or a +/// [`InvalidTypedHeader`]. pub fn with_typed_header( ) -> impl Filter + Clone + Send + Sync + 'static { warp::header::value(T::name().as_str()).and_then(decode_typed_header) @@ -45,6 +57,6 @@ pub fn with_typed_header( async fn decode_typed_header(header: HeaderValue) -> Result { let mut it = std::iter::once(&header); - let decoded = T::decode(&mut it).wrap_error()?; + let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?; Ok(decoded) } diff --git a/crates/core/src/handlers/oauth2/userinfo.rs b/crates/core/src/handlers/oauth2/userinfo.rs index 1592f9d2..02d839cc 100644 --- a/crates/core/src/handlers/oauth2/userinfo.rs +++ b/crates/core/src/handlers/oauth2/userinfo.rs @@ -17,7 +17,8 @@ use sqlx::PgPool; use warp::{Filter, Rejection, Reply}; use crate::{ - config::OAuth2Config, filters::authenticate::with_authentication, + config::OAuth2Config, + filters::authenticate::{recover_unauthorized, with_authentication}, storage::oauth2::access_token::OAuth2AccessTokenLookup, }; @@ -34,6 +35,7 @@ pub(super) fn filter( .and(warp::get().or(warp::post()).unify()) .and(with_authentication(pool)) .and_then(userinfo) + .recover(recover_unauthorized) } async fn userinfo(token: OAuth2AccessTokenLookup) -> Result { diff --git a/crates/core/src/storage/oauth2/access_token.rs b/crates/core/src/storage/oauth2/access_token.rs index fd0d737d..b1458837 100644 --- a/crates/core/src/storage/oauth2/access_token.rs +++ b/crates/core/src/storage/oauth2/access_token.rs @@ -18,6 +18,7 @@ use anyhow::Context; use chrono::{DateTime, Duration, Utc}; use serde::Serialize; use sqlx::{Executor, FromRow, Postgres}; +use thiserror::Error; #[derive(FromRow, Serialize)] pub struct OAuth2AccessToken { @@ -73,11 +74,22 @@ impl OAuth2AccessTokenLookup { } } +#[derive(Debug, Error)] +#[error("failed to lookup access token")] +pub struct AccessTokenLookupError(#[from] sqlx::Error); + +impl AccessTokenLookupError { + #[must_use] + pub fn not_found(&self) -> bool { + matches!(self.0, sqlx::Error::RowNotFound) + } +} + pub async fn lookup_access_token( executor: impl Executor<'_, Database = Postgres>, token: &str, -) -> anyhow::Result { - sqlx::query_as!( +) -> Result { + let res = sqlx::query_as!( OAuth2AccessTokenLookup, r#" SELECT @@ -99,8 +111,9 @@ pub async fn lookup_access_token( token, ) .fetch_one(executor) - .await - .context("could not introspect oauth2 access token") + .await?; + + Ok(res) } pub async fn revoke_access_token(