1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Simplify error handling in user-facing routes

This commit is contained in:
Quentin Gliech
2022-05-10 16:51:12 +02:00
parent 2cba5e7ad2
commit ca7b26cf18
17 changed files with 192 additions and 383 deletions

View File

@ -12,92 +12,32 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{convert::Infallible, error::Error};
use async_trait::async_trait;
use axum::{ use axum::{
body::{HttpBody, StreamBody},
extract::{Extension, FromRequest, RequestParts},
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Extension,
}; };
use futures_util::FutureExt; use mas_templates::ErrorContext;
use headers::{ContentType, HeaderMapExt};
use mas_templates::{ErrorContext, Templates};
use sqlx::PgPool;
struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>);
#[async_trait]
impl<B> FromRequest<B> for DatabaseConnection
where
B: Send,
{
type Rejection = FancyError;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let Extension(templates) = Extension::<Templates>::from_request(req)
.await
.map_err(internal_error)?;
let Extension(pool) = Extension::<PgPool>::from_request(req)
.await
.map_err(fancy_error(templates))?;
let conn = pool.acquire().await.map_err(internal_error)?;
Ok(Self(conn))
}
}
pub fn fancy_error<E: std::fmt::Display + 'static>(
templates: Templates,
) -> impl Fn(E) -> FancyError {
move |error: E| FancyError {
templates: Some(templates.clone()),
error: Box::new(error),
}
}
pub fn internal_error<E: Error + 'static>(error: E) -> FancyError
where
E: Error,
{
FancyError {
templates: None,
error: Box::new(error),
}
}
pub struct FancyError { pub struct FancyError {
templates: Option<Templates>, context: ErrorContext,
error: Box<dyn std::fmt::Display>, }
impl<E: std::fmt::Display> From<E> for FancyError {
fn from(err: E) -> Self {
let context = ErrorContext::new().with_description(err.to_string());
FancyError { context }
}
} }
impl IntoResponse for FancyError { impl IntoResponse for FancyError {
fn into_response(self) -> Response { fn into_response(self) -> Response {
let error = format!("{}", self.error); let error = format!("{:?}", self.context);
let context = ErrorContext::new().with_description(error.clone()); (
let body = match self.templates { StatusCode::INTERNAL_SERVER_ERROR,
Some(templates) => { Extension(self.context),
let stream = (async move { error,
Ok::<_, Infallible>(match templates.render_error(&context).await { )
Ok(s) => s, .into_response()
Err(_e) => "failed to render error template".to_string(),
})
})
.into_stream();
StreamBody::new(stream).boxed_unsync()
}
None => axum::body::Full::from(error)
.map_err(|_e| unreachable!())
.boxed_unsync(),
};
let mut res = Response::new(body);
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res.headers_mut().typed_insert(ContentType::html());
res
} }
} }

View File

@ -21,6 +21,6 @@ pub mod user_authorization;
pub use self::{ pub use self::{
cookies::CookieExt, cookies::CookieExt,
fancy_error::{fancy_error, internal_error, FancyError}, fancy_error::FancyError,
session::{SessionInfo, SessionInfoExt}, session::{SessionInfo, SessionInfoExt},
}; };

View File

@ -32,6 +32,7 @@ use mas_storage::{
}; };
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use sqlx::{Acquire, Postgres}; use sqlx::{Acquire, Postgres};
use thiserror::Error;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct AuthorizedForm<F> { struct AuthorizedForm<F> {
@ -111,10 +112,18 @@ pub enum UserAuthorizationError {
InternalError(Box<dyn Error>), InternalError(Box<dyn Error>),
} }
#[derive(Debug, Error)]
pub enum AuthorizationVerificationError { pub enum AuthorizationVerificationError {
#[error("missing token")]
MissingToken, MissingToken,
#[error("invalid token")]
InvalidToken, InvalidToken,
#[error("missing form")]
MissingForm, MissingForm,
#[error(transparent)]
InternalError(Box<dyn Error>), InternalError(Box<dyn Error>),
} }

View File

@ -13,19 +13,18 @@
// limitations under the License. // limitations under the License.
use axum::{extract::Extension, response::IntoResponse}; use axum::{extract::Extension, response::IntoResponse};
use mas_axum_utils::{internal_error, FancyError}; use mas_axum_utils::FancyError;
use sqlx::PgPool; use sqlx::PgPool;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
pub async fn get(Extension(pool): Extension<PgPool>) -> Result<impl IntoResponse, FancyError> { pub async fn get(Extension(pool): Extension<PgPool>) -> Result<impl IntoResponse, FancyError> {
let mut conn = pool.acquire().await.map_err(internal_error)?; let mut conn = pool.acquire().await?;
sqlx::query("SELECT $1") sqlx::query("SELECT $1")
.bind(1_i64) .bind(1_i64)
.execute(&mut conn) .execute(&mut conn)
.instrument(info_span!("DB health")) .instrument(info_span!("DB health"))
.await .await?;
.map_err(internal_error)?;
Ok("ok") Ok("ok")
} }

View File

@ -19,11 +19,12 @@
clippy::unused_async // Some axum handlers need that clippy::unused_async // Some axum handlers need that
)] )]
use std::{sync::Arc, time::Duration}; use std::{convert::Infallible, sync::Arc, time::Duration};
use axum::{ use axum::{
body::HttpBody, body::HttpBody,
extract::Extension, extract::Extension,
response::{Html, IntoResponse},
routing::{get, on, post, MethodFilter}, routing::{get, on, post, MethodFilter},
Router, Router,
}; };
@ -33,8 +34,9 @@ use mas_email::Mailer;
use mas_http::CorsLayerExt; use mas_http::CorsLayerExt;
use mas_jose::StaticKeystore; use mas_jose::StaticKeystore;
use mas_router::{Route, UrlBuilder}; use mas_router::{Route, UrlBuilder};
use mas_templates::Templates; use mas_templates::{ErrorContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;
use tower::util::ThenLayer;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
mod health; mod health;
@ -42,6 +44,7 @@ mod oauth2;
mod views; mod views;
#[must_use] #[must_use]
#[allow(clippy::too_many_lines, clippy::missing_panics_doc)]
pub fn router<B>( pub fn router<B>(
pool: &PgPool, pool: &PgPool,
templates: &Templates, templates: &Templates,
@ -102,47 +105,71 @@ where
.max_age(Duration::from_secs(60 * 60)), .max_age(Duration::from_secs(60 * 60)),
); );
Router::new() let human_router = {
.route(mas_router::Index::route(), get(self::views::index::get)) let templates = templates.clone();
.route(mas_router::Healthcheck::route(), get(self::health::get)) Router::new()
.route( .route(mas_router::Index::route(), get(self::views::index::get))
mas_router::Login::route(), .route(mas_router::Healthcheck::route(), get(self::health::get))
get(self::views::login::get).post(self::views::login::post), .route(
) mas_router::Login::route(),
.route(mas_router::Logout::route(), post(self::views::logout::post)) get(self::views::login::get).post(self::views::login::post),
.route( )
mas_router::Reauth::route(), .route(mas_router::Logout::route(), post(self::views::logout::post))
get(self::views::reauth::get).post(self::views::reauth::post), .route(
) mas_router::Reauth::route(),
.route( get(self::views::reauth::get).post(self::views::reauth::post),
mas_router::Register::route(), )
get(self::views::register::get).post(self::views::register::post), .route(
) mas_router::Register::route(),
.route( get(self::views::register::get).post(self::views::register::post),
mas_router::VerifyEmail::route(), )
get(self::views::verify::get), .route(
) mas_router::VerifyEmail::route(),
.route(mas_router::Account::route(), get(self::views::account::get)) get(self::views::verify::get),
.route( )
mas_router::AccountPassword::route(), .route(mas_router::Account::route(), get(self::views::account::get))
get(self::views::account::password::get).post(self::views::account::password::post), .route(
) mas_router::AccountPassword::route(),
.route( get(self::views::account::password::get).post(self::views::account::password::post),
mas_router::AccountEmails::route(), )
get(self::views::account::emails::get).post(self::views::account::emails::post), .route(
) mas_router::AccountEmails::route(),
.route( get(self::views::account::emails::get).post(self::views::account::emails::post),
mas_router::OAuth2AuthorizationEndpoint::route(), )
get(self::oauth2::authorization::get), .route(
) mas_router::OAuth2AuthorizationEndpoint::route(),
.route( get(self::oauth2::authorization::get),
mas_router::ContinueAuthorizationGrant::route(), )
get(self::oauth2::authorization::complete::get), .route(
) mas_router::ContinueAuthorizationGrant::route(),
.route( get(self::oauth2::authorization::complete::get),
mas_router::Consent::route(), )
get(self::oauth2::consent::get).post(self::oauth2::consent::post), .route(
) mas_router::Consent::route(),
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
)
.layer(ThenLayer::new(
move |result: Result<axum::response::Response, Infallible>| async move {
let response = result.unwrap();
if response.status().is_server_error() {
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx).await {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
return Ok((parts, Html(res)).into_response());
}
}
}
Ok(response)
},
))
};
human_router
.merge(api_router) .merge(api_router)
.layer(Extension(pool.clone())) .layer(Extension(pool.clone()))
.layer(Extension(templates.clone())) .layer(Extension(templates.clone()))

View File

@ -20,8 +20,7 @@ use axum::{
Json, TypedHeader, Json, TypedHeader,
}; };
use headers::ContentType; use headers::ContentType;
use hyper::StatusCode; use mas_axum_utils::{user_authorization::UserAuthorization, FancyError};
use mas_axum_utils::{internal_error, user_authorization::UserAuthorization};
use mas_jose::{DecodedJsonWebToken, SigningKeystore, StaticKeystore}; use mas_jose::{DecodedJsonWebToken, SigningKeystore, StaticKeystore};
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mime::Mime; use mime::Mime;
@ -52,18 +51,11 @@ pub async fn get(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Extension(key_store): Extension<Arc<StaticKeystore>>, Extension(key_store): Extension<Arc<StaticKeystore>>,
user_authorization: UserAuthorization, user_authorization: UserAuthorization,
) -> Result<Response, Response> { ) -> Result<Response, FancyError> {
// TODO: error handling // TODO: error handling
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(internal_error)
.map_err(IntoResponse::into_response)?;
let session = user_authorization let session = user_authorization.protected(&mut conn).await?;
.protected(&mut conn)
.await
.map_err(IntoResponse::into_response)?;
let user = session.browser_session.user; let user = session.browser_session.user;
let mut user_info = UserInfo { let mut user_info = UserInfo {
@ -81,11 +73,7 @@ pub async fn get(
} }
if let Some(alg) = session.client.userinfo_signed_response_alg { if let Some(alg) = session.client.userinfo_signed_response_alg {
let header = key_store let header = key_store.prepare_header(alg).await?;
.prepare_header(alg)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.map_err(IntoResponse::into_response)?;
let user_info = SignedUserInfo { let user_info = SignedUserInfo {
iss: url_builder.oidc_issuer().to_string(), iss: url_builder.oidc_issuer().to_string(),
@ -94,11 +82,7 @@ pub async fn get(
}; };
let user_info = DecodedJsonWebToken::new(header, user_info); let user_info = DecodedJsonWebToken::new(header, user_info);
let user_info = user_info let user_info = user_info.sign(key_store.as_ref()).await?;
.sign(key_store.as_ref())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.map_err(IntoResponse::into_response)?;
let token = user_info.serialize(); let token = user_info.serialize();
let application_jwt: Mime = "application/jwt".parse().unwrap(); let application_jwt: Mime = "application/jwt".parse().unwrap();

View File

@ -20,7 +20,7 @@ use axum_extra::extract::PrivateCookieJar;
use lettre::{message::Mailbox, Address}; use lettre::{message::Mailbox, Address};
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::{BrowserSession, User, UserEmail}; use mas_data_model::{BrowserSession, User, UserEmail};
@ -53,17 +53,11 @@ pub(crate) async fn get(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut conn).await?;
.load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(templates, session, cookie_jar, &mut conn).await render(templates, session, cookie_jar, &mut conn).await
@ -81,18 +75,13 @@ async fn render(
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let emails = get_user_emails(executor, &session.user) let emails = get_user_emails(executor, &session.user).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = AccountEmailsContext::new(emails) let ctx = AccountEmailsContext::new(emails)
.with_session(session) .with_session(session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_account_emails(&ctx).await?;
.render_account_emails(&ctx)
.await
.map_err(fancy_error(templates))?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
@ -136,14 +125,11 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>, Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?; let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut txn).await?;
.load_session(&mut txn)
.await
.map_err(fancy_error(templates.clone()))?;
let mut session = if let Some(session) = maybe_session { let mut session = if let Some(session) = maybe_session {
session session
@ -152,55 +138,39 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let form = cookie_jar let form = cookie_jar.verify_form(form)?;
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
match form { match form {
ManagementForm::Add { email } => { ManagementForm::Add { email } => {
let user_email = add_user_email(&mut txn, &session.user, email) let user_email = add_user_email(&mut txn, &session.user, email).await?;
.await
.map_err(fancy_error(templates.clone()))?;
start_email_verification(&mailer, &url_builder, &mut txn, &session.user, &user_email) start_email_verification(&mailer, &url_builder, &mut txn, &session.user, &user_email)
.await .await?;
.map_err(fancy_error(templates.clone()))?;
} }
ManagementForm::Remove { data } => { ManagementForm::Remove { data } => {
let id = data.parse().map_err(fancy_error(templates.clone()))?; let id = data.parse()?;
let email = get_user_email(&mut txn, &session.user, id) let email = get_user_email(&mut txn, &session.user, id).await?;
.await remove_user_email(&mut txn, email).await?;
.map_err(fancy_error(templates.clone()))?;
remove_user_email(&mut txn, email)
.await
.map_err(fancy_error(templates.clone()))?;
} }
ManagementForm::ResendConfirmation { data } => { ManagementForm::ResendConfirmation { data } => {
let id = data.parse().map_err(fancy_error(templates.clone()))?; let id = data.parse()?;
let user_email = get_user_email(&mut txn, &session.user, id) let user_email = get_user_email(&mut txn, &session.user, id).await?;
.await
.map_err(fancy_error(templates.clone()))?;
start_email_verification(&mailer, &url_builder, &mut txn, &session.user, &user_email) start_email_verification(&mailer, &url_builder, &mut txn, &session.user, &user_email)
.await .await?;
.map_err(fancy_error(templates.clone()))?;
} }
ManagementForm::SetPrimary { data } => { ManagementForm::SetPrimary { data } => {
let id = data.parse().map_err(fancy_error(templates.clone()))?; let id = data.parse()?;
let email = get_user_email(&mut txn, &session.user, id) let email = get_user_email(&mut txn, &session.user, id).await?;
.await set_user_email_as_primary(&mut txn, &email).await?;
.map_err(fancy_error(templates.clone()))?;
set_user_email_as_primary(&mut txn, &email)
.await
.map_err(fancy_error(templates.clone()))?;
session.user.primary_email = Some(email); session.user.primary_email = Some(email);
} }
}; };
let reply = render(templates.clone(), session, cookie_jar, &mut txn).await?; let reply = render(templates.clone(), session, cookie_jar, &mut txn).await?;
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await?;
Ok(reply) Ok(reply)
} }

View File

@ -20,7 +20,7 @@ use axum::{
response::{Html, IntoResponse, Response}, response::{Html, IntoResponse, Response},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt}; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::user::{count_active_sessions, get_user_emails}; use mas_storage::user::{count_active_sessions, get_user_emails};
@ -32,18 +32,12 @@ pub(crate) async fn get(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut conn).await?;
.load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -52,22 +46,15 @@ pub(crate) async fn get(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let active_sessions = count_active_sessions(&mut conn, &session.user) let active_sessions = count_active_sessions(&mut conn, &session.user).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let emails = get_user_emails(&mut conn, &session.user) let emails = get_user_emails(&mut conn, &session.user).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = AccountContext::new(active_sessions, emails) let ctx = AccountContext::new(active_sessions, emails)
.with_session(session) .with_session(session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_account_index(&ctx).await?;
.render_account_index(&ctx)
.await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }

View File

@ -20,7 +20,7 @@ use axum::{
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::BrowserSession; use mas_data_model::BrowserSession;
@ -45,17 +45,11 @@ pub(crate) async fn get(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut conn).await?;
.load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(templates, session, cookie_jar).await render(templates, session, cookie_jar).await
@ -76,10 +70,7 @@ async fn render(
.with_session(session) .with_session(session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_account_password(&ctx).await?;
.render_account_password(&ctx)
.await
.map_err(fancy_error(templates))?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
@ -90,18 +81,13 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ChangeForm>>, Form(form): Form<ProtectedForm<ChangeForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?; let mut txn = pool.begin().await?;
let form = cookie_jar let form = cookie_jar.verify_form(form)?;
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut txn).await?;
.load_session(&mut txn)
.await
.map_err(fancy_error(templates.clone()))?;
let mut session = if let Some(session) = maybe_session { let mut session = if let Some(session) = maybe_session {
session session
@ -110,23 +96,19 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
authenticate_session(&mut txn, &mut session, form.current_password) authenticate_session(&mut txn, &mut session, form.current_password).await?;
.await
.map_err(fancy_error(templates.clone()))?;
// TODO: display nice form errors // TODO: display nice form errors
if form.new_password != form.new_password_confirm { if form.new_password != form.new_password_confirm {
return Err(anyhow::anyhow!("password mismatch")).map_err(fancy_error(templates.clone())); return Err(anyhow::anyhow!("password mismatch").into());
} }
let phf = Argon2::default(); let phf = Argon2::default();
set_password(&mut txn, phf, &session.user, &form.new_password) set_password(&mut txn, phf, &session.user, &form.new_password).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let reply = render(templates.clone(), session, cookie_jar).await?; let reply = render(templates.clone(), session, cookie_jar).await?;
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await?;
Ok(reply) Ok(reply)
} }

View File

@ -17,7 +17,7 @@ use axum::{
response::{Html, IntoResponse}, response::{Html, IntoResponse},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt}; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_templates::{IndexContext, TemplateContext, Templates}; use mas_templates::{IndexContext, TemplateContext, Templates};
@ -29,26 +29,17 @@ pub async fn get(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info let session = session_info.load_session(&mut conn).await?;
.load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = IndexContext::new(url_builder.oidc_discovery()) let ctx = IndexContext::new(url_builder.oidc_discovery())
.maybe_with_session(session) .maybe_with_session(session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_index(&ctx).await?;
.render_index(&ctx)
.await
.map_err(fancy_error(templates))?;
Ok((cookie_jar, Html(content))) Ok((cookie_jar, Html(content)))
} }

View File

@ -19,7 +19,7 @@ use axum::{
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::errors::WrapFormError; use mas_data_model::errors::WrapFormError;
@ -44,28 +44,19 @@ pub(crate) async fn get(
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut conn).await?;
.load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
if maybe_session.is_some() { if maybe_session.is_some() {
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} else { } else {
let ctx = LoginContext::default(); let ctx = LoginContext::default();
let next = query let next = query.load_context(&mut conn).await?;
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
ctx.with_post_action(next) ctx.with_post_action(next)
} else { } else {
@ -76,10 +67,7 @@ pub(crate) async fn get(
.with_register_link(register_link.to_string()) .with_register_link(register_link.to_string())
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_login(&ctx).await?;
.render_login(&ctx)
.await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
@ -93,14 +81,9 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<LoginForm>>, Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
use mas_storage::user::LoginError; use mas_storage::user::LoginError;
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let form = cookie_jar let form = cookie_jar.verify_form(form)?;
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
@ -121,10 +104,7 @@ pub(crate) async fn post(
.with_form_error(errored_form) .with_form_error(errored_form)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_login(&ctx).await?;
.render_login(&ctx)
.await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }

View File

@ -19,40 +19,31 @@ use axum::{
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_storage::user::end_session; use mas_storage::user::end_session;
use mas_templates::Templates;
use sqlx::PgPool; use sqlx::PgPool;
pub(crate) async fn post( pub(crate) async fn post(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<()>>, Form(form): Form<ProtectedForm<()>>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?; let mut txn = pool.begin().await?;
cookie_jar cookie_jar.verify_form(form)?;
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
let (session_info, mut cookie_jar) = cookie_jar.session_info(); let (session_info, mut cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut txn).await?;
.load_session(&mut txn)
.await
.map_err(fancy_error(templates.clone()))?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
end_session(&mut txn, &session) end_session(&mut txn, &session).await?;
.await
.map_err(fancy_error(templates.clone()))?;
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
} }
txn.commit().await.map_err(fancy_error(templates))?; txn.commit().await?;
Ok((cookie_jar, Redirect::to("/login"))) Ok((cookie_jar, Redirect::to("/login")))
} }

View File

@ -19,7 +19,7 @@ use axum::{
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::Route; use mas_router::Route;
@ -41,18 +41,12 @@ pub(crate) async fn get(
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut conn).await?;
.load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -64,10 +58,7 @@ pub(crate) async fn get(
}; };
let ctx = ReauthContext::default(); let ctx = ReauthContext::default();
let next = query let next = query.load_context(&mut conn).await?;
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
ctx.with_post_action(next) ctx.with_post_action(next)
} else { } else {
@ -75,33 +66,24 @@ pub(crate) async fn get(
}; };
let ctx = ctx.with_session(session).with_csrf(csrf_token.form_value()); let ctx = ctx.with_session(session).with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_reauth(&ctx).await?;
.render_reauth(&ctx)
.await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
pub(crate) async fn post( pub(crate) async fn post(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>, Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?; let mut txn = pool.begin().await?;
let form = cookie_jar let form = cookie_jar.verify_form(form)?;
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut txn).await?;
.load_session(&mut txn)
.await
.map_err(fancy_error(templates.clone()))?;
let mut session = if let Some(session) = maybe_session { let mut session = if let Some(session) = maybe_session {
session session
@ -113,11 +95,9 @@ pub(crate) async fn post(
}; };
// TODO: recover from errors here // TODO: recover from errors here
authenticate_session(&mut txn, &mut session, form.password) authenticate_session(&mut txn, &mut session, form.password).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await?;
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())

View File

@ -22,7 +22,7 @@ use axum::{
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::Route; use mas_router::Route;
@ -46,28 +46,19 @@ pub(crate) async fn get(
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool.acquire().await?;
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut conn).await?;
.load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
if maybe_session.is_some() { if maybe_session.is_some() {
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} else { } else {
let ctx = RegisterContext::default(); let ctx = RegisterContext::default();
let next = query let next = query.load_context(&mut conn).await?;
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
ctx.with_post_action(next) ctx.with_post_action(next)
} else { } else {
@ -77,43 +68,33 @@ pub(crate) async fn get(
let ctx = ctx.with_login_link(login_link.to_string()); let ctx = ctx.with_login_link(login_link.to_string());
let ctx = ctx.with_csrf(csrf_token.form_value()); let ctx = ctx.with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_register(&ctx).await?;
.render_register(&ctx)
.await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
} }
pub(crate) async fn post( pub(crate) async fn post(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>, Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
// TODO: display nice form errors // TODO: display nice form errors
let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?; let mut txn = pool.begin().await?;
let form = cookie_jar let form = cookie_jar.verify_form(form)?;
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
if form.password != form.password_confirm { if form.password != form.password_confirm {
return Err(anyhow::anyhow!("password mismatch")).map_err(fancy_error(templates.clone())); return Err(anyhow::anyhow!("password mismatch").into());
} }
let pfh = Argon2::default(); let pfh = Argon2::default();
let user = register_user(&mut txn, pfh, &form.username, &form.password) let user = register_user(&mut txn, pfh, &form.username, &form.password).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let session = start_session(&mut txn, user) let session = start_session(&mut txn, user).await?;
.await
.map_err(fancy_error(templates.clone()))?;
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await?;
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
let reply = query.go_next(); let reply = query.go_next();

View File

@ -18,7 +18,7 @@ use axum::{
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use chrono::Duration; use chrono::Duration;
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt}; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_storage::user::{ use mas_storage::user::{
consume_email_verification, lookup_user_email_verification_code, mark_user_email_as_verified, consume_email_verification, lookup_user_email_verification_code, mark_user_email_as_verified,
@ -32,40 +32,29 @@ pub(crate) async fn get(
Path(code): Path<String>, Path(code): Path<String>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?; let mut txn = pool.begin().await?;
// TODO: make those 8 hours configurable // TODO: make those 8 hours configurable
let verification = lookup_user_email_verification_code(&mut txn, &code, Duration::hours(8)) let verification =
.await lookup_user_email_verification_code(&mut txn, &code, Duration::hours(8)).await?;
.map_err(fancy_error(templates.clone()))?;
// TODO: display nice errors if the code was already consumed or expired // TODO: display nice errors if the code was already consumed or expired
let verification = consume_email_verification(&mut txn, verification) let verification = consume_email_verification(&mut txn, verification).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let _email = mark_user_email_as_verified(&mut txn, verification.email) let _email = mark_user_email_as_verified(&mut txn, verification.email).await?;
.await
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info let maybe_session = session_info.load_session(&mut txn).await?;
.load_session(&mut txn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = EmptyContext let ctx = EmptyContext
.maybe_with_session(maybe_session) .maybe_with_session(maybe_session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates let content = templates.render_email_verification_done(&ctx).await?;
.render_email_verification_done(&ctx)
.await
.map_err(fancy_error(templates.clone()))?;
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await?;
Ok((cookie_jar, Html(content))) Ok((cookie_jar, Html(content)))
} }

View File

@ -153,7 +153,6 @@ pub fn service<B: HttpBody + Send + 'static>(
let builtin = self::builtin::service(); let builtin = self::builtin::service();
let svc = if let Some(path) = path { let svc = if let Some(path) = path {
// TODO: fallback seems to have issues
let handler = ServeDir::new(path) let handler = ServeDir::new(path)
.append_index_html_on_directories(false) .append_index_html_on_directories(false)
.fallback(builtin); .fallback(builtin);

View File

@ -555,7 +555,7 @@ impl<T> FormPostContext<T> {
} }
/// Context used by the `error.html` template /// Context used by the `error.html` template
#[derive(Default, Serialize)] #[derive(Default, Serialize, Debug, Clone)]
pub struct ErrorContext { pub struct ErrorContext {
code: Option<&'static str>, code: Option<&'static str>,
description: Option<String>, description: Option<String>,