diff --git a/Cargo.lock b/Cargo.lock index a6533dce..a8705be1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2763,6 +2763,7 @@ dependencies = [ "lettre", "mas-data-model", "mas-matrix", + "mas-policy", "mas-storage", "oauth2-types", "serde", diff --git a/crates/axum-utils/src/fancy_error.rs b/crates/axum-utils/src/fancy_error.rs index 363f423e..86e0a60f 100644 --- a/crates/axum-utils/src/fancy_error.rs +++ b/crates/axum-utils/src/fancy_error.rs @@ -23,6 +23,12 @@ pub struct FancyError { context: ErrorContext, } +impl FancyError { + pub fn new(context: ErrorContext) -> Self { + Self { context } + } +} + impl std::fmt::Display for FancyError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let code = self.context.code().unwrap_or("Internal error"); diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index f293aa10..b0014252 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -143,7 +143,7 @@ impl Options { // Listen for SIGHUP register_sighup(&templates)?; - let graphql_schema = mas_handlers::graphql_schema(&pool, conn); + let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn); let state = { let mut s = AppState { diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index 39087f6c..88db93e2 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -22,6 +22,7 @@ url.workspace = true oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } mas-matrix = { path = "../matrix" } +mas-policy = { path = "../policy" } mas-storage = { path = "../storage" } [[bin]] diff --git a/crates/graphql/src/mutations/user_email.rs b/crates/graphql/src/mutations/user_email.rs index 2ea7052c..ef8c124f 100644 --- a/crates/graphql/src/mutations/user_email.rs +++ b/crates/graphql/src/mutations/user_email.rs @@ -49,6 +49,8 @@ pub enum AddEmailStatus { Exists, /// The email address is invalid Invalid, + /// The email address is not allowed by the policy + Denied, } /// The payload of the `addEmail` mutation @@ -57,6 +59,9 @@ enum AddEmailPayload { Added(mas_data_model::UserEmail), Exists(mas_data_model::UserEmail), Invalid, + Denied { + violations: Vec, + }, } #[Object(use_type_description)] @@ -67,6 +72,7 @@ impl AddEmailPayload { AddEmailPayload::Added(_) => AddEmailStatus::Added, AddEmailPayload::Exists(_) => AddEmailStatus::Exists, AddEmailPayload::Invalid => AddEmailStatus::Invalid, + AddEmailPayload::Denied { .. } => AddEmailStatus::Denied, } } @@ -76,7 +82,7 @@ impl AddEmailPayload { AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => { Some(UserEmail(email.clone())) } - AddEmailPayload::Invalid => None, + AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None, } } @@ -87,7 +93,7 @@ impl AddEmailPayload { let user_id = match self { AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id, - AddEmailPayload::Invalid => return Ok(None), + AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None), }; let user = repo @@ -98,6 +104,16 @@ impl AddEmailPayload { Ok(Some(User(user))) } + + /// The list of policy violations if the email address was denied + async fn violations(&self) -> Option> { + let AddEmailPayload::Denied { violations } = self else { + return None; + }; + + let messages = violations.iter().map(|v| v.msg.clone()).collect(); + Some(messages) + } } /// The input for the `sendVerificationEmail` mutation @@ -382,6 +398,14 @@ impl UserEmailMutations { return Ok(AddEmailPayload::Invalid); } + let mut policy = state.policy().await?; + let res = policy.evaluate_email(&input.email).await?; + if !res.valid() { + return Ok(AddEmailPayload::Denied { + violations: res.violations, + }); + } + // Find an existing email address let existing_user_email = repo.user_email().find(&user, &input.email).await?; let (added, user_email) = if let Some(user_email) = existing_user_email { diff --git a/crates/graphql/src/state.rs b/crates/graphql/src/state.rs index 90b2e637..441d9a74 100644 --- a/crates/graphql/src/state.rs +++ b/crates/graphql/src/state.rs @@ -13,6 +13,7 @@ // limitations under the License. use mas_matrix::HomeserverConnection; +use mas_policy::Policy; use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; use crate::Requester; @@ -20,6 +21,7 @@ use crate::Requester; #[async_trait::async_trait] pub trait State { async fn repository(&self) -> Result; + async fn policy(&self) -> Result; fn homeserver_connection(&self) -> &dyn HomeserverConnection; fn clock(&self) -> BoxClock; fn rng(&self) -> BoxRng; diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index 772c4bf5..cf031d88 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -17,12 +17,12 @@ use std::{convert::Infallible, sync::Arc, time::Instant}; use axum::{ async_trait, extract::{FromRef, FromRequestParts}, - response::IntoResponse, + response::{IntoResponse, Response}, }; use hyper::StatusCode; use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; use mas_keystore::{Encrypter, Keystore}; -use mas_policy::PolicyFactory; +use mas_policy::{Policy, PolicyFactory}; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; use mas_storage_pg::PgRepository; @@ -33,7 +33,6 @@ use opentelemetry::{ }; use rand::SeedableRng; use sqlx::PgPool; -use thiserror::Error; use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver}; @@ -176,12 +175,6 @@ impl FromRef for MatrixHomeserver { } } -impl FromRef for Arc { - fn from_ref(input: &AppState) -> Self { - input.policy_factory.clone() - } -} - impl FromRef for HttpClientFactory { fn from_ref(input: &AppState) -> Self { input.http_client_factory.clone() @@ -236,19 +229,41 @@ impl FromRequestParts for BoxRng { } } -#[derive(Debug, Error)] -#[error(transparent)] -pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError); +/// A simple wrapper around an error that implements [`IntoResponse`]. +pub struct ErrorWrapper(T); -impl IntoResponse for RepositoryError { - fn into_response(self) -> axum::response::Response { +impl From for ErrorWrapper { + fn from(input: T) -> Self { + Self(input) + } +} + +impl IntoResponse for ErrorWrapper +where + T: std::error::Error, +{ + fn into_response(self) -> Response { + // TODO: make this a bit more user friendly (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() } } +#[async_trait] +impl FromRequestParts for Policy { + type Rejection = ErrorWrapper; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + let policy = state.policy_factory.instantiate().await?; + Ok(policy) + } +} + #[async_trait] impl FromRequestParts for BoxRepository { - type Rejection = RepositoryError; + type Rejection = ErrorWrapper; async fn from_request_parts( _parts: &mut axum::http::request::Parts, diff --git a/crates/handlers/src/graphql/mod.rs b/crates/handlers/src/graphql/mod.rs index 89e97bb8..16fc692c 100644 --- a/crates/handlers/src/graphql/mod.rs +++ b/crates/handlers/src/graphql/mod.rs @@ -31,6 +31,7 @@ use hyper::header::CACHE_CONTROL; use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfo, SessionInfoExt}; use mas_graphql::{Requester, Schema}; use mas_matrix::HomeserverConnection; +use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_storage::{ BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock, }; @@ -48,6 +49,7 @@ mod tests; struct GraphQLState { pool: PgPool, homeserver_connection: Arc>, + policy_factory: Arc, } #[async_trait] @@ -60,6 +62,10 @@ impl mas_graphql::State for GraphQLState { Ok(repo.map_err(RepositoryError::from_error).boxed()) } + async fn policy(&self) -> Result { + self.policy_factory.instantiate().await + } + fn homeserver_connection(&self) -> &dyn HomeserverConnection { self.homeserver_connection.as_ref() } @@ -81,10 +87,12 @@ impl mas_graphql::State for GraphQLState { #[must_use] pub fn schema( pool: &PgPool, + policy_factory: &Arc, homeserver_connection: impl HomeserverConnection + 'static, ) -> Schema { let state = GraphQLState { pool: pool.clone(), + policy_factory: Arc::clone(policy_factory), homeserver_connection: Arc::new(homeserver_connection), }; let state: mas_graphql::BoxState = Box::new(state); diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index dda27c57..cc6554eb 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -30,7 +30,7 @@ clippy::let_with_type_underscore, )] -use std::{convert::Infallible, sync::Arc, time::Duration}; +use std::{convert::Infallible, time::Duration}; use axum::{ body::{Bytes, HttpBody}, @@ -50,7 +50,7 @@ use hyper::{ use mas_axum_utils::{cookies::CookieJar, FancyError}; use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{ErrorContext, NotFoundContext, Templates}; @@ -166,12 +166,12 @@ where S: Clone + Send + Sync + 'static, Keystore: FromRef, UrlBuilder: FromRef, - Arc: FromRef, BoxRepository: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, + Policy: FromRequestParts, { // All those routes are API-like, with a common CORS layer Router::new() @@ -267,7 +267,6 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, - Arc: FromRef, BoxRepository: FromRequestParts, CookieJar: FromRequestParts, Encrypter: FromRef, @@ -278,6 +277,7 @@ where MetadataCache: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, + Policy: FromRequestParts, { Router::new() // XXX: hard-coded redirect from /account to /account/ diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 6c5b7b7f..48b1800c 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{ extract::{Path, State}, response::{Html, IntoResponse, Response}, @@ -22,7 +20,7 @@ use hyper::StatusCode; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt}; use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device}; use mas_keystore::Keystore; -use mas_policy::{EvaluationResult, PolicyFactory}; +use mas_policy::{EvaluationResult, Policy}; use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, @@ -76,7 +74,6 @@ impl IntoResponse for RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(super::callback::CallbackDestinationError); @@ -90,10 +87,10 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError); pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, State(templates): State, State(url_builder): State, State(key_store): State, + policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Path(grant_id): Path, @@ -128,7 +125,7 @@ pub(crate) async fn get( &clock, repo, key_store, - &policy_factory, + policy, url_builder, grant, &client, @@ -187,7 +184,6 @@ pub enum GrantCompletionError { impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); -impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError); @@ -196,7 +192,7 @@ pub(crate) async fn complete( clock: &impl Clock, mut repo: BoxRepository, key_store: Keystore, - policy_factory: &PolicyFactory, + mut policy: Policy, url_builder: UrlBuilder, grant: AuthorizationGrant, client: &Client, @@ -220,7 +216,6 @@ pub(crate) async fn complete( }; // Run through the policy - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_authorization_grant(&grant, client, &browser_session.user) .await?; diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 8fec59fe..1cef1ac4 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{ extract::{Form, State}, response::{Html, IntoResponse, Response}, @@ -22,7 +20,7 @@ use hyper::StatusCode; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt}; use mas_data_model::{AuthorizationCode, Pkce}; use mas_keystore::Keystore; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, @@ -94,7 +92,6 @@ impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); #[derive(Deserialize)] @@ -140,10 +137,10 @@ fn resolve_response_mode( pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, State(templates): State, State(key_store): State, State(url_builder): State, + policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Form(params): Form, @@ -346,7 +343,7 @@ pub(crate) async fn get( &clock, repo, key_store, - &policy_factory, + policy, url_builder, grant, &client, @@ -393,7 +390,7 @@ pub(crate) async fn get( &clock, repo, key_store, - &policy_factory, + policy, url_builder, grant, &client, diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 85acb82f..c448923a 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{ extract::{Form, Path, State}, response::{Html, IntoResponse, Response}, @@ -25,7 +23,7 @@ use mas_axum_utils::{ SessionInfoExt, }; use mas_data_model::{AuthorizationGrantStage, Device}; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, @@ -61,7 +59,6 @@ pub enum RouteError { impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl IntoResponse for RouteError { @@ -80,8 +77,8 @@ impl IntoResponse for RouteError { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, State(templates): State, + mut policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Path(grant_id): Path, @@ -109,7 +106,6 @@ pub(crate) async fn get( if let Some(session) = maybe_session { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_authorization_grant(&grant, &client, &session.user) .await?; @@ -146,7 +142,7 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, + mut policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Path(grant_id): Path, @@ -176,7 +172,6 @@ pub(crate) async fn post( .await? .ok_or(RouteError::NoSuchClient)?; - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_authorization_grant(&grant, &client, &session.user) .await?; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index d2f2c5e3..99cefc7a 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{extract::State, response::IntoResponse, Json}; use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; -use mas_policy::{PolicyFactory, Violation}; +use mas_policy::{Policy, Violation}; use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -49,7 +47,6 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_keystore::aead::Error); @@ -136,7 +133,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, mut repo: BoxRepository, - State(policy_factory): State>, + mut policy: Policy, State(encrypter): State, body: Result, axum::extract::rejection::JsonRejection>, ) -> Result { @@ -148,7 +145,6 @@ pub(crate) async fn post( // Validate the body let metadata = body.validate()?; - let mut policy = policy_factory.instantiate().await?; let res = policy.evaluate_client_registration(&metadata).await?; if !res.valid() { return Err(RouteError::PolicyDenied(res.violations)); diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index ca3a581f..6ed765e1 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -33,10 +33,10 @@ use hyper::{ use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; -use mas_policy::PolicyFactory; +use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage_pg::{DatabaseError, PgRepository}; use mas_templates::Templates; use rand::SeedableRng; use rand_chacha::ChaChaRng; @@ -46,7 +46,7 @@ use tower::{Layer, Service, ServiceExt}; use url::Url; use crate::{ - app_state::RepositoryError, + app_state::ErrorWrapper, passwords::{Hasher, PasswordManager}, upstream_oauth2::cache::MetadataCache, MatrixHomeserver, @@ -138,6 +138,7 @@ impl TestState { let graphql_state = TestGraphQLState { pool: pool.clone(), + policy_factory: Arc::clone(&policy_factory), homeserver_connection, rng: Arc::clone(&rng), clock: Arc::clone(&clock), @@ -202,7 +203,7 @@ impl TestState { Response::from_parts(parts, body) } - pub async fn repository(&self) -> Result { + pub async fn repository(&self) -> Result { let repo = PgRepository::from_pool(&self.pool).await?; Ok(repo .map_err(mas_storage::RepositoryError::from_error) @@ -243,6 +244,7 @@ impl TestState { struct TestGraphQLState { pool: PgPool, homeserver_connection: MockHomeserverConnection, + policy_factory: Arc, clock: Arc, rng: Arc>, } @@ -259,6 +261,10 @@ impl mas_graphql::State for TestGraphQLState { .boxed()) } + async fn policy(&self) -> Result { + self.policy_factory.instantiate().await + } + fn homeserver_connection(&self) -> &dyn HomeserverConnection { &self.homeserver_connection } @@ -316,12 +322,6 @@ impl FromRef for MatrixHomeserver { } } -impl FromRef for Arc { - fn from_ref(input: &TestState) -> Self { - input.policy_factory.clone() - } -} - impl FromRef for HttpClientFactory { fn from_ref(input: &TestState) -> Self { input.http_client_factory.clone() @@ -374,7 +374,7 @@ impl FromRequestParts for BoxRng { #[async_trait] impl FromRequestParts for BoxRepository { - type Rejection = RepositoryError; + type Rejection = ErrorWrapper; async fn from_request_parts( _parts: &mut axum::http::request::Parts, @@ -387,6 +387,19 @@ impl FromRequestParts for BoxRepository { } } +#[async_trait] +impl FromRequestParts for Policy { + type Rejection = ErrorWrapper; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &TestState, + ) -> Result { + let policy = state.policy_factory.instantiate().await?; + Ok(policy) + } +} + pub(crate) trait RequestBuilderExt { /// Builds the request with the given JSON value as body. fn json(self, body: T) -> hyper::Request; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 49036010..e4577d74 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -21,13 +21,14 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, FancyError, SessionInfoExt, }; +use mas_policy::Policy; use mas_router::Route; use mas_storage::{ job::{JobRepositoryExt, VerifyEmailJob}, user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, }; -use mas_templates::{EmailAddContext, TemplateContext, Templates}; +use mas_templates::{EmailAddContext, ErrorContext, TemplateContext, Templates}; use serde::Deserialize; use crate::views::shared::OptionalPostAuthAction; @@ -69,6 +70,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, mut repo: BoxRepository, + mut policy: Policy, cookie_jar: CookieJar, Query(query): Query, Form(form): Form>, @@ -83,23 +85,56 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = repo - .user_email() - .add(&mut rng, &clock, &session.user, form.email) - .await?; + // XXX: we really should show human readable errors on the form here - let next = mas_router::AccountVerifyEmail::new(user_email.id); - let next = if let Some(action) = query.post_auth_action { - next.and_then(action) + // Validate the email address + if form.email.parse::().is_err() { + return Err(anyhow::anyhow!("Invalid email address").into()); + } + + // Run the email policy + let res = policy.evaluate_email(&form.email).await?; + if !res.valid() { + return Err(FancyError::new( + ErrorContext::new() + .with_description(format!("Email address {:?} denied by policy", form.email)) + .with_details(format!("{res}")), + )); + } + + // Find an existing email address + let existing_user_email = repo.user_email().find(&session.user, &form.email).await?; + let user_email = if let Some(user_email) = existing_user_email { + user_email } else { - next + let user_email = repo + .user_email() + .add(&mut rng, &clock, &session.user, form.email) + .await?; + + user_email }; - repo.job() - .schedule_job(VerifyEmailJob::new(&user_email)) - .await?; + // If the email was not confirmed, send a confirmation email & redirect to the + // verify page + let next = if user_email.confirmed_at.is_none() { + repo.job() + .schedule_job(VerifyEmailJob::new(&user_email)) + .await?; + + let next = mas_router::AccountVerifyEmail::new(user_email.id); + let next = if let Some(action) = query.post_auth_action { + next.and_then(action) + } else { + next + }; + + next.go() + } else { + query.go_next_or_default(&mas_router::Account) + }; repo.save().await?; - Ok((cookie_jar, next.go()).into_response()) + Ok((cookie_jar, next).into_response()) } diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 4b2544a1..f21b5743 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -24,6 +24,7 @@ use mas_axum_utils::{ FancyError, SessionInfoExt, }; use mas_data_model::BrowserSession; +use mas_policy::Policy; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, @@ -93,6 +94,7 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, + mut policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Form(form): Form>, @@ -119,6 +121,13 @@ pub(crate) async fn post( .await? .context("user has no password")?; + let res = policy.evaluate_password(&form.new_password).await?; + + // TODO: display nice form errors + if !res.valid() { + return Err(anyhow::anyhow!("Password policy violation: {res}").into()); + } + let password = Zeroizing::new(form.current_password.into_bytes()); let new_password = Zeroizing::new(form.new_password.into_bytes()); let new_password_confirm = Zeroizing::new(form.new_password_confirm.into_bytes()); @@ -133,7 +142,7 @@ pub(crate) async fn post( // TODO: display nice form errors if new_password != new_password_confirm { - return Err(anyhow::anyhow!("password mismatch").into()); + return Err(anyhow::anyhow!("Password mismatch").into()); } let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 3e32bad1..070bb954 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{str::FromStr, sync::Arc}; +use std::str::FromStr; use axum::{ extract::{Form, Query, State}, @@ -27,7 +27,7 @@ use mas_axum_utils::{ csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::Route; use mas_storage::{ job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, @@ -101,8 +101,8 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - State(policy_factory): State>, State(templates): State, + mut policy: Policy, mut repo: BoxRepository, Query(query): Query, cookie_jar: CookieJar, @@ -148,7 +148,6 @@ pub(crate) async fn post( state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified); } - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_register(&form.username, &form.password, &form.email) .await?; diff --git a/crates/policy/src/bin/schema.rs b/crates/policy/src/bin/schema.rs index 7742a28a..53547db6 100644 --- a/crates/policy/src/bin/schema.rs +++ b/crates/policy/src/bin/schema.rs @@ -43,6 +43,8 @@ fn write_schema(out_dir: Option<&Path>, file: &str) { writer.flush().expect("Failed to flush writer"); } +/// Write the input schemas to the output directory. +/// They are then used in rego files to type check the input. fn main() { let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok(); let output_root = output_root.as_deref(); diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 3cc9ff1f..65c4a005 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -16,6 +16,7 @@ use mas_data_model::{AuthorizationGrant, Client, User}; use oauth2_types::registration::VerifiedClientMetadata; use serde::{Deserialize, Serialize}; +/// A single violation of a policy. #[derive(Deserialize, Debug)] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] pub struct Violation { @@ -23,19 +24,37 @@ pub struct Violation { pub field: Option, } +/// The result of a policy evaluation. #[derive(Deserialize, Debug)] pub struct EvaluationResult { #[serde(rename = "result")] pub violations: Vec, } +impl std::fmt::Display for EvaluationResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut first = true; + for violation in &self.violations { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{}", violation.msg)?; + } + Ok(()) + } +} + impl EvaluationResult { + /// Returns true if the policy evaluation was successful. #[must_use] pub fn valid(&self) -> bool { self.violations.is_empty() } } +/// Input for the user registration policy. #[derive(Serialize, Debug)] #[serde(tag = "registration_method", rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -47,6 +66,7 @@ pub enum RegisterInput<'a> { }, } +/// Input for the client registration policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -58,6 +78,7 @@ pub struct ClientRegistrationInput<'a> { pub client_metadata: &'a VerifiedClientMetadata, } +/// Input for the authorization grant policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -81,6 +102,7 @@ pub struct AuthorizationGrantInput<'a> { pub authorization_grant: &'a AuthorizationGrant, } +/// Input for the email add policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -88,6 +110,7 @@ pub struct EmailInput<'a> { pub email: &'a str, } +/// Input for the password set policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] diff --git a/docs/config.schema.json b/docs/config.schema.json index 835163db..80663733 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -147,6 +147,8 @@ "authorization_grant_entrypoint": "authorization_grant/violation", "client_registration_entrypoint": "client_registration/violation", "data": null, + "email_entrypoint": "email/violation", + "password_entrypoint": "password/violation", "register_entrypoint": "register/violation", "wasm_module": "./policies/policy.wasm" }, @@ -1349,6 +1351,16 @@ "description": "Arbitrary data to pass to the policy", "default": null }, + "email_entrypoint": { + "description": "Entrypoint to use when adding an email address", + "default": "email/violation", + "type": "string" + }, + "password_entrypoint": { + "description": "Entrypoint to use when changing password", + "default": "password/violation", + "type": "string" + }, "register_entrypoint": { "description": "Entrypoint to use when evaluating user registrations", "default": "register/violation", diff --git a/frontend/schema.graphql b/frontend/schema.graphql index 821534e3..5a89768e 100644 --- a/frontend/schema.graphql +++ b/frontend/schema.graphql @@ -28,6 +28,10 @@ type AddEmailPayload { The user to whom the email address was added """ user: User + """ + The list of policy violations if the email address was denied + """ + violations: [String!] } """ @@ -46,6 +50,10 @@ enum AddEmailStatus { The email address is invalid """ INVALID + """ + The email address is not allowed by the policy + """ + DENIED } type Anonymous implements Node { diff --git a/frontend/src/components/UserProfile/AddEmailForm.tsx b/frontend/src/components/UserProfile/AddEmailForm.tsx index dabad326..315220fd 100644 --- a/frontend/src/components/UserProfile/AddEmailForm.tsx +++ b/frontend/src/components/UserProfile/AddEmailForm.tsx @@ -30,6 +30,7 @@ const ADD_EMAIL_MUTATION = graphql(/* GraphQL */ ` mutation AddEmail($userId: ID!, $email: String!) { addEmail(input: { userId: $userId, email: $email }) { status + violations email { id ...UserEmail_email @@ -79,6 +80,8 @@ const AddEmailForm: React.FC<{ const status = addEmailResult.data?.addEmail.status ?? null; const emailExists = status === "EXISTS"; const emailInvalid = status === "INVALID"; + const emailDenied = status === "DENIED"; + const violations = addEmailResult.data?.addEmail.violations ?? []; return ( <> @@ -95,6 +98,17 @@ const AddEmailForm: React.FC<{ )} + {emailDenied && ( + + The entered email is not allowed by the server policy. +
    + {violations.map((violation, index) => ( +
  • • {violation}
  • + ))} +
+
+ )} + diff --git a/frontend/src/gql/gql.ts b/frontend/src/gql/gql.ts index 009714c1..199583a9 100644 --- a/frontend/src/gql/gql.ts +++ b/frontend/src/gql/gql.ts @@ -47,7 +47,7 @@ const documents = { types.UserGreetingDocument, "\n fragment UserHome_user on User {\n id\n\n primaryEmail {\n id\n ...UserEmail_email\n }\n\n confirmedEmails: emails(first: 0, state: CONFIRMED) {\n totalCount\n }\n\n unverifiedEmails: emails(first: 0, state: PENDING) {\n totalCount\n }\n\n browserSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n oauth2Sessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n compatSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n }\n": types.UserHome_UserFragmentDoc, - "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n": + "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n": types.AddEmailDocument, "\n query UserEmailListQuery(\n $userId: ID!\n $first: Int\n $after: String\n $last: Int\n $before: String\n ) {\n user(id: $userId) {\n id\n\n emails(first: $first, after: $after, last: $last, before: $before) {\n edges {\n cursor\n node {\n id\n ...UserEmail_email\n }\n }\n totalCount\n pageInfo {\n hasNextPage\n hasPreviousPage\n startCursor\n endCursor\n }\n }\n }\n }\n": types.UserEmailListQueryDocument, @@ -191,8 +191,8 @@ export function graphql( * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. */ export function graphql( - source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n", -): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"]; + source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n", +): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"]; /** * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. */ diff --git a/frontend/src/gql/graphql.ts b/frontend/src/gql/graphql.ts index 2b39fb92..96748bed 100644 --- a/frontend/src/gql/graphql.ts +++ b/frontend/src/gql/graphql.ts @@ -54,12 +54,16 @@ export type AddEmailPayload = { status: AddEmailStatus; /** The user to whom the email address was added */ user?: Maybe; + /** The list of policy violations if the email address was denied */ + violations?: Maybe>; }; /** The status of the `addEmail` mutation */ export enum AddEmailStatus { /** The email address was added */ Added = "ADDED", + /** The email address is not allowed by the policy */ + Denied = "DENIED", /** The email address already exists */ Exists = "EXISTS", /** The email address is invalid */ @@ -1231,6 +1235,7 @@ export type AddEmailMutation = { addEmail: { __typename?: "AddEmailPayload"; status: AddEmailStatus; + violations?: Array | null; email?: | ({ __typename?: "UserEmail"; id: string } & { " $fragmentRefs"?: { @@ -3129,6 +3134,7 @@ export const AddEmailDocument = { kind: "SelectionSet", selections: [ { kind: "Field", name: { kind: "Name", value: "status" } }, + { kind: "Field", name: { kind: "Name", value: "violations" } }, { kind: "Field", name: { kind: "Name", value: "email" }, diff --git a/frontend/src/gql/schema.ts b/frontend/src/gql/schema.ts index e5d546ab..c27b9bd3 100644 --- a/frontend/src/gql/schema.ts +++ b/frontend/src/gql/schema.ts @@ -42,6 +42,20 @@ export default { }, args: [], }, + { + name: "violations", + type: { + kind: "LIST", + ofType: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + }, + args: [], + }, ], interfaces: [], }, diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json index a1a49a8d..afd230c4 100644 --- a/policies/schema/authorization_grant_input.json +++ b/policies/schema/authorization_grant_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "AuthorizationGrantInput", + "description": "Input for the authorization grant policy.", "type": "object", "required": [ "authorization_grant", diff --git a/policies/schema/client_registration_input.json b/policies/schema/client_registration_input.json index 7261068e..cc9957a8 100644 --- a/policies/schema/client_registration_input.json +++ b/policies/schema/client_registration_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "ClientRegistrationInput", + "description": "Input for the client registration policy.", "type": "object", "required": [ "client_metadata" diff --git a/policies/schema/email_input.json b/policies/schema/email_input.json index 487eb4b9..19f4af52 100644 --- a/policies/schema/email_input.json +++ b/policies/schema/email_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "EmailInput", + "description": "Input for the email add policy.", "type": "object", "required": [ "email" diff --git a/policies/schema/password_input.json b/policies/schema/password_input.json index d85b2862..c3cbf92d 100644 --- a/policies/schema/password_input.json +++ b/policies/schema/password_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "PasswordInput", + "description": "Input for the password set policy.", "type": "object", "required": [ "password" diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json index d77ce66e..1f1585aa 100644 --- a/policies/schema/register_input.json +++ b/policies/schema/register_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "RegisterInput", + "description": "Input for the user registration policy.", "oneOf": [ { "type": "object",