From 9c94e11e684cae00b00082428c3e2c2d7a47c341 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 9 Nov 2023 18:18:28 +0100 Subject: [PATCH] Check for existing users ahead of time on upstream OAuth2 registration --- crates/handlers/src/upstream_oauth2/link.rs | 180 ++++++++++++++---- crates/templates/src/context.rs | 64 ++++--- crates/templates/src/forms.rs | 23 ++- crates/templates/src/lib.rs | 3 +- templates/components/errors.html | 2 + templates/components/field.html | 23 +-- .../pages/upstream_oauth2/do_register.html | 37 ++-- translations/en.json | 24 +-- 8 files changed, 247 insertions(+), 109 deletions(-) diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index deccb2e2..65fd83b6 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -14,7 +14,7 @@ use axum::{ extract::{Path, State}, - response::{Html, IntoResponse}, + response::{Html, IntoResponse, Response}, Form, TypedHeader, }; use hyper::StatusCode; @@ -35,12 +35,13 @@ use mas_storage::{ BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ - ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, - UpstreamSuggestLink, + ErrorContext, FieldError, FormError, TemplateContext, Templates, ToFormState, + UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, }; use minijinja::Environment; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use thiserror::Error; +use tracing::warn; use ulid::Ulid; use super::UpstreamSessionsCookie; @@ -83,14 +84,6 @@ pub(crate) enum RouteError { #[error("Invalid form action")] InvalidFormAction, - #[error("Missing username")] - MissingUsername, - - #[error("Policy violation: {violations:?}")] - PolicyViolation { - violations: Vec, - }, - #[error(transparent)] Internal(Box), } @@ -107,16 +100,6 @@ impl IntoResponse for RouteError { let event_id = sentry::capture_error(&self); let response = match self { Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(), - Self::PolicyViolation { violations } => { - let details = violations.iter().map(|v| v.msg.clone()).collect::>(); - let details = details.join("\n"); - let ctx = ErrorContext::new() - .with_description( - "Account registration denied because of policy violation".to_owned(), - ) - .with_details(details); - FancyError::new(ctx).into_response() - } Self::Internal(e) => FancyError::from(e).into_response(), e => FancyError::from(e).into_response(), }; @@ -171,7 +154,7 @@ fn import_claim( Ok(()) } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize)] #[serde(rename_all = "lowercase", tag = "action")] pub(crate) enum FormData { Register { @@ -185,6 +168,10 @@ pub(crate) enum FormData { Link, } +impl ToFormState for FormData { + type Field = mas_templates::UpstreamRegisterFormField; +} + #[tracing::instrument( name = "handlers.upstream_oauth2.link.get", fields(upstream_oauth_link.id = %link_id), @@ -195,6 +182,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, mut repo: BoxRepository, + mut policy: Policy, PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, @@ -339,7 +327,7 @@ pub(crate) async fn get( .map(|id_token| id_token.into_parts().1) .unwrap_or_default(); - let mut ctx = UpstreamRegister::new(&link); + let mut ctx = UpstreamRegister::default(); let env = { let mut e = Environment::new(); @@ -375,6 +363,7 @@ pub(crate) async fn get( }, )?; + let mut forced_localpart = None; import_claim( &env, provider @@ -385,10 +374,59 @@ pub(crate) async fn get( .unwrap_or("{{ user.preferred_username }}"), &provider.claims_imports.localpart, |value, force| { + if force { + // We want to run the policy check on the username if it is forced + forced_localpart = Some(value.clone()); + } + ctx.set_localpart(value, force); }, )?; + // Run the policy check and check for existing users + if let Some(localpart) = forced_localpart { + let maybe_existing_user = repo.user().find_by_username(&localpart).await?; + if let Some(existing_user) = maybe_existing_user { + // The mapper returned a username which already exists, but isn't linked to + // this upstream user. + warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username"); + + // TODO: translate + let ctx = ErrorContext::new() + .with_code("User exists") + .with_description(format!( + r#"Upstream account provider returned {localpart:?} as username, + which is not linked to that upstream account"# + )) + .with_language(&locale); + + return Ok(( + cookie_jar, + Html(templates.render_error(&ctx)?).into_response(), + )); + } + + let res = policy + .evaluate_upstream_oauth_register(&localpart, None) + .await?; + + if !res.valid() { + // TODO: translate + let ctx = ErrorContext::new() + .with_code("Policy error") + .with_description(format!( + r#"Upstream account provider returned {localpart:?} as username, + which does not pass the policy check: {res}"# + )) + .with_language(&locale); + + return Ok(( + cookie_jar, + Html(templates.render_error(&ctx)?).into_response(), + )); + } + } + let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale); Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response() @@ -411,10 +449,12 @@ pub(crate) async fn post( cookie_jar: CookieJar, user_agent: Option>, mut policy: Policy, + PreferredLanguage(locale): PreferredLanguage, + State(templates): State, State(url_builder): State, Path(link_id): Path, Form(form): Form>, -) -> Result { +) -> Result { let user_agent = user_agent.map(|ua| ua.as_str().to_owned()); let form = cookie_jar.verify_form(&clock, form)?; @@ -449,8 +489,10 @@ pub(crate) async fn post( return Err(RouteError::SessionConsumed); } + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (user_session_info, cookie_jar) = cookie_jar.session_info(); let maybe_user_session = user_session_info.load_session(&mut repo).await?; + let form_state = form.to_form_state(); let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { @@ -495,13 +537,15 @@ pub(crate) async fn post( .unwrap_or(false); // Let's try to import the claims from the ID token - let env = { let mut e = Environment::new(); e.add_global("user", payload); e }; + // Create a template context in case we need to re-render because of an error + let mut ctx = UpstreamRegister::default(); + let mut name = None; import_claim( &env, @@ -515,8 +559,10 @@ pub(crate) async fn post( |value, force| { // Import the display name if it is either forced or the user has requested it if force || import_display_name { - name = Some(value); + name = Some(value.clone()); } + + ctx.set_display_name(value, force); }, )?; @@ -533,8 +579,10 @@ pub(crate) async fn post( |value, force| { // Import the email if it is either forced or the user has requested it if force || import_email { - email = Some(value); + email = Some(value.clone()); } + + ctx.set_email(value, force); }, )?; @@ -551,21 +599,85 @@ pub(crate) async fn post( |value, force| { // If the username is forced, override whatever was in the form if force { - username = Some(value); + username = Some(value.clone()); } + + ctx.set_localpart(value, force); }, )?; - let username = username.ok_or(RouteError::MissingUsername)?; + let username = username.filter(|s| !s.is_empty()); + + let Some(username) = username else { + let form_state = form_state.with_error_on_field( + mas_templates::UpstreamRegisterFormField::Username, + FieldError::Required, + ); + + let ctx = ctx + .with_form_state(form_state) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + return Ok(( + cookie_jar, + Html(templates.render_upstream_oauth2_do_register(&ctx)?), + ) + .into_response()); + }; + + // Check if there is an existing user + let existing_user = repo.user().find_by_username(&username).await?; + if let Some(_existing_user) = existing_user { + // If there is an existing user, we can't create a new one + // with the same username + + let form_state = form_state.with_error_on_field( + mas_templates::UpstreamRegisterFormField::Username, + FieldError::Exists, + ); + + let ctx = ctx + .with_form_state(form_state) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + return Ok(( + cookie_jar, + Html(templates.render_upstream_oauth2_do_register(&ctx)?), + ) + .into_response()); + } // Policy check let res = policy .evaluate_upstream_oauth_register(&username, email.as_deref()) .await?; if !res.valid() { - return Err(RouteError::PolicyViolation { - violations: res.violations, - }); + let form_state = + res.violations + .into_iter() + .fold(form_state, |form_state, violation| { + match violation.field.as_deref() { + Some("username") => form_state.with_error_on_field( + mas_templates::UpstreamRegisterFormField::Username, + FieldError::Policy { + message: violation.msg, + }, + ), + _ => form_state.with_error_on_form(FormError::Policy { + message: violation.msg, + }), + } + }); + + let ctx = ctx + .with_form_state(form_state) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + return Ok(( + cookie_jar, + Html(templates.render_upstream_oauth2_do_register(&ctx)?), + ) + .into_response()); } // Now we can create the user @@ -631,5 +743,5 @@ pub(crate) async fn post( repo.save().await?; - Ok((cookie_jar, post_auth_action.go_next(&url_builder))) + Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response()) } diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index f6ca8e89..228064ce 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -25,7 +25,7 @@ use mas_data_model::{ UpstreamOAuthLink, UpstreamOAuthProvider, User, UserEmail, UserEmailVerification, }; use mas_i18n::DataLocale; -use mas_router::{Account, GraphQL, PostAuthAction, Route, UrlBuilder}; +use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder}; use rand::Rng; use serde::{ser::SerializeStruct, Deserialize, Serialize}; use ulid::Ulid; @@ -918,68 +918,78 @@ impl TemplateContext for UpstreamSuggestLink { } } +/// User-editeable fields of the upstream account link form +#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum UpstreamRegisterFormField { + /// The username field + Username, +} + +impl FormField for UpstreamRegisterFormField { + fn keep(&self) -> bool { + match self { + Self::Username => true, + } + } +} + /// Context used by the `pages/upstream_oauth2/do_register.html` /// templates -#[derive(Serialize)] +#[derive(Serialize, Default)] pub struct UpstreamRegister { - login_link: String, - suggested_localpart: Option, + imported_localpart: Option, force_localpart: bool, - suggested_display_name: Option, + imported_display_name: Option, force_display_name: bool, - suggested_email: Option, + imported_email: Option, force_email: bool, + form_state: FormState, } impl UpstreamRegister { /// Constructs a new context with an existing linked user #[must_use] - pub fn new(link: &UpstreamOAuthLink) -> Self { - Self::for_link_id(link.id) + pub fn new() -> Self { + Self::default() } /// Set the suggested localpart pub fn set_localpart(&mut self, localpart: String, force: bool) { - self.suggested_localpart = Some(localpart); + self.imported_localpart = Some(localpart); self.force_localpart = force; } /// Set the suggested display name pub fn set_display_name(&mut self, display_name: String, force: bool) { - self.suggested_display_name = Some(display_name); + self.imported_display_name = Some(display_name); self.force_display_name = force; } /// Set the suggested email pub fn set_email(&mut self, email: String, force: bool) { - self.suggested_email = Some(email); + self.imported_email = Some(email); self.force_email = force; } - fn for_link_id(id: Ulid) -> Self { - let login_link = mas_router::Login::and_link_upstream(id) - .path_and_query() - .into(); + /// Set the form state + pub fn set_form_state(&mut self, form_state: FormState) { + self.form_state = form_state; + } - Self { - login_link, - suggested_localpart: None, - force_localpart: false, - suggested_display_name: None, - force_display_name: false, - suggested_email: None, - force_email: false, - } + /// Set the form state + #[must_use] + pub fn with_form_state(self, form_state: FormState) -> Self { + Self { form_state, ..self } } } impl TemplateContext for UpstreamRegister { - fn sample(now: chrono::DateTime, rng: &mut impl Rng) -> Vec + fn sample(_now: chrono::DateTime, _rng: &mut impl Rng) -> Vec where Self: Sized, { - let id = Ulid::from_datetime_with_source(now.into(), rng); - vec![Self::for_link_id(id)] + vec![Self::new()] } } diff --git a/crates/templates/src/forms.rs b/crates/templates/src/forms.rs index c66b09b5..85e5014a 100644 --- a/crates/templates/src/forms.rs +++ b/crates/templates/src/forms.rs @@ -92,6 +92,22 @@ impl Default for FormState { } } +#[derive(Deserialize, PartialEq, Eq, Hash)] +#[serde(untagged)] +enum KeyOrOther { + Key(K), + Other(String), +} + +impl KeyOrOther { + fn key(self) -> Option { + match self { + Self::Key(key) => Some(key), + Self::Other(_) => None, + } + } +} + impl FormState { /// Generate a [`FormState`] out of a form /// @@ -101,17 +117,18 @@ impl FormState { /// deserialize pub fn from_form(form: &F) -> Self { let form = serde_json::to_value(form).unwrap(); - let fields: HashMap> = serde_json::from_value(form).unwrap(); + let fields: HashMap, Option> = serde_json::from_value(form).unwrap(); let fields = fields .into_iter() - .map(|(key, value)| { + .filter_map(|(key, value)| { + let key = key.key()?; let value = key.keep().then_some(value).flatten(); let field = FieldState { value, errors: Vec::new(), }; - (key, field) + Some((key, field)) }) .collect(); diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 7b3b8786..0cfad97e 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -54,7 +54,8 @@ pub use self::{ LoginContext, LoginFormField, NotFoundContext, PolicyViolationContext, PostAuthContext, PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, SiteBranding, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, - UpstreamSuggestLink, WithCsrf, WithLanguage, WithOptionalSession, WithSession, + UpstreamRegisterFormField, UpstreamSuggestLink, WithCsrf, WithLanguage, + WithOptionalSession, WithSession, }, forms::{FieldError, FormError, FormField, FormState, ToFormState}, }; diff --git a/templates/components/errors.html b/templates/components/errors.html index d60723fc..ae7384bc 100644 --- a/templates/components/errors.html +++ b/templates/components/errors.html @@ -19,6 +19,8 @@ limitations under the License. {{ _("mas.errors.invalid_credentials") }} {% elif error.kind == "password_mismatch" %} {{ _("mas.errors.password_mismatch") }} + {% elif error.kind == "policy" %} + {{ _("mas.errors.denied_policy", policy=error.message) }} {% else %} {{ error.kind }} {% endif %} diff --git a/templates/components/field.html b/templates/components/field.html index 1469ce70..61e8cc44 100644 --- a/templates/components/field.html +++ b/templates/components/field.html @@ -20,28 +20,13 @@ limitations under the License. form-{{- cnt.next() -}} {%- endmacro %} -{% macro attributes(field) -%} +{% macro attributes(field, default_value=None) -%} + {%- set value = field.value | default(default_value) -%} name="{{ field.name }}" id="{{ field.id }}" {%- if field.errors is not empty %} data-invalid{% endif %} - {%- if field.value %} value="{{ field.value }}" {% endif %} + {%- if value %} value="{{ value }}" {% endif %} {%- endmacro %} -{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false, required=false, readonly=false) %} - {% call(field) field(label=label, name=name, form_state=form_state, class=class) %} - - {% endcall %} -{% endmacro %} - {% macro field(label, name, form_state=false, class="") %} {% set field_id = new_id() %} {% if not form_state %} @@ -66,7 +51,7 @@ limitations under the License. {% if field.errors is not empty %} {% for error in field.errors %} {% if error.kind != "unspecified" %} -
+
{% if error.kind == "required" %} {{ _("mas.errors.field_required") }} {% elif error.kind == "exists" and field.name == "username" %} diff --git a/templates/pages/upstream_oauth2/do_register.html b/templates/pages/upstream_oauth2/do_register.html index cd6b35ba..c15f3014 100644 --- a/templates/pages/upstream_oauth2/do_register.html +++ b/templates/pages/upstream_oauth2/do_register.html @@ -53,31 +53,42 @@ limitations under the License. + {% if form_state.errors is not empty %} + {% for error in form_state.errors %} +
+ {{- errors.form_error_message(error=error) -}} +
+ {% endfor %} + {% endif %} + + {% if force_localpart %} {% call(f) field.field(label=_("common.mxid"), name="mxid") %} - +
- {{ _("mas.upstream_oauth2.register.enforced_by_policy") }} + {{- _("mas.upstream_oauth2.register.enforced_by_policy") -}}
{% endcall %} {% else %} - {% call(f) field.field(label=_("common.username"), name="username") %} - + {% call(f) field.field(label=_("common.username"), name="username", form_state=form_state) %} + -
- @{{ suggested_localpart or (_("common.username") | lower) }}:{{ branding.server_name }} -
+ {% if f.errors is empty %} +
+ @{{ imported_localpart or (_("common.username") | lower) }}:{{ branding.server_name }} +
+ {% endif %} {% endcall %} {% endif %} - {% if suggested_email %} + {% if imported_email %}
{% call(f) field.field(label=_("common.email_address"), name="email", class="flex-1") %} - +
- {{ _("mas.upstream_oauth2.register.imported_from_upstream") }} + {{- _("mas.upstream_oauth2.register.imported_from_upstream") -}}
{% endcall %} @@ -99,13 +110,13 @@ limitations under the License.
{% endif %} - {% if suggested_display_name %} + {% if imported_display_name %}
{% call(f) field.field(label=_("common.display_name"), name="display_name", class="flex-1") %} - +
- {{ _("mas.upstream_oauth2.register.imported_from_upstream") }} + {{- _("mas.upstream_oauth2.register.imported_from_upstream") -}}
{% endcall %} diff --git a/translations/en.json b/translations/en.json index 9bac5162..a3dba79b 100644 --- a/translations/en.json +++ b/translations/en.json @@ -10,7 +10,7 @@ }, "create_account": "Create Account", "@create_account": { - "context": "pages/login.html:69:35-61, pages/upstream_oauth2/do_register.html:132:26-52" + "context": "pages/login.html:69:35-61, pages/upstream_oauth2/do_register.html:143:26-52" }, "sign_in": "Sign in", "@sign_in": { @@ -63,15 +63,15 @@ "common": { "display_name": "Display Name", "@display_name": { - "context": "pages/upstream_oauth2/do_register.html:104:37-61" + "context": "pages/upstream_oauth2/do_register.html:115:37-61" }, "email_address": "Email address", "@email_address": { - "context": "pages/account/emails/add.html:41:33-58, pages/register.html:47:35-60, pages/upstream_oauth2/do_register.html:76:37-62" + "context": "pages/account/emails/add.html:41:33-58, pages/register.html:47:35-60, pages/upstream_oauth2/do_register.html:87:37-62" }, "mxid": "Matrix ID", "@mxid": { - "context": "pages/upstream_oauth2/do_register.html:57:35-51" + "context": "pages/upstream_oauth2/do_register.html:66:35-51" }, "password": "Password", "@password": { @@ -83,7 +83,7 @@ }, "username": "Username", "@username": { - "context": "pages/login.html:51:37-57, pages/register.html:43:35-55, pages/upstream_oauth2/do_register.html:65:35-55, pages/upstream_oauth2/do_register.html:69:38-58" + "context": "pages/login.html:51:37-57, pages/register.html:43:35-55, pages/upstream_oauth2/do_register.html:74:35-55, pages/upstream_oauth2/do_register.html:79:39-59" } }, "error": { @@ -167,11 +167,11 @@ "errors": { "denied_policy": "Denied by policy: %(policy)s", "@denied_policy": { - "context": "components/field.html:75:17-68" + "context": "components/errors.html:23:7-58, components/field.html:60:17-68" }, "field_required": "This field is required", "@field_required": { - "context": "components/field.html:71:17-47" + "context": "components/field.html:56:17-47" }, "invalid_credentials": "Invalid credentials", "@invalid_credentials": { @@ -183,7 +183,7 @@ }, "username_taken": "This username is already taken", "@username_taken": { - "context": "components/field.html:73:17-47" + "context": "components/field.html:58:17-47" } }, "login": { @@ -251,7 +251,7 @@ }, "or_separator": "Or", "@or_separator": { - "context": "components/field.html:90:10-31", + "context": "components/field.html:75:10-31", "description": "Separator between the login methods" }, "policy_violation": { @@ -353,7 +353,7 @@ }, "enforced_by_policy": "Enforced by server policy", "@enforced_by_policy": { - "context": "pages/upstream_oauth2/do_register.html:61:13-65" + "context": "pages/upstream_oauth2/do_register.html:70:14-66" }, "forced_display_name": "Will use the following display name", "@forced_display_name": { @@ -379,7 +379,7 @@ }, "imported_from_upstream": "Imported from your upstream account", "@imported_from_upstream": { - "context": "pages/upstream_oauth2/do_register.html:108:15-71, pages/upstream_oauth2/do_register.html:80:15-71" + "context": "pages/upstream_oauth2/do_register.html:119:16-72, pages/upstream_oauth2/do_register.html:91:16-72" }, "link_existing": "Link to an existing account", "@link_existing": { @@ -395,7 +395,7 @@ }, "use": "Use", "@use": { - "context": "pages/upstream_oauth2/do_register.html:124:20-57, pages/upstream_oauth2/do_register.html:95:18-55" + "context": "pages/upstream_oauth2/do_register.html:106:18-55, pages/upstream_oauth2/do_register.html:135:20-57" } }, "suggest_link": {