diff --git a/Cargo.lock b/Cargo.lock index e554b923..42c9f9b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1605,6 +1605,7 @@ dependencies = [ "oauth2-types", "serde", "serde_json", + "serde_urlencoded", "tera", "thiserror", "tokio", diff --git a/crates/core/src/handlers/oauth2/authorization.rs b/crates/core/src/handlers/oauth2/authorization.rs index 3d586ec4..d72f8978 100644 --- a/crates/core/src/handlers/oauth2/authorization.rs +++ b/crates/core/src/handlers/oauth2/authorization.rs @@ -441,7 +441,7 @@ async fn get( } } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub(crate) struct ContinueAuthorizationGrant { #[serde( with = "serde_with::rust::display_fromstr", diff --git a/crates/core/src/handlers/views/login.rs b/crates/core/src/handlers/views/login.rs index 432e8a6c..febc1bf1 100644 --- a/crates/core/src/handlers/views/login.rs +++ b/crates/core/src/handlers/views/login.rs @@ -20,7 +20,7 @@ use serde::Deserialize; use sqlx::{pool::PoolConnection, PgPool, Postgres}; use warp::{reply::html, Filter, Rejection, Reply}; -use super::shared::PostAuthAction; +use super::{shared::PostAuthAction, RegisterRequest}; use crate::{ errors::WrapError, filters::{ @@ -130,8 +130,12 @@ async fn get( let ctx = LoginContext::default(); let ctx = match query.post_auth_action { Some(next) => { + let register_link = RegisterRequest::from(next.clone()) + .build_uri() + .wrap_error()?; let next = next.load_context(&mut conn).await.wrap_error()?; ctx.with_post_action(next) + .with_register_link(register_link.to_string()) } None => ctx, }; diff --git a/crates/core/src/handlers/views/mod.rs b/crates/core/src/handlers/views/mod.rs index 8ad59530..a6c60dab 100644 --- a/crates/core/src/handlers/views/mod.rs +++ b/crates/core/src/handlers/views/mod.rs @@ -28,7 +28,9 @@ use self::{ index::filter as index, login::filter as login, logout::filter as logout, reauth::filter as reauth, register::filter as register, }; -pub(crate) use self::{login::LoginRequest, reauth::ReauthRequest, shared::PostAuthAction}; +pub(crate) use self::{ + login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction, +}; pub(super) fn filter( pool: &PgPool, diff --git a/crates/core/src/handlers/views/register.rs b/crates/core/src/handlers/views/register.rs index bfe1e5f2..02275e9f 100644 --- a/crates/core/src/handlers/views/register.rs +++ b/crates/core/src/handlers/views/register.rs @@ -15,12 +15,13 @@ use argon2::Argon2; use hyper::http::uri::{Parts, PathAndQuery, Uri}; use mas_config::{CookiesConfig, CsrfConfig}; -use mas_data_model::BrowserSession; -use mas_templates::{EmptyContext, TemplateContext, Templates}; -use serde::{Deserialize, Serialize}; +use mas_data_model::{BrowserSession, StorageBackend}; +use mas_templates::{RegisterContext, TemplateContext, Templates}; +use serde::Deserialize; use sqlx::{pool::PoolConnection, PgPool, Postgres}; use warp::{reply::html, Filter, Rejection, Reply}; +use super::{LoginRequest, PostAuthAction}; use crate::{ errors::WrapError, filters::{ @@ -33,21 +34,34 @@ use crate::{ storage::{register_user, user::start_session, PostgresqlBackend}, }; -#[derive(Serialize, Deserialize)] -pub struct RegisterRequest { - next: Option, +#[derive(Deserialize)] +#[serde(bound(deserialize = "S::AuthorizationGrantData: std::str::FromStr, + ::Err: std::fmt::Display"))] +pub struct RegisterRequest { + #[serde(flatten)] + post_auth_action: Option>, } -impl RegisterRequest { - #[allow(dead_code)] - pub fn new(next: Option) -> Self { - Self { next } +impl From> for RegisterRequest { + fn from(post_auth_action: PostAuthAction) -> Self { + Self { + post_auth_action: Some(post_auth_action), + } } +} +impl RegisterRequest { #[allow(dead_code)] - pub fn build_uri(&self) -> anyhow::Result { - let qs = serde_urlencoded::to_string(self)?; - let path_and_query = PathAndQuery::try_from(format!("/register?{}", qs))?; + pub fn build_uri(&self) -> anyhow::Result + where + S::AuthorizationGrantData: std::fmt::Display, + { + let path_and_query = if let Some(next) = &self.post_auth_action { + let qs = serde_urlencoded::to_string(next)?; + PathAndQuery::try_from(format!("/register?{}", qs))? + } else { + PathAndQuery::from_static("/register") + }; let uri = Uri::from_parts({ let mut parts = Parts::default(); parts.path_and_query = Some(path_and_query); @@ -56,19 +70,17 @@ impl RegisterRequest { Ok(uri) } - fn redirect(self) -> Result { - let uri: Uri = Uri::from_parts({ - let mut parts = Parts::default(); - parts.path_and_query = Some( - self.next - .map(warp::http::uri::PathAndQuery::try_from) - .transpose() - .wrap_error()? - .unwrap_or_else(|| PathAndQuery::from_static("/")), - ); - parts - }) - .wrap_error()?; + fn redirect(self) -> Result + where + S::AuthorizationGrantData: std::fmt::Display, + { + let uri = self + .post_auth_action + .as_ref() + .map(PostAuthAction::build_uri) + .transpose() + .wrap_error()? + .unwrap_or_else(|| Uri::from_static("/")); Ok(warp::redirect::see_other(uri)) } } @@ -88,6 +100,7 @@ pub(super) fn filter( ) -> impl Filter + Clone + Send + Sync + 'static { let get = warp::get() .and(with_templates(templates)) + .and(connection(pool)) .and(encrypted_cookie_saver(cookies_config)) .and(updated_csrf_token(cookies_config, csrf_config)) .and(warp::query()) @@ -106,15 +119,26 @@ pub(super) fn filter( async fn get( templates: Templates, + mut conn: PoolConnection, cookie_saver: EncryptedCookieSaver, csrf_token: CsrfToken, - query: RegisterRequest, + query: RegisterRequest, maybe_session: Option>, ) -> Result, Rejection> { if maybe_session.is_some() { Ok(Box::new(query.redirect()?)) } else { - let ctx = EmptyContext.with_csrf(csrf_token.form_value()); + let ctx = RegisterContext::default(); + let ctx = match query.post_auth_action { + Some(next) => { + let login_link = LoginRequest::from(next.clone()).build_uri().wrap_error()?; + let next = next.load_context(&mut conn).await.wrap_error()?; + ctx.with_post_action(next) + .with_login_link(login_link.to_string()) + } + None => ctx, + }; + let ctx = ctx.with_csrf(csrf_token.form_value()); let content = templates.render_register(&ctx).await?; let reply = html(content); let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; @@ -126,8 +150,9 @@ async fn post( mut conn: PoolConnection, cookie_saver: EncryptedCookieSaver, form: RegisterForm, - query: RegisterRequest, + query: RegisterRequest, ) -> Result { + // TODO: display nice form errors if form.password != form.password_confirm { return Err(anyhow::anyhow!("password mismatch")).wrap_error(); } diff --git a/crates/core/src/handlers/views/shared.rs b/crates/core/src/handlers/views/shared.rs index be7b3b65..6d331d3f 100644 --- a/crates/core/src/handlers/views/shared.rs +++ b/crates/core/src/handlers/views/shared.rs @@ -21,7 +21,7 @@ use sqlx::PgExecutor; use super::super::oauth2::ContinueAuthorizationGrant; use crate::storage::PostgresqlBackend; -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] #[serde(rename_all = "snake_case", tag = "next")] pub(crate) enum PostAuthAction { #[serde(bound( diff --git a/crates/templates/Cargo.toml b/crates/templates/Cargo.toml index 7e18b3e1..039e04c9 100644 --- a/crates/templates/Cargo.toml +++ b/crates/templates/Cargo.toml @@ -18,6 +18,7 @@ thiserror = "1.0.30" tera = "1.15.0" serde = { version = "1.0.131", features = ["derive"] } serde_json = "1.0.72" +serde_urlencoded = "0.7.0" url = "2.2.2" warp = "0.3.2" diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 83c0bc7f..55f07f35 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -204,7 +204,7 @@ impl TemplateContext for IndexContext { } #[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] -#[serde(rename_all = "kebab-case")] +#[serde(rename_all = "snake_case")] pub enum LoginFormField { Username, Password, @@ -222,6 +222,7 @@ pub enum PostAuthContext { pub struct LoginContext { form: ErroredForm, next: Option, + register_link: String, } impl TemplateContext for LoginContext { @@ -233,6 +234,7 @@ impl TemplateContext for LoginContext { vec![LoginContext { form: ErroredForm::default(), next: None, + register_link: "/register".to_string(), }] } } @@ -250,6 +252,14 @@ impl LoginContext { ..self } } + + #[must_use] + pub fn with_register_link(self, register_link: String) -> Self { + Self { + register_link, + ..self + } + } } impl Default for LoginContext { @@ -257,6 +267,67 @@ impl Default for LoginContext { Self { form: ErroredForm::new(), next: None, + register_link: "/register".to_string(), + } + } +} + +#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RegisterFormField { + Username, + Password, + PasswordConfirm, +} + +/// Context used by the `register.html` template +#[derive(Serialize)] +pub struct RegisterContext { + form: ErroredForm, + next: Option, + login_link: String, +} + +impl TemplateContext for RegisterContext { + fn sample() -> Vec + where + Self: Sized, + { + // TODO: samples with errors + vec![RegisterContext { + form: ErroredForm::default(), + next: None, + login_link: "/login".to_string(), + }] + } +} + +impl RegisterContext { + #[must_use] + pub fn with_form_error(self, form: ErroredForm) -> Self { + Self { form, ..self } + } + + #[must_use] + pub fn with_post_action(self, next: PostAuthContext) -> Self { + Self { + next: Some(next), + ..self + } + } + + #[must_use] + pub fn with_login_link(self, login_link: String) -> Self { + Self { login_link, ..self } + } +} + +impl Default for RegisterContext { + fn default() -> Self { + Self { + form: ErroredForm::new(), + next: None, + login_link: "/login".to_string(), } } } diff --git a/crates/templates/src/functions.rs b/crates/templates/src/functions.rs index 01aac8f9..3d80f9b9 100644 --- a/crates/templates/src/functions.rs +++ b/crates/templates/src/functions.rs @@ -14,10 +14,16 @@ //! Additional functions, tests and filters used in templates +use std::{collections::HashMap, str::FromStr}; + use tera::{helpers::tests::number_args_allowed, Tera, Value}; +use url::Url; pub fn register(tera: &mut Tera) { tera.register_tester("empty", self::tester_empty); + tera.register_function("add_params_to_uri", function_add_params_to_uri); + tera.register_function("merge", function_merge); + tera.register_function("dict", function_dict); } fn tester_empty(value: Option<&Value>, params: &[Value]) -> Result { @@ -28,3 +34,84 @@ fn tester_empty(value: Option<&Value>, params: &[Value]) -> Result Ok(false), } } + +enum ParamsWhere { + Fragment, + Query, +} + +fn function_add_params_to_uri(params: &HashMap) -> Result { + use ParamsWhere::{Fragment, Query}; + + // First, get the `uri`, `mode` and `params` parameters + let uri = params + .get("uri") + .and_then(Value::as_str) + .ok_or_else(|| tera::Error::msg("Invalid parameter `uri`"))?; + let uri = Url::from_str(uri).map_err(|e| tera::Error::chain(uri, e))?; + let mode = params + .get("mode") + .and_then(Value::as_str) + .ok_or_else(|| tera::Error::msg("Invalid parameter `mode`"))?; + let mode = match mode { + "fragment" => Fragment, + "query" => Query, + _ => return Err(tera::Error::msg("Invalid mode")), + }; + let params = params + .get("params") + .and_then(Value::as_object) + .ok_or_else(|| tera::Error::msg("Invalid parameter `params`"))?; + + // Get the relevant part of the URI and parse for existing parameters + let existing = match mode { + Fragment => uri.fragment(), + Query => uri.query(), + }; + let existing: HashMap = existing + .map(serde_urlencoded::from_str) + .transpose() + .map_err(|e| tera::Error::chain(e, "Could not parse existing `uri` parameters"))? + .unwrap_or_default(); + + // Merge the exising and the additional parameters together + let params: HashMap<&String, &Value> = params + .iter() + // Filter out the `uri` and `mode` params + .filter(|(k, _v)| k != &"uri" && k != &"mode") + .chain(existing.iter()) + .collect(); + + // Transform them back to urlencoded + let params = serde_urlencoded::to_string(params) + .map_err(|e| tera::Error::chain(e, "Could not serialize back parameters"))?; + + let uri = { + let mut uri = uri; + match mode { + Fragment => uri.set_fragment(Some(¶ms)), + Query => uri.set_query(Some(¶ms)), + }; + uri + }; + + Ok(Value::String(uri.to_string())) +} + +fn function_merge(params: &HashMap) -> Result { + let mut ret = serde_json::Map::new(); + for (k, v) in params { + let v = v + .as_object() + .ok_or_else(|| tera::Error::msg(format!("Parameter {:?} should be an object", k)))?; + ret.extend(v.clone()); + } + + Ok(Value::Object(ret)) +} + +#[allow(clippy::unnecessary_wraps)] +fn function_dict(params: &HashMap) -> Result { + let ret = params.clone().into_iter().collect(); + Ok(Value::Object(ret)) +} diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index ff4058e0..69013444 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -48,8 +48,8 @@ mod macros; pub use self::context::{ EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, - PostAuthContext, ReauthContext, ReauthFormField, TemplateContext, WithCsrf, - WithOptionalSession, WithSession, + PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, + TemplateContext, WithCsrf, WithOptionalSession, WithSession, }; /// Wrapper around [`tera::Tera`] helping rendering the various templates @@ -280,6 +280,7 @@ register_templates! { extra = { "components/button.html", "components/field.html", + "components/back_to_client.html", "base.html", }; @@ -287,7 +288,7 @@ register_templates! { pub fn render_login(WithCsrf) { "login.html" } /// Render the registration page - pub fn render_register(WithCsrf) { "register.html" } + pub fn render_register(WithCsrf) { "register.html" } /// Render the home page pub fn render_index(WithCsrf>) { "index.html" } diff --git a/crates/templates/src/res/base.html b/crates/templates/src/res/base.html index 6bb47d39..06476078 100644 --- a/crates/templates/src/res/base.html +++ b/crates/templates/src/res/base.html @@ -16,6 +16,7 @@ limitations under the License. {% import "components/button.html" as button %} {% import "components/field.html" as field %} +{% import "components/back_to_client.html" as back_to_client %} diff --git a/crates/templates/src/res/components/back_to_client.html b/crates/templates/src/res/components/back_to_client.html new file mode 100644 index 00000000..6ac30a48 --- /dev/null +++ b/crates/templates/src/res/components/back_to_client.html @@ -0,0 +1,30 @@ +{# +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% macro link(text, class="", uri, mode, params) %} + {% if mode == "form_post" %} +
+ {% for key, value in params %} + + {% endfor %} + +
+ {% elif mode == "fragment" or mode == "query" %} + {{ text }} + {% else %} + {{ throw(message="Invalid mode") }} + {% endif %} +{% endmacro %} diff --git a/crates/templates/src/res/components/button.html b/crates/templates/src/res/components/button.html index b4eaaf8e..acedec66 100644 --- a/crates/templates/src/res/components/button.html +++ b/crates/templates/src/res/components/button.html @@ -35,7 +35,7 @@ limitations under the License. {% endmacro %} {% macro link_text(text, href="#", class="") %} -{{ text }} + {{ text }} {% endmacro %} {% macro link_ghost(text, href="#", class="") %} diff --git a/crates/templates/src/res/login.html b/crates/templates/src/res/login.html index 8775bdd6..d5fbc4d4 100644 --- a/crates/templates/src/res/login.html +++ b/crates/templates/src/res/login.html @@ -18,7 +18,12 @@ limitations under the License. {% block navbar_start %} {% if next and next.kind == "continue_authorization_grant" %} - ← Back + {{ back_to_client::link( + text="← Back", + uri=next.grant.redirect_uri, + mode=next.grant.response_mode, + params=dict(error="access_denied", state=next.grant.state) + ) }} {% endif %} {% endblock %} @@ -40,14 +45,19 @@ limitations under the License. {{ field::input(label="Username", name="username") }} {{ field::input(label="Password", name="password", type="password") }} {{ button::button(text="Next") }} - {{ button::link_text(text="Create account", href="/register") }} + {{ button::link_text(text="Create account", href=register_link) }} {% if next and next.kind == "continue_authorization_grant" %}
- {# TODO: proper back link #} - {{ button::link_text(text="Return to application", href="/") }} + {{ back_to_client::link( + text="Return to application", + class=button::text_class(), + uri=next.grant.redirect_uri, + mode=next.grant.response_mode, + params=dict(error="access_denied", state=next.grant.state) + ) }}
{% endif %} {% endblock content %} diff --git a/crates/templates/src/res/reauth.html b/crates/templates/src/res/reauth.html index 8240b65b..cd3d448b 100644 --- a/crates/templates/src/res/reauth.html +++ b/crates/templates/src/res/reauth.html @@ -16,6 +16,17 @@ limitations under the License. {% extends "base.html" %} +{% block navbar_start %} + {% if next and next.kind == "continue_authorization_grant" %} + {{ back_to_client::link( + text="← Back", + uri=next.grant.redirect_uri, + mode=next.grant.response_mode, + params=dict(error="access_denied", state=next.grant.state) + ) }} + {% endif %} +{% endblock %} + {% block content %}
@@ -44,5 +55,17 @@ limitations under the License. {% endif %} + + {% if next and next.kind == "continue_authorization_grant" %} +
+ {{ back_to_client::link( + text="Return to application", + class=button::text_class(), + uri=next.grant.redirect_uri, + mode=next.grant.response_mode, + params=dict(error="access_denied", state=next.grant.state) + ) }} +
+ {% endif %} {% endblock content %} diff --git a/crates/templates/src/res/register.html b/crates/templates/src/res/register.html index 008ad18d..d5998766 100644 --- a/crates/templates/src/res/register.html +++ b/crates/templates/src/res/register.html @@ -18,7 +18,12 @@ limitations under the License. {% block navbar_start %} {% if next and next.kind == "continue_authorization_grant" %} - ← Back + {{ back_to_client::link( + text="← Back", + uri=next.grant.redirect_uri, + mode=next.grant.response_mode, + params=dict(error="access_denied", state=next.grant.state) + ) }} {% endif %} {% endblock %} @@ -40,14 +45,19 @@ limitations under the License. {{ field::input(label="Confirm Password", name="password_confirm", type="password") }} {{ button::button(text="Next") }} {# TODO: proper link #} - {{ button::link_text(text="Login instead", href="/login") }} + {{ button::link_text(text="Login instead", href=login_link) }}
- {% if next %} + {% if next and next.kind == "continue_authorization_grant" %}
- {# TODO: proper back link #} - {{ button::link_text(text="Return to application", href="/") }} + {{ back_to_client::link( + text="Return to application", + class=button::text_class(), + uri=next.grant.redirect_uri, + mode=next.grant.response_mode, + params=dict(error="access_denied", state=next.grant.state) + ) }}
{% endif %} {% endblock content %}