You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Ground work to import upstream OIDC claims on registration.
This commit is contained in:
@@ -143,7 +143,9 @@ impl UpstreamSessions {
|
||||
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
|
||||
self.0
|
||||
.iter()
|
||||
.find(|p| p.link == Some(link_id))
|
||||
.filter(|p| p.link == Some(link_id))
|
||||
// Find the session with the highest ID, aka. the most recent one
|
||||
.reduce(|a, b| if a.session > b.session { a } else { b })
|
||||
.map(|p| (p.session, p.post_auth_action.as_ref()))
|
||||
.ok_or(UpstreamSessionNotFound)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ use mas_axum_utils::{
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
SessionInfoExt,
|
||||
};
|
||||
use mas_data_model::UpstreamOAuthProviderImportPreference;
|
||||
use mas_jose::jwt::Jwt;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{
|
||||
job::{JobRepositoryExt, ProvisionUserJob},
|
||||
@@ -55,6 +57,14 @@ pub(crate) enum RouteError {
|
||||
#[error("User not found")]
|
||||
UserNotFound,
|
||||
|
||||
/// Couldn't find upstream provider
|
||||
#[error("Upstream provider not found")]
|
||||
ProviderNotFound,
|
||||
|
||||
/// Required claim was missing in id_token
|
||||
#[error("Required claim {0:?} missing from the upstream provider's response")]
|
||||
RequiredClaimMissing(&'static str),
|
||||
|
||||
/// Session was already consumed
|
||||
#[error("Session already consumed")]
|
||||
SessionConsumed,
|
||||
@@ -73,6 +83,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
@@ -85,6 +96,51 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
struct StandardClaims {
|
||||
name: Option<String>,
|
||||
email: Option<String>,
|
||||
preferred_username: Option<String>,
|
||||
}
|
||||
|
||||
/// Utility function to import a claim from the upstream provider's response,
|
||||
/// based on the preference for that attribute.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `name` - The name of the claim, for error reporting
|
||||
/// * `value` - The value of the claim, if present
|
||||
/// * `preference` - The preference for this claim
|
||||
/// * `run` - A function to run if the claim is present. The first argument is
|
||||
/// the value of the claim, and the second is whether the claim is forced to
|
||||
/// be used.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the claim is required but missing.
|
||||
fn import_claim(
|
||||
name: &'static str,
|
||||
value: Option<String>,
|
||||
preference: &UpstreamOAuthProviderImportPreference,
|
||||
mut run: impl FnMut(String, bool) -> (),
|
||||
) -> Result<(), RouteError> {
|
||||
// If this claim is ignored, we don't need to do anything.
|
||||
if preference.ignore() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If this claim is required and missing, we can't continue.
|
||||
if value.is_none() && preference.is_required() {
|
||||
return Err(RouteError::RequiredClaimMissing(name));
|
||||
}
|
||||
|
||||
if let Some(value) = value {
|
||||
run(value, preference.is_forced());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "lowercase", tag = "action")]
|
||||
pub(crate) enum FormData {
|
||||
@@ -206,7 +262,51 @@ pub(crate) async fn get(
|
||||
(None, None) => {
|
||||
// Session not linked and used not logged in: suggest creating an
|
||||
// account or logging in an existing user
|
||||
let ctx = UpstreamRegister::new(&link).with_csrf(csrf_token.form_value());
|
||||
let id_token = upstream_session
|
||||
.id_token()
|
||||
.map(Jwt::<'_, StandardClaims>::try_from)
|
||||
.transpose()?;
|
||||
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.lookup(link.provider_id)
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
let payload = id_token
|
||||
.map(|id_token| id_token.into_parts().1)
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut ctx = UpstreamRegister::new(&link);
|
||||
|
||||
import_claim(
|
||||
"name",
|
||||
payload.name,
|
||||
&provider.claims_imports.displayname,
|
||||
|value, force| {
|
||||
ctx.set_display_name(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
import_claim(
|
||||
"email",
|
||||
payload.email,
|
||||
&provider.claims_imports.email,
|
||||
|value, force| {
|
||||
ctx.set_email(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
import_claim(
|
||||
"username",
|
||||
payload.preferred_username,
|
||||
&provider.claims_imports.localpart,
|
||||
|value, force| {
|
||||
ctx.set_localpart(value, force);
|
||||
},
|
||||
)?;
|
||||
|
||||
let ctx = ctx.with_csrf(csrf_token.form_value());
|
||||
|
||||
templates.render_upstream_oauth2_do_register(&ctx).await?
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user