You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Check for existing users ahead of time on upstream OAuth2 registration
This commit is contained in:
@ -14,7 +14,7 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{Html, IntoResponse},
|
||||
response::{Html, IntoResponse, Response},
|
||||
Form, TypedHeader,
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
@ -35,12 +35,13 @@ use mas_storage::{
|
||||
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{
|
||||
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
|
||||
UpstreamSuggestLink,
|
||||
ErrorContext, FieldError, FormError, TemplateContext, Templates, ToFormState,
|
||||
UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
|
||||
};
|
||||
use minijinja::Environment;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
@ -83,14 +84,6 @@ pub(crate) enum RouteError {
|
||||
#[error("Invalid form action")]
|
||||
InvalidFormAction,
|
||||
|
||||
#[error("Missing username")]
|
||||
MissingUsername,
|
||||
|
||||
#[error("Policy violation: {violations:?}")]
|
||||
PolicyViolation {
|
||||
violations: Vec<mas_policy::Violation>,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
@ -107,16 +100,6 @@ impl IntoResponse for RouteError {
|
||||
let event_id = sentry::capture_error(&self);
|
||||
let response = match self {
|
||||
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
|
||||
Self::PolicyViolation { violations } => {
|
||||
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
|
||||
let details = details.join("\n");
|
||||
let ctx = ErrorContext::new()
|
||||
.with_description(
|
||||
"Account registration denied because of policy violation".to_owned(),
|
||||
)
|
||||
.with_details(details);
|
||||
FancyError::new(ctx).into_response()
|
||||
}
|
||||
Self::Internal(e) => FancyError::from(e).into_response(),
|
||||
e => FancyError::from(e).into_response(),
|
||||
};
|
||||
@ -171,7 +154,7 @@ fn import_claim(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase", tag = "action")]
|
||||
pub(crate) enum FormData {
|
||||
Register {
|
||||
@ -185,6 +168,10 @@ pub(crate) enum FormData {
|
||||
Link,
|
||||
}
|
||||
|
||||
impl ToFormState for FormData {
|
||||
type Field = mas_templates::UpstreamRegisterFormField;
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "handlers.upstream_oauth2.link.get",
|
||||
fields(upstream_oauth_link.id = %link_id),
|
||||
@ -195,6 +182,7 @@ pub(crate) async fn get(
|
||||
mut rng: BoxRng,
|
||||
clock: BoxClock,
|
||||
mut repo: BoxRepository,
|
||||
mut policy: Policy,
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
@ -339,7 +327,7 @@ pub(crate) async fn get(
|
||||
.map(|id_token| id_token.into_parts().1)
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut ctx = UpstreamRegister::new(&link);
|
||||
let mut ctx = UpstreamRegister::default();
|
||||
|
||||
let env = {
|
||||
let mut e = Environment::new();
|
||||
@ -375,6 +363,7 @@ pub(crate) async fn get(
|
||||
},
|
||||
)?;
|
||||
|
||||
let mut forced_localpart = None;
|
||||
import_claim(
|
||||
&env,
|
||||
provider
|
||||
@ -385,10 +374,59 @@ pub(crate) async fn get(
|
||||
.unwrap_or("{{ user.preferred_username }}"),
|
||||
&provider.claims_imports.localpart,
|
||||
|value, force| {
|
||||
if force {
|
||||
// We want to run the policy check on the username if it is forced
|
||||
forced_localpart = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_localpart(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
// Run the policy check and check for existing users
|
||||
if let Some(localpart) = forced_localpart {
|
||||
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
|
||||
if let Some(existing_user) = maybe_existing_user {
|
||||
// The mapper returned a username which already exists, but isn't linked to
|
||||
// this upstream user.
|
||||
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
|
||||
|
||||
// TODO: translate
|
||||
let ctx = ErrorContext::new()
|
||||
.with_code("User exists")
|
||||
.with_description(format!(
|
||||
r#"Upstream account provider returned {localpart:?} as username,
|
||||
which is not linked to that upstream account"#
|
||||
))
|
||||
.with_language(&locale);
|
||||
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_error(&ctx)?).into_response(),
|
||||
));
|
||||
}
|
||||
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&localpart, None)
|
||||
.await?;
|
||||
|
||||
if !res.valid() {
|
||||
// TODO: translate
|
||||
let ctx = ErrorContext::new()
|
||||
.with_code("Policy error")
|
||||
.with_description(format!(
|
||||
r#"Upstream account provider returned {localpart:?} as username,
|
||||
which does not pass the policy check: {res}"#
|
||||
))
|
||||
.with_language(&locale);
|
||||
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_error(&ctx)?).into_response(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
|
||||
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
|
||||
@ -411,10 +449,12 @@ pub(crate) async fn post(
|
||||
cookie_jar: CookieJar,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
mut policy: Policy,
|
||||
PreferredLanguage(locale): PreferredLanguage,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
Path(link_id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<FormData>>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
) -> Result<Response, RouteError> {
|
||||
let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
|
||||
let form = cookie_jar.verify_form(&clock, form)?;
|
||||
|
||||
@ -449,8 +489,10 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::SessionConsumed);
|
||||
}
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
|
||||
let (user_session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
|
||||
let form_state = form.to_form_state();
|
||||
|
||||
let session = match (maybe_user_session, link.user_id, form) {
|
||||
(Some(session), None, FormData::Link) => {
|
||||
@ -495,13 +537,15 @@ pub(crate) async fn post(
|
||||
.unwrap_or(false);
|
||||
|
||||
// Let's try to import the claims from the ID token
|
||||
|
||||
let env = {
|
||||
let mut e = Environment::new();
|
||||
e.add_global("user", payload);
|
||||
e
|
||||
};
|
||||
|
||||
// Create a template context in case we need to re-render because of an error
|
||||
let mut ctx = UpstreamRegister::default();
|
||||
|
||||
let mut name = None;
|
||||
import_claim(
|
||||
&env,
|
||||
@ -515,8 +559,10 @@ pub(crate) async fn post(
|
||||
|value, force| {
|
||||
// Import the display name if it is either forced or the user has requested it
|
||||
if force || import_display_name {
|
||||
name = Some(value);
|
||||
name = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_display_name(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
@ -533,8 +579,10 @@ pub(crate) async fn post(
|
||||
|value, force| {
|
||||
// Import the email if it is either forced or the user has requested it
|
||||
if force || import_email {
|
||||
email = Some(value);
|
||||
email = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_email(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
@ -551,21 +599,85 @@ pub(crate) async fn post(
|
||||
|value, force| {
|
||||
// If the username is forced, override whatever was in the form
|
||||
if force {
|
||||
username = Some(value);
|
||||
username = Some(value.clone());
|
||||
}
|
||||
|
||||
ctx.set_localpart(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
let username = username.ok_or(RouteError::MissingUsername)?;
|
||||
let username = username.filter(|s| !s.is_empty());
|
||||
|
||||
let Some(username) = username else {
|
||||
let form_state = form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Required,
|
||||
);
|
||||
|
||||
let ctx = ctx
|
||||
.with_form_state(form_state)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
};
|
||||
|
||||
// Check if there is an existing user
|
||||
let existing_user = repo.user().find_by_username(&username).await?;
|
||||
if let Some(_existing_user) = existing_user {
|
||||
// If there is an existing user, we can't create a new one
|
||||
// with the same username
|
||||
|
||||
let form_state = form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Exists,
|
||||
);
|
||||
|
||||
let ctx = ctx
|
||||
.with_form_state(form_state)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// Policy check
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&username, email.as_deref())
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::PolicyViolation {
|
||||
violations: res.violations,
|
||||
});
|
||||
let form_state =
|
||||
res.violations
|
||||
.into_iter()
|
||||
.fold(form_state, |form_state, violation| {
|
||||
match violation.field.as_deref() {
|
||||
Some("username") => form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Policy {
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
_ => form_state.with_error_on_form(FormError::Policy {
|
||||
message: violation.msg,
|
||||
}),
|
||||
}
|
||||
});
|
||||
|
||||
let ctx = ctx
|
||||
.with_form_state(form_state)
|
||||
.with_csrf(csrf_token.form_value())
|
||||
.with_language(locale);
|
||||
return Ok((
|
||||
cookie_jar,
|
||||
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// Now we can create the user
|
||||
@ -631,5 +743,5 @@ pub(crate) async fn post(
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)))
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ use mas_data_model::{
|
||||
UpstreamOAuthLink, UpstreamOAuthProvider, User, UserEmail, UserEmailVerification,
|
||||
};
|
||||
use mas_i18n::DataLocale;
|
||||
use mas_router::{Account, GraphQL, PostAuthAction, Route, UrlBuilder};
|
||||
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
|
||||
use rand::Rng;
|
||||
use serde::{ser::SerializeStruct, Deserialize, Serialize};
|
||||
use ulid::Ulid;
|
||||
@ -918,68 +918,78 @@ impl TemplateContext for UpstreamSuggestLink {
|
||||
}
|
||||
}
|
||||
|
||||
/// User-editeable fields of the upstream account link form
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum UpstreamRegisterFormField {
|
||||
/// The username field
|
||||
Username,
|
||||
}
|
||||
|
||||
impl FormField for UpstreamRegisterFormField {
|
||||
fn keep(&self) -> bool {
|
||||
match self {
|
||||
Self::Username => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `pages/upstream_oauth2/do_register.html`
|
||||
/// templates
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct UpstreamRegister {
|
||||
login_link: String,
|
||||
suggested_localpart: Option<String>,
|
||||
imported_localpart: Option<String>,
|
||||
force_localpart: bool,
|
||||
suggested_display_name: Option<String>,
|
||||
imported_display_name: Option<String>,
|
||||
force_display_name: bool,
|
||||
suggested_email: Option<String>,
|
||||
imported_email: Option<String>,
|
||||
force_email: bool,
|
||||
form_state: FormState<UpstreamRegisterFormField>,
|
||||
}
|
||||
|
||||
impl UpstreamRegister {
|
||||
/// Constructs a new context with an existing linked user
|
||||
#[must_use]
|
||||
pub fn new(link: &UpstreamOAuthLink) -> Self {
|
||||
Self::for_link_id(link.id)
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the suggested localpart
|
||||
pub fn set_localpart(&mut self, localpart: String, force: bool) {
|
||||
self.suggested_localpart = Some(localpart);
|
||||
self.imported_localpart = Some(localpart);
|
||||
self.force_localpart = force;
|
||||
}
|
||||
|
||||
/// Set the suggested display name
|
||||
pub fn set_display_name(&mut self, display_name: String, force: bool) {
|
||||
self.suggested_display_name = Some(display_name);
|
||||
self.imported_display_name = Some(display_name);
|
||||
self.force_display_name = force;
|
||||
}
|
||||
|
||||
/// Set the suggested email
|
||||
pub fn set_email(&mut self, email: String, force: bool) {
|
||||
self.suggested_email = Some(email);
|
||||
self.imported_email = Some(email);
|
||||
self.force_email = force;
|
||||
}
|
||||
|
||||
fn for_link_id(id: Ulid) -> Self {
|
||||
let login_link = mas_router::Login::and_link_upstream(id)
|
||||
.path_and_query()
|
||||
.into();
|
||||
/// Set the form state
|
||||
pub fn set_form_state(&mut self, form_state: FormState<UpstreamRegisterFormField>) {
|
||||
self.form_state = form_state;
|
||||
}
|
||||
|
||||
Self {
|
||||
login_link,
|
||||
suggested_localpart: None,
|
||||
force_localpart: false,
|
||||
suggested_display_name: None,
|
||||
force_display_name: false,
|
||||
suggested_email: None,
|
||||
force_email: false,
|
||||
}
|
||||
/// Set the form state
|
||||
#[must_use]
|
||||
pub fn with_form_state(self, form_state: FormState<UpstreamRegisterFormField>) -> Self {
|
||||
Self { form_state, ..self }
|
||||
}
|
||||
}
|
||||
|
||||
impl TemplateContext for UpstreamRegister {
|
||||
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
|
||||
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let id = Ulid::from_datetime_with_source(now.into(), rng);
|
||||
vec![Self::for_link_id(id)]
|
||||
vec![Self::new()]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,6 +92,22 @@ impl<K: Hash + Eq> Default for FormState<K> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, PartialEq, Eq, Hash)]
|
||||
#[serde(untagged)]
|
||||
enum KeyOrOther<K> {
|
||||
Key(K),
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl<K> KeyOrOther<K> {
|
||||
fn key(self) -> Option<K> {
|
||||
match self {
|
||||
Self::Key(key) => Some(key),
|
||||
Self::Other(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: FormField> FormState<K> {
|
||||
/// Generate a [`FormState`] out of a form
|
||||
///
|
||||
@ -101,17 +117,18 @@ impl<K: FormField> FormState<K> {
|
||||
/// deserialize
|
||||
pub fn from_form<F: Serialize>(form: &F) -> Self {
|
||||
let form = serde_json::to_value(form).unwrap();
|
||||
let fields: HashMap<K, Option<String>> = serde_json::from_value(form).unwrap();
|
||||
let fields: HashMap<KeyOrOther<K>, Option<String>> = serde_json::from_value(form).unwrap();
|
||||
|
||||
let fields = fields
|
||||
.into_iter()
|
||||
.map(|(key, value)| {
|
||||
.filter_map(|(key, value)| {
|
||||
let key = key.key()?;
|
||||
let value = key.keep().then_some(value).flatten();
|
||||
let field = FieldState {
|
||||
value,
|
||||
errors: Vec::new(),
|
||||
};
|
||||
(key, field)
|
||||
Some((key, field))
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -54,7 +54,8 @@ pub use self::{
|
||||
LoginContext, LoginFormField, NotFoundContext, PolicyViolationContext, PostAuthContext,
|
||||
PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
|
||||
SiteBranding, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister,
|
||||
UpstreamSuggestLink, WithCsrf, WithLanguage, WithOptionalSession, WithSession,
|
||||
UpstreamRegisterFormField, UpstreamSuggestLink, WithCsrf, WithLanguage,
|
||||
WithOptionalSession, WithSession,
|
||||
},
|
||||
forms::{FieldError, FormError, FormField, FormState, ToFormState},
|
||||
};
|
||||
|
Reference in New Issue
Block a user