1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Form error state overhaul

This adds a new FormState structure here to hold the state of an errored
from, including retaining field value and better error codes.

It also adds error recovery for the registration form, and properly
loads the post_login_action context in case of errors.
This commit is contained in:
Quentin Gliech
2022-05-12 13:35:58 +02:00
parent 1a76bfe558
commit 185562c866
16 changed files with 551 additions and 252 deletions

View File

@ -20,7 +20,7 @@ use std::{
use anyhow::Context; use anyhow::Context;
use clap::Parser; use clap::Parser;
use futures::{future::TryFutureExt, stream::TryStreamExt}; use futures::stream::{StreamExt, TryStreamExt};
use hyper::Server; use hyper::Server;
use mas_config::RootConfig; use mas_config::RootConfig;
use mas_email::{MailTransport, Mailer}; use mas_email::{MailTransport, Mailer};
@ -118,20 +118,16 @@ async fn watch_templates(
} }
}); });
let fut = files_changed_stream let fut = files_changed_stream.for_each(move |files| {
.try_for_each(move |files| { let templates = templates.clone();
let templates = templates.clone(); async move {
async move { info!(?files, "Files changed, reloading templates");
info!(?files, "Files changed, reloading templates");
templates templates.clone().reload().await.unwrap_or_else(|err| {
.clone() error!(?err, "Error while reloading templates");
.reload() });
.await }
.context("Could not reload templates") });
}
})
.inspect_err(|err| error!(%err, "Error while watching templates, stop watching"));
tokio::spawn(fut); tokio::spawn(fut);

View File

@ -1,117 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, fmt::Debug, hash::Hash};
use serde::{ser::SerializeMap, Serialize};
pub trait HtmlError: Debug + Send + Sync + 'static {
fn html_display(&self) -> String;
}
pub trait WrapFormError<FieldType> {
fn on_form(self) -> ErroredForm<FieldType>;
fn on_field(self, field: FieldType) -> ErroredForm<FieldType>;
}
impl<E, FieldType> WrapFormError<FieldType> for E
where
E: HtmlError,
{
fn on_form(self) -> ErroredForm<FieldType> {
let mut f = ErroredForm::new();
f.form.push(FormError {
error: Box::new(self),
});
f
}
fn on_field(self, field: FieldType) -> ErroredForm<FieldType> {
let mut f = ErroredForm::new();
f.fields.push(FieldError {
field,
error: Box::new(self),
});
f
}
}
#[derive(Debug)]
struct FormError {
error: Box<dyn HtmlError>,
}
impl Serialize for FormError {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.error.html_display())
}
}
#[derive(Debug)]
struct FieldError<FieldType> {
field: FieldType,
error: Box<dyn HtmlError>,
}
#[derive(Debug)]
pub struct ErroredForm<FieldType> {
form: Vec<FormError>,
fields: Vec<FieldError<FieldType>>,
}
impl<T> Default for ErroredForm<T> {
fn default() -> Self {
Self {
form: Vec::new(),
fields: Vec::new(),
}
}
}
impl<T> ErroredForm<T> {
#[must_use]
pub fn new() -> Self {
Self {
form: Vec::new(),
fields: Vec::new(),
}
}
}
impl<FieldType: Copy + Serialize + Hash + Eq> Serialize for ErroredForm<FieldType> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(2))?;
let has_errors = !self.form.is_empty() || !self.fields.is_empty();
map.serialize_entry("has_errors", &has_errors)?;
map.serialize_entry("form_errors", &self.form)?;
let fields: HashMap<FieldType, Vec<String>> =
self.fields.iter().fold(HashMap::new(), |mut map, err| {
map.entry(err.field)
.or_default()
.push(err.error.html_display());
map
});
map.serialize_entry("fields_errors", &fields)?;
map.end()
}
}

View File

@ -22,7 +22,6 @@
clippy::trait_duplication_in_bounds clippy::trait_duplication_in_bounds
)] )]
pub mod errors;
pub(crate) mod oauth2; pub(crate) mod oauth2;
pub(crate) mod tokens; pub(crate) mod tokens;
pub(crate) mod traits; pub(crate) mod traits;

View File

@ -96,7 +96,7 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
authenticate_session(&mut txn, &mut session, form.current_password).await?; authenticate_session(&mut txn, &mut session, &form.current_password).await?;
// TODO: display nice form errors // TODO: display nice form errors
if form.new_password != form.new_password_confirm { if form.new_password != form.new_password_confirm {

View File

@ -18,25 +18,30 @@ use axum::{
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, CsrfToken, ProtectedForm},
FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::errors::WrapFormError;
use mas_router::Route; use mas_router::Route;
use mas_storage::user::login; use mas_storage::user::{login, LoginError};
use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates}; use mas_templates::{
use serde::Deserialize; FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
use sqlx::PgPool; };
use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool};
use super::shared::OptionalPostAuthAction; use super::shared::OptionalPostAuthAction;
#[derive(Deserialize)] #[derive(Debug, Deserialize, Serialize)]
pub(crate) struct LoginForm { pub(crate) struct LoginForm {
username: String, username: String,
password: String, password: String,
} }
impl ToFormState for LoginForm {
type Field = LoginFormField;
}
#[tracing::instrument(skip(templates, pool, cookie_jar))] #[tracing::instrument(skip(templates, pool, cookie_jar))]
pub(crate) async fn get( pub(crate) async fn get(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
@ -55,19 +60,14 @@ pub(crate) async fn get(
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} else { } else {
let ctx = LoginContext::default(); let content = render(
let next = query.load_context(&mut conn).await?; LoginContext::default(),
let ctx = if let Some(next) = next { query,
ctx.with_post_action(next) csrf_token,
} else { &mut conn,
ctx &templates,
}; )
let register_link = mas_router::Register::from(query.post_auth_action).relative_url(); .await?;
let ctx = ctx
.with_register_link(register_link.to_string())
.with_csrf(csrf_token.form_value());
let content = templates.render_login(&ctx).await?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
@ -80,33 +80,86 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>, Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
use mas_storage::user::LoginError;
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
// TODO: recover // Validate the form
match login(&mut conn, &form.username, form.password).await { let state = {
let mut state = form.to_form_state();
if form.username.is_empty() {
state.add_error_on_field(LoginFormField::Username, FieldError::Required);
}
if form.password.is_empty() {
state.add_error_on_field(LoginFormField::Password, FieldError::Required);
}
state
};
if !state.is_valid() {
let content = render(
LoginContext::default().with_form_state(state),
query,
csrf_token,
&mut conn,
&templates,
)
.await?;
return Ok((cookie_jar, Html(content)).into_response());
}
match login(&mut conn, &form.username, &form.password).await {
Ok(session_info) => { Ok(session_info) => {
let cookie_jar = cookie_jar.set_session(&session_info); let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} }
Err(e) => { Err(e) => {
let errored_form = match e { let state = match e {
LoginError::NotFound { .. } => e.on_field(LoginFormField::Username), LoginError::NotFound { .. } | LoginError::Authentication { .. } => {
LoginError::Authentication { .. } => e.on_field(LoginFormField::Password), state.with_error_on_form(FormError::InvalidCredentials)
LoginError::Other(_) => e.on_form(), }
LoginError::Other(_) => state.with_error_on_form(FormError::Internal),
}; };
let ctx = LoginContext::default()
.with_form_error(errored_form)
.with_csrf(csrf_token.form_value());
let content = templates.render_login(&ctx).await?; let content = render(
LoginContext::default().with_form_state(state),
query,
csrf_token,
&mut conn,
&templates,
)
.await?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
} }
} }
async fn render(
ctx: LoginContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
conn: &mut PgConnection,
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(conn).await?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {
ctx
};
let register_link = mas_router::Register::from(action.post_auth_action).relative_url();
let ctx = ctx
.with_register_link(register_link.to_string())
.with_csrf(csrf_token.form_value());
let content = templates.render_login(&ctx).await?;
Ok(content)
}

View File

@ -95,7 +95,7 @@ pub(crate) async fn post(
}; };
// TODO: recover from errors here // TODO: recover from errors here
authenticate_session(&mut txn, &mut session, form.password).await?; authenticate_session(&mut txn, &mut session, &form.password).await?;
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?; txn.commit().await?;

View File

@ -21,25 +21,32 @@ use axum::{
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, CsrfToken, ProtectedForm},
FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::user::{register_user, start_session}; use mas_storage::user::{register_user, start_session};
use mas_templates::{RegisterContext, TemplateContext, Templates}; use mas_templates::{
use serde::Deserialize; FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates,
use sqlx::PgPool; ToFormState,
};
use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool};
use super::shared::OptionalPostAuthAction; use super::shared::OptionalPostAuthAction;
#[derive(Deserialize)] #[derive(Debug, Deserialize, Serialize)]
pub(crate) struct RegisterForm { pub(crate) struct RegisterForm {
username: String, username: String,
password: String, password: String,
password_confirm: String, password_confirm: String,
} }
impl ToFormState for RegisterForm {
type Field = RegisterFormField;
}
pub(crate) async fn get( pub(crate) async fn get(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
@ -57,36 +64,68 @@ pub(crate) async fn get(
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} else { } else {
let ctx = RegisterContext::default(); let content = render(
let next = query.load_context(&mut conn).await?; RegisterContext::default(),
let ctx = if let Some(next) = next { query,
ctx.with_post_action(next) csrf_token,
} else { &mut conn,
ctx &templates,
}; )
let login_link = mas_router::Login::from(query.post_auth_action).relative_url(); .await?;
let ctx = ctx.with_login_link(login_link.to_string());
let ctx = ctx.with_csrf(csrf_token.form_value());
let content = templates.render_register(&ctx).await?;
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
} }
pub(crate) async fn post( pub(crate) async fn post(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>, Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
// TODO: display nice form errors
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
if form.password != form.password_confirm { let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
return Err(anyhow::anyhow!("password mismatch").into());
// Validate the form
let state = {
let mut state = form.to_form_state();
if form.username.is_empty() {
state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
}
if form.password.is_empty() {
state.add_error_on_field(RegisterFormField::Password, FieldError::Required);
}
if form.password_confirm.is_empty() {
state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Required);
}
if form.password != form.password_confirm {
state.add_error_on_form(FormError::PasswordMismatch);
state.add_error_on_field(RegisterFormField::Password, FieldError::Unspecified);
state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified);
}
state
};
if !state.is_valid() {
let content = render(
RegisterContext::default().with_form_state(state),
query,
csrf_token,
&mut txn,
&templates,
)
.await?;
return Ok((cookie_jar, Html(content)).into_response());
} }
let pfh = Argon2::default(); let pfh = Argon2::default();
@ -100,3 +139,25 @@ pub(crate) async fn post(
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} }
async fn render(
ctx: RegisterContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
conn: &mut PgConnection,
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(conn).await?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {
ctx
};
let login_link = mas_router::Login::from(action.post_auth_action).relative_url();
let ctx = ctx
.with_login_link(login_link.to_string())
.with_csrf(csrf_token.form_value());
let content = templates.render_register(&ctx).await?;
Ok(content)
}

View File

@ -18,7 +18,7 @@ use anyhow::{bail, Context};
use argon2::Argon2; use argon2::Argon2;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{ use mas_data_model::{
errors::HtmlError, Authentication, BrowserSession, User, UserEmail, UserEmailVerification, Authentication, BrowserSession, User, UserEmail, UserEmailVerification,
UserEmailVerificationState, UserEmailVerificationState,
}; };
use password_hash::{PasswordHash, PasswordHasher, SaltString}; use password_hash::{PasswordHash, PasswordHasher, SaltString};
@ -61,21 +61,11 @@ pub enum LoginError {
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
impl HtmlError for LoginError {
fn html_display(&self) -> String {
match self {
LoginError::NotFound { .. } => "Could not find user".to_string(),
LoginError::Authentication { .. } => "Failed to authenticate user".to_string(),
LoginError::Other(e) => format!("Internal error: <pre>{}</pre>", e),
}
}
}
#[tracing::instrument(skip(conn, password))] #[tracing::instrument(skip(conn, password))]
pub async fn login( pub async fn login(
conn: impl Acquire<'_, Database = Postgres>, conn: impl Acquire<'_, Database = Postgres>,
username: &str, username: &str,
password: String, password: &str,
) -> Result<BrowserSession<PostgresqlBackend>, LoginError> { ) -> Result<BrowserSession<PostgresqlBackend>, LoginError> {
let mut txn = conn.begin().await.context("could not start transaction")?; let mut txn = conn.begin().await.context("could not start transaction")?;
let user = lookup_user_by_username(&mut txn, username) let user = lookup_user_by_username(&mut txn, username)
@ -287,7 +277,7 @@ pub enum AuthenticationError {
pub async fn authenticate_session( pub async fn authenticate_session(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
session: &mut BrowserSession<PostgresqlBackend>, session: &mut BrowserSession<PostgresqlBackend>,
password: String, password: &str,
) -> Result<(), AuthenticationError> { ) -> Result<(), AuthenticationError> {
// First, fetch the hashed password from the user associated with that session // First, fetch the hashed password from the user associated with that session
let hashed_password: String = sqlx::query_scalar!( let hashed_password: String = sqlx::query_scalar!(
@ -307,6 +297,7 @@ pub async fn authenticate_session(
// TODO: pass verifiers list as parameter // TODO: pass verifiers list as parameter
// Verify the password in a blocking thread to avoid blocking the async executor // Verify the password in a blocking thread to avoid blocking the async executor
let password = password.to_string();
task::spawn_blocking(move || { task::spawn_blocking(move || {
let context = Argon2::default(); let context = Argon2::default();
let hasher = PasswordHash::new(&hashed_password).map_err(AuthenticationError::Password)?; let hasher = PasswordHash::new(&hashed_password).map_err(AuthenticationError::Password)?;

View File

@ -16,12 +16,12 @@
#![allow(clippy::trait_duplication_in_bounds)] #![allow(clippy::trait_duplication_in_bounds)]
use mas_data_model::{ use mas_data_model::{AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail};
errors::ErroredForm, AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail, use serde::{ser::SerializeStruct, Deserialize, Serialize};
};
use serde::{ser::SerializeStruct, Serialize};
use url::Url; use url::Url;
use crate::{FormField, FormState};
/// Helper trait to construct context wrappers /// Helper trait to construct context wrappers
pub trait TemplateContext: Serialize { pub trait TemplateContext: Serialize {
/// Attach a user session to the template context /// Attach a user session to the template context
@ -219,7 +219,7 @@ impl TemplateContext for IndexContext {
} }
/// Fields of the login form /// Fields of the login form
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum LoginFormField { pub enum LoginFormField {
/// The username field /// The username field
@ -229,6 +229,15 @@ pub enum LoginFormField {
Password, Password,
} }
impl FormField for LoginFormField {
fn keep(&self) -> bool {
match self {
Self::Username => true,
Self::Password => false,
}
}
}
/// Context used in login and reauth screens, for the post-auth action to do /// Context used in login and reauth screens, for the post-auth action to do
#[derive(Serialize)] #[derive(Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")] #[serde(tag = "kind", rename_all = "snake_case")]
@ -243,7 +252,7 @@ pub enum PostAuthContext {
/// Context used by the `login.html` template /// Context used by the `login.html` template
#[derive(Serialize, Default)] #[derive(Serialize, Default)]
pub struct LoginContext { pub struct LoginContext {
form: ErroredForm<LoginFormField>, form: FormState<LoginFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
register_link: String, register_link: String,
} }
@ -255,7 +264,7 @@ impl TemplateContext for LoginContext {
{ {
// TODO: samples with errors // TODO: samples with errors
vec![LoginContext { vec![LoginContext {
form: ErroredForm::default(), form: FormState::default(),
next: None, next: None,
register_link: "/register".to_string(), register_link: "/register".to_string(),
}] }]
@ -263,9 +272,9 @@ impl TemplateContext for LoginContext {
} }
impl LoginContext { impl LoginContext {
/// Add an error on the login form /// Set the form state
#[must_use] #[must_use]
pub fn with_form_error(self, form: ErroredForm<LoginFormField>) -> Self { pub fn with_form_state(self, form: FormState<LoginFormField>) -> Self {
Self { form, ..self } Self { form, ..self }
} }
@ -289,7 +298,7 @@ impl LoginContext {
} }
/// Fields of the registration form /// Fields of the registration form
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum RegisterFormField { pub enum RegisterFormField {
/// The username field /// The username field
@ -302,10 +311,19 @@ pub enum RegisterFormField {
PasswordConfirm, PasswordConfirm,
} }
impl FormField for RegisterFormField {
fn keep(&self) -> bool {
match self {
Self::Username => true,
Self::Password | Self::PasswordConfirm => false,
}
}
}
/// Context used by the `register.html` template /// Context used by the `register.html` template
#[derive(Serialize, Default)] #[derive(Serialize, Default)]
pub struct RegisterContext { pub struct RegisterContext {
form: ErroredForm<LoginFormField>, form: FormState<RegisterFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
login_link: String, login_link: String,
} }
@ -317,7 +335,7 @@ impl TemplateContext for RegisterContext {
{ {
// TODO: samples with errors // TODO: samples with errors
vec![RegisterContext { vec![RegisterContext {
form: ErroredForm::default(), form: FormState::default(),
next: None, next: None,
login_link: "/login".to_string(), login_link: "/login".to_string(),
}] }]
@ -327,7 +345,7 @@ impl TemplateContext for RegisterContext {
impl RegisterContext { impl RegisterContext {
/// Add an error on the registration form /// Add an error on the registration form
#[must_use] #[must_use]
pub fn with_form_error(self, form: ErroredForm<LoginFormField>) -> Self { pub fn with_form_state(self, form: FormState<RegisterFormField>) -> Self {
Self { form, ..self } Self { form, ..self }
} }
@ -377,13 +395,28 @@ impl ConsentContext {
} }
/// Fields of the reauthentication form /// Fields of the reauthentication form
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
pub enum ReauthFormField { pub enum ReauthFormField {
/// The password field /// The password field
Password, Password,
} }
impl FormField for ReauthFormField {
fn keep(&self) -> bool {
match self {
Self::Password => false,
}
}
}
/// Context used by the `reauth.html` template
#[derive(Serialize, Default)]
pub struct ReauthContext {
form: FormState<ReauthFormField>,
next: Option<PostAuthContext>,
}
impl TemplateContext for ReauthContext { impl TemplateContext for ReauthContext {
fn sample() -> Vec<Self> fn sample() -> Vec<Self>
where where
@ -391,7 +424,7 @@ impl TemplateContext for ReauthContext {
{ {
// TODO: samples with errors // TODO: samples with errors
vec![ReauthContext { vec![ReauthContext {
form: ErroredForm::default(), form: FormState::default(),
next: None, next: None,
}] }]
} }
@ -400,7 +433,7 @@ impl TemplateContext for ReauthContext {
impl ReauthContext { impl ReauthContext {
/// Add an error on the reauthentication form /// Add an error on the reauthentication form
#[must_use] #[must_use]
pub fn with_form_error(self, form: ErroredForm<ReauthFormField>) -> Self { pub fn with_form_state(self, form: FormState<ReauthFormField>) -> Self {
Self { form, ..self } Self { form, ..self }
} }
@ -414,22 +447,6 @@ impl ReauthContext {
} }
} }
impl Default for ReauthContext {
fn default() -> Self {
Self {
form: ErroredForm::new(),
next: None,
}
}
}
/// Context used by the `reauth.html` template
#[derive(Serialize)]
pub struct ReauthContext {
form: ErroredForm<ReauthFormField>,
next: Option<PostAuthContext>,
}
/// Context used by the `account/index.html` template /// Context used by the `account/index.html` template
#[derive(Serialize)] #[derive(Serialize)]
pub struct AccountContext { pub struct AccountContext {

View File

@ -0,0 +1,239 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, hash::Hash};
use serde::{Deserialize, Serialize};
/// A trait which should be used for form field enums
pub trait FormField: Copy + Hash + PartialEq + Eq + Serialize + for<'de> Deserialize<'de> {
/// Return false for fields where values should not be kept (e.g. password
/// fields)
fn keep(&self) -> bool;
}
/// An error on a form field
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case", tag = "kind")]
pub enum FieldError {
/// A reuired field is missing
Required,
/// An unspecified error on the field
Unspecified,
}
/// An error on the whole form
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case", tag = "kind")]
pub enum FormError {
/// The given credentials are not valid
InvalidCredentials,
/// Password fields don't match
PasswordMismatch,
/// There was an internal error
Internal,
}
#[derive(Debug, Default, Serialize)]
struct FieldState {
value: Option<String>,
errors: Vec<FieldError>,
}
/// The state of a form and its fields
#[derive(Debug, Serialize)]
pub struct FormState<K: Hash + Eq> {
fields: HashMap<K, FieldState>,
errors: Vec<FormError>,
#[serde(skip)]
has_errors: bool,
}
impl<K: Hash + Eq> Default for FormState<K> {
fn default() -> Self {
FormState {
fields: HashMap::default(),
errors: Vec::default(),
has_errors: false,
}
}
}
impl<K: FormField> FormState<K> {
/// Generate a [`FormState`] out of a form
///
/// # Panics
///
/// If the form fails to serialize, or the form field keys fail to
/// deserialize
pub fn from_form<F: Serialize>(form: &F) -> Self {
let form = serde_json::to_value(form).unwrap();
let fields: HashMap<K, Option<String>> = serde_json::from_value(form).unwrap();
let fields = fields
.into_iter()
.map(|(key, value)| {
let value = key.keep().then(|| value).flatten();
let field = FieldState {
value,
errors: Vec::new(),
};
(key, field)
})
.collect();
FormState {
fields,
errors: Vec::new(),
has_errors: false,
}
}
/// Add an error on a form field
pub fn add_error_on_field(&mut self, field: K, error: FieldError) {
self.fields.entry(field).or_default().errors.push(error);
self.has_errors = true;
}
/// Add an error on a form field
#[must_use]
pub fn with_error_on_field(mut self, field: K, error: FieldError) -> Self {
self.add_error_on_field(field, error);
self
}
/// Add an error on the form
pub fn add_error_on_form(&mut self, error: FormError) {
self.errors.push(error);
self.has_errors = true;
}
/// Add an error on the form
#[must_use]
pub fn with_error_on_form(mut self, error: FormError) -> Self {
self.add_error_on_form(error);
self
}
/// Returns `true` if the form has no error attached to it
#[must_use]
pub fn is_valid(&self) -> bool {
!self.has_errors
}
}
/// Utility trait to help creating [`FormState`] out of a form
pub trait ToFormState: Serialize {
/// The enum used for field names
type Field: FormField;
/// Generate a [`FormState`] out of [`Self`]
///
/// # Panics
///
/// If the form fails to serialize or [`Self::Field`] fails to deserialize
fn to_form_state(&self) -> FormState<Self::Field> {
FormState::from_form(&self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Serialize)]
struct TestForm {
foo: String,
bar: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum TestFormField {
Foo,
Bar,
}
impl FormField for TestFormField {
fn keep(&self) -> bool {
match self {
Self::Foo => true,
Self::Bar => false,
}
}
}
impl ToFormState for TestForm {
type Field = TestFormField;
}
#[test]
fn form_state_serialization() {
let form = TestForm {
foo: "john".to_string(),
bar: "hunter2".to_string(),
};
let state = form.to_form_state();
let state = serde_json::to_value(&state).unwrap();
assert_eq!(
state,
serde_json::json!({
"errors": [],
"fields": {
"foo": {
"errors": [],
"value": "john",
},
"bar": {
"errors": [],
"value": null
},
}
})
);
let form = TestForm {
foo: "".to_string(),
bar: "".to_string(),
};
let state = form
.to_form_state()
.with_error_on_field(TestFormField::Foo, FieldError::Required)
.with_error_on_field(TestFormField::Bar, FieldError::Required)
.with_error_on_form(FormError::InvalidCredentials);
let state = serde_json::to_value(&state).unwrap();
assert_eq!(
state,
serde_json::json!({
"errors": [{"kind": "invalid_credentials"}],
"fields": {
"foo": {
"errors": [{"kind": "required"}],
"value": "",
},
"bar": {
"errors": [{"kind": "required"}],
"value": null
},
}
})
);
}
}

View File

@ -37,16 +37,20 @@ use tokio::{fs::OpenOptions, io::AsyncWriteExt, sync::RwLock, task::JoinError};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
mod context; mod context;
mod forms;
mod functions; mod functions;
#[macro_use] #[macro_use]
mod macros; mod macros;
pub use self::context::{ pub use self::{
AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext, EmptyContext, context::{
ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, PostAuthContext, AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext,
ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, TemplateContext, WithCsrf, EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField,
WithOptionalSession, WithSession, PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
TemplateContext, WithCsrf, WithOptionalSession, WithSession,
},
forms::{FieldError, FormError, FormField, FormState, ToFormState},
}; };
/// Wrapper around [`tera::Tera`] helping rendering the various templates /// Wrapper around [`tera::Tera`] helping rendering the various templates
@ -280,6 +284,7 @@ register_templates! {
"components/field.html", "components/field.html",
"components/back_to_client.html", "components/back_to_client.html",
"components/navbar.html", "components/navbar.html",
"components/errors.html",
"base.html", "base.html",
}; };

View File

@ -1,5 +1,5 @@
{# {#
Copyright 2021 The Matrix.org Foundation C.I.C. Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ limitations under the License.
{% import "components/field.html" as field %} {% import "components/field.html" as field %}
{% import "components/back_to_client.html" as back_to_client %} {% import "components/back_to_client.html" as back_to_client %}
{% import "components/navbar.html" as navbar %} {% import "components/navbar.html" as navbar %}
{% import "components/errors.html" as errors %}
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>

View File

@ -0,0 +1,25 @@
{#
Copyright 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% macro form_error_message(error) -%}
{% if error.kind == "invalid_credentials" %}
Invalid credentials
{% elif error.kind == "password_mismatch" %}
Password fields don't match
{% else %}
{{ error.kind }}
{% endif %}
{%- endmacro %}

View File

@ -1,5 +1,5 @@
{# {#
Copyright 2021 The Matrix.org Foundation C.I.C. Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -14,21 +14,36 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
#} #}
{% macro input(label, name, type="text", errors=false, class="") %} {% macro input(label, name, type="text", form_state=false, class="") %}
{% if errors is not empty %} {% if not form_state %}
{% set form_state = dict(errors=[], fields=dict()) %}
{% endif %}
{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %}
{% if state.errors is not empty %}
{% set border_color = "border-alert" %} {% set border_color = "border-alert" %}
{% set text_color = "text-alert" %} {% set text_color = "text-alert" %}
{% else %} {% else %}
{% set border_color = "border-grey-50 dark:border-grey-450" %} {% set border_color = "border-grey-50 dark:border-grey-450" %}
{% set text_color = "text-black-800 dark:text-grey-300" %} {% set text_color = "text-black-800 dark:text-grey-300" %}
{% endif %} {% endif %}
<label class="flex flex-col block {{ class }}"> <label class="flex flex-col block {{ class }}">
<div class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">{{ label }}</div> <div class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">{{ label }}</div>
<input name="{{ name }}" class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0" type="{{ type }}" /> <input name="{{ name }}" class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0" type="{{ type }}" {% if state.value %} value="{{ state.value }}" {% endif %} />
{% if errors is not empty %} {% if state.errors is not empty %}
{% for error in errors %} {% for error in state.errors %}
<div class="mx-4 text-sm text-alert">{{ error }}</div> {% if error.kind != "unspecified" %}
<div class="mx-4 text-sm text-alert">
{% if error.kind == "required" %}
This field is required
{% else %}
{{ error.kind }}
{% endif %}
</div>
{% endif %}
{% endfor %} {% endfor %}
{% endif %} {% endif %}
</label> </label>

View File

@ -1,5 +1,5 @@
{# {#
Copyright 2021 The Matrix.org Foundation C.I.C. Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -23,10 +23,17 @@ limitations under the License.
<h1 class="text-lg text-center font-medium">Sign in</h1> <h1 class="text-lg text-center font-medium">Sign in</h1>
<p>Please sign in to continue:</p> <p>Please sign in to continue:</p>
</div> </div>
{% if form.errors is not empty %}
{% for error in form.errors %}
<div class="text-alert font-medium">
{{ errors::form_error_message(error=error) }}
</div>
{% endfor %}
{% endif %}
<input type="hidden" name="csrf" value="{{ csrf_token }}" /> <input type="hidden" name="csrf" value="{{ csrf_token }}" />
{# TODO: errors #} {{ field::input(label="Username", name="username", form_state=form) }}
{{ field::input(label="Username", name="username") }} {{ field::input(label="Password", name="password", type="password", form_state=form) }}
{{ field::input(label="Password", name="password", type="password") }}
{% if next and next.kind == "continue_authorization_grant" %} {% if next and next.kind == "continue_authorization_grant" %}
<div class="grid grid-cols-2 gap-4"> <div class="grid grid-cols-2 gap-4">
{{ back_to_client::link( {{ back_to_client::link(

View File

@ -1,5 +1,5 @@
{# {#
Copyright 2021 The Matrix.org Foundation C.I.C. Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -23,11 +23,18 @@ limitations under the License.
<h1 class="text-lg text-center font-medium">Create an account</h1> <h1 class="text-lg text-center font-medium">Create an account</h1>
<p>Please create an account to get started:</p> <p>Please create an account to get started:</p>
</div> </div>
{% if form.errors is not empty %}
{% for error in form.errors %}
<div class="text-alert font-medium">
{{ errors::form_error_message(error=error) }}
</div>
{% endfor %}
{% endif %}
<input type="hidden" name="csrf" value="{{ csrf_token }}" /> <input type="hidden" name="csrf" value="{{ csrf_token }}" />
{# TODO: errors #} {{ field::input(label="Username", name="username", form_state=form) }}
{{ field::input(label="Username", name="username") }} {{ field::input(label="Password", name="password", type="password", form_state=form) }}
{{ field::input(label="Password", name="password", type="password") }} {{ field::input(label="Confirm Password", name="password_confirm", type="password", form_state=form) }}
{{ field::input(label="Confirm Password", name="password_confirm", type="password") }}
{% if next and next.kind == "continue_authorization_grant" %} {% if next and next.kind == "continue_authorization_grant" %}
<div class="grid grid-cols-2 gap-4"> <div class="grid grid-cols-2 gap-4">