diff --git a/crates/config/src/sections/policy.rs b/crates/config/src/sections/policy.rs index 04742c83..5bdd6a0d 100644 --- a/crates/config/src/sections/policy.rs +++ b/crates/config/src/sections/policy.rs @@ -22,15 +22,15 @@ use serde_with::serde_as; use super::ConfigurationSection; fn default_client_registration_endpoint() -> String { - "client_registration/allow".to_string() + "client_registration/violation".to_string() } fn default_login_endpoint() -> String { - "login/allow".to_string() + "login/violation".to_string() } fn default_register_endpoint() -> String { - "register/allow".to_string() + "register/violation".to_string() } /// Application secrets diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index b07e2f75..18bcef0d 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -119,8 +119,8 @@ pub(crate) async fn post( } let mut policy = policy_factory.instantiate().await?; - let allowed = policy.evaluate_client_registration(&body).await?; - if !allowed { + let res = policy.evaluate_client_registration(&body).await?; + if !res.valid() { return Err(RouteError::PolicyDenied); } diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index c8d7a255..77c02693 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -136,8 +136,24 @@ pub(crate) async fn post( .evaluate_register(&form.username, &form.email) .await?; - if !res { - state.add_error_on_form(FormError::Policy); + for violation in res.violations { + match violation.field.as_deref() { + Some("email") => state.add_error_on_field( + RegisterFormField::Email, + FieldError::Policy { + message: violation.msg, + }, + ), + Some("username") => state.add_error_on_field( + RegisterFormField::Username, + FieldError::Policy { + message: violation.msg, + }, + ), + _ => state.add_error_on_form(FormError::Policy { + message: violation.msg, + }), + } } state diff --git a/crates/policy/policies/Makefile b/crates/policy/policies/Makefile index 7a4aa757..199cc3a6 100644 --- a/crates/policy/policies/Makefile +++ b/crates/policy/policies/Makefile @@ -3,19 +3,21 @@ DOCKER := 0 ifeq ($(DOCKER), 0) OPA := opa + OPA_RW := opa else - OPA := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0 + OPA := docker run -v $(shell pwd):/policies:ro -w /policies --rm docker.io/openpolicyagent/opa:0.40.0 + OPA_RW := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0 endif policy.wasm: client_registration.rego login.rego register.rego - $(OPA) build -t wasm -e "client_registration/allow" -e "login/allow" -e "register/allow" $^ + $(OPA_RW) build -t wasm -e "client_registration/violation" -e "login/violation" -e "register/violation" $^ tar xzf bundle.tar.gz /policy.wasm $(RM) bundle.tar.gz touch $@ .PHONY: fmt fmt: - $(OPA) fmt -w . + $(OPA_RW) fmt -w . .PHONY: test test: diff --git a/crates/policy/policies/client_registration.rego b/crates/policy/policies/client_registration.rego index 3f0acc07..ad8e24ce 100644 --- a/crates/policy/policies/client_registration.rego +++ b/crates/policy/policies/client_registration.rego @@ -2,17 +2,54 @@ package client_registration import future.keywords.in +default allow := false + +allow { + count(violation) == 0 +} + secure_url(x) { is_string(x) startswith(x, "https://") } -default allow := false - -allow { - secure_url(input.client_metadata.client_uri) - secure_url(input.client_metadata.tos_uri) - secure_url(input.client_metadata.policy_uri) - some redirect_uri in input.client_metadata.redirect_uris - secure_url(redirect_uri) +violation[{"msg": "missing client_uri"}] { + not input.client_metadata.client_uri +} + +violation[{"msg": "invalid client_uri"}] { + not secure_url(input.client_metadata.client_uri) +} + +violation[{"msg": "missing tos_uri"}] { + not input.client_metadata.tos_uri +} + +violation[{"msg": "invalid tos_uri"}] { + not secure_url(input.client_metadata.tos_uri) +} + +violation[{"msg": "missing policy_uri"}] { + not input.client_metadata.policy_uri +} + +violation[{"msg": "invalid policy_uri"}] { + not secure_url(input.client_metadata.policy_uri) +} + +violation[{"msg": "missing redirect_uris"}] { + not input.client_metadata.redirect_uris +} + +violation[{"msg": "invalid redirect_uris"}] { + not is_array(input.client_metadata.redirect_uris) +} + +violation[{"msg": "empty redirect_uris"}] { + count(input.client_metadata.redirect_uris) == 0 +} + +violation[{"msg": "invalid redirect_uri"}] { + some redirect_uri in input.client_metadata.redirect_uris + not secure_url(redirect_uri) } diff --git a/crates/policy/policies/login.rego b/crates/policy/policies/login.rego index 8154fff2..94dac6df 100644 --- a/crates/policy/policies/login.rego +++ b/crates/policy/policies/login.rego @@ -1,3 +1,3 @@ package login -allow := true +violation := [] diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index e39a52b7..86835622 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -127,8 +127,21 @@ impl PolicyFactory { } #[derive(Deserialize)] -struct EvaluationResult { - result: bool, +pub struct Violation { + pub msg: String, + pub field: Option, +} + +#[derive(Deserialize)] +pub struct EvaluationResult { + #[serde(rename = "result")] + pub violations: Vec, +} + +impl EvaluationResult { + pub fn valid(&self) -> bool { + self.violations.is_empty() + } } #[derive(Debug)] @@ -145,7 +158,7 @@ impl Policy { pub async fn evaluate_login( &mut self, user: &mas_data_model::User<()>, - ) -> Result { + ) -> Result { let user = serde_json::to_value(user)?; let input = serde_json::json!({ "user": user }); @@ -154,7 +167,7 @@ impl Policy { .evaluate(&mut self.store, &self.login_entrypoint, &input) .await?; - Ok(res.result) + Ok(res) } #[tracing::instrument] @@ -162,7 +175,7 @@ impl Policy { &mut self, username: &str, email: &str, - ) -> Result { + ) -> Result { let input = serde_json::json!({ "user": { "username": username, @@ -175,14 +188,14 @@ impl Policy { .evaluate(&mut self.store, &self.register_entrypoint, &input) .await?; - Ok(res.result) + Ok(res) } #[tracing::instrument] pub async fn evaluate_client_registration( &mut self, client_metadata: &ClientMetadata, - ) -> Result { + ) -> Result { let client_metadata = serde_json::to_value(client_metadata)?; let input = serde_json::json!({ "client_metadata": client_metadata, @@ -197,6 +210,47 @@ impl Policy { ) .await?; - Ok(res.result) + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_register() { + let factory = PolicyFactory::load( + default_wasm_policy(), + serde_json::json!({ + "allowed_domains": ["element.io", "*.element.io"], + "banned_domains": ["staging.element.io"], + }), + "login/violation".to_string(), + "register/violation".to_string(), + "client_registration/violation".to_string(), + ) + .await + .unwrap(); + + let mut policy = factory.instantiate().await.unwrap(); + + let res = policy + .evaluate_register("hello", "hello@example.com") + .await + .unwrap(); + assert!(!res.valid()); + + let res = policy + .evaluate_register("hello", "hello@foo.element.io") + .await + .unwrap(); + assert!(res.valid()); + + let res = policy + .evaluate_register("hello", "hello@staging.element.io") + .await + .unwrap(); + assert!(!res.valid()); } } diff --git a/crates/templates/src/forms.rs b/crates/templates/src/forms.rs index 0ff8583e..26000eda 100644 --- a/crates/templates/src/forms.rs +++ b/crates/templates/src/forms.rs @@ -38,6 +38,12 @@ pub enum FieldError { /// That value already exists Exists, + + /// Denied by the policy + Policy { + /// Message for this policy violation + message: String, + }, } /// An error on the whole form @@ -54,7 +60,10 @@ pub enum FormError { Internal, /// Denied by the policy - Policy, + Policy { + /// Message for this policy violation + message: String, + }, } #[derive(Debug, Default, Serialize)] diff --git a/crates/templates/src/res/components/field.html b/crates/templates/src/res/components/field.html index bdbf6a03..6298256f 100644 --- a/crates/templates/src/res/components/field.html +++ b/crates/templates/src/res/components/field.html @@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -15,43 +15,46 @@ limitations under the License. #} {% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text") %} -{% if not form_state %} -{% set form_state = dict(errors=[], fields=dict()) %} -{% endif %} + {% if not form_state %} + {% set form_state = dict(errors=[], fields=dict()) %} + {% endif %} -{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %} - -{% if state.errors is not empty %} -{% set border_color = "border-alert" %} -{% set text_color = "text-alert" %} -{% else %} -{% set border_color = "border-grey-50 dark:border-grey-450" %} -{% set text_color = "text-black-800 dark:text-grey-300" %} -{% endif %} - - + {% endmacro input %}