You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-21 23:00:50 +03:00
Check for existing users ahead of time on upstream OAuth2 registration
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{Html, IntoResponse},
|
||||
response::{Html, IntoResponse, Response},
|
||||
Form, TypedHeader,
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
@@ -35,12 +35,13 @@ use mas_storage::{
|
||||
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{
|
||||
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
|
||||
UpstreamSuggestLink,
|
||||
ErrorContext, FieldError, FormError, TemplateContext, Templates, ToFormState,
|
||||
UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
|
||||
};
|
||||
use minijinja::Environment;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
@@ -83,14 +84,6 @@ pub(crate) enum RouteError {
|
||||
#[error("Invalid form action")]
|
||||
InvalidFormAction,
|
||||
|
||||
#[error("Missing username")]
|
||||
MissingUsername,
|
||||
|
||||
#[error("Policy violation: {violations:?}")]
|
||||
PolicyViolation {
|
||||
violations: Vec<mas_policy::Violation>,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
@@ -107,16 +100,6 @@ impl IntoResponse for RouteError {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let response = match self {
|
||||
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
|
||||
Self::PolicyViolation { violations } => {
|
||||
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
|
||||
let details = details.join("\n");
|
||||
let ctx = ErrorContext::new()
|
||||
.with_description(
|
||||
"Account registration denied because of policy violation".to_owned(),
|
||||
)
|
||||
.with_details(details);
|
||||
FancyError::new(ctx).into_response()
|
||||
}
|
||||
Self::Internal(e) => FancyError::from(e).into_response(),
|
||||
e => FancyError::from(e).into_response(),
|
||||
};
|
||||
@@ -171,7 +154,7 @@ fn import_claim(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase", tag = "action")]
|
||||
pub(crate) enum FormData {
|
||||
Register {
|
||||
@@ -185,6 +168,10 @@ pub(crate) enum FormData {
|
||||
Link,
|
||||
}
|
||||
|
||||
impl ToFormState for FormData {
|
||||
type Field = mas_templates::UpstreamRegisterFormField;
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "handlers.upstream_oauth2.link.get",
|
||||
fields(upstream_oauth_link.id = %link_id),
|
||||
@@ -195,6 +182,7 @@ pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
mut policy: Policy,
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
@@ -339,7 +327,7 @@ pub(crate) async fn get(
|
||||
.map(|id_token| id_token.into_parts().1)
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut ctx = UpstreamRegister::new(&link);
|
||||
let mut ctx = UpstreamRegister::default();
|
||||
|
||||
let env = {
|
||||
let mut e = Environment::new();
|
||||
@@ -375,6 +363,7 @@ pub(crate) async fn get(
|
||||
},
|
||||
)?;
|
||||
|
||||
let mut forced_localpart = None;
|
||||
import_claim(
|
||||
&env,
|
||||
provider
|
||||
@@ -385,10 +374,59 @@ pub(crate) async fn get(
|
||||
.unwrap_or("{{ user.preferred_username }}"),
|
||||
&provider.claims_imports.localpart,
|
||||
|value, force| {
|
||||
if force {
|
||||
// We want to run the policy check on the username if it is forced
|
||||
forced_localpart = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_localpart(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
// Run the policy check and check for existing users
|
||||
if let Some(localpart) = forced_localpart {
|
||||
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
|
||||
if let Some(existing_user) = maybe_existing_user {
|
||||
// The mapper returned a username which already exists, but isn't linked to
|
||||
// this upstream user.
|
||||
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
|
||||
|
||||
// TODO: translate
|
||||
let ctx = ErrorContext::new()
|
||||
.with_code("User exists")
|
||||
.with_description(format!(
|
||||
r#"Upstream account provider returned {localpart:?} as username,
|
||||
which is not linked to that upstream account"#
|
||||
))
|
||||
.with_language(&locale);
|
||||
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_error(&ctx)?).into_response(),
|
||||
));
|
||||
}
|
||||
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&localpart, None)
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
// TODO: translate
|
||||
let ctx = ErrorContext::new()
|
||||
.with_code("Policy error")
|
||||
.with_description(format!(
|
||||
r#"Upstream account provider returned {localpart:?} as username,
|
||||
which does not pass the policy check: {res}"#
|
||||
))
|
||||
.with_language(&locale);
|
||||
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_error(&ctx)?).into_response(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
|
||||
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
|
||||
@@ -411,10 +449,12 @@ pub(crate) async fn post(
|
||||
cookie_jar: CookieJar,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
mut policy: Policy,
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
Path(link_id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<FormData>>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
) -> Result<Response, RouteError> {
|
||||
let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
@@ -449,8 +489,10 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::SessionConsumed);
|
||||
}
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (user_session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
|
||||
let form_state = form.to_form_state();
|
||||
|
||||
let session = match (maybe_user_session, link.user_id, form) {
|
||||
(Some(session), None, FormData::Link) => {
|
||||
@@ -495,13 +537,15 @@ pub(crate) async fn post(
|
||||
.unwrap_or(false);
|
||||
|
||||
// Let's try to import the claims from the ID token
|
||||
|
||||
let env = {
|
||||
let mut e = Environment::new();
|
||||
e.add_global("user", payload);
|
||||
e
|
||||
};
|
||||
|
||||
// Create a template context in case we need to re-render because of an error
|
||||
let mut ctx = UpstreamRegister::default();
|
||||
|
||||
let mut name = None;
|
||||
import_claim(
|
||||
&env,
|
||||
@@ -515,8 +559,10 @@ pub(crate) async fn post(
|
||||
|value, force| {
|
||||
// Import the display name if it is either forced or the user has requested it
|
||||
if force || import_display_name {
|
||||
name = Some(value);
|
||||
name = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_display_name(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
@@ -533,8 +579,10 @@ pub(crate) async fn post(
|
||||
|value, force| {
|
||||
// Import the email if it is either forced or the user has requested it
|
||||
if force || import_email {
|
||||
email = Some(value);
|
||||
email = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_email(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
@@ -551,21 +599,85 @@ pub(crate) async fn post(
|
||||
|value, force| {
|
||||
// If the username is forced, override whatever was in the form
|
||||
if force {
|
||||
username = Some(value);
|
||||
username = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_localpart(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
let username = username.ok_or(RouteError::MissingUsername)?;
|
||||
let username = username.filter(|s| !s.is_empty());
|
||||
|
||||
let Some(username) = username else {
|
||||
let form_state = form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Required,
|
||||
);
|
||||
|
||||
let ctx = ctx
|
||||
.with_form_state(form_state)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
};
|
||||
|
||||
// Check if there is an existing user
|
||||
let existing_user = repo.user().find_by_username(&username).await?;
|
||||
if let Some(_existing_user) = existing_user {
|
||||
// If there is an existing user, we can't create a new one
|
||||
// with the same username
|
||||
|
||||
let form_state = form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Exists,
|
||||
);
|
||||
|
||||
let ctx = ctx
|
||||
.with_form_state(form_state)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// Policy check
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&username, email.as_deref())
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::PolicyViolation {
|
||||
violations: res.violations,
|
||||
});
|
||||
let form_state =
|
||||
res.violations
|
||||
.into_iter()
|
||||
.fold(form_state, |form_state, violation| {
|
||||
match violation.field.as_deref() {
|
||||
Some("username") => form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Policy {
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
_ => form_state.with_error_on_form(FormError::Policy {
|
||||
message: violation.msg,
|
||||
}),
|
||||
}
|
||||
});
|
||||
|
||||
let ctx = ctx
|
||||
.with_form_state(form_state)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// Now we can create the user
|
||||
@@ -631,5 +743,5 @@ pub(crate) async fn post(
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)))
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user