You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Form error state overhaul
This adds a new FormState structure here to hold the state of an errored from, including retaining field value and better error codes. It also adds error recovery for the registration form, and properly loads the post_login_action context in case of errors.
This commit is contained in:
@ -20,7 +20,7 @@ use std::{
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use futures::{future::TryFutureExt, stream::TryStreamExt};
|
||||
use futures::stream::{StreamExt, TryStreamExt};
|
||||
use hyper::Server;
|
||||
use mas_config::RootConfig;
|
||||
use mas_email::{MailTransport, Mailer};
|
||||
@ -118,20 +118,16 @@ async fn watch_templates(
|
||||
}
|
||||
});
|
||||
|
||||
let fut = files_changed_stream
|
||||
.try_for_each(move |files| {
|
||||
let templates = templates.clone();
|
||||
async move {
|
||||
info!(?files, "Files changed, reloading templates");
|
||||
let fut = files_changed_stream.for_each(move |files| {
|
||||
let templates = templates.clone();
|
||||
async move {
|
||||
info!(?files, "Files changed, reloading templates");
|
||||
|
||||
templates
|
||||
.clone()
|
||||
.reload()
|
||||
.await
|
||||
.context("Could not reload templates")
|
||||
}
|
||||
})
|
||||
.inspect_err(|err| error!(%err, "Error while watching templates, stop watching"));
|
||||
templates.clone().reload().await.unwrap_or_else(|err| {
|
||||
error!(?err, "Error while reloading templates");
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(fut);
|
||||
|
||||
|
@ -1,117 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, fmt::Debug, hash::Hash};
|
||||
|
||||
use serde::{ser::SerializeMap, Serialize};
|
||||
|
||||
pub trait HtmlError: Debug + Send + Sync + 'static {
|
||||
fn html_display(&self) -> String;
|
||||
}
|
||||
|
||||
pub trait WrapFormError<FieldType> {
|
||||
fn on_form(self) -> ErroredForm<FieldType>;
|
||||
fn on_field(self, field: FieldType) -> ErroredForm<FieldType>;
|
||||
}
|
||||
|
||||
impl<E, FieldType> WrapFormError<FieldType> for E
|
||||
where
|
||||
E: HtmlError,
|
||||
{
|
||||
fn on_form(self) -> ErroredForm<FieldType> {
|
||||
let mut f = ErroredForm::new();
|
||||
f.form.push(FormError {
|
||||
error: Box::new(self),
|
||||
});
|
||||
f
|
||||
}
|
||||
|
||||
fn on_field(self, field: FieldType) -> ErroredForm<FieldType> {
|
||||
let mut f = ErroredForm::new();
|
||||
f.fields.push(FieldError {
|
||||
field,
|
||||
error: Box::new(self),
|
||||
});
|
||||
f
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FormError {
|
||||
error: Box<dyn HtmlError>,
|
||||
}
|
||||
|
||||
impl Serialize for FormError {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.error.html_display())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FieldError<FieldType> {
|
||||
field: FieldType,
|
||||
error: Box<dyn HtmlError>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ErroredForm<FieldType> {
|
||||
form: Vec<FormError>,
|
||||
fields: Vec<FieldError<FieldType>>,
|
||||
}
|
||||
|
||||
impl<T> Default for ErroredForm<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
form: Vec::new(),
|
||||
fields: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ErroredForm<T> {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
form: Vec::new(),
|
||||
fields: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<FieldType: Copy + Serialize + Hash + Eq> Serialize for ErroredForm<FieldType> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
let has_errors = !self.form.is_empty() || !self.fields.is_empty();
|
||||
map.serialize_entry("has_errors", &has_errors)?;
|
||||
map.serialize_entry("form_errors", &self.form)?;
|
||||
|
||||
let fields: HashMap<FieldType, Vec<String>> =
|
||||
self.fields.iter().fold(HashMap::new(), |mut map, err| {
|
||||
map.entry(err.field)
|
||||
.or_default()
|
||||
.push(err.error.html_display());
|
||||
map
|
||||
});
|
||||
|
||||
map.serialize_entry("fields_errors", &fields)?;
|
||||
|
||||
map.end()
|
||||
}
|
||||
}
|
@ -22,7 +22,6 @@
|
||||
clippy::trait_duplication_in_bounds
|
||||
)]
|
||||
|
||||
pub mod errors;
|
||||
pub(crate) mod oauth2;
|
||||
pub(crate) mod tokens;
|
||||
pub(crate) mod traits;
|
||||
|
@ -96,7 +96,7 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
authenticate_session(&mut txn, &mut session, form.current_password).await?;
|
||||
authenticate_session(&mut txn, &mut session, &form.current_password).await?;
|
||||
|
||||
// TODO: display nice form errors
|
||||
if form.new_password != form.new_password_confirm {
|
||||
|
@ -18,25 +18,30 @@ use axum::{
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use mas_axum_utils::{
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
csrf::{CsrfExt, CsrfToken, ProtectedForm},
|
||||
FancyError, SessionInfoExt,
|
||||
};
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::errors::WrapFormError;
|
||||
use mas_router::Route;
|
||||
use mas_storage::user::login;
|
||||
use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use mas_storage::user::{login, LoginError};
|
||||
use mas_templates::{
|
||||
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{PgConnection, PgPool};
|
||||
|
||||
use super::shared::OptionalPostAuthAction;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct LoginForm {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
impl ToFormState for LoginForm {
|
||||
type Field = LoginFormField;
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(templates, pool, cookie_jar))]
|
||||
pub(crate) async fn get(
|
||||
Extension(templates): Extension<Templates>,
|
||||
@ -55,19 +60,14 @@ pub(crate) async fn get(
|
||||
let reply = query.go_next();
|
||||
Ok((cookie_jar, reply).into_response())
|
||||
} else {
|
||||
let ctx = LoginContext::default();
|
||||
let next = query.load_context(&mut conn).await?;
|
||||
let ctx = if let Some(next) = next {
|
||||
ctx.with_post_action(next)
|
||||
} else {
|
||||
ctx
|
||||
};
|
||||
let register_link = mas_router::Register::from(query.post_auth_action).relative_url();
|
||||
let ctx = ctx
|
||||
.with_register_link(register_link.to_string())
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_login(&ctx).await?;
|
||||
let content = render(
|
||||
LoginContext::default(),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok((cookie_jar, Html(content)).into_response())
|
||||
}
|
||||
@ -80,33 +80,86 @@ pub(crate) async fn post(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<LoginForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
use mas_storage::user::LoginError;
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
|
||||
|
||||
// TODO: recover
|
||||
match login(&mut conn, &form.username, form.password).await {
|
||||
// Validate the form
|
||||
let state = {
|
||||
let mut state = form.to_form_state();
|
||||
|
||||
if form.username.is_empty() {
|
||||
state.add_error_on_field(LoginFormField::Username, FieldError::Required);
|
||||
}
|
||||
|
||||
if form.password.is_empty() {
|
||||
state.add_error_on_field(LoginFormField::Password, FieldError::Required);
|
||||
}
|
||||
|
||||
state
|
||||
};
|
||||
|
||||
if !state.is_valid() {
|
||||
let content = render(
|
||||
LoginContext::default().with_form_state(state),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
match login(&mut conn, &form.username, &form.password).await {
|
||||
Ok(session_info) => {
|
||||
let cookie_jar = cookie_jar.set_session(&session_info);
|
||||
let reply = query.go_next();
|
||||
Ok((cookie_jar, reply).into_response())
|
||||
}
|
||||
Err(e) => {
|
||||
let errored_form = match e {
|
||||
LoginError::NotFound { .. } => e.on_field(LoginFormField::Username),
|
||||
LoginError::Authentication { .. } => e.on_field(LoginFormField::Password),
|
||||
LoginError::Other(_) => e.on_form(),
|
||||
let state = match e {
|
||||
LoginError::NotFound { .. } | LoginError::Authentication { .. } => {
|
||||
state.with_error_on_form(FormError::InvalidCredentials)
|
||||
}
|
||||
LoginError::Other(_) => state.with_error_on_form(FormError::Internal),
|
||||
};
|
||||
let ctx = LoginContext::default()
|
||||
.with_form_error(errored_form)
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_login(&ctx).await?;
|
||||
let content = render(
|
||||
LoginContext::default().with_form_state(state),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok((cookie_jar, Html(content)).into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn render(
|
||||
ctx: LoginContext,
|
||||
action: OptionalPostAuthAction,
|
||||
csrf_token: CsrfToken,
|
||||
conn: &mut PgConnection,
|
||||
templates: &Templates,
|
||||
) -> Result<String, FancyError> {
|
||||
let next = action.load_context(conn).await?;
|
||||
let ctx = if let Some(next) = next {
|
||||
ctx.with_post_action(next)
|
||||
} else {
|
||||
ctx
|
||||
};
|
||||
let register_link = mas_router::Register::from(action.post_auth_action).relative_url();
|
||||
let ctx = ctx
|
||||
.with_register_link(register_link.to_string())
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_login(&ctx).await?;
|
||||
Ok(content)
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ pub(crate) async fn post(
|
||||
};
|
||||
|
||||
// TODO: recover from errors here
|
||||
authenticate_session(&mut txn, &mut session, form.password).await?;
|
||||
authenticate_session(&mut txn, &mut session, &form.password).await?;
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
txn.commit().await?;
|
||||
|
||||
|
@ -21,25 +21,32 @@ use axum::{
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use mas_axum_utils::{
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
csrf::{CsrfExt, CsrfToken, ProtectedForm},
|
||||
FancyError, SessionInfoExt,
|
||||
};
|
||||
use mas_config::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::user::{register_user, start_session};
|
||||
use mas_templates::{RegisterContext, TemplateContext, Templates};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use mas_templates::{
|
||||
FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates,
|
||||
ToFormState,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{PgConnection, PgPool};
|
||||
|
||||
use super::shared::OptionalPostAuthAction;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct RegisterForm {
|
||||
username: String,
|
||||
password: String,
|
||||
password_confirm: String,
|
||||
}
|
||||
|
||||
impl ToFormState for RegisterForm {
|
||||
type Field = RegisterFormField;
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
Extension(templates): Extension<Templates>,
|
||||
Extension(pool): Extension<PgPool>,
|
||||
@ -57,36 +64,68 @@ pub(crate) async fn get(
|
||||
let reply = query.go_next();
|
||||
Ok((cookie_jar, reply).into_response())
|
||||
} else {
|
||||
let ctx = RegisterContext::default();
|
||||
let next = query.load_context(&mut conn).await?;
|
||||
let ctx = if let Some(next) = next {
|
||||
ctx.with_post_action(next)
|
||||
} else {
|
||||
ctx
|
||||
};
|
||||
let login_link = mas_router::Login::from(query.post_auth_action).relative_url();
|
||||
let ctx = ctx.with_login_link(login_link.to_string());
|
||||
let ctx = ctx.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_register(&ctx).await?;
|
||||
let content = render(
|
||||
RegisterContext::default(),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut conn,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok((cookie_jar, Html(content)).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn post(
|
||||
Extension(templates): Extension<Templates>,
|
||||
Extension(pool): Extension<PgPool>,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<RegisterForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
// TODO: display nice form errors
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
|
||||
if form.password != form.password_confirm {
|
||||
return Err(anyhow::anyhow!("password mismatch").into());
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
|
||||
|
||||
// Validate the form
|
||||
let state = {
|
||||
let mut state = form.to_form_state();
|
||||
|
||||
if form.username.is_empty() {
|
||||
state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
|
||||
}
|
||||
|
||||
if form.password.is_empty() {
|
||||
state.add_error_on_field(RegisterFormField::Password, FieldError::Required);
|
||||
}
|
||||
|
||||
if form.password_confirm.is_empty() {
|
||||
state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Required);
|
||||
}
|
||||
|
||||
if form.password != form.password_confirm {
|
||||
state.add_error_on_form(FormError::PasswordMismatch);
|
||||
state.add_error_on_field(RegisterFormField::Password, FieldError::Unspecified);
|
||||
state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified);
|
||||
}
|
||||
|
||||
state
|
||||
};
|
||||
|
||||
if !state.is_valid() {
|
||||
let content = render(
|
||||
RegisterContext::default().with_form_state(state),
|
||||
query,
|
||||
csrf_token,
|
||||
&mut txn,
|
||||
&templates,
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
let pfh = Argon2::default();
|
||||
@ -100,3 +139,25 @@ pub(crate) async fn post(
|
||||
let reply = query.go_next();
|
||||
Ok((cookie_jar, reply).into_response())
|
||||
}
|
||||
|
||||
async fn render(
|
||||
ctx: RegisterContext,
|
||||
action: OptionalPostAuthAction,
|
||||
csrf_token: CsrfToken,
|
||||
conn: &mut PgConnection,
|
||||
templates: &Templates,
|
||||
) -> Result<String, FancyError> {
|
||||
let next = action.load_context(conn).await?;
|
||||
let ctx = if let Some(next) = next {
|
||||
ctx.with_post_action(next)
|
||||
} else {
|
||||
ctx
|
||||
};
|
||||
let login_link = mas_router::Login::from(action.post_auth_action).relative_url();
|
||||
let ctx = ctx
|
||||
.with_login_link(login_link.to_string())
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_register(&ctx).await?;
|
||||
Ok(content)
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ use anyhow::{bail, Context};
|
||||
use argon2::Argon2;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
errors::HtmlError, Authentication, BrowserSession, User, UserEmail, UserEmailVerification,
|
||||
Authentication, BrowserSession, User, UserEmail, UserEmailVerification,
|
||||
UserEmailVerificationState,
|
||||
};
|
||||
use password_hash::{PasswordHash, PasswordHasher, SaltString};
|
||||
@ -61,21 +61,11 @@ pub enum LoginError {
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl HtmlError for LoginError {
|
||||
fn html_display(&self) -> String {
|
||||
match self {
|
||||
LoginError::NotFound { .. } => "Could not find user".to_string(),
|
||||
LoginError::Authentication { .. } => "Failed to authenticate user".to_string(),
|
||||
LoginError::Other(e) => format!("Internal error: <pre>{}</pre>", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(conn, password))]
|
||||
pub async fn login(
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
username: &str,
|
||||
password: String,
|
||||
password: &str,
|
||||
) -> Result<BrowserSession<PostgresqlBackend>, LoginError> {
|
||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
@ -287,7 +277,7 @@ pub enum AuthenticationError {
|
||||
pub async fn authenticate_session(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
session: &mut BrowserSession<PostgresqlBackend>,
|
||||
password: String,
|
||||
password: &str,
|
||||
) -> Result<(), AuthenticationError> {
|
||||
// First, fetch the hashed password from the user associated with that session
|
||||
let hashed_password: String = sqlx::query_scalar!(
|
||||
@ -307,6 +297,7 @@ pub async fn authenticate_session(
|
||||
|
||||
// TODO: pass verifiers list as parameter
|
||||
// Verify the password in a blocking thread to avoid blocking the async executor
|
||||
let password = password.to_string();
|
||||
task::spawn_blocking(move || {
|
||||
let context = Argon2::default();
|
||||
let hasher = PasswordHash::new(&hashed_password).map_err(AuthenticationError::Password)?;
|
||||
|
@ -16,12 +16,12 @@
|
||||
|
||||
#![allow(clippy::trait_duplication_in_bounds)]
|
||||
|
||||
use mas_data_model::{
|
||||
errors::ErroredForm, AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail,
|
||||
};
|
||||
use serde::{ser::SerializeStruct, Serialize};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail};
|
||||
use serde::{ser::SerializeStruct, Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use crate::{FormField, FormState};
|
||||
|
||||
/// Helper trait to construct context wrappers
|
||||
pub trait TemplateContext: Serialize {
|
||||
/// Attach a user session to the template context
|
||||
@ -219,7 +219,7 @@ impl TemplateContext for IndexContext {
|
||||
}
|
||||
|
||||
/// Fields of the login form
|
||||
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LoginFormField {
|
||||
/// The username field
|
||||
@ -229,6 +229,15 @@ pub enum LoginFormField {
|
||||
Password,
|
||||
}
|
||||
|
||||
impl FormField for LoginFormField {
|
||||
fn keep(&self) -> bool {
|
||||
match self {
|
||||
Self::Username => true,
|
||||
Self::Password => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used in login and reauth screens, for the post-auth action to do
|
||||
#[derive(Serialize)]
|
||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||
@ -243,7 +252,7 @@ pub enum PostAuthContext {
|
||||
/// Context used by the `login.html` template
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct LoginContext {
|
||||
form: ErroredForm<LoginFormField>,
|
||||
form: FormState<LoginFormField>,
|
||||
next: Option<PostAuthContext>,
|
||||
register_link: String,
|
||||
}
|
||||
@ -255,7 +264,7 @@ impl TemplateContext for LoginContext {
|
||||
{
|
||||
// TODO: samples with errors
|
||||
vec![LoginContext {
|
||||
form: ErroredForm::default(),
|
||||
form: FormState::default(),
|
||||
next: None,
|
||||
register_link: "/register".to_string(),
|
||||
}]
|
||||
@ -263,9 +272,9 @@ impl TemplateContext for LoginContext {
|
||||
}
|
||||
|
||||
impl LoginContext {
|
||||
/// Add an error on the login form
|
||||
/// Set the form state
|
||||
#[must_use]
|
||||
pub fn with_form_error(self, form: ErroredForm<LoginFormField>) -> Self {
|
||||
pub fn with_form_state(self, form: FormState<LoginFormField>) -> Self {
|
||||
Self { form, ..self }
|
||||
}
|
||||
|
||||
@ -289,7 +298,7 @@ impl LoginContext {
|
||||
}
|
||||
|
||||
/// Fields of the registration form
|
||||
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RegisterFormField {
|
||||
/// The username field
|
||||
@ -302,10 +311,19 @@ pub enum RegisterFormField {
|
||||
PasswordConfirm,
|
||||
}
|
||||
|
||||
impl FormField for RegisterFormField {
|
||||
fn keep(&self) -> bool {
|
||||
match self {
|
||||
Self::Username => true,
|
||||
Self::Password | Self::PasswordConfirm => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `register.html` template
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct RegisterContext {
|
||||
form: ErroredForm<LoginFormField>,
|
||||
form: FormState<RegisterFormField>,
|
||||
next: Option<PostAuthContext>,
|
||||
login_link: String,
|
||||
}
|
||||
@ -317,7 +335,7 @@ impl TemplateContext for RegisterContext {
|
||||
{
|
||||
// TODO: samples with errors
|
||||
vec![RegisterContext {
|
||||
form: ErroredForm::default(),
|
||||
form: FormState::default(),
|
||||
next: None,
|
||||
login_link: "/login".to_string(),
|
||||
}]
|
||||
@ -327,7 +345,7 @@ impl TemplateContext for RegisterContext {
|
||||
impl RegisterContext {
|
||||
/// Add an error on the registration form
|
||||
#[must_use]
|
||||
pub fn with_form_error(self, form: ErroredForm<LoginFormField>) -> Self {
|
||||
pub fn with_form_state(self, form: FormState<RegisterFormField>) -> Self {
|
||||
Self { form, ..self }
|
||||
}
|
||||
|
||||
@ -377,13 +395,28 @@ impl ConsentContext {
|
||||
}
|
||||
|
||||
/// Fields of the reauthentication form
|
||||
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ReauthFormField {
|
||||
/// The password field
|
||||
Password,
|
||||
}
|
||||
|
||||
impl FormField for ReauthFormField {
|
||||
fn keep(&self) -> bool {
|
||||
match self {
|
||||
Self::Password => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `reauth.html` template
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct ReauthContext {
|
||||
form: FormState<ReauthFormField>,
|
||||
next: Option<PostAuthContext>,
|
||||
}
|
||||
|
||||
impl TemplateContext for ReauthContext {
|
||||
fn sample() -> Vec<Self>
|
||||
where
|
||||
@ -391,7 +424,7 @@ impl TemplateContext for ReauthContext {
|
||||
{
|
||||
// TODO: samples with errors
|
||||
vec![ReauthContext {
|
||||
form: ErroredForm::default(),
|
||||
form: FormState::default(),
|
||||
next: None,
|
||||
}]
|
||||
}
|
||||
@ -400,7 +433,7 @@ impl TemplateContext for ReauthContext {
|
||||
impl ReauthContext {
|
||||
/// Add an error on the reauthentication form
|
||||
#[must_use]
|
||||
pub fn with_form_error(self, form: ErroredForm<ReauthFormField>) -> Self {
|
||||
pub fn with_form_state(self, form: FormState<ReauthFormField>) -> Self {
|
||||
Self { form, ..self }
|
||||
}
|
||||
|
||||
@ -414,22 +447,6 @@ impl ReauthContext {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReauthContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
form: ErroredForm::new(),
|
||||
next: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `reauth.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct ReauthContext {
|
||||
form: ErroredForm<ReauthFormField>,
|
||||
next: Option<PostAuthContext>,
|
||||
}
|
||||
|
||||
/// Context used by the `account/index.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct AccountContext {
|
||||
|
239
crates/templates/src/forms.rs
Normal file
239
crates/templates/src/forms.rs
Normal file
@ -0,0 +1,239 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, hash::Hash};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A trait which should be used for form field enums
|
||||
pub trait FormField: Copy + Hash + PartialEq + Eq + Serialize + for<'de> Deserialize<'de> {
|
||||
/// Return false for fields where values should not be kept (e.g. password
|
||||
/// fields)
|
||||
fn keep(&self) -> bool;
|
||||
}
|
||||
|
||||
/// An error on a form field
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "kind")]
|
||||
pub enum FieldError {
|
||||
/// A reuired field is missing
|
||||
Required,
|
||||
|
||||
/// An unspecified error on the field
|
||||
Unspecified,
|
||||
}
|
||||
|
||||
/// An error on the whole form
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "kind")]
|
||||
pub enum FormError {
|
||||
/// The given credentials are not valid
|
||||
InvalidCredentials,
|
||||
|
||||
/// Password fields don't match
|
||||
PasswordMismatch,
|
||||
|
||||
/// There was an internal error
|
||||
Internal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
struct FieldState {
|
||||
value: Option<String>,
|
||||
errors: Vec<FieldError>,
|
||||
}
|
||||
|
||||
/// The state of a form and its fields
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FormState<K: Hash + Eq> {
|
||||
fields: HashMap<K, FieldState>,
|
||||
errors: Vec<FormError>,
|
||||
|
||||
#[serde(skip)]
|
||||
has_errors: bool,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq> Default for FormState<K> {
|
||||
fn default() -> Self {
|
||||
FormState {
|
||||
fields: HashMap::default(),
|
||||
errors: Vec::default(),
|
||||
has_errors: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: FormField> FormState<K> {
|
||||
/// Generate a [`FormState`] out of a form
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the form fails to serialize, or the form field keys fail to
|
||||
/// deserialize
|
||||
pub fn from_form<F: Serialize>(form: &F) -> Self {
|
||||
let form = serde_json::to_value(form).unwrap();
|
||||
let fields: HashMap<K, Option<String>> = serde_json::from_value(form).unwrap();
|
||||
|
||||
let fields = fields
|
||||
.into_iter()
|
||||
.map(|(key, value)| {
|
||||
let value = key.keep().then(|| value).flatten();
|
||||
let field = FieldState {
|
||||
value,
|
||||
errors: Vec::new(),
|
||||
};
|
||||
(key, field)
|
||||
})
|
||||
.collect();
|
||||
|
||||
FormState {
|
||||
fields,
|
||||
errors: Vec::new(),
|
||||
has_errors: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an error on a form field
|
||||
pub fn add_error_on_field(&mut self, field: K, error: FieldError) {
|
||||
self.fields.entry(field).or_default().errors.push(error);
|
||||
self.has_errors = true;
|
||||
}
|
||||
|
||||
/// Add an error on a form field
|
||||
#[must_use]
|
||||
pub fn with_error_on_field(mut self, field: K, error: FieldError) -> Self {
|
||||
self.add_error_on_field(field, error);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add an error on the form
|
||||
pub fn add_error_on_form(&mut self, error: FormError) {
|
||||
self.errors.push(error);
|
||||
self.has_errors = true;
|
||||
}
|
||||
|
||||
/// Add an error on the form
|
||||
#[must_use]
|
||||
pub fn with_error_on_form(mut self, error: FormError) -> Self {
|
||||
self.add_error_on_form(error);
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns `true` if the form has no error attached to it
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
!self.has_errors
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility trait to help creating [`FormState`] out of a form
|
||||
pub trait ToFormState: Serialize {
|
||||
/// The enum used for field names
|
||||
type Field: FormField;
|
||||
|
||||
/// Generate a [`FormState`] out of [`Self`]
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the form fails to serialize or [`Self::Field`] fails to deserialize
|
||||
fn to_form_state(&self) -> FormState<Self::Field> {
|
||||
FormState::from_form(&self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TestForm {
|
||||
foo: String,
|
||||
bar: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum TestFormField {
|
||||
Foo,
|
||||
Bar,
|
||||
}
|
||||
|
||||
impl FormField for TestFormField {
|
||||
fn keep(&self) -> bool {
|
||||
match self {
|
||||
Self::Foo => true,
|
||||
Self::Bar => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFormState for TestForm {
|
||||
type Field = TestFormField;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn form_state_serialization() {
|
||||
let form = TestForm {
|
||||
foo: "john".to_string(),
|
||||
bar: "hunter2".to_string(),
|
||||
};
|
||||
|
||||
let state = form.to_form_state();
|
||||
let state = serde_json::to_value(&state).unwrap();
|
||||
assert_eq!(
|
||||
state,
|
||||
serde_json::json!({
|
||||
"errors": [],
|
||||
"fields": {
|
||||
"foo": {
|
||||
"errors": [],
|
||||
"value": "john",
|
||||
},
|
||||
"bar": {
|
||||
"errors": [],
|
||||
"value": null
|
||||
},
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
let form = TestForm {
|
||||
foo: "".to_string(),
|
||||
bar: "".to_string(),
|
||||
};
|
||||
let state = form
|
||||
.to_form_state()
|
||||
.with_error_on_field(TestFormField::Foo, FieldError::Required)
|
||||
.with_error_on_field(TestFormField::Bar, FieldError::Required)
|
||||
.with_error_on_form(FormError::InvalidCredentials);
|
||||
|
||||
let state = serde_json::to_value(&state).unwrap();
|
||||
assert_eq!(
|
||||
state,
|
||||
serde_json::json!({
|
||||
"errors": [{"kind": "invalid_credentials"}],
|
||||
"fields": {
|
||||
"foo": {
|
||||
"errors": [{"kind": "required"}],
|
||||
"value": "",
|
||||
},
|
||||
"bar": {
|
||||
"errors": [{"kind": "required"}],
|
||||
"value": null
|
||||
},
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
@ -37,16 +37,20 @@ use tokio::{fs::OpenOptions, io::AsyncWriteExt, sync::RwLock, task::JoinError};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
mod context;
|
||||
mod forms;
|
||||
mod functions;
|
||||
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
|
||||
pub use self::context::{
|
||||
AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext, EmptyContext,
|
||||
ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, PostAuthContext,
|
||||
ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, TemplateContext, WithCsrf,
|
||||
WithOptionalSession, WithSession,
|
||||
pub use self::{
|
||||
context::{
|
||||
AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext,
|
||||
EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField,
|
||||
PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
|
||||
TemplateContext, WithCsrf, WithOptionalSession, WithSession,
|
||||
},
|
||||
forms::{FieldError, FormError, FormField, FormState, ToFormState},
|
||||
};
|
||||
|
||||
/// Wrapper around [`tera::Tera`] helping rendering the various templates
|
||||
@ -280,6 +284,7 @@ register_templates! {
|
||||
"components/field.html",
|
||||
"components/back_to_client.html",
|
||||
"components/navbar.html",
|
||||
"components/errors.html",
|
||||
"base.html",
|
||||
};
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
{% import "components/field.html" as field %}
|
||||
{% import "components/back_to_client.html" as back_to_client %}
|
||||
{% import "components/navbar.html" as navbar %}
|
||||
{% import "components/errors.html" as errors %}
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
25
crates/templates/src/res/components/errors.html
Normal file
25
crates/templates/src/res/components/errors.html
Normal file
@ -0,0 +1,25 @@
|
||||
{#
|
||||
Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% macro form_error_message(error) -%}
|
||||
{% if error.kind == "invalid_credentials" %}
|
||||
Invalid credentials
|
||||
{% elif error.kind == "password_mismatch" %}
|
||||
Password fields don't match
|
||||
{% else %}
|
||||
{{ error.kind }}
|
||||
{% endif %}
|
||||
{%- endmacro %}
|
@ -1,5 +1,5 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,21 +14,36 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% macro input(label, name, type="text", errors=false, class="") %}
|
||||
{% if errors is not empty %}
|
||||
{% macro input(label, name, type="text", form_state=false, class="") %}
|
||||
{% if not form_state %}
|
||||
{% set form_state = dict(errors=[], fields=dict()) %}
|
||||
{% endif %}
|
||||
|
||||
{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %}
|
||||
|
||||
{% if state.errors is not empty %}
|
||||
{% set border_color = "border-alert" %}
|
||||
{% set text_color = "text-alert" %}
|
||||
{% else %}
|
||||
{% set border_color = "border-grey-50 dark:border-grey-450" %}
|
||||
{% set text_color = "text-black-800 dark:text-grey-300" %}
|
||||
{% endif %}
|
||||
|
||||
<label class="flex flex-col block {{ class }}">
|
||||
<div class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">{{ label }}</div>
|
||||
<input name="{{ name }}" class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0" type="{{ type }}" />
|
||||
<input name="{{ name }}" class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0" type="{{ type }}" {% if state.value %} value="{{ state.value }}" {% endif %} />
|
||||
|
||||
{% if errors is not empty %}
|
||||
{% for error in errors %}
|
||||
<div class="mx-4 text-sm text-alert">{{ error }}</div>
|
||||
{% if state.errors is not empty %}
|
||||
{% for error in state.errors %}
|
||||
{% if error.kind != "unspecified" %}
|
||||
<div class="mx-4 text-sm text-alert">
|
||||
{% if error.kind == "required" %}
|
||||
This field is required
|
||||
{% else %}
|
||||
{{ error.kind }}
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</label>
|
||||
|
@ -1,5 +1,5 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -23,10 +23,17 @@ limitations under the License.
|
||||
<h1 class="text-lg text-center font-medium">Sign in</h1>
|
||||
<p>Please sign in to continue:</p>
|
||||
</div>
|
||||
{% if form.errors is not empty %}
|
||||
{% for error in form.errors %}
|
||||
<div class="text-alert font-medium">
|
||||
{{ errors::form_error_message(error=error) }}
|
||||
</div>
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
|
||||
{# TODO: errors #}
|
||||
{{ field::input(label="Username", name="username") }}
|
||||
{{ field::input(label="Password", name="password", type="password") }}
|
||||
{{ field::input(label="Username", name="username", form_state=form) }}
|
||||
{{ field::input(label="Password", name="password", type="password", form_state=form) }}
|
||||
{% if next and next.kind == "continue_authorization_grant" %}
|
||||
<div class="grid grid-cols-2 gap-4">
|
||||
{{ back_to_client::link(
|
||||
|
@ -1,5 +1,5 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -23,11 +23,18 @@ limitations under the License.
|
||||
<h1 class="text-lg text-center font-medium">Create an account</h1>
|
||||
<p>Please create an account to get started:</p>
|
||||
</div>
|
||||
{% if form.errors is not empty %}
|
||||
{% for error in form.errors %}
|
||||
<div class="text-alert font-medium">
|
||||
{{ errors::form_error_message(error=error) }}
|
||||
</div>
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
|
||||
{# TODO: errors #}
|
||||
{{ field::input(label="Username", name="username") }}
|
||||
{{ field::input(label="Password", name="password", type="password") }}
|
||||
{{ field::input(label="Confirm Password", name="password_confirm", type="password") }}
|
||||
{{ field::input(label="Username", name="username", form_state=form) }}
|
||||
{{ field::input(label="Password", name="password", type="password", form_state=form) }}
|
||||
{{ field::input(label="Confirm Password", name="password_confirm", type="password", form_state=form) }}
|
||||
|
||||
{% if next and next.kind == "continue_authorization_grant" %}
|
||||
<div class="grid grid-cols-2 gap-4">
|
||||
|
Reference in New Issue
Block a user