diff --git a/crates/core/src/filters/cookies.rs b/crates/core/src/filters/cookies.rs index cb2dcfe5..dafbce3c 100644 --- a/crates/core/src/filters/cookies.rs +++ b/crates/core/src/filters/cookies.rs @@ -23,8 +23,12 @@ use data_encoding::BASE64URL_NOPAD; use headers::{Header, HeaderValue, SetCookie}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use thiserror::Error; -use warp::{reject::Reject, Filter, Rejection, Reply}; +use warp::{ + reject::{MissingCookie, Reject}, + Filter, Rejection, Reply, +}; +use super::none_on_error; use crate::{ config::CookiesConfig, errors::WrapError, @@ -108,17 +112,16 @@ impl EncryptedCookie { #[must_use] pub fn maybe_encrypted( options: &CookiesConfig, -) -> impl Filter,), Error = Infallible> + Clone + Send + Sync + 'static +) -> impl Filter,), Error = Rejection> + Clone + Send + Sync + 'static where T: DeserializeOwned + EncryptableCookieValue + 'static, { - encrypted(options).map(Some).recover(recover::).unify() -} - -async fn recover(_rejection: Rejection) -> Result, Infallible> { - // We could actually look for MissingCookie and CookieDecryptionError - // rejections, but nothing else should happen here anyway - Ok(None) + encrypted(options) + .map(Some) + .recover(none_on_error::) + .unify() + .recover(none_on_error::>) + .unify() } /// Extract an encrypted cookie @@ -143,6 +146,7 @@ where }) } +/// Get an [`EncryptedCookieSaver`] to help saving an [`EncryptableCookieValue`] #[must_use] pub fn encrypted_cookie_saver( options: &CookiesConfig, @@ -153,7 +157,7 @@ pub fn encrypted_cookie_saver( } /// A cookie that can be encrypted with a well-known cookie key -pub trait EncryptableCookieValue: Send + Sync + std::fmt::Debug { +pub trait EncryptableCookieValue: Serialize + Send + Sync + std::fmt::Debug { fn cookie_key() -> &'static str; } @@ -163,7 +167,7 @@ pub struct EncryptedCookieSaver { } impl EncryptedCookieSaver { - pub fn save_encrypted( + pub fn save_encrypted( &self, cookie: &T, reply: R, diff --git a/crates/core/src/filters/headers.rs b/crates/core/src/filters/headers.rs index 78d8ae32..1fc43cce 100644 --- a/crates/core/src/filters/headers.rs +++ b/crates/core/src/filters/headers.rs @@ -16,6 +16,7 @@ use headers::{Header, HeaderValue}; use thiserror::Error; use warp::{reject::Reject, Filter, Rejection}; +/// Failed to decode typed header #[derive(Debug, Error)] #[error("could not decode header {1}")] pub struct InvalidTypedHeader(#[source] headers::Error, &'static str); diff --git a/crates/core/src/filters/mod.rs b/crates/core/src/filters/mod.rs index eceb18d6..847a68f4 100644 --- a/crates/core/src/filters/mod.rs +++ b/crates/core/src/filters/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Set of [`warp`] filters + #![allow(clippy::unused_async)] // Some warp filters need that pub mod csrf; @@ -25,7 +27,7 @@ pub mod session; use std::convert::Infallible; -use warp::Filter; +use warp::{Filter, Rejection}; pub use self::csrf::CsrfToken; use crate::{ @@ -48,3 +50,30 @@ pub fn with_keys( let keyset = oauth2_config.keys.clone(); warp::any().map(move || keyset.clone()) } + +/// Recover a particular rejection type with a `None` option variant +/// +/// # Example +/// +/// ```rust +/// extern crate warp; +/// +/// use warp::{filters::header::header, reject::MissingHeader, Filter}; +/// +/// use mas_core::filters::none_on_error; +/// +/// header("Content-Length") +/// .map(Some) +/// .recover(none_on_error::<_, MissingHeader>) +/// .unify() +/// .map(|length: Option| { +/// format!("header: {:?}", length) +/// }); +/// ``` +pub async fn none_on_error(rejection: Rejection) -> Result, Rejection> { + if rejection.find::().is_some() { + Ok(None) + } else { + Err(rejection) + } +} diff --git a/crates/core/src/filters/session.rs b/crates/core/src/filters/session.rs index 10b7a6bb..44b005c7 100644 --- a/crates/core/src/filters/session.rs +++ b/crates/core/src/filters/session.rs @@ -14,18 +14,37 @@ use serde::{Deserialize, Serialize}; use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres}; -use warp::{Filter, Rejection}; +use thiserror::Error; +use tracing::warn; +use warp::{ + reject::{MissingCookie, Reject}, + Filter, Rejection, +}; use super::{ - cookies::{encrypted, maybe_encrypted, EncryptableCookieValue}, + cookies::{encrypted, CookieDecryptionError, EncryptableCookieValue}, database::connection, + none_on_error, }; use crate::{ config::CookiesConfig, - errors::WrapError, - storage::{lookup_active_session, SessionInfo}, + storage::{lookup_active_session, user::ActiveSessionLookupError, SessionInfo}, }; +#[derive(Error, Debug)] +pub enum SessionLoadError { + #[error("missing session cookie")] + MissingCookie, + + #[error("unable to parse or decrypt session cookie")] + InvalidCookie, + + #[error("unknown or inactive session")] + UnknownSession, +} + +impl Reject for SessionLoadError {} + #[derive(Serialize, Deserialize, Debug)] pub struct SessionCookie { current: i64, @@ -42,7 +61,7 @@ impl SessionCookie { pub async fn load_session_info( &self, executor: impl Executor<'_, Database = Postgres>, - ) -> anyhow::Result { + ) -> Result { let res = lookup_active_session(executor, self.current).await?; Ok(res) } @@ -61,31 +80,59 @@ pub fn optional_session( cookies_config: &CookiesConfig, ) -> impl Filter,), Error = Rejection> + Clone + Send + Sync + 'static { - maybe_encrypted(cookies_config) - .and(connection(pool)) - .and_then( - |maybe_session: Option, mut conn: PoolConnection| async move { - let maybe_session_info = if let Some(session) = maybe_session { - session.load_session_info(&mut conn).await.ok() - } else { - None - }; - Ok::<_, Rejection>(maybe_session_info) - }, - ) + session(pool, cookies_config) + .map(Some) + .recover(none_on_error::<_, SessionLoadError>) + .unify() } /// Extract a user session information, rejecting if not logged in +/// +/// # Rejections +/// +/// This filter will reject with a [`SessionLoadError`] when the session is +/// inactive or missing. It will reject with a wrapped error on other database +/// failures. #[must_use] pub fn session( pool: &PgPool, cookies_config: &CookiesConfig, ) -> impl Filter + Clone + Send + Sync + 'static { - // TODO: this should be wrapped up in a recoverable error - encrypted(cookies_config).and(connection(pool)).and_then( - |session: SessionCookie, mut conn: PoolConnection| async move { - let session_info = session.load_session_info(&mut conn).await.wrap_error()?; - Ok::<_, Rejection>(session_info) - }, - ) + encrypted(cookies_config) + .and(connection(pool)) + .and_then(load_session) + .recover(recover) + .unify() +} + +async fn load_session( + session: SessionCookie, + mut conn: PoolConnection, +) -> Result { + let session_info = session.load_session_info(&mut conn).await?; + Ok(session_info) +} + +/// Recover from expected rejections, to transform them into a +/// [`SessionLoadError`] +async fn recover(rejection: Rejection) -> Result { + if let Some(e) = rejection.find::() { + if e.not_found() { + return Err(warp::reject::custom(SessionLoadError::UnknownSession)); + } + + // If we're here, there is a real database error that should be + // propagated + } + + if let Some(_e) = rejection.find::() { + return Err(warp::reject::custom(SessionLoadError::MissingCookie)); + } + + if let Some(error) = rejection.find::>() { + warn!(?error, "could not decrypt session cookie"); + return Err(warp::reject::custom(SessionLoadError::InvalidCookie)); + } + + Err(rejection) } diff --git a/crates/core/src/storage/user.rs b/crates/core/src/storage/user.rs index a599edeb..b35414f7 100644 --- a/crates/core/src/storage/user.rs +++ b/crates/core/src/storage/user.rs @@ -24,6 +24,7 @@ use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction}; use thiserror::Error; use tokio::task; use tracing::{info_span, Instrument}; +use warp::reject::Reject; use crate::errors::HtmlError; @@ -142,14 +143,14 @@ pub async fn login( #[error("could not fetch session")] pub struct ActiveSessionLookupError(#[from] sqlx::Error); -/* +impl Reject for ActiveSessionLookupError {} + impl ActiveSessionLookupError { #[must_use] pub fn not_found(&self) -> bool { matches!(self.0, sqlx::Error::RowNotFound) } } -*/ pub async fn lookup_active_session( executor: impl Executor<'_, Database = Postgres>,