diff --git a/crates/axum-utils/src/csrf.rs b/crates/axum-utils/src/csrf.rs index 3c843bfe..e01f6e4c 100644 --- a/crates/axum-utils/src/csrf.rs +++ b/crates/axum-utils/src/csrf.rs @@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize}; use serde_with::{serde_as, TimestampSeconds}; use thiserror::Error; -use crate::{CookieExt, PrivateCookieJar}; +use crate::{cookies::CookieDecodeError, CookieExt, PrivateCookieJar}; /// Failed to validate CSRF token #[derive(Debug, Error)] @@ -28,6 +28,14 @@ pub enum CsrfError { #[error("CSRF token mismatch")] Mismatch, + /// The token in the form did not match the token in the cookie + #[error("Missing CSRF cookie")] + Missing, + + /// Failed to decode the token + #[error("could not decode CSRF cookie")] + DecodeCookie(#[from] CookieDecodeError), + /// The token expired #[error("CSRF token expired")] Expired, @@ -89,8 +97,18 @@ impl CsrfToken { } } +// A CSRF-protected form +#[derive(Deserialize)] +pub struct ProtectedForm { + csrf: String, + + #[serde(flatten)] + inner: T, +} + pub trait CsrfExt { fn csrf_token(self) -> (CsrfToken, Self); + fn verify_form(&self, form: ProtectedForm) -> Result; } impl CsrfExt for PrivateCookieJar { @@ -108,4 +126,12 @@ impl CsrfExt for PrivateCookieJar { let jar = jar.add(cookie); (new_token, jar) } + + fn verify_form(&self, form: ProtectedForm) -> Result { + let cookie = self.get("csrf").ok_or(CsrfError::Missing)?; + let token: CsrfToken = cookie.decode()?; + let token = token.verify_expiration()?; + token.verify_form_value(&form.csrf)?; + Ok(form.inner) + } } diff --git a/crates/axum-utils/src/fancy_error.rs b/crates/axum-utils/src/fancy_error.rs index 991da794..1baf2f2b 100644 --- a/crates/axum-utils/src/fancy_error.rs +++ b/crates/axum-utils/src/fancy_error.rs @@ -50,7 +50,9 @@ where } } -pub fn fancy_error(templates: Templates) -> impl Fn(E) -> FancyError { +pub fn fancy_error( + templates: Templates, +) -> impl Fn(E) -> FancyError { move |error: E| FancyError { templates: Some(templates.clone()), error: Box::new(error), @@ -69,7 +71,7 @@ where pub struct FancyError { templates: Option, - error: Box, + error: Box, } impl IntoResponse for FancyError { @@ -99,4 +101,3 @@ impl IntoResponse for FancyError { res } } - diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index e963efa1..8e93d3aa 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ use std::sync::Arc; -use axum::{extract::Extension, routing::get, Router}; +use axum::{body::HttpBody, extract::Extension, routing::get, Router}; use mas_axum_utils::UrlBuilder; use mas_config::{Encrypter, RootConfig}; use mas_email::Mailer; @@ -61,17 +61,26 @@ pub fn root( } #[must_use] -pub fn router( +pub fn router( pool: &PgPool, templates: &Templates, key_store: &Arc, encrypter: &Encrypter, mailer: &Mailer, url_builder: &UrlBuilder, -) -> Router { +) -> Router +where + B: HttpBody + Send + 'static, + ::Data: Send, + ::Error: std::error::Error + Send + Sync, +{ Router::new() .route("/", get(self::views::index::get)) .route("/health", get(self::health::get)) + .route( + "/login", + get(self::views::login::get).post(self::views::login::post), + ) .fallback(mas_static_files::Assets) .layer(Extension(pool.clone())) .layer(Extension(templates.clone())) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 654c0935..befde971 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,25 +14,21 @@ #![allow(clippy::trait_duplication_in_bounds)] -use hyper::http::uri::{Parts, PathAndQuery, Uri}; -use mas_config::{CsrfConfig, Encrypter}; -use mas_data_model::{errors::WrapFormError, BrowserSession}; -use mas_storage::{user::login, PostgresqlBackend}; -use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates}; -use mas_warp_utils::{ - errors::WrapError, - filters::{ - self, - cookies::{encrypted_cookie_saver, EncryptedCookieSaver}, - csrf::{protected_form, updated_csrf_token}, - database::connection, - session::{optional_session, SessionCookie}, - with_templates, CsrfToken, - }, +use axum::{ + extract::{Extension, Form, Query}, + response::{Html, IntoResponse, Redirect, Response}, }; +use hyper::http::uri::{Parts, PathAndQuery, Uri}; +use mas_axum_utils::{ + csrf::{CsrfExt, ProtectedForm}, + fancy_error, FancyError, PrivateCookieJar, SessionInfoExt, +}; +use mas_config::Encrypter; +use mas_data_model::errors::WrapFormError; +use mas_storage::user::login; +use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::{pool::PoolConnection, PgPool, Postgres}; -use warp::{filters::BoxedFilter, reply::html, Filter, Rejection, Reply}; +use sqlx::PgPool; use super::{shared::PostAuthAction, RegisterRequest}; @@ -66,100 +62,100 @@ impl LoginRequest { Ok(uri) } - fn redirect(self) -> Result { - let uri = self - .post_auth_action - .as_ref() - .map(PostAuthAction::build_uri) - .transpose() - .wrap_error()? - .unwrap_or_else(|| Uri::from_static("/")); - Ok(warp::redirect::see_other(uri)) + fn redirect(self) -> Result { + let uri = if let Some(action) = self.post_auth_action { + action.build_uri()? + } else { + Uri::from_static("/") + }; + + Ok(Redirect::to(uri)) } } #[derive(Deserialize)] -struct LoginForm { +pub(crate) struct LoginForm { username: String, password: String, } -pub(super) fn filter( - pool: &PgPool, - templates: &Templates, - encrypter: &Encrypter, - csrf_config: &CsrfConfig, -) -> BoxedFilter<(Box,)> { - let get = warp::get() - .and(filters::trace::name("GET /login")) - .and(with_templates(templates)) - .and(connection(pool)) - .and(encrypted_cookie_saver(encrypter)) - .and(updated_csrf_token(encrypter, csrf_config)) - .and(warp::query()) - .and(optional_session(pool, encrypter)) - .and_then(get); +pub(crate) async fn get( + Extension(templates): Extension, + Extension(pool): Extension, + Query(query): Query, + cookie_jar: PrivateCookieJar, +) -> Result { + let mut conn = pool + .acquire() + .await + .map_err(fancy_error(templates.clone()))?; - let post = warp::post() - .and(filters::trace::name("POST /login")) - .and(with_templates(templates)) - .and(connection(pool)) - .and(encrypted_cookie_saver(encrypter)) - .and(updated_csrf_token(encrypter, csrf_config)) - .and(protected_form(encrypter)) - .and(warp::query()) - .and_then(post); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); + let (session_info, cookie_jar) = cookie_jar.session_info(); - warp::path!("login").and(get.or(post).unify()).boxed() -} + let maybe_session = session_info + .load_session(&mut conn) + .await + .map_err(fancy_error(templates.clone()))?; -async fn get( - templates: Templates, - mut conn: PoolConnection, - cookie_saver: EncryptedCookieSaver, - csrf_token: CsrfToken, - query: LoginRequest, - maybe_session: Option>, -) -> Result, Rejection> { if maybe_session.is_some() { - Ok(Box::new(query.redirect()?)) + let response = query + .redirect() + .map_err(fancy_error(templates.clone()))? + .into_response(); + Ok(response) } else { let ctx = LoginContext::default(); let ctx = match query.post_auth_action { Some(next) => { let register_link = RegisterRequest::from(next.clone()) .build_uri() - .wrap_error()?; - let next = next.load_context(&mut conn).await.wrap_error()?; + .map_err(fancy_error(templates.clone()))?; + let next = next + .load_context(&mut conn) + .await + .map_err(fancy_error(templates.clone()))?; ctx.with_post_action(next) .with_register_link(register_link.to_string()) } None => ctx, }; let ctx = ctx.with_csrf(csrf_token.form_value()); - let content = templates.render_login(&ctx).await?; - let reply = html(content); - let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; - Ok(Box::new(reply)) + + let content = templates + .render_login(&ctx) + .await + .map_err(fancy_error(templates.clone()))?; + + Ok((cookie_jar.headers(), Html(content)).into_response()) } } -async fn post( - templates: Templates, - mut conn: PoolConnection, - cookie_saver: EncryptedCookieSaver, - csrf_token: CsrfToken, - form: LoginForm, - query: LoginRequest, -) -> Result, Rejection> { +pub(crate) async fn post( + Extension(templates): Extension, + Extension(pool): Extension, + Query(query): Query, + cookie_jar: PrivateCookieJar, + Form(form): Form>, +) -> Result { use mas_storage::user::LoginError; + let mut conn = pool + .acquire() + .await + .map_err(fancy_error(templates.clone()))?; + + let form = cookie_jar + .verify_form(form) + .map_err(fancy_error(templates.clone()))?; + + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); + // TODO: recover match login(&mut conn, &form.username, form.password).await { Ok(session_info) => { - let session_cookie = SessionCookie::from_session(&session_info); - let reply = query.redirect()?; - let reply = cookie_saver.save_encrypted(&session_cookie, reply)?; - Ok(Box::new(reply)) + let cookie_jar = cookie_jar.set_session(&session_info); + let reply = query.redirect().map_err(fancy_error(templates.clone()))?; + Ok((cookie_jar.headers(), reply).into_response()) } Err(e) => { let errored_form = match e { @@ -170,10 +166,13 @@ async fn post( let ctx = LoginContext::default() .with_form_error(errored_form) .with_csrf(csrf_token.form_value()); - let content = templates.render_login(&ctx).await?; - let reply = html(content); - let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; - Ok(Box::new(reply)) + + let content = templates + .render_login(&ctx) + .await + .map_err(fancy_error(templates.clone()))?; + + Ok((cookie_jar.headers(), Html(content)).into_response()) } } } diff --git a/crates/handlers/src/views/mod.rs b/crates/handlers/src/views/mod.rs index 937a0c32..6a6b7fe2 100644 --- a/crates/handlers/src/views/mod.rs +++ b/crates/handlers/src/views/mod.rs @@ -28,8 +28,8 @@ pub mod shared; pub mod verify; use self::{ - account::filter as account, login::filter as login, logout::filter as logout, - reauth::filter as reauth, register::filter as register, verify::filter as verify, + account::filter as account, logout::filter as logout, reauth::filter as reauth, + register::filter as register, verify::filter as verify, }; pub(crate) use self::{ login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction, @@ -44,15 +44,12 @@ pub(super) fn filter( csrf_config: &CsrfConfig, ) -> BoxedFilter<(Box,)> { let account = account(pool, templates, mailer, encrypter, http_config, csrf_config); - let login = login(pool, templates, encrypter, csrf_config); let register = register(pool, templates, encrypter, csrf_config); let logout = logout(pool, encrypter); let reauth = reauth(pool, templates, encrypter, csrf_config); let verify = verify(pool, templates, encrypter, csrf_config); account - .or(login) - .unify() .or(register) .unify() .or(logout)