1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Axum migration: CSRF token and login page

This commit is contained in:
Quentin Gliech
2022-03-25 15:38:30 +01:00
parent 5d3b4aa182
commit 5e95c705d4
5 changed files with 130 additions and 98 deletions

View File

@@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds}; use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error; use thiserror::Error;
use crate::{CookieExt, PrivateCookieJar}; use crate::{cookies::CookieDecodeError, CookieExt, PrivateCookieJar};
/// Failed to validate CSRF token /// Failed to validate CSRF token
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -28,6 +28,14 @@ pub enum CsrfError {
#[error("CSRF token mismatch")] #[error("CSRF token mismatch")]
Mismatch, Mismatch,
/// The token in the form did not match the token in the cookie
#[error("Missing CSRF cookie")]
Missing,
/// Failed to decode the token
#[error("could not decode CSRF cookie")]
DecodeCookie(#[from] CookieDecodeError),
/// The token expired /// The token expired
#[error("CSRF token expired")] #[error("CSRF token expired")]
Expired, Expired,
@@ -89,8 +97,18 @@ impl CsrfToken {
} }
} }
// A CSRF-protected form
#[derive(Deserialize)]
pub struct ProtectedForm<T> {
csrf: String,
#[serde(flatten)]
inner: T,
}
pub trait CsrfExt { pub trait CsrfExt {
fn csrf_token(self) -> (CsrfToken, Self); fn csrf_token(self) -> (CsrfToken, Self);
fn verify_form<T>(&self, form: ProtectedForm<T>) -> Result<T, CsrfError>;
} }
impl<K> CsrfExt for PrivateCookieJar<K> { impl<K> CsrfExt for PrivateCookieJar<K> {
@@ -108,4 +126,12 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
let jar = jar.add(cookie); let jar = jar.add(cookie);
(new_token, jar) (new_token, jar)
} }
fn verify_form<T>(&self, form: ProtectedForm<T>) -> Result<T, CsrfError> {
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
let token: CsrfToken = cookie.decode()?;
let token = token.verify_expiration()?;
token.verify_form_value(&form.csrf)?;
Ok(form.inner)
}
} }

View File

@@ -50,7 +50,9 @@ where
} }
} }
pub fn fancy_error<E: Error + 'static>(templates: Templates) -> impl Fn(E) -> FancyError { pub fn fancy_error<E: std::fmt::Display + 'static>(
templates: Templates,
) -> impl Fn(E) -> FancyError {
move |error: E| FancyError { move |error: E| FancyError {
templates: Some(templates.clone()), templates: Some(templates.clone()),
error: Box::new(error), error: Box::new(error),
@@ -69,7 +71,7 @@ where
pub struct FancyError { pub struct FancyError {
templates: Option<Templates>, templates: Option<Templates>,
error: Box<dyn Error>, error: Box<dyn std::fmt::Display>,
} }
impl IntoResponse for FancyError { impl IntoResponse for FancyError {
@@ -99,4 +101,3 @@ impl IntoResponse for FancyError {
res res
} }
} }

View File

@@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C. // Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{extract::Extension, routing::get, Router}; use axum::{body::HttpBody, extract::Extension, routing::get, Router};
use mas_axum_utils::UrlBuilder; use mas_axum_utils::UrlBuilder;
use mas_config::{Encrypter, RootConfig}; use mas_config::{Encrypter, RootConfig};
use mas_email::Mailer; use mas_email::Mailer;
@@ -61,17 +61,26 @@ pub fn root(
} }
#[must_use] #[must_use]
pub fn router<B: Send + 'static>( pub fn router<B>(
pool: &PgPool, pool: &PgPool,
templates: &Templates, templates: &Templates,
key_store: &Arc<StaticKeystore>, key_store: &Arc<StaticKeystore>,
encrypter: &Encrypter, encrypter: &Encrypter,
mailer: &Mailer, mailer: &Mailer,
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
) -> Router<B> { ) -> Router<B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
{
Router::new() Router::new()
.route("/", get(self::views::index::get)) .route("/", get(self::views::index::get))
.route("/health", get(self::health::get)) .route("/health", get(self::health::get))
.route(
"/login",
get(self::views::login::get).post(self::views::login::post),
)
.fallback(mas_static_files::Assets) .fallback(mas_static_files::Assets)
.layer(Extension(pool.clone())) .layer(Extension(pool.clone()))
.layer(Extension(templates.clone())) .layer(Extension(templates.clone()))

View File

@@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C. // Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -14,25 +14,21 @@
#![allow(clippy::trait_duplication_in_bounds)] #![allow(clippy::trait_duplication_in_bounds)]
use hyper::http::uri::{Parts, PathAndQuery, Uri}; use axum::{
use mas_config::{CsrfConfig, Encrypter}; extract::{Extension, Form, Query},
use mas_data_model::{errors::WrapFormError, BrowserSession}; response::{Html, IntoResponse, Redirect, Response},
use mas_storage::{user::login, PostgresqlBackend};
use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates};
use mas_warp_utils::{
errors::WrapError,
filters::{
self,
cookies::{encrypted_cookie_saver, EncryptedCookieSaver},
csrf::{protected_form, updated_csrf_token},
database::connection,
session::{optional_session, SessionCookie},
with_templates, CsrfToken,
},
}; };
use hyper::http::uri::{Parts, PathAndQuery, Uri};
use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, PrivateCookieJar, SessionInfoExt,
};
use mas_config::Encrypter;
use mas_data_model::errors::WrapFormError;
use mas_storage::user::login;
use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use sqlx::{pool::PoolConnection, PgPool, Postgres}; use sqlx::PgPool;
use warp::{filters::BoxedFilter, reply::html, Filter, Rejection, Reply};
use super::{shared::PostAuthAction, RegisterRequest}; use super::{shared::PostAuthAction, RegisterRequest};
@@ -66,100 +62,100 @@ impl LoginRequest {
Ok(uri) Ok(uri)
} }
fn redirect(self) -> Result<impl Reply, Rejection> { fn redirect(self) -> Result<impl IntoResponse, anyhow::Error> {
let uri = self let uri = if let Some(action) = self.post_auth_action {
.post_auth_action action.build_uri()?
.as_ref() } else {
.map(PostAuthAction::build_uri) Uri::from_static("/")
.transpose() };
.wrap_error()?
.unwrap_or_else(|| Uri::from_static("/")); Ok(Redirect::to(uri))
Ok(warp::redirect::see_other(uri))
} }
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct LoginForm { pub(crate) struct LoginForm {
username: String, username: String,
password: String, password: String,
} }
pub(super) fn filter( pub(crate) async fn get(
pool: &PgPool, Extension(templates): Extension<Templates>,
templates: &Templates, Extension(pool): Extension<PgPool>,
encrypter: &Encrypter, Query(query): Query<LoginRequest>,
csrf_config: &CsrfConfig, cookie_jar: PrivateCookieJar<Encrypter>,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> Result<Response, FancyError> {
let get = warp::get() let mut conn = pool
.and(filters::trace::name("GET /login")) .acquire()
.and(with_templates(templates)) .await
.and(connection(pool)) .map_err(fancy_error(templates.clone()))?;
.and(encrypted_cookie_saver(encrypter))
.and(updated_csrf_token(encrypter, csrf_config))
.and(warp::query())
.and(optional_session(pool, encrypter))
.and_then(get);
let post = warp::post() let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
.and(filters::trace::name("POST /login")) let (session_info, cookie_jar) = cookie_jar.session_info();
.and(with_templates(templates))
.and(connection(pool))
.and(encrypted_cookie_saver(encrypter))
.and(updated_csrf_token(encrypter, csrf_config))
.and(protected_form(encrypter))
.and(warp::query())
.and_then(post);
warp::path!("login").and(get.or(post).unify()).boxed() let maybe_session = session_info
} .load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
async fn get(
templates: Templates,
mut conn: PoolConnection<Postgres>,
cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken,
query: LoginRequest,
maybe_session: Option<BrowserSession<PostgresqlBackend>>,
) -> Result<Box<dyn Reply>, Rejection> {
if maybe_session.is_some() { if maybe_session.is_some() {
Ok(Box::new(query.redirect()?)) let response = query
.redirect()
.map_err(fancy_error(templates.clone()))?
.into_response();
Ok(response)
} else { } else {
let ctx = LoginContext::default(); let ctx = LoginContext::default();
let ctx = match query.post_auth_action { let ctx = match query.post_auth_action {
Some(next) => { Some(next) => {
let register_link = RegisterRequest::from(next.clone()) let register_link = RegisterRequest::from(next.clone())
.build_uri() .build_uri()
.wrap_error()?; .map_err(fancy_error(templates.clone()))?;
let next = next.load_context(&mut conn).await.wrap_error()?; let next = next
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
ctx.with_post_action(next) ctx.with_post_action(next)
.with_register_link(register_link.to_string()) .with_register_link(register_link.to_string())
} }
None => ctx, None => ctx,
}; };
let ctx = ctx.with_csrf(csrf_token.form_value()); let ctx = ctx.with_csrf(csrf_token.form_value());
let content = templates.render_login(&ctx).await?;
let reply = html(content); let content = templates
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; .render_login(&ctx)
Ok(Box::new(reply)) .await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response())
} }
} }
async fn post( pub(crate) async fn post(
templates: Templates, Extension(templates): Extension<Templates>,
mut conn: PoolConnection<Postgres>, Extension(pool): Extension<PgPool>,
cookie_saver: EncryptedCookieSaver, Query(query): Query<LoginRequest>,
csrf_token: CsrfToken, cookie_jar: PrivateCookieJar<Encrypter>,
form: LoginForm, Form(form): Form<ProtectedForm<LoginForm>>,
query: LoginRequest, ) -> Result<Response, FancyError> {
) -> Result<Box<dyn Reply>, Rejection> {
use mas_storage::user::LoginError; use mas_storage::user::LoginError;
let mut conn = pool
.acquire()
.await
.map_err(fancy_error(templates.clone()))?;
let form = cookie_jar
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
// TODO: recover // TODO: recover
match login(&mut conn, &form.username, form.password).await { match login(&mut conn, &form.username, form.password).await {
Ok(session_info) => { Ok(session_info) => {
let session_cookie = SessionCookie::from_session(&session_info); let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.redirect()?; let reply = query.redirect().map_err(fancy_error(templates.clone()))?;
let reply = cookie_saver.save_encrypted(&session_cookie, reply)?; Ok((cookie_jar.headers(), reply).into_response())
Ok(Box::new(reply))
} }
Err(e) => { Err(e) => {
let errored_form = match e { let errored_form = match e {
@@ -170,10 +166,13 @@ async fn post(
let ctx = LoginContext::default() let ctx = LoginContext::default()
.with_form_error(errored_form) .with_form_error(errored_form)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
let content = templates.render_login(&ctx).await?;
let reply = html(content); let content = templates
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; .render_login(&ctx)
Ok(Box::new(reply)) .await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response())
} }
} }
} }

View File

@@ -28,8 +28,8 @@ pub mod shared;
pub mod verify; pub mod verify;
use self::{ use self::{
account::filter as account, login::filter as login, logout::filter as logout, account::filter as account, logout::filter as logout, reauth::filter as reauth,
reauth::filter as reauth, register::filter as register, verify::filter as verify, register::filter as register, verify::filter as verify,
}; };
pub(crate) use self::{ pub(crate) use self::{
login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction, login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction,
@@ -44,15 +44,12 @@ pub(super) fn filter(
csrf_config: &CsrfConfig, csrf_config: &CsrfConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> BoxedFilter<(Box<dyn Reply>,)> {
let account = account(pool, templates, mailer, encrypter, http_config, csrf_config); let account = account(pool, templates, mailer, encrypter, http_config, csrf_config);
let login = login(pool, templates, encrypter, csrf_config);
let register = register(pool, templates, encrypter, csrf_config); let register = register(pool, templates, encrypter, csrf_config);
let logout = logout(pool, encrypter); let logout = logout(pool, encrypter);
let reauth = reauth(pool, templates, encrypter, csrf_config); let reauth = reauth(pool, templates, encrypter, csrf_config);
let verify = verify(pool, templates, encrypter, csrf_config); let verify = verify(pool, templates, encrypter, csrf_config);
account account
.or(login)
.unify()
.or(register) .or(register)
.unify() .unify()
.or(logout) .or(logout)