1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-21 23:00:50 +03:00

Refactor the upstream link provider template logic

Also adds tests for new account registration through an upstream oauth2
provider
This commit is contained in:
Quentin Gliech
2023-11-13 12:29:01 +01:00
parent 9c94e11e68
commit 89420a2cfc
9 changed files with 560 additions and 153 deletions

View File

@@ -24,7 +24,7 @@ use mas_axum_utils::{
sentry::SentryEventID,
FancyError, SessionInfoExt,
};
use mas_data_model::{UpstreamOAuthProviderImportAction, User};
use mas_data_model::User;
use mas_jose::jwt::Jwt;
use mas_policy::Policy;
use mas_router::UrlBuilder;
@@ -44,9 +44,13 @@ use thiserror::Error;
use tracing::warn;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
use super::{template::environment, UpstreamSessionsCookie};
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage};
const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
#[derive(Debug, Error)]
pub(crate) enum RouteError {
/// Couldn't find the link specified in the URL
@@ -65,6 +69,10 @@ pub(crate) enum RouteError {
#[error("Upstream provider not found")]
ProviderNotFound,
/// Required attribute rendered to an empty string
#[error("Template {template:?} rendered to an empty string")]
RequiredAttributeEmpty { template: String },
/// Required claim was missing in id_token
#[error("Template {template:?} could not be rendered from the upstream provider's response for required claim")]
RequiredAttributeRender {
@@ -85,7 +93,7 @@ pub(crate) enum RouteError {
InvalidFormAction,
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl_from_error_for_route!(mas_templates::TemplateError);
@@ -108,39 +116,38 @@ impl IntoResponse for RouteError {
}
}
/// Utility function to import a claim from the upstream provider's response,
/// based on the preference for that attribute.
/// Utility function to render an attribute template.
///
/// # Parameters
///
/// * `name` - The name of the claim, for error reporting
/// * `value` - The value of the claim, if present
/// * `preference` - The preference for this claim
/// * `run` - A function to run if the claim is present. The first argument is
/// the value of the claim, and the second is whether the claim is forced to
/// be used.
/// * `environment` - The minijinja environment to use to render the template
/// * `template` - The template to use to render the claim
/// * `required` - Whether the attribute is required or not
///
/// # Errors
///
/// Returns an error if the claim is required but missing.
fn import_claim(
/// Returns an error if the attribute is required but fails to render or is
/// empty
fn render_attribute_template(
environment: &Environment,
template: &str,
action: &UpstreamOAuthProviderImportAction,
mut run: impl FnMut(String, bool),
) -> Result<(), RouteError> {
// If this claim is ignored, we don't need to do anything.
if action.ignore() {
return Ok(());
}
required: bool,
) -> Result<Option<String>, RouteError> {
match environment.render_str(template, ()) {
Ok(value) if value.is_empty() => { /* Do nothing on empty strings */ }
Ok(value) if value.is_empty() => {
if required {
return Err(RouteError::RequiredAttributeEmpty {
template: template.to_owned(),
});
}
Ok(value) => run(value, action.is_forced()),
Ok(None)
}
Ok(value) => Ok(Some(value)),
Err(source) => {
if action.is_required() {
if required {
return Err(RouteError::RequiredAttributeRender {
template: template.to_owned(),
source,
@@ -148,10 +155,9 @@ fn import_claim(
}
tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
Ok(None)
}
}
Ok(())
}
#[derive(Deserialize, Serialize)]
@@ -327,105 +333,120 @@ pub(crate) async fn get(
.map(|id_token| id_token.into_parts().1)
.unwrap_or_default();
let mut ctx = UpstreamRegister::default();
let ctx = UpstreamRegister::default();
let env = {
let mut e = Environment::new();
let mut e = environment();
e.add_global("user", payload);
e
};
import_claim(
&env,
provider
let ctx = if provider.claims_imports.displayname.ignore() {
ctx
} else {
let template = provider
.claims_imports
.displayname
.template
.as_deref()
.unwrap_or("{{ user.name }}"),
&provider.claims_imports.displayname,
|value, force| {
ctx.set_display_name(value, force);
},
)?;
.unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
import_claim(
&env,
provider
match render_attribute_template(
&env,
template,
provider.claims_imports.displayname.is_required(),
)? {
Some(value) => ctx
.with_display_name(value, provider.claims_imports.displayname.is_forced()),
None => ctx,
}
};
let ctx = if provider.claims_imports.email.ignore() {
ctx
} else {
let template = provider
.claims_imports
.email
.template
.as_deref()
.unwrap_or("{{ user.email }}"),
&provider.claims_imports.email,
|value, force| {
ctx.set_email(value, force);
},
)?;
.unwrap_or(DEFAULT_EMAIL_TEMPLATE);
let mut forced_localpart = None;
import_claim(
&env,
provider
match render_attribute_template(
&env,
template,
provider.claims_imports.email.is_required(),
)? {
Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
None => ctx,
}
};
let ctx = if provider.claims_imports.localpart.ignore() {
ctx
} else {
let template = provider
.claims_imports
.localpart
.template
.as_deref()
.unwrap_or("{{ user.preferred_username }}"),
&provider.claims_imports.localpart,
|value, force| {
if force {
// We want to run the policy check on the username if it is forced
forced_localpart = Some(value.clone());
}
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
ctx.set_localpart(value, force);
},
)?;
match render_attribute_template(
&env,
template,
provider.claims_imports.localpart.is_required(),
)? {
Some(localpart) => {
// We could run policy & existing user checks when the user submits the
// form, but this lead to poor UX. This is why we do
// it ahead of time here.
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
if let Some(existing_user) = maybe_existing_user {
// The mapper returned a username which already exists, but isn't linked
// to this upstream user.
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
// Run the policy check and check for existing users
if let Some(localpart) = forced_localpart {
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
if let Some(existing_user) = maybe_existing_user {
// The mapper returned a username which already exists, but isn't linked to
// this upstream user.
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
// TODO: translate
let ctx = ErrorContext::new()
.with_code("User exists")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
// TODO: translate
let ctx = ErrorContext::new()
.with_code("User exists")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which is not linked to that upstream account"#
))
.with_language(&locale);
))
.with_language(&locale);
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
let res = policy
.evaluate_upstream_oauth_register(&localpart, None)
.await?;
let res = policy
.evaluate_upstream_oauth_register(&localpart, None)
.await?;
if !res.valid() {
// TODO: translate
let ctx = ErrorContext::new()
.with_code("Policy error")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
if !res.valid() {
// TODO: translate
let ctx = ErrorContext::new()
.with_code("Policy error")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which does not pass the policy check: {res}"#
))
.with_language(&locale);
))
.with_language(&locale);
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
ctx.with_localpart(localpart, provider.claims_imports.localpart.is_forced())
}
None => ctx,
}
}
};
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
@@ -496,6 +517,8 @@ pub(crate) async fn post(
let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
// The user is already logged in, the link is not linked to any user, and the
// user asked to link their account.
repo.upstream_oauth_link()
.associate_to_user(&link, &session.user)
.await?;
@@ -512,6 +535,11 @@ pub(crate) async fn post(
import_display_name,
},
) => {
// The user got the form to register a new account, and is not logged in.
// Depending on the claims_imports, we've let the user choose their username,
// choose whether they want to import the email and display name, or
// not.
// Those fields are Some("on") if the checkbox is checked
let import_email = import_email.is_some();
let import_display_name = import_display_name.is_some();
@@ -531,6 +559,7 @@ pub(crate) async fn post(
.map(|id_token| id_token.into_parts().1)
.unwrap_or_default();
// Is the email verified according to the upstream provider?
let provider_email_verified = payload
.get_item(&minijinja::Value::from("email_verified"))
.map(|v| v.is_true())
@@ -538,77 +567,91 @@ pub(crate) async fn post(
// Let's try to import the claims from the ID token
let env = {
let mut e = Environment::new();
let mut e = environment();
e.add_global("user", payload);
e
};
// Create a template context in case we need to re-render because of an error
let mut ctx = UpstreamRegister::default();
let ctx = UpstreamRegister::default();
let mut name = None;
import_claim(
&env,
provider
let display_name = if provider
.claims_imports
.displayname
.should_import(import_display_name)
{
let template = provider
.claims_imports
.displayname
.template
.as_deref()
.unwrap_or("{{ user.name }}"),
&provider.claims_imports.displayname,
|value, force| {
// Import the display name if it is either forced or the user has requested it
if force || import_display_name {
name = Some(value.clone());
}
.unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
ctx.set_display_name(value, force);
},
)?;
render_attribute_template(
&env,
template,
provider.claims_imports.displayname.is_required(),
)?
} else {
None
};
let mut email = None;
import_claim(
&env,
provider
let ctx = if let Some(ref display_name) = display_name {
ctx.with_display_name(
display_name.clone(),
provider.claims_imports.email.is_forced(),
)
} else {
ctx
};
let email = if provider.claims_imports.email.should_import(import_email) {
let template = provider
.claims_imports
.email
.template
.as_deref()
.unwrap_or("{{ user.email }}"),
&provider.claims_imports.email,
|value, force| {
// Import the email if it is either forced or the user has requested it
if force || import_email {
email = Some(value.clone());
}
.unwrap_or(DEFAULT_EMAIL_TEMPLATE);
ctx.set_email(value, force);
},
)?;
render_attribute_template(
&env,
template,
provider.claims_imports.email.is_required(),
)?
} else {
None
};
let mut username = username;
import_claim(
&env,
provider
let ctx = if let Some(ref email) = email {
ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
} else {
ctx
};
let forced_username = if provider.claims_imports.localpart.is_forced() {
let template = provider
.claims_imports
.localpart
.template
.as_deref()
.unwrap_or("{{ user.preferred_username }}"),
&provider.claims_imports.localpart,
|value, force| {
// If the username is forced, override whatever was in the form
if force {
username = Some(value.clone());
}
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
ctx.set_localpart(value, force);
},
)?;
render_attribute_template(
&env,
template,
provider.claims_imports.email.is_required(),
)?
} else {
None
};
let username = username.filter(|s| !s.is_empty());
// If there is no forced username, we can use the one the user entered
let username = forced_username
.or(username)
.filter(|username| !username.is_empty());
let Some(username) = username else {
// We're missing a username, let's re-render the form with an error
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
@@ -625,11 +668,16 @@ pub(crate) async fn post(
.into_response());
};
let ctx = ctx.with_localpart(
username.clone(),
provider.claims_imports.localpart.is_forced(),
);
// Check if there is an existing user
let existing_user = repo.user().find_by_username(&username).await?;
if let Some(_existing_user) = existing_user {
// If there is an existing user, we can't create a new one
// with the same username
// with the same username, show an error
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
@@ -687,7 +735,7 @@ pub(crate) async fn post(
let mut job = ProvisionUserJob::new(&user);
// If we have a display name, set it during provisioning
if let Some(name) = name {
if let Some(name) = display_name {
job = job.set_display_name(name);
}
@@ -745,3 +793,172 @@ pub(crate) async fn post(
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}
#[cfg(test)]
mod tests {
use hyper::{header::CONTENT_TYPE, Request, StatusCode};
use mas_data_model::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
use mas_router::Route;
use oauth2_types::scope::{Scope, OPENID};
use sqlx::PgPool;
use super::UpstreamSessionsCookie;
use crate::test_utils::{
init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState,
};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_register(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
let mut rng = state.rng();
let cookies = CookieHelper::new();
let claims_imports = UpstreamOAuthProviderClaimsImports {
localpart: UpstreamOAuthProviderImportPreference {
action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
template: None,
},
email: UpstreamOAuthProviderImportPreference {
action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
template: None,
},
..UpstreamOAuthProviderClaimsImports::default()
};
let id_token = serde_json::json!({
"preferred_username": "john",
"email": "john@example.com",
"email_verified": true,
});
// Grab a key to sign the id_token
// We could generate a key on the fly, but because we have one available here,
// why not use it?
let key = state
.key_store
.signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
.unwrap();
let signer = key
.params()
.signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
.unwrap();
let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
// Provision a provider and a link
let mut repo = state.repository().await.unwrap();
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&state.clock,
"https://example.com/".to_owned(),
Scope::from_iter([OPENID]),
OAuthClientAuthenticationMethod::None,
None,
"client".to_owned(),
None,
claims_imports,
)
.await
.unwrap();
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
&state.clock,
&provider,
"state".to_owned(),
None,
"nonce".to_owned(),
)
.await
.unwrap();
let link = repo
.upstream_oauth_link()
.add(&mut rng, &state.clock, &provider, "subject".to_owned())
.await
.unwrap();
let session = repo
.upstream_oauth_session()
.complete_with_link(&state.clock, session, &link, Some(id_token.into_string()))
.await
.unwrap();
repo.save().await.unwrap();
let cookie_jar = state.cookie_jar();
let upstream_sessions = UpstreamSessionsCookie::default()
.add(session.id, provider.id, "state".to_owned(), None)
.add_link_to_session(session.id, link.id)
.unwrap();
let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
cookies.import(cookie_jar);
let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();
let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
serde_json::json!({
"csrf": csrf_token,
"action": "register",
"import_email": "on",
}),
);
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::SEE_OTHER);
// Check that we have a registered user, with the email imported
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.find_by_username("john")
.await
.unwrap()
.expect("user exists");
let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "subject")
.await
.unwrap()
.expect("link exists");
assert_eq!(link.user_id, Some(user.id));
let email = repo
.user_email()
.get_primary(&user)
.await
.unwrap()
.expect("email exists");
assert_eq!(email.email, "john@example.com");
assert!(email.confirmed_at.is_some());
}
}