diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 0508911a..0248bad9 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -29,7 +29,7 @@ use mas_keystore::Encrypter; use mas_storage::{ job::{JobRepositoryExt, ProvisionUserJob}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, - user::{BrowserSessionRepository, UserRepository}, + user::{BrowserSessionRepository, UserEmailRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ @@ -100,6 +100,8 @@ impl IntoResponse for RouteError { struct StandardClaims { name: Option, email: Option, + #[serde(default)] + email_verified: bool, preferred_username: Option, } @@ -144,7 +146,13 @@ fn import_claim( #[derive(Deserialize)] #[serde(rename_all = "lowercase", tag = "action")] pub(crate) enum FormData { - Register { username: String }, + Register { + username: String, + #[serde(default)] + import_email: Option, + #[serde(default)] + import_display_name: Option, + }, Link, Login, } @@ -298,7 +306,7 @@ pub(crate) async fn get( )?; import_claim( - "username", + "preferred_username", payload.preferred_username, &provider.claims_imports.localpart, |value, force| { @@ -384,12 +392,102 @@ pub(crate) async fn post( repo.browser_session().add(&mut rng, &clock, &user).await? } - (None, None, FormData::Register { username }) => { + ( + None, + None, + FormData::Register { + username, + import_email, + import_display_name, + }, + ) => { + // Those fields are Some("on") if the checkbox is checked + let import_email = import_email.is_some(); + let import_display_name = import_display_name.is_some(); + + let id_token = upstream_session + .id_token() + .map(Jwt::<'_, StandardClaims>::try_from) + .transpose()?; + + let provider = repo + .upstream_oauth_provider() + .lookup(link.provider_id) + .await? + .ok_or(RouteError::ProviderNotFound)?; + + let payload = id_token + .map(|id_token| id_token.into_parts().1) + .unwrap_or_default(); + + // Let's try to import the claims from the ID token + + let mut name = None; + import_claim( + "name", + payload.name, + &provider.claims_imports.displayname, + |value, force| { + // Import the display name if it is either forced or the user has requested it + if force || import_display_name { + name = Some(value); + } + }, + )?; + + let mut email = None; + import_claim( + "email", + payload.email, + &provider.claims_imports.email, + |value, force| { + // Import the email if it is either forced or the user has requested it + if force || import_email { + email = Some(value); + } + }, + )?; + + let mut username = username; + import_claim( + "preferred_username", + payload.preferred_username, + &provider.claims_imports.localpart, + |value, force| { + // If the username is forced, override whatever was in the form + if force { + username = value; + } + }, + )?; + + // Now we can create the user let user = repo.user().add(&mut rng, &clock, username).await?; - repo.job() - .schedule_job(ProvisionUserJob::new(&user)) - .await?; + // And schedule the job to provision it + let mut job = ProvisionUserJob::new(&user); + + // If we have a display name, set it during provisioning + if let Some(name) = name { + job = job.set_display_name(name); + } + + repo.job().schedule_job(job).await?; + + // If we have an email, add it to the user + if let Some(email) = email { + let user_email = repo + .user_email() + .add(&mut rng, &clock, &user, email) + .await?; + + // Mark the email as verified if the upstream provider says it is. + if payload.email_verified { + repo.user_email() + .mark_as_verified(&clock, user_email) + .await?; + } + } repo.upstream_oauth_link() .associate_to_user(&link, &user) diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs index 8a5d3564..754e4994 100644 --- a/crates/storage/src/job.rs +++ b/crates/storage/src/job.rs @@ -256,13 +256,30 @@ mod jobs { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ProvisionUserJob { user_id: Ulid, + set_display_name: Option, } impl ProvisionUserJob { /// Create a new job to provision the user on the homeserver. #[must_use] pub fn new(user: &User) -> Self { - Self { user_id: user.id } + Self { + user_id: user.id, + set_display_name: None, + } + } + + /// Set the display name of the user. + #[must_use] + pub fn set_display_name(mut self, display_name: String) -> Self { + self.set_display_name = Some(display_name); + self + } + + /// Get the display name to be set. + #[must_use] + pub fn display_name_to_set(&self) -> Option<&str> { + self.set_display_name.as_deref() } /// The ID of the user to provision. diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index b6158e45..649f297d 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -69,7 +69,12 @@ async fn provision_user( repo.cancel().await?; - let request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); + let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); + + if let Some(display_name) = job.display_name_to_set() { + request = request.set_displayname(display_name.to_owned()); + } + let created = matrix.provision_user(&request).await?; if created { diff --git a/templates/pages/upstream_oauth2/do_register.html b/templates/pages/upstream_oauth2/do_register.html index bb148f52..84e4f7e2 100644 --- a/templates/pages/upstream_oauth2/do_register.html +++ b/templates/pages/upstream_oauth2/do_register.html @@ -21,7 +21,11 @@ limitations under the License.

- Choose your username + {% if force_localpart %} + Create a new account + {% else %} + Choose your username + {% endif %}

@@ -34,7 +38,7 @@ limitations under the License. {% if force_email %} Will import the following email address {% else %} - + {% endif %}
@@ -48,7 +52,7 @@ limitations under the License. {% if force_display_name %} Will import the following display name {% else %} - + {% endif %}