1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Ground work to import upstream OIDC claims on registration.

This commit is contained in:
Quentin Gliech
2023-06-21 18:09:46 +02:00
parent 00ea31b9c9
commit c183830489
11 changed files with 481 additions and 161 deletions

View File

@ -47,7 +47,8 @@ pub use self::{
}, },
upstream_oauth2::{ upstream_oauth2::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
UpstreamOAuthProviderImportPreference,
}, },
users::{ users::{
Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification, Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification,

View File

@ -18,6 +18,9 @@ mod session;
pub use self::{ pub use self::{
link::UpstreamOAuthLink, link::UpstreamOAuthLink,
provider::UpstreamOAuthProvider, provider::{
ClaimsImports as UpstreamOAuthProviderClaimsImports,
ImportPreference as UpstreamOAuthProviderImportPreference, UpstreamOAuthProvider,
},
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState}, session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
}; };

View File

@ -15,7 +15,7 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use serde::Serialize; use serde::{Deserialize, Serialize};
use ulid::Ulid; use ulid::Ulid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
@ -28,4 +28,62 @@ pub struct UpstreamOAuthProvider {
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>, pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub claims_imports: ClaimsImports,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct ClaimsImports {
#[serde(default)]
pub localpart: ImportPreference,
#[serde(default)]
pub displayname: ImportPreference,
#[serde(default)]
pub email: ImportPreference,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct ImportPreference {
#[serde(default)]
pub action: ImportAction,
}
impl std::ops::Deref for ImportPreference {
type Target = ImportAction;
fn deref(&self) -> &Self::Target {
&self.action
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ImportAction {
/// Ignore the claim
#[default]
Ignore,
/// Suggest the claim value, but allow the user to change it
Suggest,
/// Force the claim value, but don't fail if it is missing
Force,
/// Force the claim value, and fail if it is missing
Require,
}
impl ImportAction {
pub fn is_forced(&self) -> bool {
matches!(self, Self::Force | Self::Require)
}
pub fn ignore(&self) -> bool {
matches!(self, Self::Ignore)
}
pub fn is_required(&self) -> bool {
matches!(self, Self::Require)
}
} }

View File

@ -143,7 +143,9 @@ impl UpstreamSessions {
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> { ) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
self.0 self.0
.iter() .iter()
.find(|p| p.link == Some(link_id)) .filter(|p| p.link == Some(link_id))
// Find the session with the highest ID, aka. the most recent one
.reduce(|a, b| if a.session > b.session { a } else { b })
.map(|p| (p.session, p.post_auth_action.as_ref())) .map(|p| (p.session, p.post_auth_action.as_ref()))
.ok_or(UpstreamSessionNotFound) .ok_or(UpstreamSessionNotFound)
} }

View File

@ -23,6 +23,8 @@ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
SessionInfoExt, SessionInfoExt,
}; };
use mas_data_model::UpstreamOAuthProviderImportPreference;
use mas_jose::jwt::Jwt;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob}, job::{JobRepositoryExt, ProvisionUserJob},
@ -55,6 +57,14 @@ pub(crate) enum RouteError {
#[error("User not found")] #[error("User not found")]
UserNotFound, UserNotFound,
/// Couldn't find upstream provider
#[error("Upstream provider not found")]
ProviderNotFound,
/// Required claim was missing in id_token
#[error("Required claim {0:?} missing from the upstream provider's response")]
RequiredClaimMissing(&'static str),
/// Session was already consumed /// Session was already consumed
#[error("Session already consumed")] #[error("Session already consumed")]
SessionConsumed, SessionConsumed,
@ -73,6 +83,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -85,6 +96,51 @@ impl IntoResponse for RouteError {
} }
} }
#[derive(Deserialize, Default)]
struct StandardClaims {
name: Option<String>,
email: Option<String>,
preferred_username: Option<String>,
}
/// Utility function to import a claim from the upstream provider's response,
/// based on the preference for that attribute.
///
/// # Parameters
///
/// * `name` - The name of the claim, for error reporting
/// * `value` - The value of the claim, if present
/// * `preference` - The preference for this claim
/// * `run` - A function to run if the claim is present. The first argument is
/// the value of the claim, and the second is whether the claim is forced to
/// be used.
///
/// # Errors
///
/// Returns an error if the claim is required but missing.
fn import_claim(
name: &'static str,
value: Option<String>,
preference: &UpstreamOAuthProviderImportPreference,
mut run: impl FnMut(String, bool) -> (),
) -> Result<(), RouteError> {
// If this claim is ignored, we don't need to do anything.
if preference.ignore() {
return Ok(());
}
// If this claim is required and missing, we can't continue.
if value.is_none() && preference.is_required() {
return Err(RouteError::RequiredClaimMissing(name));
}
if let Some(value) = value {
run(value, preference.is_forced());
}
Ok(())
}
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "lowercase", tag = "action")] #[serde(rename_all = "lowercase", tag = "action")]
pub(crate) enum FormData { pub(crate) enum FormData {
@ -206,7 +262,51 @@ pub(crate) async fn get(
(None, None) => { (None, None) => {
// Session not linked and used not logged in: suggest creating an // Session not linked and used not logged in: suggest creating an
// account or logging in an existing user // account or logging in an existing user
let ctx = UpstreamRegister::new(&link).with_csrf(csrf_token.form_value()); let id_token = upstream_session
.id_token()
.map(Jwt::<'_, StandardClaims>::try_from)
.transpose()?;
let provider = repo
.upstream_oauth_provider()
.lookup(link.provider_id)
.await?
.ok_or(RouteError::ProviderNotFound)?;
let payload = id_token
.map(|id_token| id_token.into_parts().1)
.unwrap_or_default();
let mut ctx = UpstreamRegister::new(&link);
import_claim(
"name",
payload.name,
&provider.claims_imports.displayname,
|value, force| {
ctx.set_display_name(value, force);
},
)?;
import_claim(
"email",
payload.email,
&provider.claims_imports.email,
|value, force| {
ctx.set_email(value, force);
},
)?;
import_claim(
"username",
payload.preferred_username,
&provider.claims_imports.localpart,
|value, force| {
ctx.set_localpart(value, force);
},
)?;
let ctx = ctx.with_csrf(csrf_token.form_value());
templates.render_upstream_oauth2_do_register(&ctx).await? templates.render_upstream_oauth2_do_register(&ctx).await?
} }

View File

@ -0,0 +1,19 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
ALTER TABLE upstream_oauth_providers
ADD COLUMN claims_imports
JSONB
NOT NULL
DEFAULT '{}';

View File

@ -102,66 +102,6 @@
}, },
"query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n " "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE user_id = $1\n "
}, },
"154e2e4488ff87e09163698750b56a43127cee4e1392785416a586d40a4d9b21": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false
],
"parameters": {
"Left": []
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n "
},
"1a8701f5672de052bb766933f60b93249acc7237b996e8b93cd61b9f69c902ff": { "1a8701f5672de052bb766933f60b93249acc7237b996e8b93cd61b9f69c902ff": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -225,25 +165,6 @@
}, },
"query": "\n SELECT user_email_confirmation_code_id\n , user_email_id\n , code\n , created_at\n , expires_at\n , consumed_at\n FROM user_email_confirmation_codes\n WHERE code = $1\n AND user_email_id = $2\n " "query": "\n SELECT user_email_confirmation_code_id\n , user_email_id\n , code\n , created_at\n , expires_at\n , consumed_at\n FROM user_email_confirmation_codes\n WHERE code = $1\n AND user_email_id = $2\n "
}, },
"1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n "
},
"1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc": { "1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -800,6 +721,26 @@
}, },
"query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n "
}, },
"6021c1b9e17b0b2e8b511888f8c6be00683ba0635a13eb7fcd403d3d4a3f90db": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz",
"Jsonb"
]
}
},
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at,\n claims_imports\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n "
},
"64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": { "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -817,6 +758,74 @@
}, },
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n " "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n "
}, },
"6733c54a8d9ed93a760f365a9362fdb0f77340d7a4df642a2942174aba2c6502": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
},
{
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"ordinal": 8,
"type_info": "Jsonb"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n "
},
"67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c": { "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c": {
"describe": { "describe": {
"columns": [ "columns": [
@ -1410,68 +1419,6 @@
}, },
"query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n "
}, },
"8f7a9fb1f24c24f8dbc3c193df2a742c9ac730ab958587b67297de2d4b843863": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n "
},
"90b5512c0c9dc3b3eb6500056cc72f9993216d9b553c2e33a7edec26ffb0fc59": { "90b5512c0c9dc3b3eb6500056cc72f9993216d9b553c2e33a7edec26ffb0fc59": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -1717,6 +1664,72 @@
}, },
"query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n "
}, },
"af65441068530b68826561d4308e15923ba6c6882ded4860ebde4a7641359abb": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
},
{
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"ordinal": 8,
"type_info": "Jsonb"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false,
false
],
"parameters": {
"Left": []
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\"\n FROM upstream_oauth_providers\n "
},
"afa86e79e3de2a83265cb0db8549d378a2f11b2a27bbd86d60558318c87eb698": { "afa86e79e3de2a83265cb0db8549d378a2f11b2a27bbd86d60558318c87eb698": {
"describe": { "describe": {
"columns": [], "columns": [],

View File

@ -14,12 +14,12 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::UpstreamOAuthProvider; use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination}; use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination};
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use rand::RngCore; use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder}; use sqlx::{types::Json, PgConnection, QueryBuilder};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
@ -52,6 +52,7 @@ struct ProviderLookup {
token_endpoint_signing_alg: Option<String>, token_endpoint_signing_alg: Option<String>,
token_endpoint_auth_method: String, token_endpoint_auth_method: String,
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
} }
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider { impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
@ -90,6 +91,7 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
token_endpoint_auth_method, token_endpoint_auth_method,
token_endpoint_signing_alg, token_endpoint_signing_alg,
created_at: value.created_at, created_at: value.created_at,
claims_imports: value.claims_imports.0,
}) })
} }
} }
@ -119,7 +121,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
encrypted_client_secret, encrypted_client_secret,
token_endpoint_signing_alg, token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method,
created_at created_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>"
FROM upstream_oauth_providers FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1 WHERE upstream_oauth_provider_id = $1
"#, "#,
@ -165,6 +168,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
let id = Ulid::from_datetime_with_source(created_at.into(), rng); let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
let claims_imports = UpstreamOAuthProviderClaimsImports::default();
sqlx::query!( sqlx::query!(
r#" r#"
INSERT INTO upstream_oauth_providers ( INSERT INTO upstream_oauth_providers (
@ -175,8 +180,9 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg, token_endpoint_signing_alg,
client_id, client_id,
encrypted_client_secret, encrypted_client_secret,
created_at created_at,
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) claims_imports
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
"#, "#,
Uuid::from(id), Uuid::from(id),
&issuer, &issuer,
@ -186,6 +192,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
&client_id, &client_id,
encrypted_client_secret.as_deref(), encrypted_client_secret.as_deref(),
created_at, created_at,
Json(&claims_imports) as _,
) )
.traced() .traced()
.execute(&mut *self.conn) .execute(&mut *self.conn)
@ -200,6 +207,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg, token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method,
created_at, created_at,
claims_imports,
}) })
} }
@ -225,7 +233,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
encrypted_client_secret, encrypted_client_secret,
token_endpoint_signing_alg, token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method,
created_at created_at,
claims_imports
FROM upstream_oauth_providers FROM upstream_oauth_providers
WHERE 1 = 1 WHERE 1 = 1
"#, "#,
@ -263,7 +272,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
encrypted_client_secret, encrypted_client_secret,
token_endpoint_signing_alg, token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method,
created_at created_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>"
FROM upstream_oauth_providers FROM upstream_oauth_providers
"#, "#,
) )

View File

@ -856,6 +856,12 @@ impl TemplateContext for UpstreamSuggestLink {
#[derive(Serialize)] #[derive(Serialize)]
pub struct UpstreamRegister { pub struct UpstreamRegister {
login_link: String, login_link: String,
suggested_localpart: Option<String>,
force_localpart: bool,
suggested_display_name: Option<String>,
force_display_name: bool,
suggested_email: Option<String>,
force_email: bool,
} }
impl UpstreamRegister { impl UpstreamRegister {
@ -865,12 +871,38 @@ impl UpstreamRegister {
Self::for_link_id(link.id) Self::for_link_id(link.id)
} }
/// Set the suggested localpart
pub fn set_localpart(&mut self, localpart: String, force: bool) {
self.suggested_localpart = Some(localpart);
self.force_localpart = force;
}
/// Set the suggested display name
pub fn set_display_name(&mut self, display_name: String, force: bool) {
self.suggested_display_name = Some(display_name);
self.force_display_name = force;
}
/// Set the suggested email
pub fn set_email(&mut self, email: String, force: bool) {
self.suggested_email = Some(email);
self.force_email = force;
}
fn for_link_id(id: Ulid) -> Self { fn for_link_id(id: Ulid) -> Self {
let login_link = mas_router::Login::and_link_upstream(id) let login_link = mas_router::Login::and_link_upstream(id)
.relative_url() .relative_url()
.into(); .into();
Self { login_link } Self {
login_link,
suggested_localpart: None,
force_localpart: false,
suggested_display_name: None,
force_display_name: false,
suggested_email: None,
force_email: false,
}
} }
} }

View File

@ -50,14 +50,68 @@ limitations under the License.
<a class="{{ self::outline_class() }} {{ class }}" href="{{ href }}">{{ text }}</a> <a class="{{ self::outline_class() }} {{ class }}" href="{{ href }}">{{ text }}</a>
{% endmacro %} {% endmacro %}
{% macro button(text, name="", type="submit", class="", value="") %} {% macro button(
<button name="{{ name }}" value="{{ value }}" type="{{ type }}" class="{{ self::plain_class() }} {{ class }}">{{ text }}</button> text,
name="",
type="submit",
class="",
value="",
disabled=False,
autocomplete=False,
autocorrect=False,
autocapitalize=False) %}
<button
name="{{ name }}"
value="{{ value }}"
type="{{ type }}"
{% if disabled %}disabled{% endif %}
class="{{ self::plain_class() }} {{ class }}"
{% if autocapitalize %}autocapitilize="{{ autocapitilize }}"{% endif %}
{% if autocomplete %}autocomplete="{{ autocomplete }}"{% endif %}
{% if autocorrect %}autocorrect="{{ autocorrect }}"{% endif %}
>{{ text }}</button>
{% endmacro %} {% endmacro %}
{% macro button_text(text, name="", type="submit", class="", value="") %} {% macro button_text(
<button name="{{ name }}" value="{{ value }}" type="{{ type }}" class="{{ self::text_class() }} {{ class }}">{{ text }}</button> text,
name="",
type="submit",
class="",
value="",
disabled=False,
autocomplete=False,
autocorrect=False,
autocapitalize=False) %}
<button
name="{{ name }}"
value="{{ value }}"
type="{{ type }}"
{% if disabled %}disabled{% endif %}
class="{{ self::text_class() }} {{ class }}"
{% if autocapitalize %}autocapitilize="{{ autocapitilize }}"{% endif %}
{% if autocomplete %}autocomplete="{{ autocomplete }}"{% endif %}
{% if autocorrect %}autocorrect="{{ autocorrect }}"{% endif %}
>{{ text }}</button>
{% endmacro %} {% endmacro %}
{% macro button_outline(text, name="", type="submit", class="", value="") %} {% macro button_outline(
<button name="{{ name }}" value="{{ value }}" type="{{ type }}" class="{{ self::outline_class() }} {{ class }}">{{ text }}</button> text,
name="",
type="submit",
class="",
value="",
disabled=False,
autocomplete=False,
autocorrect=False,
autocapitalize=False) %}
<button
name="{{ name }}"
value="{{ value }}"
type="{{ type }}"
{% if disabled %}disabled{% endif %}
{% if autocapitalize %}autocapitilize="{{ autocapitilize }}"{% endif %}
{% if autocomplete %}autocomplete="{{ autocomplete }}"{% endif %}
{% if autocorrect %}autocorrect="{{ autocorrect }}"{% endif %}
class="{{ self::outline_class() }} {{ class }}"
>{{ text }}</button>
{% endmacro %} {% endmacro %}

View File

@ -26,7 +26,35 @@ limitations under the License.
<input type="hidden" name="csrf" value="{{ csrf_token }}" /> <input type="hidden" name="csrf" value="{{ csrf_token }}" />
<input type="hidden" name="action" value="register" /> <input type="hidden" name="action" value="register" />
{{ field::input(label="Username", name="username", autocomplete="username", autocorrect="off", autocapitalize="none") }} {{ field::input(label="Username", name="username", autocomplete="username", autocorrect="off", autocapitalize="none", disabled=force_localpart, value=suggested_localpart) }}
{% if suggested_email %}
<div class="rounded-lg bg-grey-25 dark:bg-grey-450 p-4">
<div class="font-medium">
{% if force_email %}
Will import the following email address
{% else %}
<input type="checkbox" name="import_email" id="import_email" value="1" checked="checked" />
<label for="import_email">Import email address</label>
{% endif %}
</div>
<div class="font-mono">{{ suggested_email }}</div>
</div>
{% endif %}
{% if suggested_display_name %}
<div class="rounded-lg bg-grey-25 dark:bg-grey-450 p-4">
<div class="font-medium">
{% if force_display_name %}
Will import the following display name
{% else %}
<input type="checkbox" name="import_display_name" id="import_display_name" value="1" checked="checked" />
<label for="import_display_name">Import display name</label>
{% endif %}
</div>
<div class="font-mono">{{ suggested_display_name }}</div>
</div>
{% endif %}
{{ button::button(text="Create a new account") }} {{ button::button(text="Create a new account") }}
</form> </form>