You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-07 22:41:18 +03:00
Run the registration policy on upstream OAuth registration
This commit is contained in:
@ -24,6 +24,7 @@ pub struct FancyError {
|
||||
}
|
||||
|
||||
impl FancyError {
|
||||
#[must_use]
|
||||
pub fn new(context: ErrorContext) -> Self {
|
||||
Self { context }
|
||||
}
|
||||
|
@ -21,10 +21,11 @@ use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
SessionInfoExt,
|
||||
FancyError, SessionInfoExt,
|
||||
};
|
||||
use mas_data_model::{UpstreamOAuthProviderImportPreference, User};
|
||||
use mas_jose::jwt::Jwt;
|
||||
use mas_policy::Policy;
|
||||
use mas_storage::{
|
||||
job::{JobRepositoryExt, ProvisionUserJob},
|
||||
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
|
||||
@ -32,7 +33,8 @@ use mas_storage::{
|
||||
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{
|
||||
TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
|
||||
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
|
||||
UpstreamSuggestLink,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
@ -76,6 +78,11 @@ pub(crate) enum RouteError {
|
||||
#[error("Missing username")]
|
||||
MissingUsername,
|
||||
|
||||
#[error("Policy violation: {violations:?}")]
|
||||
PolicyViolation {
|
||||
violations: Vec<mas_policy::Violation>,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
@ -84,6 +91,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
@ -91,6 +99,16 @@ impl IntoResponse for RouteError {
|
||||
sentry::capture_error(&self);
|
||||
match self {
|
||||
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
|
||||
Self::PolicyViolation { violations } => {
|
||||
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
|
||||
let details = details.join("\n");
|
||||
let ctx = ErrorContext::new()
|
||||
.with_description(
|
||||
"Account registration denied because of policy violation".to_owned(),
|
||||
)
|
||||
.with_details(details);
|
||||
FancyError::new(ctx).into_response()
|
||||
}
|
||||
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
}
|
||||
@ -358,6 +376,7 @@ pub(crate) async fn post(
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: CookieJar,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
mut policy: Policy,
|
||||
Path(link_id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<FormData>>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
@ -478,6 +497,16 @@ pub(crate) async fn post(
|
||||
|
||||
let username = username.ok_or(RouteError::MissingUsername)?;
|
||||
|
||||
// Policy check
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&username, email.as_deref())
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::PolicyViolation {
|
||||
violations: res.violations,
|
||||
});
|
||||
}
|
||||
|
||||
// Now we can create the user
|
||||
let user = repo.user().add(&mut rng, &clock, username).await?;
|
||||
|
||||
|
@ -107,12 +107,9 @@ pub(crate) async fn post(
|
||||
let user_email = if let Some(user_email) = existing_user_email {
|
||||
user_email
|
||||
} else {
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
repo.user_email()
|
||||
.add(&mut rng, &clock, &session.user, form.email)
|
||||
.await?;
|
||||
|
||||
user_email
|
||||
.await?
|
||||
};
|
||||
|
||||
// If the email was not confirmed, send a confirmation email & redirect to the
|
||||
|
@ -251,6 +251,31 @@ impl Policy {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.upstream_oauth_register",
|
||||
skip_all,
|
||||
fields(
|
||||
input.registration_method = "password",
|
||||
input.user.username = username,
|
||||
input.user.email = email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_upstream_oauth_register(
|
||||
&mut self,
|
||||
username: &str,
|
||||
email: Option<&str>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = RegisterInput::UpstreamOAuth2 { username, email };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.register, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn evaluate_client_registration(
|
||||
&mut self,
|
||||
|
@ -56,14 +56,23 @@ impl EvaluationResult {
|
||||
|
||||
/// Input for the user registration policy.
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(tag = "registration_method", rename_all = "snake_case")]
|
||||
#[serde(tag = "registration_method")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub enum RegisterInput<'a> {
|
||||
#[serde(rename = "password")]
|
||||
Password {
|
||||
username: &'a str,
|
||||
password: &'a str,
|
||||
email: &'a str,
|
||||
},
|
||||
|
||||
#[serde(rename = "upstream-oauth2")]
|
||||
UpstreamOAuth2 {
|
||||
username: &'a str,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
email: Option<&'a str>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Input for the client registration policy.
|
||||
|
Reference in New Issue
Block a user