You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Refactor the upstream link provider template logic
Also adds tests for new account registration through an upstream oauth2 provider
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -2946,6 +2946,7 @@ dependencies = [
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"axum-macros",
|
||||
"base64ct",
|
||||
"bcrypt",
|
||||
"camino",
|
||||
"chrono",
|
||||
@@ -4957,9 +4958,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "self_cell"
|
||||
version = "1.0.1"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c309e515543e67811222dbc9e3dd7e1056279b782e1dacffe4242b718734fb6"
|
||||
checksum = "e388332cd64eb80cd595a00941baf513caffae8dce9cfd0467fc9c66397dade6"
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
|
@@ -55,6 +55,22 @@ impl CookieManager {
|
||||
let key = Key::derive_from(key);
|
||||
Self::new(base_url, key)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn cookie_jar(&self) -> CookieJar {
|
||||
let inner = PrivateCookieJar::new(self.key.clone());
|
||||
let options = self.options.clone();
|
||||
|
||||
CookieJar { inner, options }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn cookie_jar_from_headers(&self, headers: &http::HeaderMap) -> CookieJar {
|
||||
let inner = PrivateCookieJar::from_headers(headers, self.key.clone());
|
||||
let options = self.options.clone();
|
||||
|
||||
CookieJar { inner, options }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -67,10 +83,7 @@ where
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let cookie_manager = CookieManager::from_ref(state);
|
||||
let inner = PrivateCookieJar::from_headers(&parts.headers, cookie_manager.key.clone());
|
||||
let options = cookie_manager.options.clone();
|
||||
|
||||
Ok(CookieJar { inner, options })
|
||||
Ok(cookie_manager.cookie_jar_from_headers(&parts.headers))
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -130,4 +130,13 @@ impl ImportAction {
|
||||
pub fn is_required(&self) -> bool {
|
||||
matches!(self, Self::Require)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn should_import(&self, user_preference: bool) -> bool {
|
||||
match self {
|
||||
Self::Ignore => false,
|
||||
Self::Suggest => user_preference,
|
||||
Self::Force | Self::Require => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -51,6 +51,7 @@ pbkdf2 = { version = "0.12.2", features = ["password-hash", "std", "simple", "pa
|
||||
zeroize = "1.6.0"
|
||||
|
||||
# Various data types and utilities
|
||||
base64ct = "1.6.0"
|
||||
camino.workspace = true
|
||||
chrono.workspace = true
|
||||
psl = "2.1.4"
|
||||
|
@@ -22,6 +22,7 @@ use axum::{
|
||||
async_trait,
|
||||
body::{Bytes, HttpBody},
|
||||
extract::{FromRef, FromRequestParts},
|
||||
response::{IntoResponse, IntoResponseParts},
|
||||
};
|
||||
use cookie_store::{CookieStore, RawCookie};
|
||||
use futures_util::future::BoxFuture;
|
||||
@@ -31,7 +32,9 @@ use hyper::{
|
||||
Request, Response, StatusCode,
|
||||
};
|
||||
use mas_axum_utils::{
|
||||
cookies::CookieManager, http_client_factory::HttpClientFactory, ErrorWrapper,
|
||||
cookies::{CookieJar, CookieManager},
|
||||
http_client_factory::HttpClientFactory,
|
||||
ErrorWrapper,
|
||||
};
|
||||
use mas_i18n::Translator;
|
||||
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
|
||||
@@ -264,6 +267,11 @@ impl TestState {
|
||||
_ => panic!("Unexpected status code: {}", response.status()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an empty cookie jar
|
||||
pub fn cookie_jar(&self) -> CookieJar {
|
||||
self.cookie_manager.cookie_jar()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestGraphQLState {
|
||||
@@ -631,6 +639,11 @@ impl CookieHelper {
|
||||
&url,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn import(&self, res: impl IntoResponseParts) {
|
||||
let response = (res, "").into_response();
|
||||
self.save_cookies(&response);
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for CookieHelper {
|
||||
|
@@ -24,7 +24,7 @@ use mas_axum_utils::{
|
||||
sentry::SentryEventID,
|
||||
FancyError, SessionInfoExt,
|
||||
};
|
||||
use mas_data_model::{UpstreamOAuthProviderImportAction, User};
|
||||
use mas_data_model::User;
|
||||
use mas_jose::jwt::Jwt;
|
||||
use mas_policy::Policy;
|
||||
use mas_router::UrlBuilder;
|
||||
@@ -44,9 +44,13 @@ use thiserror::Error;
|
||||
use tracing::warn;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use super::{template::environment, UpstreamSessionsCookie};
|
||||
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage};
|
||||
|
||||
const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
|
||||
const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
|
||||
const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
/// Couldn't find the link specified in the URL
|
||||
@@ -65,6 +69,10 @@ pub(crate) enum RouteError {
|
||||
#[error("Upstream provider not found")]
|
||||
ProviderNotFound,
|
||||
|
||||
/// Required attribute rendered to an empty string
|
||||
#[error("Template {template:?} rendered to an empty string")]
|
||||
RequiredAttributeEmpty { template: String },
|
||||
|
||||
/// Required claim was missing in id_token
|
||||
#[error("Template {template:?} could not be rendered from the upstream provider's response for required claim")]
|
||||
RequiredAttributeRender {
|
||||
@@ -85,7 +93,7 @@ pub(crate) enum RouteError {
|
||||
InvalidFormAction,
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
@@ -108,39 +116,38 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to import a claim from the upstream provider's response,
|
||||
/// based on the preference for that attribute.
|
||||
/// Utility function to render an attribute template.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `name` - The name of the claim, for error reporting
|
||||
/// * `value` - The value of the claim, if present
|
||||
/// * `preference` - The preference for this claim
|
||||
/// * `run` - A function to run if the claim is present. The first argument is
|
||||
/// the value of the claim, and the second is whether the claim is forced to
|
||||
/// be used.
|
||||
/// * `environment` - The minijinja environment to use to render the template
|
||||
/// * `template` - The template to use to render the claim
|
||||
/// * `required` - Whether the attribute is required or not
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the claim is required but missing.
|
||||
fn import_claim(
|
||||
/// Returns an error if the attribute is required but fails to render or is
|
||||
/// empty
|
||||
fn render_attribute_template(
|
||||
environment: &Environment,
|
||||
template: &str,
|
||||
action: &UpstreamOAuthProviderImportAction,
|
||||
mut run: impl FnMut(String, bool),
|
||||
) -> Result<(), RouteError> {
|
||||
// If this claim is ignored, we don't need to do anything.
|
||||
if action.ignore() {
|
||||
return Ok(());
|
||||
required: bool,
|
||||
) -> Result<Option<String>, RouteError> {
|
||||
match environment.render_str(template, ()) {
|
||||
Ok(value) if value.is_empty() => {
|
||||
if required {
|
||||
return Err(RouteError::RequiredAttributeEmpty {
|
||||
template: template.to_owned(),
|
||||
});
|
||||
}
|
||||
|
||||
match environment.render_str(template, ()) {
|
||||
Ok(value) if value.is_empty() => { /* Do nothing on empty strings */ }
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
Ok(value) => run(value, action.is_forced()),
|
||||
Ok(value) => Ok(Some(value)),
|
||||
|
||||
Err(source) => {
|
||||
if action.is_required() {
|
||||
if required {
|
||||
return Err(RouteError::RequiredAttributeRender {
|
||||
template: template.to_owned(),
|
||||
source,
|
||||
@@ -148,10 +155,9 @@ fn import_claim(
|
||||
}
|
||||
|
||||
tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
@@ -327,68 +333,78 @@ pub(crate) async fn get(
|
||||
.map(|id_token| id_token.into_parts().1)
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut ctx = UpstreamRegister::default();
|
||||
let ctx = UpstreamRegister::default();
|
||||
|
||||
let env = {
|
||||
let mut e = Environment::new();
|
||||
let mut e = environment();
|
||||
e.add_global("user", payload);
|
||||
e
|
||||
};
|
||||
|
||||
import_claim(
|
||||
&env,
|
||||
provider
|
||||
let ctx = if provider.claims_imports.displayname.ignore() {
|
||||
ctx
|
||||
} else {
|
||||
let template = provider
|
||||
.claims_imports
|
||||
.displayname
|
||||
.template
|
||||
.as_deref()
|
||||
.unwrap_or("{{ user.name }}"),
|
||||
&provider.claims_imports.displayname,
|
||||
|value, force| {
|
||||
ctx.set_display_name(value, force);
|
||||
},
|
||||
)?;
|
||||
.unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
|
||||
|
||||
import_claim(
|
||||
match render_attribute_template(
|
||||
&env,
|
||||
provider
|
||||
template,
|
||||
provider.claims_imports.displayname.is_required(),
|
||||
)? {
|
||||
Some(value) => ctx
|
||||
.with_display_name(value, provider.claims_imports.displayname.is_forced()),
|
||||
None => ctx,
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = if provider.claims_imports.email.ignore() {
|
||||
ctx
|
||||
} else {
|
||||
let template = provider
|
||||
.claims_imports
|
||||
.email
|
||||
.template
|
||||
.as_deref()
|
||||
.unwrap_or("{{ user.email }}"),
|
||||
&provider.claims_imports.email,
|
||||
|value, force| {
|
||||
ctx.set_email(value, force);
|
||||
},
|
||||
)?;
|
||||
.unwrap_or(DEFAULT_EMAIL_TEMPLATE);
|
||||
|
||||
let mut forced_localpart = None;
|
||||
import_claim(
|
||||
match render_attribute_template(
|
||||
&env,
|
||||
provider
|
||||
template,
|
||||
provider.claims_imports.email.is_required(),
|
||||
)? {
|
||||
Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
|
||||
None => ctx,
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = if provider.claims_imports.localpart.ignore() {
|
||||
ctx
|
||||
} else {
|
||||
let template = provider
|
||||
.claims_imports
|
||||
.localpart
|
||||
.template
|
||||
.as_deref()
|
||||
.unwrap_or("{{ user.preferred_username }}"),
|
||||
&provider.claims_imports.localpart,
|
||||
|value, force| {
|
||||
if force {
|
||||
// We want to run the policy check on the username if it is forced
|
||||
forced_localpart = Some(value.clone());
|
||||
}
|
||||
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
|
||||
|
||||
ctx.set_localpart(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
// Run the policy check and check for existing users
|
||||
if let Some(localpart) = forced_localpart {
|
||||
match render_attribute_template(
|
||||
&env,
|
||||
template,
|
||||
provider.claims_imports.localpart.is_required(),
|
||||
)? {
|
||||
Some(localpart) => {
|
||||
// We could run policy & existing user checks when the user submits the
|
||||
// form, but this lead to poor UX. This is why we do
|
||||
// it ahead of time here.
|
||||
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
|
||||
if let Some(existing_user) = maybe_existing_user {
|
||||
// The mapper returned a username which already exists, but isn't linked to
|
||||
// this upstream user.
|
||||
// The mapper returned a username which already exists, but isn't linked
|
||||
// to this upstream user.
|
||||
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
|
||||
|
||||
// TODO: translate
|
||||
@@ -425,7 +441,12 @@ pub(crate) async fn get(
|
||||
Html(templates.render_error(&ctx)?).into_response(),
|
||||
));
|
||||
}
|
||||
|
||||
ctx.with_localpart(localpart, provider.claims_imports.localpart.is_forced())
|
||||
}
|
||||
None => ctx,
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
|
||||
|
||||
@@ -496,6 +517,8 @@ pub(crate) async fn post(
|
||||
|
||||
let session = match (maybe_user_session, link.user_id, form) {
|
||||
(Some(session), None, FormData::Link) => {
|
||||
// The user is already logged in, the link is not linked to any user, and the
|
||||
// user asked to link their account.
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&link, &session.user)
|
||||
.await?;
|
||||
@@ -512,6 +535,11 @@ pub(crate) async fn post(
|
||||
import_display_name,
|
||||
},
|
||||
) => {
|
||||
// The user got the form to register a new account, and is not logged in.
|
||||
// Depending on the claims_imports, we've let the user choose their username,
|
||||
// choose whether they want to import the email and display name, or
|
||||
// not.
|
||||
|
||||
// Those fields are Some("on") if the checkbox is checked
|
||||
let import_email = import_email.is_some();
|
||||
let import_display_name = import_display_name.is_some();
|
||||
@@ -531,6 +559,7 @@ pub(crate) async fn post(
|
||||
.map(|id_token| id_token.into_parts().1)
|
||||
.unwrap_or_default();
|
||||
|
||||
// Is the email verified according to the upstream provider?
|
||||
let provider_email_verified = payload
|
||||
.get_item(&minijinja::Value::from("email_verified"))
|
||||
.map(|v| v.is_true())
|
||||
@@ -538,77 +567,91 @@ pub(crate) async fn post(
|
||||
|
||||
// Let's try to import the claims from the ID token
|
||||
let env = {
|
||||
let mut e = Environment::new();
|
||||
let mut e = environment();
|
||||
e.add_global("user", payload);
|
||||
e
|
||||
};
|
||||
|
||||
// Create a template context in case we need to re-render because of an error
|
||||
let mut ctx = UpstreamRegister::default();
|
||||
let ctx = UpstreamRegister::default();
|
||||
|
||||
let mut name = None;
|
||||
import_claim(
|
||||
&env,
|
||||
provider
|
||||
let display_name = if provider
|
||||
.claims_imports
|
||||
.displayname
|
||||
.should_import(import_display_name)
|
||||
{
|
||||
let template = provider
|
||||
.claims_imports
|
||||
.displayname
|
||||
.template
|
||||
.as_deref()
|
||||
.unwrap_or("{{ user.name }}"),
|
||||
&provider.claims_imports.displayname,
|
||||
|value, force| {
|
||||
// Import the display name if it is either forced or the user has requested it
|
||||
if force || import_display_name {
|
||||
name = Some(value.clone());
|
||||
}
|
||||
.unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
|
||||
|
||||
ctx.set_display_name(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
let mut email = None;
|
||||
import_claim(
|
||||
render_attribute_template(
|
||||
&env,
|
||||
provider
|
||||
template,
|
||||
provider.claims_imports.displayname.is_required(),
|
||||
)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let ctx = if let Some(ref display_name) = display_name {
|
||||
ctx.with_display_name(
|
||||
display_name.clone(),
|
||||
provider.claims_imports.email.is_forced(),
|
||||
)
|
||||
} else {
|
||||
ctx
|
||||
};
|
||||
|
||||
let email = if provider.claims_imports.email.should_import(import_email) {
|
||||
let template = provider
|
||||
.claims_imports
|
||||
.email
|
||||
.template
|
||||
.as_deref()
|
||||
.unwrap_or("{{ user.email }}"),
|
||||
&provider.claims_imports.email,
|
||||
|value, force| {
|
||||
// Import the email if it is either forced or the user has requested it
|
||||
if force || import_email {
|
||||
email = Some(value.clone());
|
||||
}
|
||||
.unwrap_or(DEFAULT_EMAIL_TEMPLATE);
|
||||
|
||||
ctx.set_email(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
let mut username = username;
|
||||
import_claim(
|
||||
render_attribute_template(
|
||||
&env,
|
||||
provider
|
||||
template,
|
||||
provider.claims_imports.email.is_required(),
|
||||
)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let ctx = if let Some(ref email) = email {
|
||||
ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
|
||||
} else {
|
||||
ctx
|
||||
};
|
||||
|
||||
let forced_username = if provider.claims_imports.localpart.is_forced() {
|
||||
let template = provider
|
||||
.claims_imports
|
||||
.localpart
|
||||
.template
|
||||
.as_deref()
|
||||
.unwrap_or("{{ user.preferred_username }}"),
|
||||
&provider.claims_imports.localpart,
|
||||
|value, force| {
|
||||
// If the username is forced, override whatever was in the form
|
||||
if force {
|
||||
username = Some(value.clone());
|
||||
}
|
||||
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
|
||||
|
||||
ctx.set_localpart(value, force);
|
||||
},
|
||||
)?;
|
||||
render_attribute_template(
|
||||
&env,
|
||||
template,
|
||||
provider.claims_imports.email.is_required(),
|
||||
)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let username = username.filter(|s| !s.is_empty());
|
||||
// If there is no forced username, we can use the one the user entered
|
||||
let username = forced_username
|
||||
.or(username)
|
||||
.filter(|username| !username.is_empty());
|
||||
|
||||
let Some(username) = username else {
|
||||
// We're missing a username, let's re-render the form with an error
|
||||
let form_state = form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
FieldError::Required,
|
||||
@@ -625,11 +668,16 @@ pub(crate) async fn post(
|
||||
.into_response());
|
||||
};
|
||||
|
||||
let ctx = ctx.with_localpart(
|
||||
username.clone(),
|
||||
provider.claims_imports.localpart.is_forced(),
|
||||
);
|
||||
|
||||
// Check if there is an existing user
|
||||
let existing_user = repo.user().find_by_username(&username).await?;
|
||||
if let Some(_existing_user) = existing_user {
|
||||
// If there is an existing user, we can't create a new one
|
||||
// with the same username
|
||||
// with the same username, show an error
|
||||
|
||||
let form_state = form_state.with_error_on_field(
|
||||
mas_templates::UpstreamRegisterFormField::Username,
|
||||
@@ -687,7 +735,7 @@ pub(crate) async fn post(
|
||||
let mut job = ProvisionUserJob::new(&user);
|
||||
|
||||
// If we have a display name, set it during provisioning
|
||||
if let Some(name) = name {
|
||||
if let Some(name) = display_name {
|
||||
job = job.set_display_name(name);
|
||||
}
|
||||
|
||||
@@ -745,3 +793,172 @@ pub(crate) async fn post(
|
||||
|
||||
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hyper::{header::CONTENT_TYPE, Request, StatusCode};
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
|
||||
};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
|
||||
use mas_router::Route;
|
||||
use oauth2_types::scope::{Scope, OPENID};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::test_utils::{
|
||||
init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState,
|
||||
};
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_register(pool: PgPool) {
|
||||
init_tracing();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
let mut rng = state.rng();
|
||||
let cookies = CookieHelper::new();
|
||||
|
||||
let claims_imports = UpstreamOAuthProviderClaimsImports {
|
||||
localpart: UpstreamOAuthProviderImportPreference {
|
||||
action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
|
||||
template: None,
|
||||
},
|
||||
email: UpstreamOAuthProviderImportPreference {
|
||||
action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
|
||||
template: None,
|
||||
},
|
||||
..UpstreamOAuthProviderClaimsImports::default()
|
||||
};
|
||||
|
||||
let id_token = serde_json::json!({
|
||||
"preferred_username": "john",
|
||||
"email": "john@example.com",
|
||||
"email_verified": true,
|
||||
});
|
||||
|
||||
// Grab a key to sign the id_token
|
||||
// We could generate a key on the fly, but because we have one available here,
|
||||
// why not use it?
|
||||
let key = state
|
||||
.key_store
|
||||
.signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
|
||||
.unwrap();
|
||||
|
||||
let signer = key
|
||||
.params()
|
||||
.signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
|
||||
.unwrap();
|
||||
let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
|
||||
let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
|
||||
|
||||
// Provision a provider and a link
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&state.clock,
|
||||
"https://example.com/".to_owned(),
|
||||
Scope::from_iter([OPENID]),
|
||||
OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
"client".to_owned(),
|
||||
None,
|
||||
claims_imports,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.add(
|
||||
&mut rng,
|
||||
&state.clock,
|
||||
&provider,
|
||||
"state".to_owned(),
|
||||
None,
|
||||
"nonce".to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.add(&mut rng, &state.clock, &provider, "subject".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&state.clock, session, &link, Some(id_token.into_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
repo.save().await.unwrap();
|
||||
|
||||
let cookie_jar = state.cookie_jar();
|
||||
let upstream_sessions = UpstreamSessionsCookie::default()
|
||||
.add(session.id, provider.id, "state".to_owned(), None)
|
||||
.add_link_to_session(session.id, link.id)
|
||||
.unwrap();
|
||||
let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
|
||||
cookies.import(cookie_jar);
|
||||
|
||||
let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
|
||||
let request = cookies.with_cookies(request);
|
||||
let response = state.request(request).await;
|
||||
cookies.save_cookies(&response);
|
||||
response.assert_status(StatusCode::OK);
|
||||
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
|
||||
|
||||
// Extract the CSRF token from the response body
|
||||
let csrf_token = response
|
||||
.body()
|
||||
.split("name=\"csrf\" value=\"")
|
||||
.nth(1)
|
||||
.unwrap()
|
||||
.split('\"')
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
|
||||
serde_json::json!({
|
||||
"csrf": csrf_token,
|
||||
"action": "register",
|
||||
"import_email": "on",
|
||||
}),
|
||||
);
|
||||
let request = cookies.with_cookies(request);
|
||||
let response = state.request(request).await;
|
||||
cookies.save_cookies(&response);
|
||||
response.assert_status(StatusCode::SEE_OTHER);
|
||||
|
||||
// Check that we have a registered user, with the email imported
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username("john")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user exists");
|
||||
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.find_by_subject(&provider, "subject")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("link exists");
|
||||
|
||||
assert_eq!(link.user_id, Some(user.id));
|
||||
|
||||
let email = repo
|
||||
.user_email()
|
||||
.get_primary(&user)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("email exists");
|
||||
|
||||
assert_eq!(email.email, "john@example.com");
|
||||
assert!(email.confirmed_at.is_some());
|
||||
}
|
||||
}
|
||||
|
@@ -26,6 +26,7 @@ pub(crate) mod cache;
|
||||
pub(crate) mod callback;
|
||||
mod cookie;
|
||||
pub(crate) mod link;
|
||||
mod template;
|
||||
|
||||
use self::cookie::UpstreamSessions as UpstreamSessionsCookie;
|
||||
|
||||
|
122
crates/handlers/src/upstream_oauth2/template.rs
Normal file
122
crates/handlers/src/upstream_oauth2/template.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use base64ct::{Base64, Encoding};
|
||||
use minijinja::{Environment, Error, ErrorKind, Value};
|
||||
|
||||
fn split(value: &str, separator: Option<&str>) -> Vec<String> {
|
||||
value
|
||||
.split(separator.unwrap_or(" "))
|
||||
.map(ToOwned::to_owned)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn b64decode(value: &str) -> Result<Value, Error> {
|
||||
let bytes = Base64::decode_vec(value).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidOperation,
|
||||
"Failed to decode base64 string",
|
||||
)
|
||||
.with_source(e)
|
||||
})?;
|
||||
|
||||
// It is not obvious, but the cleanest way to get a Value stored as raw bytes is
|
||||
// to wrap it in an Arc, because Value implements From<Arc<Vec<u8>>>
|
||||
Ok(Value::from(Arc::new(bytes)))
|
||||
}
|
||||
|
||||
fn b64encode(bytes: &[u8]) -> String {
|
||||
Base64::encode_string(bytes)
|
||||
}
|
||||
|
||||
/// Decode a Tag-Length-Value encoded byte array into a map of tag to value.
|
||||
fn tlvdecode(bytes: &[u8]) -> Result<HashMap<u8, Value>, Error> {
|
||||
let mut iter = bytes.iter().copied();
|
||||
let mut ret = HashMap::new();
|
||||
loop {
|
||||
// TODO: this assumes the tag and the length are both single bytes, which is not
|
||||
// always the case with protobufs. We should properly decode varints
|
||||
// here.
|
||||
let Some(tag) = iter.next() else {
|
||||
break;
|
||||
};
|
||||
|
||||
let len = iter
|
||||
.next()
|
||||
.ok_or_else(|| Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding"))?;
|
||||
|
||||
let mut bytes = Vec::with_capacity(len.into());
|
||||
for _ in 0..len {
|
||||
bytes.push(
|
||||
iter.next().ok_or_else(|| {
|
||||
Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding")
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
ret.insert(tag, Value::from(Arc::new(bytes)));
|
||||
}
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
fn string(value: &Value) -> String {
|
||||
value.to_string()
|
||||
}
|
||||
|
||||
pub fn environment() -> Environment<'static> {
|
||||
let mut env = Environment::new();
|
||||
|
||||
env.add_filter("split", split);
|
||||
env.add_filter("b64decode", b64decode);
|
||||
env.add_filter("b64encode", b64encode);
|
||||
env.add_filter("tlvdecode", tlvdecode);
|
||||
env.add_filter("string", string);
|
||||
|
||||
env
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::environment;
|
||||
|
||||
#[test]
|
||||
fn test_split() {
|
||||
let env = environment();
|
||||
let res = env
|
||||
.render_str(r#"{{ 'foo, bar' | split(', ') | join(" | ") }}"#, ())
|
||||
.unwrap();
|
||||
assert_eq!(res, "foo | bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ilvdecode() {
|
||||
let env = environment();
|
||||
let res = env
|
||||
.render_str(
|
||||
r#"
|
||||
{%- set tlv = 'Cg0wLTM4NS0yODA4OS0wEgRtb2Nr' | b64decode | tlvdecode -%}
|
||||
{%- if tlv[18]|string != 'mock' -%}
|
||||
{{ "FAIL"/0 }}
|
||||
{%- endif -%}
|
||||
{{- tlv[10]|string -}}
|
||||
"#,
|
||||
(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(res, "0-385-28089-0");
|
||||
}
|
||||
}
|
@@ -954,24 +954,54 @@ impl UpstreamRegister {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the suggested localpart
|
||||
/// Set the imported localpart
|
||||
pub fn set_localpart(&mut self, localpart: String, force: bool) {
|
||||
self.imported_localpart = Some(localpart);
|
||||
self.force_localpart = force;
|
||||
}
|
||||
|
||||
/// Set the suggested display name
|
||||
/// Set the imported localpart
|
||||
#[must_use]
|
||||
pub fn with_localpart(self, localpart: String, force: bool) -> Self {
|
||||
Self {
|
||||
imported_localpart: Some(localpart),
|
||||
force_localpart: force,
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the imported display name
|
||||
pub fn set_display_name(&mut self, display_name: String, force: bool) {
|
||||
self.imported_display_name = Some(display_name);
|
||||
self.force_display_name = force;
|
||||
}
|
||||
|
||||
/// Set the suggested email
|
||||
/// Set the imported display name
|
||||
#[must_use]
|
||||
pub fn with_display_name(self, display_name: String, force: bool) -> Self {
|
||||
Self {
|
||||
imported_display_name: Some(display_name),
|
||||
force_display_name: force,
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the imported email
|
||||
pub fn set_email(&mut self, email: String, force: bool) {
|
||||
self.imported_email = Some(email);
|
||||
self.force_email = force;
|
||||
}
|
||||
|
||||
/// Set the imported email
|
||||
#[must_use]
|
||||
pub fn with_email(self, email: String, force: bool) -> Self {
|
||||
Self {
|
||||
imported_email: Some(email),
|
||||
force_email: force,
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the form state
|
||||
pub fn set_form_state(&mut self, form_state: FormState<UpstreamRegisterFormField>) {
|
||||
self.form_state = form_state;
|
||||
|
Reference in New Issue
Block a user