diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index d3f94939..b1f709b8 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -91,6 +91,10 @@ where "/reauth", get(self::views::reauth::get).post(self::views::reauth::post), ) + .route( + "/register", + get(self::views::register::get).post(self::views::register::post), + ) .fallback(mas_static_files::Assets) .layer(Extension(pool.clone())) .layer(Extension(templates.clone())) diff --git a/crates/handlers/src/views/mod.rs b/crates/handlers/src/views/mod.rs index 73ceb74f..240ba4e3 100644 --- a/crates/handlers/src/views/mod.rs +++ b/crates/handlers/src/views/mod.rs @@ -27,7 +27,7 @@ pub mod register; pub mod shared; pub mod verify; -use self::{account::filter as account, register::filter as register, verify::filter as verify}; +use self::{account::filter as account, verify::filter as verify}; pub(crate) use self::{ login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction, }; @@ -41,8 +41,7 @@ pub(super) fn filter( csrf_config: &CsrfConfig, ) -> BoxedFilter<(Box,)> { let account = account(pool, templates, mailer, encrypter, http_config, csrf_config); - let register = register(pool, templates, encrypter, csrf_config); let verify = verify(pool, templates, encrypter, csrf_config); - account.or(register).unify().or(verify).unify().boxed() + account.or(verify).unify().boxed() } diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 0adb756d..722d4a69 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,28 +15,20 @@ #![allow(clippy::trait_duplication_in_bounds)] use argon2::Argon2; +use axum::{ + extract::{Extension, Form, Query}, + response::{Html, IntoResponse, Redirect, Response}, +}; use hyper::http::uri::{Parts, PathAndQuery, Uri}; -use mas_config::{CsrfConfig, Encrypter}; -use mas_data_model::BrowserSession; -use mas_storage::{ - user::{register_user, start_session}, - PostgresqlBackend, +use mas_axum_utils::{ + csrf::{CsrfExt, ProtectedForm}, + fancy_error, FancyError, PrivateCookieJar, SessionInfoExt, }; +use mas_config::Encrypter; +use mas_storage::user::{register_user, start_session}; use mas_templates::{RegisterContext, TemplateContext, Templates}; -use mas_warp_utils::{ - errors::WrapError, - filters::{ - self, - cookies::{encrypted_cookie_saver, EncryptedCookieSaver}, - csrf::{protected_form, updated_csrf_token}, - database::{connection, transaction}, - session::{optional_session, SessionCookie}, - with_templates, CsrfToken, - }, -}; use serde::Deserialize; -use sqlx::{pool::PoolConnection, PgPool, Postgres, Transaction}; -use warp::{filters::BoxedFilter, reply::html, Filter, Rejection, Reply}; +use sqlx::PgPool; use super::{LoginRequest, PostAuthAction}; @@ -71,103 +63,106 @@ impl RegisterRequest { Ok(uri) } - fn redirect(self) -> Result { - let uri = self - .post_auth_action - .as_ref() - .map(PostAuthAction::build_uri) - .transpose() - .wrap_error()? - .unwrap_or_else(|| Uri::from_static("/")); - Ok(warp::redirect::see_other(uri)) + fn redirect(self) -> Result { + let uri = if let Some(action) = self.post_auth_action { + action.build_uri()? + } else { + Uri::from_static("/") + }; + + Ok(Redirect::to(uri)) } } #[derive(Deserialize)] -struct RegisterForm { +pub(crate) struct RegisterForm { username: String, password: String, password_confirm: String, } -pub(super) fn filter( - pool: &PgPool, - templates: &Templates, - encrypter: &Encrypter, - csrf_config: &CsrfConfig, -) -> BoxedFilter<(Box,)> { - let get = warp::get() - .and(filters::trace::name("GET /register")) - .and(with_templates(templates)) - .and(connection(pool)) - .and(encrypted_cookie_saver(encrypter)) - .and(updated_csrf_token(encrypter, csrf_config)) - .and(warp::query()) - .and(optional_session(pool, encrypter)) - .and_then(get); +pub(crate) async fn get( + Extension(templates): Extension, + Extension(pool): Extension, + Query(query): Query, + cookie_jar: PrivateCookieJar, +) -> Result { + let mut conn = pool + .acquire() + .await + .map_err(fancy_error(templates.clone()))?; - let post = warp::post() - .and(filters::trace::name("POST /register")) - .and(transaction(pool)) - .and(encrypted_cookie_saver(encrypter)) - .and(protected_form(encrypter)) - .and(warp::query()) - .and_then(post); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); + let (session_info, cookie_jar) = cookie_jar.session_info(); - warp::path!("register").and(get.or(post).unify()).boxed() -} + let maybe_session = session_info + .load_session(&mut conn) + .await + .map_err(fancy_error(templates.clone()))?; -async fn get( - templates: Templates, - mut conn: PoolConnection, - cookie_saver: EncryptedCookieSaver, - csrf_token: CsrfToken, - query: RegisterRequest, - maybe_session: Option>, -) -> Result, Rejection> { if maybe_session.is_some() { - Ok(Box::new(query.redirect()?)) + let response = query + .redirect() + .map_err(fancy_error(templates.clone()))? + .into_response(); + Ok(response) } else { let ctx = RegisterContext::default(); - let ctx = match query.post_auth_action { + let ctx = match &query.post_auth_action { Some(next) => { - let login_link = LoginRequest::from(next.clone()).build_uri().wrap_error()?; - let next = next.load_context(&mut conn).await.wrap_error()?; + let next = next + .load_context(&mut conn) + .await + .map_err(fancy_error(templates.clone()))?; ctx.with_post_action(next) - .with_login_link(login_link.to_string()) } None => ctx, }; + let login_link = LoginRequest::from(query.post_auth_action) + .build_uri() + .map_err(fancy_error(templates.clone()))?; + let ctx = ctx.with_login_link(login_link.to_string()); let ctx = ctx.with_csrf(csrf_token.form_value()); - let content = templates.render_register(&ctx).await?; - let reply = html(content); - let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; - Ok(Box::new(reply)) + + let content = templates + .render_register(&ctx) + .await + .map_err(fancy_error(templates.clone()))?; + + Ok((cookie_jar.headers(), Html(content)).into_response()) } } -async fn post( - mut txn: Transaction<'_, Postgres>, - cookie_saver: EncryptedCookieSaver, - form: RegisterForm, - query: RegisterRequest, -) -> Result, Rejection> { +pub(crate) async fn post( + Extension(templates): Extension, + Extension(pool): Extension, + Query(query): Query, + cookie_jar: PrivateCookieJar, + Form(form): Form>, +) -> Result { // TODO: display nice form errors + let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?; + + let form = cookie_jar + .verify_form(form) + .map_err(fancy_error(templates.clone()))?; + if form.password != form.password_confirm { - return Err(anyhow::anyhow!("password mismatch")).wrap_error(); + return Err(anyhow::anyhow!("password mismatch")).map_err(fancy_error(templates.clone())); } let pfh = Argon2::default(); let user = register_user(&mut txn, pfh, &form.username, &form.password) .await - .wrap_error()?; + .map_err(fancy_error(templates.clone()))?; - let session_info = start_session(&mut txn, user).await.wrap_error()?; + let session = start_session(&mut txn, user) + .await + .map_err(fancy_error(templates.clone()))?; - txn.commit().await.wrap_error()?; + txn.commit().await.map_err(fancy_error(templates.clone()))?; - let session_cookie = SessionCookie::from_session(&session_info); - let reply = query.redirect()?; - let reply = cookie_saver.save_encrypted(&session_cookie, reply)?; - Ok(Box::new(reply)) + let cookie_jar = cookie_jar.set_session(&session); + let reply = query.redirect().map_err(fancy_error(templates.clone()))?; + Ok((cookie_jar.headers(), reply).into_response()) } diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index c7f4c172..2f9b8b92 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![allow(clippy::trait_duplication_in_bounds)] - use hyper::Uri; use mas_templates::PostAuthContext; use serde::{Deserialize, Serialize};