1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Make sure we validate passwords & emails by the policy at all stages

Also refactors the way we get the policy engines in requests
This commit is contained in:
Quentin Gliech
2023-08-30 16:47:57 +02:00
parent 23151ef092
commit 7fcd022eea
30 changed files with 264 additions and 84 deletions

1
Cargo.lock generated
View File

@@ -2763,6 +2763,7 @@ dependencies = [
"lettre",
"mas-data-model",
"mas-matrix",
"mas-policy",
"mas-storage",
"oauth2-types",
"serde",

View File

@@ -23,6 +23,12 @@ pub struct FancyError {
context: ErrorContext,
}
impl FancyError {
pub fn new(context: ErrorContext) -> Self {
Self { context }
}
}
impl std::fmt::Display for FancyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let code = self.context.code().unwrap_or("Internal error");

View File

@@ -143,7 +143,7 @@ impl Options {
// Listen for SIGHUP
register_sighup(&templates)?;
let graphql_schema = mas_handlers::graphql_schema(&pool, conn);
let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn);
let state = {
let mut s = AppState {

View File

@@ -22,6 +22,7 @@ url.workspace = true
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-matrix = { path = "../matrix" }
mas-policy = { path = "../policy" }
mas-storage = { path = "../storage" }
[[bin]]

View File

@@ -49,6 +49,8 @@ pub enum AddEmailStatus {
Exists,
/// The email address is invalid
Invalid,
/// The email address is not allowed by the policy
Denied,
}
/// The payload of the `addEmail` mutation
@@ -57,6 +59,9 @@ enum AddEmailPayload {
Added(mas_data_model::UserEmail),
Exists(mas_data_model::UserEmail),
Invalid,
Denied {
violations: Vec<mas_policy::Violation>,
},
}
#[Object(use_type_description)]
@@ -67,6 +72,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(_) => AddEmailStatus::Added,
AddEmailPayload::Exists(_) => AddEmailStatus::Exists,
AddEmailPayload::Invalid => AddEmailStatus::Invalid,
AddEmailPayload::Denied { .. } => AddEmailStatus::Denied,
}
}
@@ -76,7 +82,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => {
Some(UserEmail(email.clone()))
}
AddEmailPayload::Invalid => None,
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None,
}
}
@@ -87,7 +93,7 @@ impl AddEmailPayload {
let user_id = match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id,
AddEmailPayload::Invalid => return Ok(None),
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None),
};
let user = repo
@@ -98,6 +104,16 @@ impl AddEmailPayload {
Ok(Some(User(user)))
}
/// The list of policy violations if the email address was denied
async fn violations(&self) -> Option<Vec<String>> {
let AddEmailPayload::Denied { violations } = self else {
return None;
};
let messages = violations.iter().map(|v| v.msg.clone()).collect();
Some(messages)
}
}
/// The input for the `sendVerificationEmail` mutation
@@ -382,6 +398,14 @@ impl UserEmailMutations {
return Ok(AddEmailPayload::Invalid);
}
let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
});
}
// Find an existing email address
let existing_user_email = repo.user_email().find(&user, &input.email).await?;
let (added, user_email) = if let Some(user_email) = existing_user_email {

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
use crate::Requester;
@@ -20,6 +21,7 @@ use crate::Requester;
#[async_trait::async_trait]
pub trait State {
async fn repository(&self) -> Result<BoxRepository, RepositoryError>;
async fn policy(&self) -> Result<Policy, mas_policy::InstantiateError>;
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error>;
fn clock(&self) -> BoxClock;
fn rng(&self) -> BoxRng;

View File

@@ -17,12 +17,12 @@ use std::{convert::Infallible, sync::Arc, time::Instant};
use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
response::IntoResponse,
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage_pg::PgRepository;
@@ -33,7 +33,6 @@ use opentelemetry::{
};
use rand::SeedableRng;
use sqlx::PgPool;
use thiserror::Error;
use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver};
@@ -176,12 +175,6 @@ impl FromRef<AppState> for MatrixHomeserver {
}
}
impl FromRef<AppState> for Arc<PolicyFactory> {
fn from_ref(input: &AppState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<AppState> for HttpClientFactory {
fn from_ref(input: &AppState) -> Self {
input.http_client_factory.clone()
@@ -236,19 +229,41 @@ impl FromRequestParts<AppState> for BoxRng {
}
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError);
/// A simple wrapper around an error that implements [`IntoResponse`].
pub struct ErrorWrapper<T>(T);
impl IntoResponse for RepositoryError {
fn into_response(self) -> axum::response::Response {
impl<T> From<T> for ErrorWrapper<T> {
fn from(input: T) -> Self {
Self(input)
}
}
impl<T> IntoResponse for ErrorWrapper<T>
where
T: std::error::Error,
{
fn into_response(self) -> Response {
// TODO: make this a bit more user friendly
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
}
}
#[async_trait]
impl FromRequestParts<AppState> for Policy {
type Rejection = ErrorWrapper<mas_policy::InstantiateError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let policy = state.policy_factory.instantiate().await?;
Ok(policy)
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = RepositoryError;
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,

View File

@@ -31,6 +31,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfo, SessionInfoExt};
use mas_graphql::{Requester, Schema};
use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock,
};
@@ -48,6 +49,7 @@ mod tests;
struct GraphQLState {
pool: PgPool,
homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
policy_factory: Arc<PolicyFactory>,
}
#[async_trait]
@@ -60,6 +62,10 @@ impl mas_graphql::State for GraphQLState {
Ok(repo.map_err(RepositoryError::from_error).boxed())
}
async fn policy(&self) -> Result<Policy, InstantiateError> {
self.policy_factory.instantiate().await
}
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
self.homeserver_connection.as_ref()
}
@@ -81,10 +87,12 @@ impl mas_graphql::State for GraphQLState {
#[must_use]
pub fn schema(
pool: &PgPool,
policy_factory: &Arc<PolicyFactory>,
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Schema {
let state = GraphQLState {
pool: pool.clone(),
policy_factory: Arc::clone(policy_factory),
homeserver_connection: Arc::new(homeserver_connection),
};
let state: mas_graphql::BoxState = Box::new(state);

View File

@@ -30,7 +30,7 @@
clippy::let_with_type_underscore,
)]
use std::{convert::Infallible, sync::Arc, time::Duration};
use std::{convert::Infallible, time::Duration};
use axum::{
body::{Bytes, HttpBody},
@@ -50,7 +50,7 @@ use hyper::{
use mas_axum_utils::{cookies::CookieJar, FancyError};
use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_policy::Policy;
use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{ErrorContext, NotFoundContext, Templates};
@@ -166,12 +166,12 @@ where
S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
// All those routes are API-like, with a common CORS layer
Router::new()
@@ -267,7 +267,6 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>,
CookieJar: FromRequestParts<S>,
Encrypter: FromRef<S>,
@@ -278,6 +277,7 @@ where
MetadataCache: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
Router::new()
// XXX: hard-coded redirect from /account to /account/

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::{
extract::{Path, State},
response::{Html, IntoResponse, Response},
@@ -22,7 +20,7 @@ use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt};
use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device};
use mas_keystore::Keystore;
use mas_policy::{EvaluationResult, PolicyFactory};
use mas_policy::{EvaluationResult, Policy};
use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
@@ -76,7 +74,6 @@ impl IntoResponse for RouteError {
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(super::callback::CallbackDestinationError);
@@ -90,10 +87,10 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(key_store): State<Keystore>,
policy: Policy,
mut repo: BoxRepository,
cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>,
@@ -128,7 +125,7 @@ pub(crate) async fn get(
&clock,
repo,
key_store,
&policy_factory,
policy,
url_builder,
grant,
&client,
@@ -187,7 +184,6 @@ pub enum GrantCompletionError {
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError);
@@ -196,7 +192,7 @@ pub(crate) async fn complete(
clock: &impl Clock,
mut repo: BoxRepository,
key_store: Keystore,
policy_factory: &PolicyFactory,
mut policy: Policy,
url_builder: UrlBuilder,
grant: AuthorizationGrant,
client: &Client,
@@ -220,7 +216,6 @@ pub(crate) async fn complete(
};
// Run through the policy
let mut policy = policy_factory.instantiate().await?;
let res = policy
.evaluate_authorization_grant(&grant, client, &browser_session.user)
.await?;

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::{
extract::{Form, State},
response::{Html, IntoResponse, Response},
@@ -22,7 +20,7 @@ use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt};
use mas_data_model::{AuthorizationCode, Pkce};
use mas_keystore::Keystore;
use mas_policy::PolicyFactory;
use mas_policy::Policy;
use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
@@ -94,7 +92,6 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
#[derive(Deserialize)]
@@ -140,10 +137,10 @@ fn resolve_response_mode(
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>,
policy: Policy,
mut repo: BoxRepository,
cookie_jar: CookieJar,
Form(params): Form<Params>,
@@ -346,7 +343,7 @@ pub(crate) async fn get(
&clock,
repo,
key_store,
&policy_factory,
policy,
url_builder,
grant,
&client,
@@ -393,7 +390,7 @@ pub(crate) async fn get(
&clock,
repo,
key_store,
&policy_factory,
policy,
url_builder,
grant,
&client,

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::{
extract::{Form, Path, State},
response::{Html, IntoResponse, Response},
@@ -25,7 +23,7 @@ use mas_axum_utils::{
SessionInfoExt,
};
use mas_data_model::{AuthorizationGrantStage, Device};
use mas_policy::PolicyFactory;
use mas_policy::Policy;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
@@ -61,7 +59,6 @@ pub enum RouteError {
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl IntoResponse for RouteError {
@@ -80,8 +77,8 @@ impl IntoResponse for RouteError {
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository,
cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>,
@@ -109,7 +106,6 @@ pub(crate) async fn get(
if let Some(session) = maybe_session {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let mut policy = policy_factory.instantiate().await?;
let res = policy
.evaluate_authorization_grant(&grant, &client, &session.user)
.await?;
@@ -146,7 +142,7 @@ pub(crate) async fn get(
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
mut policy: Policy,
mut repo: BoxRepository,
cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>,
@@ -176,7 +172,6 @@ pub(crate) async fn post(
.await?
.ok_or(RouteError::NoSuchClient)?;
let mut policy = policy_factory.instantiate().await?;
let res = policy
.evaluate_authorization_grant(&grant, &client, &session.user)
.await?;

View File

@@ -12,13 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::{extract::State, response::IntoResponse, Json};
use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation};
use mas_policy::{Policy, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
@@ -49,7 +47,6 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_keystore::aead::Error);
@@ -136,7 +133,7 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(policy_factory): State<Arc<PolicyFactory>>,
mut policy: Policy,
State(encrypter): State<Encrypter>,
body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
) -> Result<impl IntoResponse, RouteError> {
@@ -148,7 +145,6 @@ pub(crate) async fn post(
// Validate the body
let metadata = body.validate()?;
let mut policy = policy_factory.instantiate().await?;
let res = policy.evaluate_client_registration(&metadata).await?;
if !res.valid() {
return Err(RouteError::PolicyDenied(res.violations));

View File

@@ -33,10 +33,10 @@ use hyper::{
use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{HomeserverConnection, MockHomeserverConnection};
use mas_policy::PolicyFactory;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
use mas_storage_pg::PgRepository;
use mas_storage_pg::{DatabaseError, PgRepository};
use mas_templates::Templates;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
@@ -46,7 +46,7 @@ use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
app_state::RepositoryError,
app_state::ErrorWrapper,
passwords::{Hasher, PasswordManager},
upstream_oauth2::cache::MetadataCache,
MatrixHomeserver,
@@ -138,6 +138,7 @@ impl TestState {
let graphql_state = TestGraphQLState {
pool: pool.clone(),
policy_factory: Arc::clone(&policy_factory),
homeserver_connection,
rng: Arc::clone(&rng),
clock: Arc::clone(&clock),
@@ -202,7 +203,7 @@ impl TestState {
Response::from_parts(parts, body)
}
pub async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
let repo = PgRepository::from_pool(&self.pool).await?;
Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
@@ -243,6 +244,7 @@ impl TestState {
struct TestGraphQLState {
pool: PgPool,
homeserver_connection: MockHomeserverConnection,
policy_factory: Arc<PolicyFactory>,
clock: Arc<MockClock>,
rng: Arc<Mutex<ChaChaRng>>,
}
@@ -259,6 +261,10 @@ impl mas_graphql::State for TestGraphQLState {
.boxed())
}
async fn policy(&self) -> Result<Policy, InstantiateError> {
self.policy_factory.instantiate().await
}
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
&self.homeserver_connection
}
@@ -316,12 +322,6 @@ impl FromRef<TestState> for MatrixHomeserver {
}
}
impl FromRef<TestState> for Arc<PolicyFactory> {
fn from_ref(input: &TestState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<TestState> for HttpClientFactory {
fn from_ref(input: &TestState) -> Self {
input.http_client_factory.clone()
@@ -374,7 +374,7 @@ impl FromRequestParts<TestState> for BoxRng {
#[async_trait]
impl FromRequestParts<TestState> for BoxRepository {
type Rejection = RepositoryError;
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
@@ -387,6 +387,19 @@ impl FromRequestParts<TestState> for BoxRepository {
}
}
#[async_trait]
impl FromRequestParts<TestState> for Policy {
type Rejection = ErrorWrapper<mas_policy::InstantiateError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &TestState,
) -> Result<Self, Self::Rejection> {
let policy = state.policy_factory.instantiate().await?;
Ok(policy)
}
}
pub(crate) trait RequestBuilderExt {
/// Builds the request with the given JSON value as body.
fn json<T: Serialize>(self, body: T) -> hyper::Request<String>;

View File

@@ -21,13 +21,14 @@ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
FancyError, SessionInfoExt,
};
use mas_policy::Policy;
use mas_router::Route;
use mas_storage::{
job::{JobRepositoryExt, VerifyEmailJob},
user::UserEmailRepository,
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{EmailAddContext, TemplateContext, Templates};
use mas_templates::{EmailAddContext, ErrorContext, TemplateContext, Templates};
use serde::Deserialize;
use crate::views::shared::OptionalPostAuthAction;
@@ -69,6 +70,7 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
mut policy: Policy,
cookie_jar: CookieJar,
Query(query): Query<OptionalPostAuthAction>,
Form(form): Form<ProtectedForm<EmailForm>>,
@@ -83,23 +85,56 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response());
};
let user_email = repo
.user_email()
.add(&mut rng, &clock, &session.user, form.email)
.await?;
// XXX: we really should show human readable errors on the form here
let next = mas_router::AccountVerifyEmail::new(user_email.id);
let next = if let Some(action) = query.post_auth_action {
next.and_then(action)
// Validate the email address
if form.email.parse::<lettre::Address>().is_err() {
return Err(anyhow::anyhow!("Invalid email address").into());
}
// Run the email policy
let res = policy.evaluate_email(&form.email).await?;
if !res.valid() {
return Err(FancyError::new(
ErrorContext::new()
.with_description(format!("Email address {:?} denied by policy", form.email))
.with_details(format!("{res}")),
));
}
// Find an existing email address
let existing_user_email = repo.user_email().find(&session.user, &form.email).await?;
let user_email = if let Some(user_email) = existing_user_email {
user_email
} else {
next
let user_email = repo
.user_email()
.add(&mut rng, &clock, &session.user, form.email)
.await?;
user_email
};
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
// If the email was not confirmed, send a confirmation email & redirect to the
// verify page
let next = if user_email.confirmed_at.is_none() {
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
let next = mas_router::AccountVerifyEmail::new(user_email.id);
let next = if let Some(action) = query.post_auth_action {
next.and_then(action)
} else {
next
};
next.go()
} else {
query.go_next_or_default(&mas_router::Account)
};
repo.save().await?;
Ok((cookie_jar, next.go()).into_response())
Ok((cookie_jar, next).into_response())
}

View File

@@ -24,6 +24,7 @@ use mas_axum_utils::{
FancyError, SessionInfoExt,
};
use mas_data_model::BrowserSession;
use mas_policy::Policy;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository},
@@ -93,6 +94,7 @@ pub(crate) async fn post(
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository,
cookie_jar: CookieJar,
Form(form): Form<ProtectedForm<ChangeForm>>,
@@ -119,6 +121,13 @@ pub(crate) async fn post(
.await?
.context("user has no password")?;
let res = policy.evaluate_password(&form.new_password).await?;
// TODO: display nice form errors
if !res.valid() {
return Err(anyhow::anyhow!("Password policy violation: {res}").into());
}
let password = Zeroizing::new(form.current_password.into_bytes());
let new_password = Zeroizing::new(form.new_password.into_bytes());
let new_password_confirm = Zeroizing::new(form.new_password_confirm.into_bytes());
@@ -133,7 +142,7 @@ pub(crate) async fn post(
// TODO: display nice form errors
if new_password != new_password_confirm {
return Err(anyhow::anyhow!("password mismatch").into());
return Err(anyhow::anyhow!("Password mismatch").into());
}
let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?;

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{str::FromStr, sync::Arc};
use std::str::FromStr;
use axum::{
extract::{Form, Query, State},
@@ -27,7 +27,7 @@ use mas_axum_utils::{
csrf::{CsrfExt, CsrfToken, ProtectedForm},
FancyError, SessionInfoExt,
};
use mas_policy::PolicyFactory;
use mas_policy::Policy;
use mas_router::Route;
use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob},
@@ -101,8 +101,8 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar,
@@ -148,7 +148,6 @@ pub(crate) async fn post(
state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified);
}
let mut policy = policy_factory.instantiate().await?;
let res = policy
.evaluate_register(&form.username, &form.password, &form.email)
.await?;

View File

@@ -43,6 +43,8 @@ fn write_schema<T: JsonSchema>(out_dir: Option<&Path>, file: &str) {
writer.flush().expect("Failed to flush writer");
}
/// Write the input schemas to the output directory.
/// They are then used in rego files to type check the input.
fn main() {
let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok();
let output_root = output_root.as_deref();

View File

@@ -16,6 +16,7 @@ use mas_data_model::{AuthorizationGrant, Client, User};
use oauth2_types::registration::VerifiedClientMetadata;
use serde::{Deserialize, Serialize};
/// A single violation of a policy.
#[derive(Deserialize, Debug)]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub struct Violation {
@@ -23,19 +24,37 @@ pub struct Violation {
pub field: Option<String>,
}
/// The result of a policy evaluation.
#[derive(Deserialize, Debug)]
pub struct EvaluationResult {
#[serde(rename = "result")]
pub violations: Vec<Violation>,
}
impl std::fmt::Display for EvaluationResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut first = true;
for violation in &self.violations {
if first {
first = false;
} else {
write!(f, ", ")?;
}
write!(f, "{}", violation.msg)?;
}
Ok(())
}
}
impl EvaluationResult {
/// Returns true if the policy evaluation was successful.
#[must_use]
pub fn valid(&self) -> bool {
self.violations.is_empty()
}
}
/// Input for the user registration policy.
#[derive(Serialize, Debug)]
#[serde(tag = "registration_method", rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -47,6 +66,7 @@ pub enum RegisterInput<'a> {
},
}
/// Input for the client registration policy.
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -58,6 +78,7 @@ pub struct ClientRegistrationInput<'a> {
pub client_metadata: &'a VerifiedClientMetadata,
}
/// Input for the authorization grant policy.
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -81,6 +102,7 @@ pub struct AuthorizationGrantInput<'a> {
pub authorization_grant: &'a AuthorizationGrant,
}
/// Input for the email add policy.
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
@@ -88,6 +110,7 @@ pub struct EmailInput<'a> {
pub email: &'a str,
}
/// Input for the password set policy.
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]

View File

@@ -147,6 +147,8 @@
"authorization_grant_entrypoint": "authorization_grant/violation",
"client_registration_entrypoint": "client_registration/violation",
"data": null,
"email_entrypoint": "email/violation",
"password_entrypoint": "password/violation",
"register_entrypoint": "register/violation",
"wasm_module": "./policies/policy.wasm"
},
@@ -1349,6 +1351,16 @@
"description": "Arbitrary data to pass to the policy",
"default": null
},
"email_entrypoint": {
"description": "Entrypoint to use when adding an email address",
"default": "email/violation",
"type": "string"
},
"password_entrypoint": {
"description": "Entrypoint to use when changing password",
"default": "password/violation",
"type": "string"
},
"register_entrypoint": {
"description": "Entrypoint to use when evaluating user registrations",
"default": "register/violation",

View File

@@ -28,6 +28,10 @@ type AddEmailPayload {
The user to whom the email address was added
"""
user: User
"""
The list of policy violations if the email address was denied
"""
violations: [String!]
}
"""
@@ -46,6 +50,10 @@ enum AddEmailStatus {
The email address is invalid
"""
INVALID
"""
The email address is not allowed by the policy
"""
DENIED
}
type Anonymous implements Node {

View File

@@ -30,6 +30,7 @@ const ADD_EMAIL_MUTATION = graphql(/* GraphQL */ `
mutation AddEmail($userId: ID!, $email: String!) {
addEmail(input: { userId: $userId, email: $email }) {
status
violations
email {
id
...UserEmail_email
@@ -79,6 +80,8 @@ const AddEmailForm: React.FC<{
const status = addEmailResult.data?.addEmail.status ?? null;
const emailExists = status === "EXISTS";
const emailInvalid = status === "INVALID";
const emailDenied = status === "DENIED";
const violations = addEmailResult.data?.addEmail.violations ?? [];
return (
<>
@@ -95,6 +98,17 @@ const AddEmailForm: React.FC<{
</Alert>
)}
{emailDenied && (
<Alert type="critical" title="Email denied by policy">
The entered email is not allowed by the server policy.
<ul>
{violations.map((violation, index) => (
<li key={index}> {violation}</li>
))}
</ul>
</Alert>
)}
<Field name="email" className="my-2">
<Label>Add email</Label>
<Control disabled={pending} inputMode="email" ref={fieldRef} />

View File

@@ -47,7 +47,7 @@ const documents = {
types.UserGreetingDocument,
"\n fragment UserHome_user on User {\n id\n\n primaryEmail {\n id\n ...UserEmail_email\n }\n\n confirmedEmails: emails(first: 0, state: CONFIRMED) {\n totalCount\n }\n\n unverifiedEmails: emails(first: 0, state: PENDING) {\n totalCount\n }\n\n browserSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n oauth2Sessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n compatSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n }\n":
types.UserHome_UserFragmentDoc,
"\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n":
"\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n":
types.AddEmailDocument,
"\n query UserEmailListQuery(\n $userId: ID!\n $first: Int\n $after: String\n $last: Int\n $before: String\n ) {\n user(id: $userId) {\n id\n\n emails(first: $first, after: $after, last: $last, before: $before) {\n edges {\n cursor\n node {\n id\n ...UserEmail_email\n }\n }\n totalCount\n pageInfo {\n hasNextPage\n hasPreviousPage\n startCursor\n endCursor\n }\n }\n }\n }\n":
types.UserEmailListQueryDocument,
@@ -191,8 +191,8 @@ export function graphql(
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/
export function graphql(
source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n",
): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"];
source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n",
): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"];
/**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/

View File

@@ -54,12 +54,16 @@ export type AddEmailPayload = {
status: AddEmailStatus;
/** The user to whom the email address was added */
user?: Maybe<User>;
/** The list of policy violations if the email address was denied */
violations?: Maybe<Array<Scalars["String"]["output"]>>;
};
/** The status of the `addEmail` mutation */
export enum AddEmailStatus {
/** The email address was added */
Added = "ADDED",
/** The email address is not allowed by the policy */
Denied = "DENIED",
/** The email address already exists */
Exists = "EXISTS",
/** The email address is invalid */
@@ -1231,6 +1235,7 @@ export type AddEmailMutation = {
addEmail: {
__typename?: "AddEmailPayload";
status: AddEmailStatus;
violations?: Array<string> | null;
email?:
| ({ __typename?: "UserEmail"; id: string } & {
" $fragmentRefs"?: {
@@ -3129,6 +3134,7 @@ export const AddEmailDocument = {
kind: "SelectionSet",
selections: [
{ kind: "Field", name: { kind: "Name", value: "status" } },
{ kind: "Field", name: { kind: "Name", value: "violations" } },
{
kind: "Field",
name: { kind: "Name", value: "email" },

View File

@@ -42,6 +42,20 @@ export default {
},
args: [],
},
{
name: "violations",
type: {
kind: "LIST",
ofType: {
kind: "NON_NULL",
ofType: {
kind: "SCALAR",
name: "Any",
},
},
},
args: [],
},
],
interfaces: [],
},

View File

@@ -1,6 +1,7 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "AuthorizationGrantInput",
"description": "Input for the authorization grant policy.",
"type": "object",
"required": [
"authorization_grant",

View File

@@ -1,6 +1,7 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "ClientRegistrationInput",
"description": "Input for the client registration policy.",
"type": "object",
"required": [
"client_metadata"

View File

@@ -1,6 +1,7 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "EmailInput",
"description": "Input for the email add policy.",
"type": "object",
"required": [
"email"

View File

@@ -1,6 +1,7 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "PasswordInput",
"description": "Input for the password set policy.",
"type": "object",
"required": [
"password"

View File

@@ -1,6 +1,7 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "RegisterInput",
"description": "Input for the user registration policy.",
"oneOf": [
{
"type": "object",