1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Check for existing users ahead of time on upstream OAuth2 registration

This commit is contained in:
Quentin Gliech
2023-11-09 18:18:28 +01:00
parent 8a1329de05
commit 9c94e11e68
8 changed files with 247 additions and 109 deletions

View File

@ -14,7 +14,7 @@
use axum::{
extract::{Path, State},
response::{Html, IntoResponse},
response::{Html, IntoResponse, Response},
Form, TypedHeader,
};
use hyper::StatusCode;
@ -35,12 +35,13 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_templates::{
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
ErrorContext, FieldError, FormError, TemplateContext, Templates, ToFormState,
UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
};
use minijinja::Environment;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::warn;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
@ -83,14 +84,6 @@ pub(crate) enum RouteError {
#[error("Invalid form action")]
InvalidFormAction,
#[error("Missing username")]
MissingUsername,
#[error("Policy violation: {violations:?}")]
PolicyViolation {
violations: Vec<mas_policy::Violation>,
},
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
@ -107,16 +100,6 @@ impl IntoResponse for RouteError {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
Self::PolicyViolation { violations } => {
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
let details = details.join("\n");
let ctx = ErrorContext::new()
.with_description(
"Account registration denied because of policy violation".to_owned(),
)
.with_details(details);
FancyError::new(ctx).into_response()
}
Self::Internal(e) => FancyError::from(e).into_response(),
e => FancyError::from(e).into_response(),
};
@ -171,7 +154,7 @@ fn import_claim(
Ok(())
}
#[derive(Deserialize)]
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "lowercase", tag = "action")]
pub(crate) enum FormData {
Register {
@ -185,6 +168,10 @@ pub(crate) enum FormData {
Link,
}
impl ToFormState for FormData {
type Field = mas_templates::UpstreamRegisterFormField;
}
#[tracing::instrument(
name = "handlers.upstream_oauth2.link.get",
fields(upstream_oauth_link.id = %link_id),
@ -195,6 +182,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
mut policy: Policy,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
@ -339,7 +327,7 @@ pub(crate) async fn get(
.map(|id_token| id_token.into_parts().1)
.unwrap_or_default();
let mut ctx = UpstreamRegister::new(&link);
let mut ctx = UpstreamRegister::default();
let env = {
let mut e = Environment::new();
@ -375,6 +363,7 @@ pub(crate) async fn get(
},
)?;
let mut forced_localpart = None;
import_claim(
&env,
provider
@ -385,10 +374,59 @@ pub(crate) async fn get(
.unwrap_or("{{ user.preferred_username }}"),
&provider.claims_imports.localpart,
|value, force| {
if force {
// We want to run the policy check on the username if it is forced
forced_localpart = Some(value.clone());
}
ctx.set_localpart(value, force);
},
)?;
// Run the policy check and check for existing users
if let Some(localpart) = forced_localpart {
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
if let Some(existing_user) = maybe_existing_user {
// The mapper returned a username which already exists, but isn't linked to
// this upstream user.
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
// TODO: translate
let ctx = ErrorContext::new()
.with_code("User exists")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which is not linked to that upstream account"#
))
.with_language(&locale);
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
let res = policy
.evaluate_upstream_oauth_register(&localpart, None)
.await?;
if !res.valid() {
// TODO: translate
let ctx = ErrorContext::new()
.with_code("Policy error")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which does not pass the policy check: {res}"#
))
.with_language(&locale);
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
}
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
@ -411,10 +449,12 @@ pub(crate) async fn post(
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
mut policy: Policy,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> {
) -> Result<Response, RouteError> {
let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
let form = cookie_jar.verify_form(&clock, form)?;
@ -449,8 +489,10 @@ pub(crate) async fn post(
return Err(RouteError::SessionConsumed);
}
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let form_state = form.to_form_state();
let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
@ -495,13 +537,15 @@ pub(crate) async fn post(
.unwrap_or(false);
// Let's try to import the claims from the ID token
let env = {
let mut e = Environment::new();
e.add_global("user", payload);
e
};
// Create a template context in case we need to re-render because of an error
let mut ctx = UpstreamRegister::default();
let mut name = None;
import_claim(
&env,
@ -515,8 +559,10 @@ pub(crate) async fn post(
|value, force| {
// Import the display name if it is either forced or the user has requested it
if force || import_display_name {
name = Some(value);
name = Some(value.clone());
}
ctx.set_display_name(value, force);
},
)?;
@ -533,8 +579,10 @@ pub(crate) async fn post(
|value, force| {
// Import the email if it is either forced or the user has requested it
if force || import_email {
email = Some(value);
email = Some(value.clone());
}
ctx.set_email(value, force);
},
)?;
@ -551,21 +599,85 @@ pub(crate) async fn post(
|value, force| {
// If the username is forced, override whatever was in the form
if force {
username = Some(value);
username = Some(value.clone());
}
ctx.set_localpart(value, force);
},
)?;
let username = username.ok_or(RouteError::MissingUsername)?;
let username = username.filter(|s| !s.is_empty());
let Some(username) = username else {
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
);
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
};
// Check if there is an existing user
let existing_user = repo.user().find_by_username(&username).await?;
if let Some(_existing_user) = existing_user {
// If there is an existing user, we can't create a new one
// with the same username
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}
// Policy check
let res = policy
.evaluate_upstream_oauth_register(&username, email.as_deref())
.await?;
if !res.valid() {
return Err(RouteError::PolicyViolation {
violations: res.violations,
});
let form_state =
res.violations
.into_iter()
.fold(form_state, |form_state, violation| {
match violation.field.as_deref() {
Some("username") => form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
message: violation.msg,
},
),
_ => form_state.with_error_on_form(FormError::Policy {
message: violation.msg,
}),
}
});
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}
// Now we can create the user
@ -631,5 +743,5 @@ pub(crate) async fn post(
repo.save().await?;
Ok((cookie_jar, post_auth_action.go_next(&url_builder)))
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}

View File

@ -25,7 +25,7 @@ use mas_data_model::{
UpstreamOAuthLink, UpstreamOAuthProvider, User, UserEmail, UserEmailVerification,
};
use mas_i18n::DataLocale;
use mas_router::{Account, GraphQL, PostAuthAction, Route, UrlBuilder};
use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder};
use rand::Rng;
use serde::{ser::SerializeStruct, Deserialize, Serialize};
use ulid::Ulid;
@ -918,68 +918,78 @@ impl TemplateContext for UpstreamSuggestLink {
}
}
/// User-editeable fields of the upstream account link form
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum UpstreamRegisterFormField {
/// The username field
Username,
}
impl FormField for UpstreamRegisterFormField {
fn keep(&self) -> bool {
match self {
Self::Username => true,
}
}
}
/// Context used by the `pages/upstream_oauth2/do_register.html`
/// templates
#[derive(Serialize)]
#[derive(Serialize, Default)]
pub struct UpstreamRegister {
login_link: String,
suggested_localpart: Option<String>,
imported_localpart: Option<String>,
force_localpart: bool,
suggested_display_name: Option<String>,
imported_display_name: Option<String>,
force_display_name: bool,
suggested_email: Option<String>,
imported_email: Option<String>,
force_email: bool,
form_state: FormState<UpstreamRegisterFormField>,
}
impl UpstreamRegister {
/// Constructs a new context with an existing linked user
#[must_use]
pub fn new(link: &UpstreamOAuthLink) -> Self {
Self::for_link_id(link.id)
pub fn new() -> Self {
Self::default()
}
/// Set the suggested localpart
pub fn set_localpart(&mut self, localpart: String, force: bool) {
self.suggested_localpart = Some(localpart);
self.imported_localpart = Some(localpart);
self.force_localpart = force;
}
/// Set the suggested display name
pub fn set_display_name(&mut self, display_name: String, force: bool) {
self.suggested_display_name = Some(display_name);
self.imported_display_name = Some(display_name);
self.force_display_name = force;
}
/// Set the suggested email
pub fn set_email(&mut self, email: String, force: bool) {
self.suggested_email = Some(email);
self.imported_email = Some(email);
self.force_email = force;
}
fn for_link_id(id: Ulid) -> Self {
let login_link = mas_router::Login::and_link_upstream(id)
.path_and_query()
.into();
/// Set the form state
pub fn set_form_state(&mut self, form_state: FormState<UpstreamRegisterFormField>) {
self.form_state = form_state;
}
Self {
login_link,
suggested_localpart: None,
force_localpart: false,
suggested_display_name: None,
force_display_name: false,
suggested_email: None,
force_email: false,
}
/// Set the form state
#[must_use]
pub fn with_form_state(self, form_state: FormState<UpstreamRegisterFormField>) -> Self {
Self { form_state, ..self }
}
}
impl TemplateContext for UpstreamRegister {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![Self::for_link_id(id)]
vec![Self::new()]
}
}

View File

@ -92,6 +92,22 @@ impl<K: Hash + Eq> Default for FormState<K> {
}
}
#[derive(Deserialize, PartialEq, Eq, Hash)]
#[serde(untagged)]
enum KeyOrOther<K> {
Key(K),
Other(String),
}
impl<K> KeyOrOther<K> {
fn key(self) -> Option<K> {
match self {
Self::Key(key) => Some(key),
Self::Other(_) => None,
}
}
}
impl<K: FormField> FormState<K> {
/// Generate a [`FormState`] out of a form
///
@ -101,17 +117,18 @@ impl<K: FormField> FormState<K> {
/// deserialize
pub fn from_form<F: Serialize>(form: &F) -> Self {
let form = serde_json::to_value(form).unwrap();
let fields: HashMap<K, Option<String>> = serde_json::from_value(form).unwrap();
let fields: HashMap<KeyOrOther<K>, Option<String>> = serde_json::from_value(form).unwrap();
let fields = fields
.into_iter()
.map(|(key, value)| {
.filter_map(|(key, value)| {
let key = key.key()?;
let value = key.keep().then_some(value).flatten();
let field = FieldState {
value,
errors: Vec::new(),
};
(key, field)
Some((key, field))
})
.collect();

View File

@ -54,7 +54,8 @@ pub use self::{
LoginContext, LoginFormField, NotFoundContext, PolicyViolationContext, PostAuthContext,
PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
SiteBranding, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink, WithCsrf, WithLanguage, WithOptionalSession, WithSession,
UpstreamRegisterFormField, UpstreamSuggestLink, WithCsrf, WithLanguage,
WithOptionalSession, WithSession,
},
forms::{FieldError, FormError, FormField, FormState, ToFormState},
};