1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Check for existing users ahead of time on upstream OAuth2 registration

This commit is contained in:
Quentin Gliech
2023-11-09 18:18:28 +01:00
parent 8a1329de05
commit 9c94e11e68
8 changed files with 247 additions and 109 deletions

View File

@ -14,7 +14,7 @@
use axum::{
extract::{Path, State},
response::{Html, IntoResponse},
response::{Html, IntoResponse, Response},
Form, TypedHeader,
};
use hyper::StatusCode;
@ -35,12 +35,13 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_templates::{
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
ErrorContext, FieldError, FormError, TemplateContext, Templates, ToFormState,
UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
};
use minijinja::Environment;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::warn;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
@ -83,14 +84,6 @@ pub(crate) enum RouteError {
#[error("Invalid form action")]
InvalidFormAction,
#[error("Missing username")]
MissingUsername,
#[error("Policy violation: {violations:?}")]
PolicyViolation {
violations: Vec<mas_policy::Violation>,
},
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
@ -107,16 +100,6 @@ impl IntoResponse for RouteError {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
Self::PolicyViolation { violations } => {
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
let details = details.join("\n");
let ctx = ErrorContext::new()
.with_description(
"Account registration denied because of policy violation".to_owned(),
)
.with_details(details);
FancyError::new(ctx).into_response()
}
Self::Internal(e) => FancyError::from(e).into_response(),
e => FancyError::from(e).into_response(),
};
@ -171,7 +154,7 @@ fn import_claim(
Ok(())
}
#[derive(Deserialize)]
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "lowercase", tag = "action")]
pub(crate) enum FormData {
Register {
@ -185,6 +168,10 @@ pub(crate) enum FormData {
Link,
}
impl ToFormState for FormData {
type Field = mas_templates::UpstreamRegisterFormField;
}
#[tracing::instrument(
name = "handlers.upstream_oauth2.link.get",
fields(upstream_oauth_link.id = %link_id),
@ -195,6 +182,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
mut policy: Policy,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
@ -339,7 +327,7 @@ pub(crate) async fn get(
.map(|id_token| id_token.into_parts().1)
.unwrap_or_default();
let mut ctx = UpstreamRegister::new(&link);
let mut ctx = UpstreamRegister::default();
let env = {
let mut e = Environment::new();
@ -375,6 +363,7 @@ pub(crate) async fn get(
},
)?;
let mut forced_localpart = None;
import_claim(
&env,
provider
@ -385,10 +374,59 @@ pub(crate) async fn get(
.unwrap_or("{{ user.preferred_username }}"),
&provider.claims_imports.localpart,
|value, force| {
if force {
// We want to run the policy check on the username if it is forced
forced_localpart = Some(value.clone());
}
ctx.set_localpart(value, force);
},
)?;
// Run the policy check and check for existing users
if let Some(localpart) = forced_localpart {
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
if let Some(existing_user) = maybe_existing_user {
// The mapper returned a username which already exists, but isn't linked to
// this upstream user.
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
// TODO: translate
let ctx = ErrorContext::new()
.with_code("User exists")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which is not linked to that upstream account"#
))
.with_language(&locale);
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
let res = policy
.evaluate_upstream_oauth_register(&localpart, None)
.await?;
if !res.valid() {
// TODO: translate
let ctx = ErrorContext::new()
.with_code("Policy error")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which does not pass the policy check: {res}"#
))
.with_language(&locale);
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
}
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
@ -411,10 +449,12 @@ pub(crate) async fn post(
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
mut policy: Policy,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> {
) -> Result<Response, RouteError> {
let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
let form = cookie_jar.verify_form(&clock, form)?;
@ -449,8 +489,10 @@ pub(crate) async fn post(
return Err(RouteError::SessionConsumed);
}
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let form_state = form.to_form_state();
let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
@ -495,13 +537,15 @@ pub(crate) async fn post(
.unwrap_or(false);
// Let's try to import the claims from the ID token
let env = {
let mut e = Environment::new();
e.add_global("user", payload);
e
};
// Create a template context in case we need to re-render because of an error
let mut ctx = UpstreamRegister::default();
let mut name = None;
import_claim(
&env,
@ -515,8 +559,10 @@ pub(crate) async fn post(
|value, force| {
// Import the display name if it is either forced or the user has requested it
if force || import_display_name {
name = Some(value);
name = Some(value.clone());
}
ctx.set_display_name(value, force);
},
)?;
@ -533,8 +579,10 @@ pub(crate) async fn post(
|value, force| {
// Import the email if it is either forced or the user has requested it
if force || import_email {
email = Some(value);
email = Some(value.clone());
}
ctx.set_email(value, force);
},
)?;
@ -551,21 +599,85 @@ pub(crate) async fn post(
|value, force| {
// If the username is forced, override whatever was in the form
if force {
username = Some(value);
username = Some(value.clone());
}
ctx.set_localpart(value, force);
},
)?;
let username = username.ok_or(RouteError::MissingUsername)?;
let username = username.filter(|s| !s.is_empty());
let Some(username) = username else {
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
);
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
};
// Check if there is an existing user
let existing_user = repo.user().find_by_username(&username).await?;
if let Some(_existing_user) = existing_user {
// If there is an existing user, we can't create a new one
// with the same username
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}
// Policy check
let res = policy
.evaluate_upstream_oauth_register(&username, email.as_deref())
.await?;
if !res.valid() {
return Err(RouteError::PolicyViolation {
violations: res.violations,
});
let form_state =
res.violations
.into_iter()
.fold(form_state, |form_state, violation| {
match violation.field.as_deref() {
Some("username") => form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
message: violation.msg,
},
),
_ => form_state.with_error_on_form(FormError::Policy {
message: violation.msg,
}),
}
});
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}
// Now we can create the user
@ -631,5 +743,5 @@ pub(crate) async fn post(
repo.save().await?;
Ok((cookie_jar, post_auth_action.go_next(&url_builder)))
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}

View File

@ -25,7 +25,7 @@ use mas_data_model::{
UpstreamOAuthLink, UpstreamOAuthProvider, User, UserEmail, UserEmailVerification,
};
use mas_i18n::DataLocale;
use mas_router::{Account, GraphQL, PostAuthAction, Route, UrlBuilder};
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
use rand::Rng;
use serde::{ser::SerializeStruct, Deserialize, Serialize};
use ulid::Ulid;
@ -918,68 +918,78 @@ impl TemplateContext for UpstreamSuggestLink {
}
}
/// User-editeable fields of the upstream account link form
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum UpstreamRegisterFormField {
/// The username field
Username,
}
impl FormField for UpstreamRegisterFormField {
fn keep(&self) -> bool {
match self {
Self::Username => true,
}
}
}
/// Context used by the `pages/upstream_oauth2/do_register.html`
/// templates
#[derive(Serialize)]
#[derive(Serialize, Default)]
pub struct UpstreamRegister {
login_link: String,
suggested_localpart: Option<String>,
imported_localpart: Option<String>,
force_localpart: bool,
suggested_display_name: Option<String>,
imported_display_name: Option<String>,
force_display_name: bool,
suggested_email: Option<String>,
imported_email: Option<String>,
force_email: bool,
form_state: FormState<UpstreamRegisterFormField>,
}
impl UpstreamRegister {
/// Constructs a new context with an existing linked user
#[must_use]
pub fn new(link: &UpstreamOAuthLink) -> Self {
Self::for_link_id(link.id)
pub fn new() -> Self {
Self::default()
}
/// Set the suggested localpart
pub fn set_localpart(&mut self, localpart: String, force: bool) {
self.suggested_localpart = Some(localpart);
self.imported_localpart = Some(localpart);
self.force_localpart = force;
}
/// Set the suggested display name
pub fn set_display_name(&mut self, display_name: String, force: bool) {
self.suggested_display_name = Some(display_name);
self.imported_display_name = Some(display_name);
self.force_display_name = force;
}
/// Set the suggested email
pub fn set_email(&mut self, email: String, force: bool) {
self.suggested_email = Some(email);
self.imported_email = Some(email);
self.force_email = force;
}
fn for_link_id(id: Ulid) -> Self {
let login_link = mas_router::Login::and_link_upstream(id)
.path_and_query()
.into();
/// Set the form state
pub fn set_form_state(&mut self, form_state: FormState<UpstreamRegisterFormField>) {
self.form_state = form_state;
}
Self {
login_link,
suggested_localpart: None,
force_localpart: false,
suggested_display_name: None,
force_display_name: false,
suggested_email: None,
force_email: false,
}
/// Set the form state
#[must_use]
pub fn with_form_state(self, form_state: FormState<UpstreamRegisterFormField>) -> Self {
Self { form_state, ..self }
}
}
impl TemplateContext for UpstreamRegister {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![Self::for_link_id(id)]
vec![Self::new()]
}
}

View File

@ -92,6 +92,22 @@ impl<K: Hash + Eq> Default for FormState<K> {
}
}
#[derive(Deserialize, PartialEq, Eq, Hash)]
#[serde(untagged)]
enum KeyOrOther<K> {
Key(K),
Other(String),
}
impl<K> KeyOrOther<K> {
fn key(self) -> Option<K> {
match self {
Self::Key(key) => Some(key),
Self::Other(_) => None,
}
}
}
impl<K: FormField> FormState<K> {
/// Generate a [`FormState`] out of a form
///
@ -101,17 +117,18 @@ impl<K: FormField> FormState<K> {
/// deserialize
pub fn from_form<F: Serialize>(form: &F) -> Self {
let form = serde_json::to_value(form).unwrap();
let fields: HashMap<K, Option<String>> = serde_json::from_value(form).unwrap();
let fields: HashMap<KeyOrOther<K>, Option<String>> = serde_json::from_value(form).unwrap();
let fields = fields
.into_iter()
.map(|(key, value)| {
.filter_map(|(key, value)| {
let key = key.key()?;
let value = key.keep().then_some(value).flatten();
let field = FieldState {
value,
errors: Vec::new(),
};
(key, field)
Some((key, field))
})
.collect();

View File

@ -54,7 +54,8 @@ pub use self::{
LoginContext, LoginFormField, NotFoundContext, PolicyViolationContext, PostAuthContext,
PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
SiteBranding, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink, WithCsrf, WithLanguage, WithOptionalSession, WithSession,
UpstreamRegisterFormField, UpstreamSuggestLink, WithCsrf, WithLanguage,
WithOptionalSession, WithSession,
},
forms::{FieldError, FormError, FormField, FormState, ToFormState},
};

View File

@ -19,6 +19,8 @@ limitations under the License.
{{ _("mas.errors.invalid_credentials") }}
{% elif error.kind == "password_mismatch" %}
{{ _("mas.errors.password_mismatch") }}
{% elif error.kind == "policy" %}
{{ _("mas.errors.denied_policy", policy=error.message) }}
{% else %}
{{ error.kind }}
{% endif %}

View File

@ -20,28 +20,13 @@ limitations under the License.
form-{{- cnt.next() -}}
{%- endmacro %}
{% macro attributes(field) -%}
{% macro attributes(field, default_value=None) -%}
{%- set value = field.value | default(default_value) -%}
name="{{ field.name }}" id="{{ field.id }}"
{%- if field.errors is not empty %} data-invalid{% endif %}
{%- if field.value %} value="{{ field.value }}" {% endif %}
{%- if value %} value="{{ value }}" {% endif %}
{%- endmacro %}
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false, required=false, readonly=false) %}
{% call(field) field(label=label, name=name, form_state=form_state, class=class) %}
<input {{ attributes(field) }}
class="cpd-text-control"
type="{{ type }}"
inputmode="{{ inputmode }}"
{% if required %} required {% endif %}
{% if disabled %} disabled {% endif %}
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
{% if autocorrect %} autocorrect="{{ autocorrect }}" {% endif %}
{% if autocapitalize %} autocapitalize="{{ autocapitalize }}" {% endif %}
{% if readonly %} readonly {% endif %}
/>
{% endcall %}
{% endmacro %}
{% macro field(label, name, form_state=false, class="") %}
{% set field_id = new_id() %}
{% if not form_state %}
@ -66,7 +51,7 @@ limitations under the License.
{% if field.errors is not empty %}
{% for error in field.errors %}
{% if error.kind != "unspecified" %}
<div class="cpd-form-error-message">
<div class="cpd-form-message cpd-form-error-message">
{% if error.kind == "required" %}
{{ _("mas.errors.field_required") }}
{% elif error.kind == "exists" and field.name == "username" %}

View File

@ -53,31 +53,42 @@ limitations under the License.
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
<input type="hidden" name="action" value="register" />
{% if form_state.errors is not empty %}
{% for error in form_state.errors %}
<div class="text-critical font-medium">
{{- errors.form_error_message(error=error) -}}
</div>
{% endfor %}
{% endif %}
{% if force_localpart %}
{% call(f) field.field(label=_("common.mxid"), name="mxid") %}
<input {{ field.attributes(f) }} class="cpd-text-control" type="text" value="@{{ suggested_localpart }}:{{ branding.server_name }}" readonly aria-describedby="{{ f.id }}-help" />
<input {{ field.attributes(f) }} class="cpd-text-control" type="text" value="@{{ imported_localpart }}:{{ branding.server_name }}" readonly aria-describedby="{{ f.id }}-help" />
<div class="cpd-form-message cpd-form-help-message" id="{{ f.id }}-help">
{{ _("mas.upstream_oauth2.register.enforced_by_policy") }}
{{- _("mas.upstream_oauth2.register.enforced_by_policy") -}}
</div>
{% endcall %}
{% else %}
{% call(f) field.field(label=_("common.username"), name="username") %}
<input {{ field.attributes(f) }} class="cpd-text-control" type="text" autocomplete="username" autocorrect="off" autocapitalize="none" value="{{ suggested_localpart or '' }}" aria-describedby="{{ f.id }}-help" />
{% call(f) field.field(label=_("common.username"), name="username", form_state=form_state) %}
<input {{ field.attributes(f) }} class="cpd-text-control" type="text" autocomplete="username" autocorrect="off" autocapitalize="none" value="{{ imported_localpart or '' }}" aria-describedby="{{ f.id }}-help" />
<div class="cpd-form-message cpd-form-help-message" id="{{ f.id }}-help">
@{{ suggested_localpart or (_("common.username") | lower) }}:{{ branding.server_name }}
</div>
{% if f.errors is empty %}
<div class="cpd-form-message cpd-form-help-message" id="{{ f.id }}-help">
@{{ imported_localpart or (_("common.username") | lower) }}:{{ branding.server_name }}
</div>
{% endif %}
{% endcall %}
{% endif %}
{% if suggested_email %}
{% if imported_email %}
<div class="flex gap-6 items-center">
{% call(f) field.field(label=_("common.email_address"), name="email", class="flex-1") %}
<input {{ field.attributes(f) }} class="cpd-text-control" type="email" value="{{ suggested_email }}" readonly aria-describedby="{{ f.id }}-help" />
<input {{ field.attributes(f) }} class="cpd-text-control" type="email" value="{{ imported_email }}" readonly aria-describedby="{{ f.id }}-help" />
<div class="cpd-form-message cpd-form-help-message" id="{{ f.id }}-help">
{{ _("mas.upstream_oauth2.register.imported_from_upstream") }}
{{- _("mas.upstream_oauth2.register.imported_from_upstream") -}}
</div>
{% endcall %}
@ -99,13 +110,13 @@ limitations under the License.
</div>
{% endif %}
{% if suggested_display_name %}
{% if imported_display_name %}
<div class="flex gap-6 items-center">
{% call(f) field.field(label=_("common.display_name"), name="display_name", class="flex-1") %}
<input {{ field.attributes(f) }} class="cpd-text-control" type="text" value="{{ suggested_display_name }}" readonly />
<input {{ field.attributes(f) }} class="cpd-text-control" type="text" value="{{ imported_display_name }}" readonly />
<div class="cpd-form-message cpd-form-help-message">
{{ _("mas.upstream_oauth2.register.imported_from_upstream") }}
{{- _("mas.upstream_oauth2.register.imported_from_upstream") -}}
</div>
{% endcall %}

View File

@ -10,7 +10,7 @@
},
"create_account": "Create Account",
"@create_account": {
"context": "pages/login.html:69:35-61, pages/upstream_oauth2/do_register.html:132:26-52"
"context": "pages/login.html:69:35-61, pages/upstream_oauth2/do_register.html:143:26-52"
},
"sign_in": "Sign in",
"@sign_in": {
@ -63,15 +63,15 @@
"common": {
"display_name": "Display Name",
"@display_name": {
"context": "pages/upstream_oauth2/do_register.html:104:37-61"
"context": "pages/upstream_oauth2/do_register.html:115:37-61"
},
"email_address": "Email address",
"@email_address": {
"context": "pages/account/emails/add.html:41:33-58, pages/register.html:47:35-60, pages/upstream_oauth2/do_register.html:76:37-62"
"context": "pages/account/emails/add.html:41:33-58, pages/register.html:47:35-60, pages/upstream_oauth2/do_register.html:87:37-62"
},
"mxid": "Matrix ID",
"@mxid": {
"context": "pages/upstream_oauth2/do_register.html:57:35-51"
"context": "pages/upstream_oauth2/do_register.html:66:35-51"
},
"password": "Password",
"@password": {
@ -83,7 +83,7 @@
},
"username": "Username",
"@username": {
"context": "pages/login.html:51:37-57, pages/register.html:43:35-55, pages/upstream_oauth2/do_register.html:65:35-55, pages/upstream_oauth2/do_register.html:69:38-58"
"context": "pages/login.html:51:37-57, pages/register.html:43:35-55, pages/upstream_oauth2/do_register.html:74:35-55, pages/upstream_oauth2/do_register.html:79:39-59"
}
},
"error": {
@ -167,11 +167,11 @@
"errors": {
"denied_policy": "Denied by policy: %(policy)s",
"@denied_policy": {
"context": "components/field.html:75:17-68"
"context": "components/errors.html:23:7-58, components/field.html:60:17-68"
},
"field_required": "This field is required",
"@field_required": {
"context": "components/field.html:71:17-47"
"context": "components/field.html:56:17-47"
},
"invalid_credentials": "Invalid credentials",
"@invalid_credentials": {
@ -183,7 +183,7 @@
},
"username_taken": "This username is already taken",
"@username_taken": {
"context": "components/field.html:73:17-47"
"context": "components/field.html:58:17-47"
}
},
"login": {
@ -251,7 +251,7 @@
},
"or_separator": "Or",
"@or_separator": {
"context": "components/field.html:90:10-31",
"context": "components/field.html:75:10-31",
"description": "Separator between the login methods"
},
"policy_violation": {
@ -353,7 +353,7 @@
},
"enforced_by_policy": "Enforced by server policy",
"@enforced_by_policy": {
"context": "pages/upstream_oauth2/do_register.html:61:13-65"
"context": "pages/upstream_oauth2/do_register.html:70:14-66"
},
"forced_display_name": "Will use the following display name",
"@forced_display_name": {
@ -379,7 +379,7 @@
},
"imported_from_upstream": "Imported from your upstream account",
"@imported_from_upstream": {
"context": "pages/upstream_oauth2/do_register.html:108:15-71, pages/upstream_oauth2/do_register.html:80:15-71"
"context": "pages/upstream_oauth2/do_register.html:119:16-72, pages/upstream_oauth2/do_register.html:91:16-72"
},
"link_existing": "Link to an existing account",
"@link_existing": {
@ -395,7 +395,7 @@
},
"use": "Use",
"@use": {
"context": "pages/upstream_oauth2/do_register.html:124:20-57, pages/upstream_oauth2/do_register.html:95:18-55"
"context": "pages/upstream_oauth2/do_register.html:106:18-55, pages/upstream_oauth2/do_register.html:135:20-57"
}
},
"suggest_link": {