diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 1e158f38..59259383 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -20,7 +20,7 @@ use std::{ use anyhow::Context; use clap::Parser; -use futures::{future::TryFutureExt, stream::TryStreamExt}; +use futures::stream::{StreamExt, TryStreamExt}; use hyper::Server; use mas_config::RootConfig; use mas_email::{MailTransport, Mailer}; @@ -118,20 +118,16 @@ async fn watch_templates( } }); - let fut = files_changed_stream - .try_for_each(move |files| { - let templates = templates.clone(); - async move { - info!(?files, "Files changed, reloading templates"); + let fut = files_changed_stream.for_each(move |files| { + let templates = templates.clone(); + async move { + info!(?files, "Files changed, reloading templates"); - templates - .clone() - .reload() - .await - .context("Could not reload templates") - } - }) - .inspect_err(|err| error!(%err, "Error while watching templates, stop watching")); + templates.clone().reload().await.unwrap_or_else(|err| { + error!(?err, "Error while reloading templates"); + }); + } + }); tokio::spawn(fut); diff --git a/crates/data-model/src/errors.rs b/crates/data-model/src/errors.rs deleted file mode 100644 index 61d5f011..00000000 --- a/crates/data-model/src/errors.rs +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{collections::HashMap, fmt::Debug, hash::Hash}; - -use serde::{ser::SerializeMap, Serialize}; - -pub trait HtmlError: Debug + Send + Sync + 'static { - fn html_display(&self) -> String; -} - -pub trait WrapFormError { - fn on_form(self) -> ErroredForm; - fn on_field(self, field: FieldType) -> ErroredForm; -} - -impl WrapFormError for E -where - E: HtmlError, -{ - fn on_form(self) -> ErroredForm { - let mut f = ErroredForm::new(); - f.form.push(FormError { - error: Box::new(self), - }); - f - } - - fn on_field(self, field: FieldType) -> ErroredForm { - let mut f = ErroredForm::new(); - f.fields.push(FieldError { - field, - error: Box::new(self), - }); - f - } -} - -#[derive(Debug)] -struct FormError { - error: Box, -} - -impl Serialize for FormError { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&self.error.html_display()) - } -} - -#[derive(Debug)] -struct FieldError { - field: FieldType, - error: Box, -} - -#[derive(Debug)] -pub struct ErroredForm { - form: Vec, - fields: Vec>, -} - -impl Default for ErroredForm { - fn default() -> Self { - Self { - form: Vec::new(), - fields: Vec::new(), - } - } -} - -impl ErroredForm { - #[must_use] - pub fn new() -> Self { - Self { - form: Vec::new(), - fields: Vec::new(), - } - } -} - -impl Serialize for ErroredForm { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut map = serializer.serialize_map(Some(2))?; - let has_errors = !self.form.is_empty() || !self.fields.is_empty(); - map.serialize_entry("has_errors", &has_errors)?; - map.serialize_entry("form_errors", &self.form)?; - - let fields: HashMap> = - self.fields.iter().fold(HashMap::new(), |mut map, err| { - map.entry(err.field) - .or_default() - .push(err.error.html_display()); - map - }); - - map.serialize_entry("fields_errors", &fields)?; - - map.end() - } -} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 25690f49..9f6ef9f8 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -22,7 +22,6 @@ clippy::trait_duplication_in_bounds )] -pub mod errors; pub(crate) mod oauth2; pub(crate) mod tokens; pub(crate) mod traits; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 51cdd80a..127c9631 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -96,7 +96,7 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - authenticate_session(&mut txn, &mut session, form.current_password).await?; + authenticate_session(&mut txn, &mut session, &form.current_password).await?; // TODO: display nice form errors if form.new_password != form.new_password_confirm { diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index c7130e68..7dcb48e5 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -18,25 +18,30 @@ use axum::{ }; use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{ - csrf::{CsrfExt, ProtectedForm}, + csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; use mas_config::Encrypter; -use mas_data_model::errors::WrapFormError; use mas_router::Route; -use mas_storage::user::login; -use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates}; -use serde::Deserialize; -use sqlx::PgPool; +use mas_storage::user::{login, LoginError}; +use mas_templates::{ + FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, +}; +use serde::{Deserialize, Serialize}; +use sqlx::{PgConnection, PgPool}; use super::shared::OptionalPostAuthAction; -#[derive(Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub(crate) struct LoginForm { username: String, password: String, } +impl ToFormState for LoginForm { + type Field = LoginFormField; +} + #[tracing::instrument(skip(templates, pool, cookie_jar))] pub(crate) async fn get( Extension(templates): Extension, @@ -55,19 +60,14 @@ pub(crate) async fn get( let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } else { - let ctx = LoginContext::default(); - let next = query.load_context(&mut conn).await?; - let ctx = if let Some(next) = next { - ctx.with_post_action(next) - } else { - ctx - }; - let register_link = mas_router::Register::from(query.post_auth_action).relative_url(); - let ctx = ctx - .with_register_link(register_link.to_string()) - .with_csrf(csrf_token.form_value()); - - let content = templates.render_login(&ctx).await?; + let content = render( + LoginContext::default(), + query, + csrf_token, + &mut conn, + &templates, + ) + .await?; Ok((cookie_jar, Html(content)).into_response()) } @@ -80,33 +80,86 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - use mas_storage::user::LoginError; let mut conn = pool.acquire().await?; let form = cookie_jar.verify_form(form)?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); - // TODO: recover - match login(&mut conn, &form.username, form.password).await { + // Validate the form + let state = { + let mut state = form.to_form_state(); + + if form.username.is_empty() { + state.add_error_on_field(LoginFormField::Username, FieldError::Required); + } + + if form.password.is_empty() { + state.add_error_on_field(LoginFormField::Password, FieldError::Required); + } + + state + }; + + if !state.is_valid() { + let content = render( + LoginContext::default().with_form_state(state), + query, + csrf_token, + &mut conn, + &templates, + ) + .await?; + + return Ok((cookie_jar, Html(content)).into_response()); + } + + match login(&mut conn, &form.username, &form.password).await { Ok(session_info) => { let cookie_jar = cookie_jar.set_session(&session_info); let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } Err(e) => { - let errored_form = match e { - LoginError::NotFound { .. } => e.on_field(LoginFormField::Username), - LoginError::Authentication { .. } => e.on_field(LoginFormField::Password), - LoginError::Other(_) => e.on_form(), + let state = match e { + LoginError::NotFound { .. } | LoginError::Authentication { .. } => { + state.with_error_on_form(FormError::InvalidCredentials) + } + LoginError::Other(_) => state.with_error_on_form(FormError::Internal), }; - let ctx = LoginContext::default() - .with_form_error(errored_form) - .with_csrf(csrf_token.form_value()); - let content = templates.render_login(&ctx).await?; + let content = render( + LoginContext::default().with_form_state(state), + query, + csrf_token, + &mut conn, + &templates, + ) + .await?; Ok((cookie_jar, Html(content)).into_response()) } } } + +async fn render( + ctx: LoginContext, + action: OptionalPostAuthAction, + csrf_token: CsrfToken, + conn: &mut PgConnection, + templates: &Templates, +) -> Result { + let next = action.load_context(conn).await?; + let ctx = if let Some(next) = next { + ctx.with_post_action(next) + } else { + ctx + }; + let register_link = mas_router::Register::from(action.post_auth_action).relative_url(); + let ctx = ctx + .with_register_link(register_link.to_string()) + .with_csrf(csrf_token.form_value()); + + let content = templates.render_login(&ctx).await?; + Ok(content) +} diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 3daa2482..0b6c78f0 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -95,7 +95,7 @@ pub(crate) async fn post( }; // TODO: recover from errors here - authenticate_session(&mut txn, &mut session, form.password).await?; + authenticate_session(&mut txn, &mut session, &form.password).await?; let cookie_jar = cookie_jar.set_session(&session); txn.commit().await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 8357631b..904d3364 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -21,25 +21,32 @@ use axum::{ }; use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{ - csrf::{CsrfExt, ProtectedForm}, + csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; use mas_config::Encrypter; use mas_router::Route; use mas_storage::user::{register_user, start_session}; -use mas_templates::{RegisterContext, TemplateContext, Templates}; -use serde::Deserialize; -use sqlx::PgPool; +use mas_templates::{ + FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates, + ToFormState, +}; +use serde::{Deserialize, Serialize}; +use sqlx::{PgConnection, PgPool}; use super::shared::OptionalPostAuthAction; -#[derive(Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub(crate) struct RegisterForm { username: String, password: String, password_confirm: String, } +impl ToFormState for RegisterForm { + type Field = RegisterFormField; +} + pub(crate) async fn get( Extension(templates): Extension, Extension(pool): Extension, @@ -57,36 +64,68 @@ pub(crate) async fn get( let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } else { - let ctx = RegisterContext::default(); - let next = query.load_context(&mut conn).await?; - let ctx = if let Some(next) = next { - ctx.with_post_action(next) - } else { - ctx - }; - let login_link = mas_router::Login::from(query.post_auth_action).relative_url(); - let ctx = ctx.with_login_link(login_link.to_string()); - let ctx = ctx.with_csrf(csrf_token.form_value()); - - let content = templates.render_register(&ctx).await?; + let content = render( + RegisterContext::default(), + query, + csrf_token, + &mut conn, + &templates, + ) + .await?; Ok((cookie_jar, Html(content)).into_response()) } } pub(crate) async fn post( + Extension(templates): Extension, Extension(pool): Extension, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - // TODO: display nice form errors let mut txn = pool.begin().await?; let form = cookie_jar.verify_form(form)?; - if form.password != form.password_confirm { - return Err(anyhow::anyhow!("password mismatch").into()); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); + + // Validate the form + let state = { + let mut state = form.to_form_state(); + + if form.username.is_empty() { + state.add_error_on_field(RegisterFormField::Username, FieldError::Required); + } + + if form.password.is_empty() { + state.add_error_on_field(RegisterFormField::Password, FieldError::Required); + } + + if form.password_confirm.is_empty() { + state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Required); + } + + if form.password != form.password_confirm { + state.add_error_on_form(FormError::PasswordMismatch); + state.add_error_on_field(RegisterFormField::Password, FieldError::Unspecified); + state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified); + } + + state + }; + + if !state.is_valid() { + let content = render( + RegisterContext::default().with_form_state(state), + query, + csrf_token, + &mut txn, + &templates, + ) + .await?; + + return Ok((cookie_jar, Html(content)).into_response()); } let pfh = Argon2::default(); @@ -100,3 +139,25 @@ pub(crate) async fn post( let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } + +async fn render( + ctx: RegisterContext, + action: OptionalPostAuthAction, + csrf_token: CsrfToken, + conn: &mut PgConnection, + templates: &Templates, +) -> Result { + let next = action.load_context(conn).await?; + let ctx = if let Some(next) = next { + ctx.with_post_action(next) + } else { + ctx + }; + let login_link = mas_router::Login::from(action.post_auth_action).relative_url(); + let ctx = ctx + .with_login_link(login_link.to_string()) + .with_csrf(csrf_token.form_value()); + + let content = templates.render_register(&ctx).await?; + Ok(content) +} diff --git a/crates/storage/src/user.rs b/crates/storage/src/user.rs index 7309e24c..1ca939b6 100644 --- a/crates/storage/src/user.rs +++ b/crates/storage/src/user.rs @@ -18,7 +18,7 @@ use anyhow::{bail, Context}; use argon2::Argon2; use chrono::{DateTime, Utc}; use mas_data_model::{ - errors::HtmlError, Authentication, BrowserSession, User, UserEmail, UserEmailVerification, + Authentication, BrowserSession, User, UserEmail, UserEmailVerification, UserEmailVerificationState, }; use password_hash::{PasswordHash, PasswordHasher, SaltString}; @@ -61,21 +61,11 @@ pub enum LoginError { Other(#[from] anyhow::Error), } -impl HtmlError for LoginError { - fn html_display(&self) -> String { - match self { - LoginError::NotFound { .. } => "Could not find user".to_string(), - LoginError::Authentication { .. } => "Failed to authenticate user".to_string(), - LoginError::Other(e) => format!("Internal error:
{}
", e), - } - } -} - #[tracing::instrument(skip(conn, password))] pub async fn login( conn: impl Acquire<'_, Database = Postgres>, username: &str, - password: String, + password: &str, ) -> Result, LoginError> { let mut txn = conn.begin().await.context("could not start transaction")?; let user = lookup_user_by_username(&mut txn, username) @@ -287,7 +277,7 @@ pub enum AuthenticationError { pub async fn authenticate_session( txn: &mut Transaction<'_, Postgres>, session: &mut BrowserSession, - password: String, + password: &str, ) -> Result<(), AuthenticationError> { // First, fetch the hashed password from the user associated with that session let hashed_password: String = sqlx::query_scalar!( @@ -307,6 +297,7 @@ pub async fn authenticate_session( // TODO: pass verifiers list as parameter // Verify the password in a blocking thread to avoid blocking the async executor + let password = password.to_string(); task::spawn_blocking(move || { let context = Argon2::default(); let hasher = PasswordHash::new(&hashed_password).map_err(AuthenticationError::Password)?; diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index e45d34a3..744c2885 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -16,12 +16,12 @@ #![allow(clippy::trait_duplication_in_bounds)] -use mas_data_model::{ - errors::ErroredForm, AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail, -}; -use serde::{ser::SerializeStruct, Serialize}; +use mas_data_model::{AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail}; +use serde::{ser::SerializeStruct, Deserialize, Serialize}; use url::Url; +use crate::{FormField, FormState}; + /// Helper trait to construct context wrappers pub trait TemplateContext: Serialize { /// Attach a user session to the template context @@ -219,7 +219,7 @@ impl TemplateContext for IndexContext { } /// Fields of the login form -#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum LoginFormField { /// The username field @@ -229,6 +229,15 @@ pub enum LoginFormField { Password, } +impl FormField for LoginFormField { + fn keep(&self) -> bool { + match self { + Self::Username => true, + Self::Password => false, + } + } +} + /// Context used in login and reauth screens, for the post-auth action to do #[derive(Serialize)] #[serde(tag = "kind", rename_all = "snake_case")] @@ -243,7 +252,7 @@ pub enum PostAuthContext { /// Context used by the `login.html` template #[derive(Serialize, Default)] pub struct LoginContext { - form: ErroredForm, + form: FormState, next: Option, register_link: String, } @@ -255,7 +264,7 @@ impl TemplateContext for LoginContext { { // TODO: samples with errors vec![LoginContext { - form: ErroredForm::default(), + form: FormState::default(), next: None, register_link: "/register".to_string(), }] @@ -263,9 +272,9 @@ impl TemplateContext for LoginContext { } impl LoginContext { - /// Add an error on the login form + /// Set the form state #[must_use] - pub fn with_form_error(self, form: ErroredForm) -> Self { + pub fn with_form_state(self, form: FormState) -> Self { Self { form, ..self } } @@ -289,7 +298,7 @@ impl LoginContext { } /// Fields of the registration form -#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum RegisterFormField { /// The username field @@ -302,10 +311,19 @@ pub enum RegisterFormField { PasswordConfirm, } +impl FormField for RegisterFormField { + fn keep(&self) -> bool { + match self { + Self::Username => true, + Self::Password | Self::PasswordConfirm => false, + } + } +} + /// Context used by the `register.html` template #[derive(Serialize, Default)] pub struct RegisterContext { - form: ErroredForm, + form: FormState, next: Option, login_link: String, } @@ -317,7 +335,7 @@ impl TemplateContext for RegisterContext { { // TODO: samples with errors vec![RegisterContext { - form: ErroredForm::default(), + form: FormState::default(), next: None, login_link: "/login".to_string(), }] @@ -327,7 +345,7 @@ impl TemplateContext for RegisterContext { impl RegisterContext { /// Add an error on the registration form #[must_use] - pub fn with_form_error(self, form: ErroredForm) -> Self { + pub fn with_form_state(self, form: FormState) -> Self { Self { form, ..self } } @@ -377,13 +395,28 @@ impl ConsentContext { } /// Fields of the reauthentication form -#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] pub enum ReauthFormField { /// The password field Password, } +impl FormField for ReauthFormField { + fn keep(&self) -> bool { + match self { + Self::Password => false, + } + } +} + +/// Context used by the `reauth.html` template +#[derive(Serialize, Default)] +pub struct ReauthContext { + form: FormState, + next: Option, +} + impl TemplateContext for ReauthContext { fn sample() -> Vec where @@ -391,7 +424,7 @@ impl TemplateContext for ReauthContext { { // TODO: samples with errors vec![ReauthContext { - form: ErroredForm::default(), + form: FormState::default(), next: None, }] } @@ -400,7 +433,7 @@ impl TemplateContext for ReauthContext { impl ReauthContext { /// Add an error on the reauthentication form #[must_use] - pub fn with_form_error(self, form: ErroredForm) -> Self { + pub fn with_form_state(self, form: FormState) -> Self { Self { form, ..self } } @@ -414,22 +447,6 @@ impl ReauthContext { } } -impl Default for ReauthContext { - fn default() -> Self { - Self { - form: ErroredForm::new(), - next: None, - } - } -} - -/// Context used by the `reauth.html` template -#[derive(Serialize)] -pub struct ReauthContext { - form: ErroredForm, - next: Option, -} - /// Context used by the `account/index.html` template #[derive(Serialize)] pub struct AccountContext { diff --git a/crates/templates/src/forms.rs b/crates/templates/src/forms.rs new file mode 100644 index 00000000..8d88a825 --- /dev/null +++ b/crates/templates/src/forms.rs @@ -0,0 +1,239 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashMap, hash::Hash}; + +use serde::{Deserialize, Serialize}; + +/// A trait which should be used for form field enums +pub trait FormField: Copy + Hash + PartialEq + Eq + Serialize + for<'de> Deserialize<'de> { + /// Return false for fields where values should not be kept (e.g. password + /// fields) + fn keep(&self) -> bool; +} + +/// An error on a form field +#[derive(Debug, Serialize)] +#[serde(rename_all = "snake_case", tag = "kind")] +pub enum FieldError { + /// A reuired field is missing + Required, + + /// An unspecified error on the field + Unspecified, +} + +/// An error on the whole form +#[derive(Debug, Serialize)] +#[serde(rename_all = "snake_case", tag = "kind")] +pub enum FormError { + /// The given credentials are not valid + InvalidCredentials, + + /// Password fields don't match + PasswordMismatch, + + /// There was an internal error + Internal, +} + +#[derive(Debug, Default, Serialize)] +struct FieldState { + value: Option, + errors: Vec, +} + +/// The state of a form and its fields +#[derive(Debug, Serialize)] +pub struct FormState { + fields: HashMap, + errors: Vec, + + #[serde(skip)] + has_errors: bool, +} + +impl Default for FormState { + fn default() -> Self { + FormState { + fields: HashMap::default(), + errors: Vec::default(), + has_errors: false, + } + } +} + +impl FormState { + /// Generate a [`FormState`] out of a form + /// + /// # Panics + /// + /// If the form fails to serialize, or the form field keys fail to + /// deserialize + pub fn from_form(form: &F) -> Self { + let form = serde_json::to_value(form).unwrap(); + let fields: HashMap> = serde_json::from_value(form).unwrap(); + + let fields = fields + .into_iter() + .map(|(key, value)| { + let value = key.keep().then(|| value).flatten(); + let field = FieldState { + value, + errors: Vec::new(), + }; + (key, field) + }) + .collect(); + + FormState { + fields, + errors: Vec::new(), + has_errors: false, + } + } + + /// Add an error on a form field + pub fn add_error_on_field(&mut self, field: K, error: FieldError) { + self.fields.entry(field).or_default().errors.push(error); + self.has_errors = true; + } + + /// Add an error on a form field + #[must_use] + pub fn with_error_on_field(mut self, field: K, error: FieldError) -> Self { + self.add_error_on_field(field, error); + self + } + + /// Add an error on the form + pub fn add_error_on_form(&mut self, error: FormError) { + self.errors.push(error); + self.has_errors = true; + } + + /// Add an error on the form + #[must_use] + pub fn with_error_on_form(mut self, error: FormError) -> Self { + self.add_error_on_form(error); + self + } + + /// Returns `true` if the form has no error attached to it + #[must_use] + pub fn is_valid(&self) -> bool { + !self.has_errors + } +} + +/// Utility trait to help creating [`FormState`] out of a form +pub trait ToFormState: Serialize { + /// The enum used for field names + type Field: FormField; + + /// Generate a [`FormState`] out of [`Self`] + /// + /// # Panics + /// + /// If the form fails to serialize or [`Self::Field`] fails to deserialize + fn to_form_state(&self) -> FormState { + FormState::from_form(&self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Serialize)] + struct TestForm { + foo: String, + bar: String, + } + + #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] + #[serde(rename_all = "snake_case")] + enum TestFormField { + Foo, + Bar, + } + + impl FormField for TestFormField { + fn keep(&self) -> bool { + match self { + Self::Foo => true, + Self::Bar => false, + } + } + } + + impl ToFormState for TestForm { + type Field = TestFormField; + } + + #[test] + fn form_state_serialization() { + let form = TestForm { + foo: "john".to_string(), + bar: "hunter2".to_string(), + }; + + let state = form.to_form_state(); + let state = serde_json::to_value(&state).unwrap(); + assert_eq!( + state, + serde_json::json!({ + "errors": [], + "fields": { + "foo": { + "errors": [], + "value": "john", + }, + "bar": { + "errors": [], + "value": null + }, + } + }) + ); + + let form = TestForm { + foo: "".to_string(), + bar: "".to_string(), + }; + let state = form + .to_form_state() + .with_error_on_field(TestFormField::Foo, FieldError::Required) + .with_error_on_field(TestFormField::Bar, FieldError::Required) + .with_error_on_form(FormError::InvalidCredentials); + + let state = serde_json::to_value(&state).unwrap(); + assert_eq!( + state, + serde_json::json!({ + "errors": [{"kind": "invalid_credentials"}], + "fields": { + "foo": { + "errors": [{"kind": "required"}], + "value": "", + }, + "bar": { + "errors": [{"kind": "required"}], + "value": null + }, + } + }) + ); + } +} diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 3ff52657..232cf994 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -37,16 +37,20 @@ use tokio::{fs::OpenOptions, io::AsyncWriteExt, sync::RwLock, task::JoinError}; use tracing::{debug, info, warn}; mod context; +mod forms; mod functions; #[macro_use] mod macros; -pub use self::context::{ - AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext, EmptyContext, - ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, PostAuthContext, - ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, TemplateContext, WithCsrf, - WithOptionalSession, WithSession, +pub use self::{ + context::{ + AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext, + EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, + PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, + TemplateContext, WithCsrf, WithOptionalSession, WithSession, + }, + forms::{FieldError, FormError, FormField, FormState, ToFormState}, }; /// Wrapper around [`tera::Tera`] helping rendering the various templates @@ -280,6 +284,7 @@ register_templates! { "components/field.html", "components/back_to_client.html", "components/navbar.html", + "components/errors.html", "base.html", }; diff --git a/crates/templates/src/res/base.html b/crates/templates/src/res/base.html index 162a0b48..2483cb2f 100644 --- a/crates/templates/src/res/base.html +++ b/crates/templates/src/res/base.html @@ -1,5 +1,5 @@ {# -Copyright 2021 The Matrix.org Foundation C.I.C. +Copyright 2021, 2022 The Matrix.org Foundation C.I.C. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. {% import "components/field.html" as field %} {% import "components/back_to_client.html" as back_to_client %} {% import "components/navbar.html" as navbar %} +{% import "components/errors.html" as errors %} diff --git a/crates/templates/src/res/components/errors.html b/crates/templates/src/res/components/errors.html new file mode 100644 index 00000000..d398316d --- /dev/null +++ b/crates/templates/src/res/components/errors.html @@ -0,0 +1,25 @@ +{# +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% macro form_error_message(error) -%} + {% if error.kind == "invalid_credentials" %} + Invalid credentials + {% elif error.kind == "password_mismatch" %} + Password fields don't match + {% else %} + {{ error.kind }} + {% endif %} +{%- endmacro %} diff --git a/crates/templates/src/res/components/field.html b/crates/templates/src/res/components/field.html index 2fdd1fff..ebfab1d8 100644 --- a/crates/templates/src/res/components/field.html +++ b/crates/templates/src/res/components/field.html @@ -1,5 +1,5 @@ {# -Copyright 2021 The Matrix.org Foundation C.I.C. +Copyright 2021, 2022 The Matrix.org Foundation C.I.C. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,21 +14,36 @@ See the License for the specific language governing permissions and limitations under the License. #} -{% macro input(label, name, type="text", errors=false, class="") %} - {% if errors is not empty %} +{% macro input(label, name, type="text", form_state=false, class="") %} + {% if not form_state %} + {% set form_state = dict(errors=[], fields=dict()) %} + {% endif %} + + {% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %} + + {% if state.errors is not empty %} {% set border_color = "border-alert" %} {% set text_color = "text-alert" %} {% else %} {% set border_color = "border-grey-50 dark:border-grey-450" %} {% set text_color = "text-black-800 dark:text-grey-300" %} {% endif %} + diff --git a/crates/templates/src/res/pages/login.html b/crates/templates/src/res/pages/login.html index 16b823f1..e56a7a1f 100644 --- a/crates/templates/src/res/pages/login.html +++ b/crates/templates/src/res/pages/login.html @@ -1,5 +1,5 @@ {# -Copyright 2021 The Matrix.org Foundation C.I.C. +Copyright 2021, 2022 The Matrix.org Foundation C.I.C. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,10 +23,17 @@ limitations under the License.

Sign in

Please sign in to continue:

+ {% if form.errors is not empty %} + {% for error in form.errors %} +
+ {{ errors::form_error_message(error=error) }} +
+ {% endfor %} + {% endif %} + - {# TODO: errors #} - {{ field::input(label="Username", name="username") }} - {{ field::input(label="Password", name="password", type="password") }} + {{ field::input(label="Username", name="username", form_state=form) }} + {{ field::input(label="Password", name="password", type="password", form_state=form) }} {% if next and next.kind == "continue_authorization_grant" %}
{{ back_to_client::link( diff --git a/crates/templates/src/res/pages/register.html b/crates/templates/src/res/pages/register.html index 661a0510..fc7bce0d 100644 --- a/crates/templates/src/res/pages/register.html +++ b/crates/templates/src/res/pages/register.html @@ -1,5 +1,5 @@ {# -Copyright 2021 The Matrix.org Foundation C.I.C. +Copyright 2021, 2022 The Matrix.org Foundation C.I.C. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,11 +23,18 @@ limitations under the License.

Create an account

Please create an account to get started:

+ {% if form.errors is not empty %} + {% for error in form.errors %} +
+ {{ errors::form_error_message(error=error) }} +
+ {% endfor %} + {% endif %} + - {# TODO: errors #} - {{ field::input(label="Username", name="username") }} - {{ field::input(label="Password", name="password", type="password") }} - {{ field::input(label="Confirm Password", name="password_confirm", type="password") }} + {{ field::input(label="Username", name="username", form_state=form) }} + {{ field::input(label="Password", name="password", type="password", form_state=form) }} + {{ field::input(label="Confirm Password", name="password_confirm", type="password", form_state=form) }} {% if next and next.kind == "continue_authorization_grant" %}