1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-07 22:41:18 +03:00

Run the registration policy on upstream OAuth registration

This commit is contained in:
Quentin Gliech
2023-08-30 18:36:53 +02:00
parent 7fcd022eea
commit 23571e87ea
14 changed files with 207 additions and 41 deletions

View File

@ -24,6 +24,7 @@ pub struct FancyError {
}
impl FancyError {
#[must_use]
pub fn new(context: ErrorContext) -> Self {
Self { context }
}

View File

@ -21,10 +21,11 @@ use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
SessionInfoExt,
FancyError, SessionInfoExt,
};
use mas_data_model::{UpstreamOAuthProviderImportPreference, User};
use mas_jose::jwt::Jwt;
use mas_policy::Policy;
use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
@ -32,7 +33,8 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_templates::{
TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
};
use serde::Deserialize;
use thiserror::Error;
@ -76,6 +78,11 @@ pub(crate) enum RouteError {
#[error("Missing username")]
MissingUsername,
#[error("Policy violation: {violations:?}")]
PolicyViolation {
violations: Vec<mas_policy::Violation>,
},
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
@ -84,6 +91,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
impl IntoResponse for RouteError {
@ -91,6 +99,16 @@ impl IntoResponse for RouteError {
sentry::capture_error(&self);
match self {
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
Self::PolicyViolation { violations } => {
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
let details = details.join("\n");
let ctx = ErrorContext::new()
.with_description(
"Account registration denied because of policy violation".to_owned(),
)
.with_details(details);
FancyError::new(ctx).into_response()
}
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
@ -358,6 +376,7 @@ pub(crate) async fn post(
mut repo: BoxRepository,
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
mut policy: Policy,
Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> {
@ -478,6 +497,16 @@ pub(crate) async fn post(
let username = username.ok_or(RouteError::MissingUsername)?;
// Policy check
let res = policy
.evaluate_upstream_oauth_register(&username, email.as_deref())
.await?;
if !res.valid() {
return Err(RouteError::PolicyViolation {
violations: res.violations,
});
}
// Now we can create the user
let user = repo.user().add(&mut rng, &clock, username).await?;

View File

@ -107,12 +107,9 @@ pub(crate) async fn post(
let user_email = if let Some(user_email) = existing_user_email {
user_email
} else {
let user_email = repo
.user_email()
repo.user_email()
.add(&mut rng, &clock, &session.user, form.email)
.await?;
user_email
.await?
};
// If the email was not confirmed, send a confirmation email & redirect to the

View File

@ -251,6 +251,31 @@ impl Policy {
Ok(res)
}
#[tracing::instrument(
name = "policy.evaluate.upstream_oauth_register",
skip_all,
fields(
input.registration_method = "password",
input.user.username = username,
input.user.email = email,
),
err,
)]
pub async fn evaluate_upstream_oauth_register(
&mut self,
username: &str,
email: Option<&str>,
) -> Result<EvaluationResult, EvaluationError> {
let input = RegisterInput::UpstreamOAuth2 { username, email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.register, &input)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self))]
pub async fn evaluate_client_registration(
&mut self,

View File

@ -56,14 +56,23 @@ impl EvaluationResult {
/// Input for the user registration policy.
#[derive(Serialize, Debug)]
#[serde(tag = "registration_method", rename_all = "snake_case")]
#[serde(tag = "registration_method")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub enum RegisterInput<'a> {
#[serde(rename = "password")]
Password {
username: &'a str,
password: &'a str,
email: &'a str,
},
#[serde(rename = "upstream-oauth2")]
UpstreamOAuth2 {
username: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
email: Option<&'a str>,
},
}
/// Input for the client registration policy.