1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Axum migration: /register route

This commit is contained in:
Quentin Gliech
2022-03-25 16:59:21 +01:00
parent b4dc2b38d0
commit 6fb4d27046
4 changed files with 86 additions and 90 deletions

View File

@ -91,6 +91,10 @@ where
"/reauth", "/reauth",
get(self::views::reauth::get).post(self::views::reauth::post), get(self::views::reauth::get).post(self::views::reauth::post),
) )
.route(
"/register",
get(self::views::register::get).post(self::views::register::post),
)
.fallback(mas_static_files::Assets) .fallback(mas_static_files::Assets)
.layer(Extension(pool.clone())) .layer(Extension(pool.clone()))
.layer(Extension(templates.clone())) .layer(Extension(templates.clone()))

View File

@ -27,7 +27,7 @@ pub mod register;
pub mod shared; pub mod shared;
pub mod verify; pub mod verify;
use self::{account::filter as account, register::filter as register, verify::filter as verify}; use self::{account::filter as account, verify::filter as verify};
pub(crate) use self::{ pub(crate) use self::{
login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction, login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction,
}; };
@ -41,8 +41,7 @@ pub(super) fn filter(
csrf_config: &CsrfConfig, csrf_config: &CsrfConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> BoxedFilter<(Box<dyn Reply>,)> {
let account = account(pool, templates, mailer, encrypter, http_config, csrf_config); let account = account(pool, templates, mailer, encrypter, http_config, csrf_config);
let register = register(pool, templates, encrypter, csrf_config);
let verify = verify(pool, templates, encrypter, csrf_config); let verify = verify(pool, templates, encrypter, csrf_config);
account.or(register).unify().or(verify).unify().boxed() account.or(verify).unify().boxed()
} }

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C. // Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,28 +15,20 @@
#![allow(clippy::trait_duplication_in_bounds)] #![allow(clippy::trait_duplication_in_bounds)]
use argon2::Argon2; use argon2::Argon2;
use axum::{
extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response},
};
use hyper::http::uri::{Parts, PathAndQuery, Uri}; use hyper::http::uri::{Parts, PathAndQuery, Uri};
use mas_config::{CsrfConfig, Encrypter}; use mas_axum_utils::{
use mas_data_model::BrowserSession; csrf::{CsrfExt, ProtectedForm},
use mas_storage::{ fancy_error, FancyError, PrivateCookieJar, SessionInfoExt,
user::{register_user, start_session},
PostgresqlBackend,
}; };
use mas_config::Encrypter;
use mas_storage::user::{register_user, start_session};
use mas_templates::{RegisterContext, TemplateContext, Templates}; use mas_templates::{RegisterContext, TemplateContext, Templates};
use mas_warp_utils::{
errors::WrapError,
filters::{
self,
cookies::{encrypted_cookie_saver, EncryptedCookieSaver},
csrf::{protected_form, updated_csrf_token},
database::{connection, transaction},
session::{optional_session, SessionCookie},
with_templates, CsrfToken,
},
};
use serde::Deserialize; use serde::Deserialize;
use sqlx::{pool::PoolConnection, PgPool, Postgres, Transaction}; use sqlx::PgPool;
use warp::{filters::BoxedFilter, reply::html, Filter, Rejection, Reply};
use super::{LoginRequest, PostAuthAction}; use super::{LoginRequest, PostAuthAction};
@ -71,103 +63,106 @@ impl RegisterRequest {
Ok(uri) Ok(uri)
} }
fn redirect(self) -> Result<impl Reply, Rejection> { fn redirect(self) -> Result<impl IntoResponse, anyhow::Error> {
let uri = self let uri = if let Some(action) = self.post_auth_action {
.post_auth_action action.build_uri()?
.as_ref() } else {
.map(PostAuthAction::build_uri) Uri::from_static("/")
.transpose() };
.wrap_error()?
.unwrap_or_else(|| Uri::from_static("/")); Ok(Redirect::to(uri))
Ok(warp::redirect::see_other(uri))
} }
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct RegisterForm { pub(crate) struct RegisterForm {
username: String, username: String,
password: String, password: String,
password_confirm: String, password_confirm: String,
} }
pub(super) fn filter( pub(crate) async fn get(
pool: &PgPool, Extension(templates): Extension<Templates>,
templates: &Templates, Extension(pool): Extension<PgPool>,
encrypter: &Encrypter, Query(query): Query<RegisterRequest>,
csrf_config: &CsrfConfig, cookie_jar: PrivateCookieJar<Encrypter>,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> Result<Response, FancyError> {
let get = warp::get() let mut conn = pool
.and(filters::trace::name("GET /register")) .acquire()
.and(with_templates(templates)) .await
.and(connection(pool)) .map_err(fancy_error(templates.clone()))?;
.and(encrypted_cookie_saver(encrypter))
.and(updated_csrf_token(encrypter, csrf_config))
.and(warp::query())
.and(optional_session(pool, encrypter))
.and_then(get);
let post = warp::post() let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
.and(filters::trace::name("POST /register")) let (session_info, cookie_jar) = cookie_jar.session_info();
.and(transaction(pool))
.and(encrypted_cookie_saver(encrypter))
.and(protected_form(encrypter))
.and(warp::query())
.and_then(post);
warp::path!("register").and(get.or(post).unify()).boxed() let maybe_session = session_info
} .load_session(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
async fn get(
templates: Templates,
mut conn: PoolConnection<Postgres>,
cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken,
query: RegisterRequest,
maybe_session: Option<BrowserSession<PostgresqlBackend>>,
) -> Result<Box<dyn Reply>, Rejection> {
if maybe_session.is_some() { if maybe_session.is_some() {
Ok(Box::new(query.redirect()?)) let response = query
.redirect()
.map_err(fancy_error(templates.clone()))?
.into_response();
Ok(response)
} else { } else {
let ctx = RegisterContext::default(); let ctx = RegisterContext::default();
let ctx = match query.post_auth_action { let ctx = match &query.post_auth_action {
Some(next) => { Some(next) => {
let login_link = LoginRequest::from(next.clone()).build_uri().wrap_error()?; let next = next
let next = next.load_context(&mut conn).await.wrap_error()?; .load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
ctx.with_post_action(next) ctx.with_post_action(next)
.with_login_link(login_link.to_string())
} }
None => ctx, None => ctx,
}; };
let login_link = LoginRequest::from(query.post_auth_action)
.build_uri()
.map_err(fancy_error(templates.clone()))?;
let ctx = ctx.with_login_link(login_link.to_string());
let ctx = ctx.with_csrf(csrf_token.form_value()); let ctx = ctx.with_csrf(csrf_token.form_value());
let content = templates.render_register(&ctx).await?;
let reply = html(content); let content = templates
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; .render_register(&ctx)
Ok(Box::new(reply)) .await
.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response())
} }
} }
async fn post( pub(crate) async fn post(
mut txn: Transaction<'_, Postgres>, Extension(templates): Extension<Templates>,
cookie_saver: EncryptedCookieSaver, Extension(pool): Extension<PgPool>,
form: RegisterForm, Query(query): Query<RegisterRequest>,
query: RegisterRequest, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Box<dyn Reply>, Rejection> { Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> {
// TODO: display nice form errors // TODO: display nice form errors
let mut txn = pool.begin().await.map_err(fancy_error(templates.clone()))?;
let form = cookie_jar
.verify_form(form)
.map_err(fancy_error(templates.clone()))?;
if form.password != form.password_confirm { if form.password != form.password_confirm {
return Err(anyhow::anyhow!("password mismatch")).wrap_error(); return Err(anyhow::anyhow!("password mismatch")).map_err(fancy_error(templates.clone()));
} }
let pfh = Argon2::default(); let pfh = Argon2::default();
let user = register_user(&mut txn, pfh, &form.username, &form.password) let user = register_user(&mut txn, pfh, &form.username, &form.password)
.await .await
.wrap_error()?; .map_err(fancy_error(templates.clone()))?;
let session_info = start_session(&mut txn, user).await.wrap_error()?; let session = start_session(&mut txn, user)
.await
.map_err(fancy_error(templates.clone()))?;
txn.commit().await.wrap_error()?; txn.commit().await.map_err(fancy_error(templates.clone()))?;
let session_cookie = SessionCookie::from_session(&session_info); let cookie_jar = cookie_jar.set_session(&session);
let reply = query.redirect()?; let reply = query.redirect().map_err(fancy_error(templates.clone()))?;
let reply = cookie_saver.save_encrypted(&session_cookie, reply)?; Ok((cookie_jar.headers(), reply).into_response())
Ok(Box::new(reply))
} }

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C. // Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![allow(clippy::trait_duplication_in_bounds)]
use hyper::Uri; use hyper::Uri;
use mas_templates::PostAuthContext; use mas_templates::PostAuthContext;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};