1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Save the imported attributes

This commit is contained in:
Quentin Gliech
2023-06-21 18:50:03 +02:00
parent c183830489
commit 31788a95f2
4 changed files with 136 additions and 12 deletions

View File

@@ -29,7 +29,7 @@ use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob}, job::{JobRepositoryExt, ProvisionUserJob},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, RepositoryAccess, BoxClock, BoxRepository, BoxRng, RepositoryAccess,
}; };
use mas_templates::{ use mas_templates::{
@@ -100,6 +100,8 @@ impl IntoResponse for RouteError {
struct StandardClaims { struct StandardClaims {
name: Option<String>, name: Option<String>,
email: Option<String>, email: Option<String>,
#[serde(default)]
email_verified: bool,
preferred_username: Option<String>, preferred_username: Option<String>,
} }
@@ -144,7 +146,13 @@ fn import_claim(
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "lowercase", tag = "action")] #[serde(rename_all = "lowercase", tag = "action")]
pub(crate) enum FormData { pub(crate) enum FormData {
Register { username: String }, Register {
username: String,
#[serde(default)]
import_email: Option<String>,
#[serde(default)]
import_display_name: Option<String>,
},
Link, Link,
Login, Login,
} }
@@ -298,7 +306,7 @@ pub(crate) async fn get(
)?; )?;
import_claim( import_claim(
"username", "preferred_username",
payload.preferred_username, payload.preferred_username,
&provider.claims_imports.localpart, &provider.claims_imports.localpart,
|value, force| { |value, force| {
@@ -384,13 +392,103 @@ pub(crate) async fn post(
repo.browser_session().add(&mut rng, &clock, &user).await? repo.browser_session().add(&mut rng, &clock, &user).await?
} }
(None, None, FormData::Register { username }) => { (
None,
None,
FormData::Register {
username,
import_email,
import_display_name,
},
) => {
// Those fields are Some("on") if the checkbox is checked
let import_email = import_email.is_some();
let import_display_name = import_display_name.is_some();
let id_token = upstream_session
.id_token()
.map(Jwt::<'_, StandardClaims>::try_from)
.transpose()?;
let provider = repo
.upstream_oauth_provider()
.lookup(link.provider_id)
.await?
.ok_or(RouteError::ProviderNotFound)?;
let payload = id_token
.map(|id_token| id_token.into_parts().1)
.unwrap_or_default();
// Let's try to import the claims from the ID token
let mut name = None;
import_claim(
"name",
payload.name,
&provider.claims_imports.displayname,
|value, force| {
// Import the display name if it is either forced or the user has requested it
if force || import_display_name {
name = Some(value);
}
},
)?;
let mut email = None;
import_claim(
"email",
payload.email,
&provider.claims_imports.email,
|value, force| {
// Import the email if it is either forced or the user has requested it
if force || import_email {
email = Some(value);
}
},
)?;
let mut username = username;
import_claim(
"preferred_username",
payload.preferred_username,
&provider.claims_imports.localpart,
|value, force| {
// If the username is forced, override whatever was in the form
if force {
username = value;
}
},
)?;
// Now we can create the user
let user = repo.user().add(&mut rng, &clock, username).await?; let user = repo.user().add(&mut rng, &clock, username).await?;
repo.job() // And schedule the job to provision it
.schedule_job(ProvisionUserJob::new(&user)) let mut job = ProvisionUserJob::new(&user);
// If we have a display name, set it during provisioning
if let Some(name) = name {
job = job.set_display_name(name);
}
repo.job().schedule_job(job).await?;
// If we have an email, add it to the user
if let Some(email) = email {
let user_email = repo
.user_email()
.add(&mut rng, &clock, &user, email)
.await?; .await?;
// Mark the email as verified if the upstream provider says it is.
if payload.email_verified {
repo.user_email()
.mark_as_verified(&clock, user_email)
.await?;
}
}
repo.upstream_oauth_link() repo.upstream_oauth_link()
.associate_to_user(&link, &user) .associate_to_user(&link, &user)
.await?; .await?;

View File

@@ -256,13 +256,30 @@ mod jobs {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ProvisionUserJob { pub struct ProvisionUserJob {
user_id: Ulid, user_id: Ulid,
set_display_name: Option<String>,
} }
impl ProvisionUserJob { impl ProvisionUserJob {
/// Create a new job to provision the user on the homeserver. /// Create a new job to provision the user on the homeserver.
#[must_use] #[must_use]
pub fn new(user: &User) -> Self { pub fn new(user: &User) -> Self {
Self { user_id: user.id } Self {
user_id: user.id,
set_display_name: None,
}
}
/// Set the display name of the user.
#[must_use]
pub fn set_display_name(mut self, display_name: String) -> Self {
self.set_display_name = Some(display_name);
self
}
/// Get the display name to be set.
#[must_use]
pub fn display_name_to_set(&self) -> Option<&str> {
self.set_display_name.as_deref()
} }
/// The ID of the user to provision. /// The ID of the user to provision.

View File

@@ -69,7 +69,12 @@ async fn provision_user(
repo.cancel().await?; repo.cancel().await?;
let request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
if let Some(display_name) = job.display_name_to_set() {
request = request.set_displayname(display_name.to_owned());
}
let created = matrix.provision_user(&request).await?; let created = matrix.provision_user(&request).await?;
if created { if created {

View File

@@ -21,7 +21,11 @@ limitations under the License.
<div class="grid grid-cols-1 gap-6 w-96"> <div class="grid grid-cols-1 gap-6 w-96">
<form method="POST" class="grid grid-cols-1 gap-6"> <form method="POST" class="grid grid-cols-1 gap-6">
<h1 class="rounded-lg bg-grey-25 dark:bg-grey-450 p-2 text-center font-medium text-lg"> <h1 class="rounded-lg bg-grey-25 dark:bg-grey-450 p-2 text-center font-medium text-lg">
{% if force_localpart %}
Create a new account
{% else %}
Choose your username Choose your username
{% endif %}
</h1> </h1>
<input type="hidden" name="csrf" value="{{ csrf_token }}" /> <input type="hidden" name="csrf" value="{{ csrf_token }}" />
@@ -34,7 +38,7 @@ limitations under the License.
{% if force_email %} {% if force_email %}
Will import the following email address Will import the following email address
{% else %} {% else %}
<input type="checkbox" name="import_email" id="import_email" value="1" checked="checked" /> <input type="checkbox" name="import_email" id="import_email" checked="checked" />
<label for="import_email">Import email address</label> <label for="import_email">Import email address</label>
{% endif %} {% endif %}
</div> </div>
@@ -48,7 +52,7 @@ limitations under the License.
{% if force_display_name %} {% if force_display_name %}
Will import the following display name Will import the following display name
{% else %} {% else %}
<input type="checkbox" name="import_display_name" id="import_display_name" value="1" checked="checked" /> <input type="checkbox" name="import_display_name" id="import_display_name" checked="checked" />
<label for="import_display_name">Import display name</label> <label for="import_display_name">Import display name</label>
{% endif %} {% endif %}
</div> </div>