1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Make sure we validate passwords & emails by the policy at all stages

Also refactors the way we get the policy engines in requests
This commit is contained in:
Quentin Gliech
2023-08-30 16:47:57 +02:00
parent 23151ef092
commit 7fcd022eea
30 changed files with 264 additions and 84 deletions

1
Cargo.lock generated
View File

@@ -2763,6 +2763,7 @@ dependencies = [
"lettre", "lettre",
"mas-data-model", "mas-data-model",
"mas-matrix", "mas-matrix",
"mas-policy",
"mas-storage", "mas-storage",
"oauth2-types", "oauth2-types",
"serde", "serde",

View File

@@ -23,6 +23,12 @@ pub struct FancyError {
context: ErrorContext, context: ErrorContext,
} }
impl FancyError {
pub fn new(context: ErrorContext) -> Self {
Self { context }
}
}
impl std::fmt::Display for FancyError { impl std::fmt::Display for FancyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let code = self.context.code().unwrap_or("Internal error"); let code = self.context.code().unwrap_or("Internal error");

View File

@@ -143,7 +143,7 @@ impl Options {
// Listen for SIGHUP // Listen for SIGHUP
register_sighup(&templates)?; register_sighup(&templates)?;
let graphql_schema = mas_handlers::graphql_schema(&pool, conn); let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn);
let state = { let state = {
let mut s = AppState { let mut s = AppState {

View File

@@ -22,6 +22,7 @@ url.workspace = true
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-matrix = { path = "../matrix" } mas-matrix = { path = "../matrix" }
mas-policy = { path = "../policy" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
[[bin]] [[bin]]

View File

@@ -49,6 +49,8 @@ pub enum AddEmailStatus {
Exists, Exists,
/// The email address is invalid /// The email address is invalid
Invalid, Invalid,
/// The email address is not allowed by the policy
Denied,
} }
/// The payload of the `addEmail` mutation /// The payload of the `addEmail` mutation
@@ -57,6 +59,9 @@ enum AddEmailPayload {
Added(mas_data_model::UserEmail), Added(mas_data_model::UserEmail),
Exists(mas_data_model::UserEmail), Exists(mas_data_model::UserEmail),
Invalid, Invalid,
Denied {
violations: Vec<mas_policy::Violation>,
},
} }
#[Object(use_type_description)] #[Object(use_type_description)]
@@ -67,6 +72,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(_) => AddEmailStatus::Added, AddEmailPayload::Added(_) => AddEmailStatus::Added,
AddEmailPayload::Exists(_) => AddEmailStatus::Exists, AddEmailPayload::Exists(_) => AddEmailStatus::Exists,
AddEmailPayload::Invalid => AddEmailStatus::Invalid, AddEmailPayload::Invalid => AddEmailStatus::Invalid,
AddEmailPayload::Denied { .. } => AddEmailStatus::Denied,
} }
} }
@@ -76,7 +82,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => { AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => {
Some(UserEmail(email.clone())) Some(UserEmail(email.clone()))
} }
AddEmailPayload::Invalid => None, AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None,
} }
} }
@@ -87,7 +93,7 @@ impl AddEmailPayload {
let user_id = match self { let user_id = match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id, AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id,
AddEmailPayload::Invalid => return Ok(None), AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None),
}; };
let user = repo let user = repo
@@ -98,6 +104,16 @@ impl AddEmailPayload {
Ok(Some(User(user))) Ok(Some(User(user)))
} }
/// The list of policy violations if the email address was denied
async fn violations(&self) -> Option<Vec<String>> {
let AddEmailPayload::Denied { violations } = self else {
return None;
};
let messages = violations.iter().map(|v| v.msg.clone()).collect();
Some(messages)
}
} }
/// The input for the `sendVerificationEmail` mutation /// The input for the `sendVerificationEmail` mutation
@@ -382,6 +398,14 @@ impl UserEmailMutations {
return Ok(AddEmailPayload::Invalid); return Ok(AddEmailPayload::Invalid);
} }
let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
});
}
// Find an existing email address // Find an existing email address
let existing_user_email = repo.user_email().find(&user, &input.email).await?; let existing_user_email = repo.user_email().find(&user, &input.email).await?;
let (added, user_email) = if let Some(user_email) = existing_user_email { let (added, user_email) = if let Some(user_email) = existing_user_email {

View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
use crate::Requester; use crate::Requester;
@@ -20,6 +21,7 @@ use crate::Requester;
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait State { pub trait State {
async fn repository(&self) -> Result<BoxRepository, RepositoryError>; async fn repository(&self) -> Result<BoxRepository, RepositoryError>;
async fn policy(&self) -> Result<Policy, mas_policy::InstantiateError>;
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error>; fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error>;
fn clock(&self) -> BoxClock; fn clock(&self) -> BoxClock;
fn rng(&self) -> BoxRng; fn rng(&self) -> BoxRng;

View File

@@ -17,12 +17,12 @@ use std::{convert::Infallible, sync::Arc, time::Instant};
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRef, FromRequestParts}, extract::{FromRef, FromRequestParts},
response::IntoResponse, response::{IntoResponse, Response},
}; };
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory; use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
@@ -33,7 +33,6 @@ use opentelemetry::{
}; };
use rand::SeedableRng; use rand::SeedableRng;
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error;
use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver}; use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver};
@@ -176,12 +175,6 @@ impl FromRef<AppState> for MatrixHomeserver {
} }
} }
impl FromRef<AppState> for Arc<PolicyFactory> {
fn from_ref(input: &AppState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<AppState> for HttpClientFactory { impl FromRef<AppState> for HttpClientFactory {
fn from_ref(input: &AppState) -> Self { fn from_ref(input: &AppState) -> Self {
input.http_client_factory.clone() input.http_client_factory.clone()
@@ -236,19 +229,41 @@ impl FromRequestParts<AppState> for BoxRng {
} }
} }
#[derive(Debug, Error)] /// A simple wrapper around an error that implements [`IntoResponse`].
#[error(transparent)] pub struct ErrorWrapper<T>(T);
pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError);
impl IntoResponse for RepositoryError { impl<T> From<T> for ErrorWrapper<T> {
fn into_response(self) -> axum::response::Response { fn from(input: T) -> Self {
Self(input)
}
}
impl<T> IntoResponse for ErrorWrapper<T>
where
T: std::error::Error,
{
fn into_response(self) -> Response {
// TODO: make this a bit more user friendly
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
} }
} }
#[async_trait]
impl FromRequestParts<AppState> for Policy {
type Rejection = ErrorWrapper<mas_policy::InstantiateError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let policy = state.policy_factory.instantiate().await?;
Ok(policy)
}
}
#[async_trait] #[async_trait]
impl FromRequestParts<AppState> for BoxRepository { impl FromRequestParts<AppState> for BoxRepository {
type Rejection = RepositoryError; type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
async fn from_request_parts( async fn from_request_parts(
_parts: &mut axum::http::request::Parts, _parts: &mut axum::http::request::Parts,

View File

@@ -31,6 +31,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfo, SessionInfoExt}; use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfo, SessionInfoExt};
use mas_graphql::{Requester, Schema}; use mas_graphql::{Requester, Schema};
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_storage::{ use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock, BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock,
}; };
@@ -48,6 +49,7 @@ mod tests;
struct GraphQLState { struct GraphQLState {
pool: PgPool, pool: PgPool,
homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>, homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
policy_factory: Arc<PolicyFactory>,
} }
#[async_trait] #[async_trait]
@@ -60,6 +62,10 @@ impl mas_graphql::State for GraphQLState {
Ok(repo.map_err(RepositoryError::from_error).boxed()) Ok(repo.map_err(RepositoryError::from_error).boxed())
} }
async fn policy(&self) -> Result<Policy, InstantiateError> {
self.policy_factory.instantiate().await
}
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> { fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
self.homeserver_connection.as_ref() self.homeserver_connection.as_ref()
} }
@@ -81,10 +87,12 @@ impl mas_graphql::State for GraphQLState {
#[must_use] #[must_use]
pub fn schema( pub fn schema(
pool: &PgPool, pool: &PgPool,
policy_factory: &Arc<PolicyFactory>,
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static, homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Schema { ) -> Schema {
let state = GraphQLState { let state = GraphQLState {
pool: pool.clone(), pool: pool.clone(),
policy_factory: Arc::clone(policy_factory),
homeserver_connection: Arc::new(homeserver_connection), homeserver_connection: Arc::new(homeserver_connection),
}; };
let state: mas_graphql::BoxState = Box::new(state); let state: mas_graphql::BoxState = Box::new(state);

View File

@@ -30,7 +30,7 @@
clippy::let_with_type_underscore, clippy::let_with_type_underscore,
)] )]
use std::{convert::Infallible, sync::Arc, time::Duration}; use std::{convert::Infallible, time::Duration};
use axum::{ use axum::{
body::{Bytes, HttpBody}, body::{Bytes, HttpBody},
@@ -50,7 +50,7 @@ use hyper::{
use mas_axum_utils::{cookies::CookieJar, FancyError}; use mas_axum_utils::{cookies::CookieJar, FancyError};
use mas_http::CorsLayerExt; use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory; use mas_policy::Policy;
use mas_router::{Route, UrlBuilder}; use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{ErrorContext, NotFoundContext, Templates}; use mas_templates::{ErrorContext, NotFoundContext, Templates};
@@ -166,12 +166,12 @@ where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>, Keystore: FromRef<S>,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>, BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>, Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>, HttpClientFactory: FromRef<S>,
BoxClock: FromRequestParts<S>, BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>, BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{ {
// All those routes are API-like, with a common CORS layer // All those routes are API-like, with a common CORS layer
Router::new() Router::new()
@@ -267,7 +267,6 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync, <B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>, BoxRepository: FromRequestParts<S>,
CookieJar: FromRequestParts<S>, CookieJar: FromRequestParts<S>,
Encrypter: FromRef<S>, Encrypter: FromRef<S>,
@@ -278,6 +277,7 @@ where
MetadataCache: FromRef<S>, MetadataCache: FromRef<S>,
BoxClock: FromRequestParts<S>, BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>, BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{ {
Router::new() Router::new()
// XXX: hard-coded redirect from /account to /account/ // XXX: hard-coded redirect from /account to /account/

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, State},
response::{Html, IntoResponse, Response}, response::{Html, IntoResponse, Response},
@@ -22,7 +20,7 @@ use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt}; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt};
use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device}; use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device};
use mas_keystore::Keystore; use mas_keystore::Keystore;
use mas_policy::{EvaluationResult, PolicyFactory}; use mas_policy::{EvaluationResult, Policy};
use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
@@ -76,7 +74,6 @@ impl IntoResponse for RouteError {
impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(super::callback::CallbackDestinationError); impl_from_error_for_route!(super::callback::CallbackDestinationError);
@@ -90,10 +87,10 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);
pub(crate) async fn get( pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(key_store): State<Keystore>, State(key_store): State<Keystore>,
policy: Policy,
mut repo: BoxRepository, mut repo: BoxRepository,
cookie_jar: CookieJar, cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
@@ -128,7 +125,7 @@ pub(crate) async fn get(
&clock, &clock,
repo, repo,
key_store, key_store,
&policy_factory, policy,
url_builder, url_builder,
grant, grant,
&client, &client,
@@ -187,7 +184,6 @@ pub enum GrantCompletionError {
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError); impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError);
@@ -196,7 +192,7 @@ pub(crate) async fn complete(
clock: &impl Clock, clock: &impl Clock,
mut repo: BoxRepository, mut repo: BoxRepository,
key_store: Keystore, key_store: Keystore,
policy_factory: &PolicyFactory, mut policy: Policy,
url_builder: UrlBuilder, url_builder: UrlBuilder,
grant: AuthorizationGrant, grant: AuthorizationGrant,
client: &Client, client: &Client,
@@ -220,7 +216,6 @@ pub(crate) async fn complete(
}; };
// Run through the policy // Run through the policy
let mut policy = policy_factory.instantiate().await?;
let res = policy let res = policy
.evaluate_authorization_grant(&grant, client, &browser_session.user) .evaluate_authorization_grant(&grant, client, &browser_session.user)
.await?; .await?;

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use axum::{ use axum::{
extract::{Form, State}, extract::{Form, State},
response::{Html, IntoResponse, Response}, response::{Html, IntoResponse, Response},
@@ -22,7 +20,7 @@ use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt}; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt};
use mas_data_model::{AuthorizationCode, Pkce}; use mas_data_model::{AuthorizationCode, Pkce};
use mas_keystore::Keystore; use mas_keystore::Keystore;
use mas_policy::PolicyFactory; use mas_policy::Policy;
use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
@@ -94,7 +92,6 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_policy::EvaluationError);
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -140,10 +137,10 @@ fn resolve_response_mode(
pub(crate) async fn get( pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
State(key_store): State<Keystore>, State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
policy: Policy,
mut repo: BoxRepository, mut repo: BoxRepository,
cookie_jar: CookieJar, cookie_jar: CookieJar,
Form(params): Form<Params>, Form(params): Form<Params>,
@@ -346,7 +343,7 @@ pub(crate) async fn get(
&clock, &clock,
repo, repo,
key_store, key_store,
&policy_factory, policy,
url_builder, url_builder,
grant, grant,
&client, &client,
@@ -393,7 +390,7 @@ pub(crate) async fn get(
&clock, &clock,
repo, repo,
key_store, key_store,
&policy_factory, policy,
url_builder, url_builder,
grant, grant,
&client, &client,

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use axum::{ use axum::{
extract::{Form, Path, State}, extract::{Form, Path, State},
response::{Html, IntoResponse, Response}, response::{Html, IntoResponse, Response},
@@ -25,7 +23,7 @@ use mas_axum_utils::{
SessionInfoExt, SessionInfoExt,
}; };
use mas_data_model::{AuthorizationGrantStage, Device}; use mas_data_model::{AuthorizationGrantStage, Device};
use mas_policy::PolicyFactory; use mas_policy::Policy;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
@@ -61,7 +59,6 @@ pub enum RouteError {
impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_policy::EvaluationError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
@@ -80,8 +77,8 @@ impl IntoResponse for RouteError {
pub(crate) async fn get( pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository, mut repo: BoxRepository,
cookie_jar: CookieJar, cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
@@ -109,7 +106,6 @@ pub(crate) async fn get(
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let mut policy = policy_factory.instantiate().await?;
let res = policy let res = policy
.evaluate_authorization_grant(&grant, &client, &session.user) .evaluate_authorization_grant(&grant, &client, &session.user)
.await?; .await?;
@@ -146,7 +142,7 @@ pub(crate) async fn get(
pub(crate) async fn post( pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>, mut policy: Policy,
mut repo: BoxRepository, mut repo: BoxRepository,
cookie_jar: CookieJar, cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
@@ -176,7 +172,6 @@ pub(crate) async fn post(
.await? .await?
.ok_or(RouteError::NoSuchClient)?; .ok_or(RouteError::NoSuchClient)?;
let mut policy = policy_factory.instantiate().await?;
let res = policy let res = policy
.evaluate_authorization_grant(&grant, &client, &session.user) .evaluate_authorization_grant(&grant, &client, &session.user)
.await?; .await?;

View File

@@ -12,13 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use hyper::StatusCode; use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation}; use mas_policy::{Policy, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng}; use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng};
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
@@ -49,7 +47,6 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_keystore::aead::Error); impl_from_error_for_route!(mas_keystore::aead::Error);
@@ -136,7 +133,7 @@ pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: BoxRepository, mut repo: BoxRepository,
State(policy_factory): State<Arc<PolicyFactory>>, mut policy: Policy,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>, body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
@@ -148,7 +145,6 @@ pub(crate) async fn post(
// Validate the body // Validate the body
let metadata = body.validate()?; let metadata = body.validate()?;
let mut policy = policy_factory.instantiate().await?;
let res = policy.evaluate_client_registration(&metadata).await?; let res = policy.evaluate_client_registration(&metadata).await?;
if !res.valid() { if !res.valid() {
return Err(RouteError::PolicyDenied(res.violations)); return Err(RouteError::PolicyDenied(res.violations));

View File

@@ -33,10 +33,10 @@ use hyper::{
use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; use mas_matrix::{HomeserverConnection, MockHomeserverConnection};
use mas_policy::PolicyFactory; use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::{SimpleRoute, UrlBuilder}; use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
use mas_storage_pg::PgRepository; use mas_storage_pg::{DatabaseError, PgRepository};
use mas_templates::Templates; use mas_templates::Templates;
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;
@@ -46,7 +46,7 @@ use tower::{Layer, Service, ServiceExt};
use url::Url; use url::Url;
use crate::{ use crate::{
app_state::RepositoryError, app_state::ErrorWrapper,
passwords::{Hasher, PasswordManager}, passwords::{Hasher, PasswordManager},
upstream_oauth2::cache::MetadataCache, upstream_oauth2::cache::MetadataCache,
MatrixHomeserver, MatrixHomeserver,
@@ -138,6 +138,7 @@ impl TestState {
let graphql_state = TestGraphQLState { let graphql_state = TestGraphQLState {
pool: pool.clone(), pool: pool.clone(),
policy_factory: Arc::clone(&policy_factory),
homeserver_connection, homeserver_connection,
rng: Arc::clone(&rng), rng: Arc::clone(&rng),
clock: Arc::clone(&clock), clock: Arc::clone(&clock),
@@ -202,7 +203,7 @@ impl TestState {
Response::from_parts(parts, body) Response::from_parts(parts, body)
} }
pub async fn repository(&self) -> Result<BoxRepository, RepositoryError> { pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
let repo = PgRepository::from_pool(&self.pool).await?; let repo = PgRepository::from_pool(&self.pool).await?;
Ok(repo Ok(repo
.map_err(mas_storage::RepositoryError::from_error) .map_err(mas_storage::RepositoryError::from_error)
@@ -243,6 +244,7 @@ impl TestState {
struct TestGraphQLState { struct TestGraphQLState {
pool: PgPool, pool: PgPool,
homeserver_connection: MockHomeserverConnection, homeserver_connection: MockHomeserverConnection,
policy_factory: Arc<PolicyFactory>,
clock: Arc<MockClock>, clock: Arc<MockClock>,
rng: Arc<Mutex<ChaChaRng>>, rng: Arc<Mutex<ChaChaRng>>,
} }
@@ -259,6 +261,10 @@ impl mas_graphql::State for TestGraphQLState {
.boxed()) .boxed())
} }
async fn policy(&self) -> Result<Policy, InstantiateError> {
self.policy_factory.instantiate().await
}
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> { fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
&self.homeserver_connection &self.homeserver_connection
} }
@@ -316,12 +322,6 @@ impl FromRef<TestState> for MatrixHomeserver {
} }
} }
impl FromRef<TestState> for Arc<PolicyFactory> {
fn from_ref(input: &TestState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<TestState> for HttpClientFactory { impl FromRef<TestState> for HttpClientFactory {
fn from_ref(input: &TestState) -> Self { fn from_ref(input: &TestState) -> Self {
input.http_client_factory.clone() input.http_client_factory.clone()
@@ -374,7 +374,7 @@ impl FromRequestParts<TestState> for BoxRng {
#[async_trait] #[async_trait]
impl FromRequestParts<TestState> for BoxRepository { impl FromRequestParts<TestState> for BoxRepository {
type Rejection = RepositoryError; type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
async fn from_request_parts( async fn from_request_parts(
_parts: &mut axum::http::request::Parts, _parts: &mut axum::http::request::Parts,
@@ -387,6 +387,19 @@ impl FromRequestParts<TestState> for BoxRepository {
} }
} }
#[async_trait]
impl FromRequestParts<TestState> for Policy {
type Rejection = ErrorWrapper<mas_policy::InstantiateError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &TestState,
) -> Result<Self, Self::Rejection> {
let policy = state.policy_factory.instantiate().await?;
Ok(policy)
}
}
pub(crate) trait RequestBuilderExt { pub(crate) trait RequestBuilderExt {
/// Builds the request with the given JSON value as body. /// Builds the request with the given JSON value as body.
fn json<T: Serialize>(self, body: T) -> hyper::Request<String>; fn json<T: Serialize>(self, body: T) -> hyper::Request<String>;

View File

@@ -21,13 +21,14 @@ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_policy::Policy;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
job::{JobRepositoryExt, VerifyEmailJob}, job::{JobRepositoryExt, VerifyEmailJob},
user::UserEmailRepository, user::UserEmailRepository,
BoxClock, BoxRepository, BoxRng, BoxClock, BoxRepository, BoxRng,
}; };
use mas_templates::{EmailAddContext, TemplateContext, Templates}; use mas_templates::{EmailAddContext, ErrorContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use crate::views::shared::OptionalPostAuthAction; use crate::views::shared::OptionalPostAuthAction;
@@ -69,6 +70,7 @@ pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: BoxRepository, mut repo: BoxRepository,
mut policy: Policy,
cookie_jar: CookieJar, cookie_jar: CookieJar,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
Form(form): Form<ProtectedForm<EmailForm>>, Form(form): Form<ProtectedForm<EmailForm>>,
@@ -83,23 +85,56 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let user_email = repo // XXX: we really should show human readable errors on the form here
.user_email()
.add(&mut rng, &clock, &session.user, form.email)
.await?;
let next = mas_router::AccountVerifyEmail::new(user_email.id); // Validate the email address
let next = if let Some(action) = query.post_auth_action { if form.email.parse::<lettre::Address>().is_err() {
next.and_then(action) return Err(anyhow::anyhow!("Invalid email address").into());
}
// Run the email policy
let res = policy.evaluate_email(&form.email).await?;
if !res.valid() {
return Err(FancyError::new(
ErrorContext::new()
.with_description(format!("Email address {:?} denied by policy", form.email))
.with_details(format!("{res}")),
));
}
// Find an existing email address
let existing_user_email = repo.user_email().find(&session.user, &form.email).await?;
let user_email = if let Some(user_email) = existing_user_email {
user_email
} else { } else {
next let user_email = repo
.user_email()
.add(&mut rng, &clock, &session.user, form.email)
.await?;
user_email
}; };
repo.job() // If the email was not confirmed, send a confirmation email & redirect to the
.schedule_job(VerifyEmailJob::new(&user_email)) // verify page
.await?; let next = if user_email.confirmed_at.is_none() {
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
let next = mas_router::AccountVerifyEmail::new(user_email.id);
let next = if let Some(action) = query.post_auth_action {
next.and_then(action)
} else {
next
};
next.go()
} else {
query.go_next_or_default(&mas_router::Account)
};
repo.save().await?; repo.save().await?;
Ok((cookie_jar, next.go()).into_response()) Ok((cookie_jar, next).into_response())
} }

View File

@@ -24,6 +24,7 @@ use mas_axum_utils::{
FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_data_model::BrowserSession; use mas_data_model::BrowserSession;
use mas_policy::Policy;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository}, user::{BrowserSessionRepository, UserPasswordRepository},
@@ -93,6 +94,7 @@ pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository, mut repo: BoxRepository,
cookie_jar: CookieJar, cookie_jar: CookieJar,
Form(form): Form<ProtectedForm<ChangeForm>>, Form(form): Form<ProtectedForm<ChangeForm>>,
@@ -119,6 +121,13 @@ pub(crate) async fn post(
.await? .await?
.context("user has no password")?; .context("user has no password")?;
let res = policy.evaluate_password(&form.new_password).await?;
// TODO: display nice form errors
if !res.valid() {
return Err(anyhow::anyhow!("Password policy violation: {res}").into());
}
let password = Zeroizing::new(form.current_password.into_bytes()); let password = Zeroizing::new(form.current_password.into_bytes());
let new_password = Zeroizing::new(form.new_password.into_bytes()); let new_password = Zeroizing::new(form.new_password.into_bytes());
let new_password_confirm = Zeroizing::new(form.new_password_confirm.into_bytes()); let new_password_confirm = Zeroizing::new(form.new_password_confirm.into_bytes());
@@ -133,7 +142,7 @@ pub(crate) async fn post(
// TODO: display nice form errors // TODO: display nice form errors
if new_password != new_password_confirm { if new_password != new_password_confirm {
return Err(anyhow::anyhow!("password mismatch").into()); return Err(anyhow::anyhow!("Password mismatch").into());
} }
let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?;

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{str::FromStr, sync::Arc}; use std::str::FromStr;
use axum::{ use axum::{
extract::{Form, Query, State}, extract::{Form, Query, State},
@@ -27,7 +27,7 @@ use mas_axum_utils::{
csrf::{CsrfExt, CsrfToken, ProtectedForm}, csrf::{CsrfExt, CsrfToken, ProtectedForm},
FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_policy::PolicyFactory; use mas_policy::Policy;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob},
@@ -101,8 +101,8 @@ pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar, cookie_jar: CookieJar,
@@ -148,7 +148,6 @@ pub(crate) async fn post(
state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified); state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified);
} }
let mut policy = policy_factory.instantiate().await?;
let res = policy let res = policy
.evaluate_register(&form.username, &form.password, &form.email) .evaluate_register(&form.username, &form.password, &form.email)
.await?; .await?;

View File

@@ -43,6 +43,8 @@ fn write_schema<T: JsonSchema>(out_dir: Option<&Path>, file: &str) {
writer.flush().expect("Failed to flush writer"); writer.flush().expect("Failed to flush writer");
} }
/// Write the input schemas to the output directory.
/// They are then used in rego files to type check the input.
fn main() { fn main() {
let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok(); let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok();
let output_root = output_root.as_deref(); let output_root = output_root.as_deref();

View File

@@ -16,6 +16,7 @@ use mas_data_model::{AuthorizationGrant, Client, User};
use oauth2_types::registration::VerifiedClientMetadata; use oauth2_types::registration::VerifiedClientMetadata;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// A single violation of a policy.
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub struct Violation { pub struct Violation {
@@ -23,19 +24,37 @@ pub struct Violation {
pub field: Option<String>, pub field: Option<String>,
} }
/// The result of a policy evaluation.
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct EvaluationResult { pub struct EvaluationResult {
#[serde(rename = "result")] #[serde(rename = "result")]
pub violations: Vec<Violation>, pub violations: Vec<Violation>,
} }
impl std::fmt::Display for EvaluationResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut first = true;
for violation in &self.violations {
if first {
first = false;
} else {
write!(f, ", ")?;
}
write!(f, "{}", violation.msg)?;
}
Ok(())
}
}
impl EvaluationResult { impl EvaluationResult {
/// Returns true if the policy evaluation was successful.
#[must_use] #[must_use]
pub fn valid(&self) -> bool { pub fn valid(&self) -> bool {
self.violations.is_empty() self.violations.is_empty()
} }
} }
/// Input for the user registration policy.
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[serde(tag = "registration_method", rename_all = "snake_case")] #[serde(tag = "registration_method", rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -47,6 +66,7 @@ pub enum RegisterInput<'a> {
}, },
} }
/// Input for the client registration policy.
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -58,6 +78,7 @@ pub struct ClientRegistrationInput<'a> {
pub client_metadata: &'a VerifiedClientMetadata, pub client_metadata: &'a VerifiedClientMetadata,
} }
/// Input for the authorization grant policy.
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -81,6 +102,7 @@ pub struct AuthorizationGrantInput<'a> {
pub authorization_grant: &'a AuthorizationGrant, pub authorization_grant: &'a AuthorizationGrant,
} }
/// Input for the email add policy.
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -88,6 +110,7 @@ pub struct EmailInput<'a> {
pub email: &'a str, pub email: &'a str,
} }
/// Input for the password set policy.
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]

View File

@@ -147,6 +147,8 @@
"authorization_grant_entrypoint": "authorization_grant/violation", "authorization_grant_entrypoint": "authorization_grant/violation",
"client_registration_entrypoint": "client_registration/violation", "client_registration_entrypoint": "client_registration/violation",
"data": null, "data": null,
"email_entrypoint": "email/violation",
"password_entrypoint": "password/violation",
"register_entrypoint": "register/violation", "register_entrypoint": "register/violation",
"wasm_module": "./policies/policy.wasm" "wasm_module": "./policies/policy.wasm"
}, },
@@ -1349,6 +1351,16 @@
"description": "Arbitrary data to pass to the policy", "description": "Arbitrary data to pass to the policy",
"default": null "default": null
}, },
"email_entrypoint": {
"description": "Entrypoint to use when adding an email address",
"default": "email/violation",
"type": "string"
},
"password_entrypoint": {
"description": "Entrypoint to use when changing password",
"default": "password/violation",
"type": "string"
},
"register_entrypoint": { "register_entrypoint": {
"description": "Entrypoint to use when evaluating user registrations", "description": "Entrypoint to use when evaluating user registrations",
"default": "register/violation", "default": "register/violation",

View File

@@ -28,6 +28,10 @@ type AddEmailPayload {
The user to whom the email address was added The user to whom the email address was added
""" """
user: User user: User
"""
The list of policy violations if the email address was denied
"""
violations: [String!]
} }
""" """
@@ -46,6 +50,10 @@ enum AddEmailStatus {
The email address is invalid The email address is invalid
""" """
INVALID INVALID
"""
The email address is not allowed by the policy
"""
DENIED
} }
type Anonymous implements Node { type Anonymous implements Node {

View File

@@ -30,6 +30,7 @@ const ADD_EMAIL_MUTATION = graphql(/* GraphQL */ `
mutation AddEmail($userId: ID!, $email: String!) { mutation AddEmail($userId: ID!, $email: String!) {
addEmail(input: { userId: $userId, email: $email }) { addEmail(input: { userId: $userId, email: $email }) {
status status
violations
email { email {
id id
...UserEmail_email ...UserEmail_email
@@ -79,6 +80,8 @@ const AddEmailForm: React.FC<{
const status = addEmailResult.data?.addEmail.status ?? null; const status = addEmailResult.data?.addEmail.status ?? null;
const emailExists = status === "EXISTS"; const emailExists = status === "EXISTS";
const emailInvalid = status === "INVALID"; const emailInvalid = status === "INVALID";
const emailDenied = status === "DENIED";
const violations = addEmailResult.data?.addEmail.violations ?? [];
return ( return (
<> <>
@@ -95,6 +98,17 @@ const AddEmailForm: React.FC<{
</Alert> </Alert>
)} )}
{emailDenied && (
<Alert type="critical" title="Email denied by policy">
The entered email is not allowed by the server policy.
<ul>
{violations.map((violation, index) => (
<li key={index}> {violation}</li>
))}
</ul>
</Alert>
)}
<Field name="email" className="my-2"> <Field name="email" className="my-2">
<Label>Add email</Label> <Label>Add email</Label>
<Control disabled={pending} inputMode="email" ref={fieldRef} /> <Control disabled={pending} inputMode="email" ref={fieldRef} />

View File

@@ -47,7 +47,7 @@ const documents = {
types.UserGreetingDocument, types.UserGreetingDocument,
"\n fragment UserHome_user on User {\n id\n\n primaryEmail {\n id\n ...UserEmail_email\n }\n\n confirmedEmails: emails(first: 0, state: CONFIRMED) {\n totalCount\n }\n\n unverifiedEmails: emails(first: 0, state: PENDING) {\n totalCount\n }\n\n browserSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n oauth2Sessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n compatSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n }\n": "\n fragment UserHome_user on User {\n id\n\n primaryEmail {\n id\n ...UserEmail_email\n }\n\n confirmedEmails: emails(first: 0, state: CONFIRMED) {\n totalCount\n }\n\n unverifiedEmails: emails(first: 0, state: PENDING) {\n totalCount\n }\n\n browserSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n oauth2Sessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n compatSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n }\n":
types.UserHome_UserFragmentDoc, types.UserHome_UserFragmentDoc,
"\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n": "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n":
types.AddEmailDocument, types.AddEmailDocument,
"\n query UserEmailListQuery(\n $userId: ID!\n $first: Int\n $after: String\n $last: Int\n $before: String\n ) {\n user(id: $userId) {\n id\n\n emails(first: $first, after: $after, last: $last, before: $before) {\n edges {\n cursor\n node {\n id\n ...UserEmail_email\n }\n }\n totalCount\n pageInfo {\n hasNextPage\n hasPreviousPage\n startCursor\n endCursor\n }\n }\n }\n }\n": "\n query UserEmailListQuery(\n $userId: ID!\n $first: Int\n $after: String\n $last: Int\n $before: String\n ) {\n user(id: $userId) {\n id\n\n emails(first: $first, after: $after, last: $last, before: $before) {\n edges {\n cursor\n node {\n id\n ...UserEmail_email\n }\n }\n totalCount\n pageInfo {\n hasNextPage\n hasPreviousPage\n startCursor\n endCursor\n }\n }\n }\n }\n":
types.UserEmailListQueryDocument, types.UserEmailListQueryDocument,
@@ -191,8 +191,8 @@ export function graphql(
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/ */
export function graphql( export function graphql(
source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n", source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n",
): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"]; ): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"];
/** /**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/ */

View File

@@ -54,12 +54,16 @@ export type AddEmailPayload = {
status: AddEmailStatus; status: AddEmailStatus;
/** The user to whom the email address was added */ /** The user to whom the email address was added */
user?: Maybe<User>; user?: Maybe<User>;
/** The list of policy violations if the email address was denied */
violations?: Maybe<Array<Scalars["String"]["output"]>>;
}; };
/** The status of the `addEmail` mutation */ /** The status of the `addEmail` mutation */
export enum AddEmailStatus { export enum AddEmailStatus {
/** The email address was added */ /** The email address was added */
Added = "ADDED", Added = "ADDED",
/** The email address is not allowed by the policy */
Denied = "DENIED",
/** The email address already exists */ /** The email address already exists */
Exists = "EXISTS", Exists = "EXISTS",
/** The email address is invalid */ /** The email address is invalid */
@@ -1231,6 +1235,7 @@ export type AddEmailMutation = {
addEmail: { addEmail: {
__typename?: "AddEmailPayload"; __typename?: "AddEmailPayload";
status: AddEmailStatus; status: AddEmailStatus;
violations?: Array<string> | null;
email?: email?:
| ({ __typename?: "UserEmail"; id: string } & { | ({ __typename?: "UserEmail"; id: string } & {
" $fragmentRefs"?: { " $fragmentRefs"?: {
@@ -3129,6 +3134,7 @@ export const AddEmailDocument = {
kind: "SelectionSet", kind: "SelectionSet",
selections: [ selections: [
{ kind: "Field", name: { kind: "Name", value: "status" } }, { kind: "Field", name: { kind: "Name", value: "status" } },
{ kind: "Field", name: { kind: "Name", value: "violations" } },
{ {
kind: "Field", kind: "Field",
name: { kind: "Name", value: "email" }, name: { kind: "Name", value: "email" },

View File

@@ -42,6 +42,20 @@ export default {
}, },
args: [], args: [],
}, },
{
name: "violations",
type: {
kind: "LIST",
ofType: {
kind: "NON_NULL",
ofType: {
kind: "SCALAR",
name: "Any",
},
},
},
args: [],
},
], ],
interfaces: [], interfaces: [],
}, },

View File

@@ -1,6 +1,7 @@
{ {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
"title": "AuthorizationGrantInput", "title": "AuthorizationGrantInput",
"description": "Input for the authorization grant policy.",
"type": "object", "type": "object",
"required": [ "required": [
"authorization_grant", "authorization_grant",

View File

@@ -1,6 +1,7 @@
{ {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
"title": "ClientRegistrationInput", "title": "ClientRegistrationInput",
"description": "Input for the client registration policy.",
"type": "object", "type": "object",
"required": [ "required": [
"client_metadata" "client_metadata"

View File

@@ -1,6 +1,7 @@
{ {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
"title": "EmailInput", "title": "EmailInput",
"description": "Input for the email add policy.",
"type": "object", "type": "object",
"required": [ "required": [
"email" "email"

View File

@@ -1,6 +1,7 @@
{ {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
"title": "PasswordInput", "title": "PasswordInput",
"description": "Input for the password set policy.",
"type": "object", "type": "object",
"required": [ "required": [
"password" "password"

View File

@@ -1,6 +1,7 @@
{ {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
"title": "RegisterInput", "title": "RegisterInput",
"description": "Input for the user registration policy.",
"oneOf": [ "oneOf": [
{ {
"type": "object", "type": "object",