1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Switch the policies to a violation list based approach

This allows policies to give proper feedback on form fields
This commit is contained in:
Quentin Gliech
2022-06-03 11:59:56 +02:00
parent 88c2625dc0
commit 7c8893e596
9 changed files with 185 additions and 64 deletions

View File

@@ -22,15 +22,15 @@ use serde_with::serde_as;
use super::ConfigurationSection; use super::ConfigurationSection;
fn default_client_registration_endpoint() -> String { fn default_client_registration_endpoint() -> String {
"client_registration/allow".to_string() "client_registration/violation".to_string()
} }
fn default_login_endpoint() -> String { fn default_login_endpoint() -> String {
"login/allow".to_string() "login/violation".to_string()
} }
fn default_register_endpoint() -> String { fn default_register_endpoint() -> String {
"register/allow".to_string() "register/violation".to_string()
} }
/// Application secrets /// Application secrets

View File

@@ -119,8 +119,8 @@ pub(crate) async fn post(
} }
let mut policy = policy_factory.instantiate().await?; let mut policy = policy_factory.instantiate().await?;
let allowed = policy.evaluate_client_registration(&body).await?; let res = policy.evaluate_client_registration(&body).await?;
if !allowed { if !res.valid() {
return Err(RouteError::PolicyDenied); return Err(RouteError::PolicyDenied);
} }

View File

@@ -136,8 +136,24 @@ pub(crate) async fn post(
.evaluate_register(&form.username, &form.email) .evaluate_register(&form.username, &form.email)
.await?; .await?;
if !res { for violation in res.violations {
state.add_error_on_form(FormError::Policy); match violation.field.as_deref() {
Some("email") => state.add_error_on_field(
RegisterFormField::Email,
FieldError::Policy {
message: violation.msg,
},
),
Some("username") => state.add_error_on_field(
RegisterFormField::Username,
FieldError::Policy {
message: violation.msg,
},
),
_ => state.add_error_on_form(FormError::Policy {
message: violation.msg,
}),
}
} }
state state

View File

@@ -3,19 +3,21 @@ DOCKER := 0
ifeq ($(DOCKER), 0) ifeq ($(DOCKER), 0)
OPA := opa OPA := opa
OPA_RW := opa
else else
OPA := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0 OPA := docker run -v $(shell pwd):/policies:ro -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
OPA_RW := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
endif endif
policy.wasm: client_registration.rego login.rego register.rego policy.wasm: client_registration.rego login.rego register.rego
$(OPA) build -t wasm -e "client_registration/allow" -e "login/allow" -e "register/allow" $^ $(OPA_RW) build -t wasm -e "client_registration/violation" -e "login/violation" -e "register/violation" $^
tar xzf bundle.tar.gz /policy.wasm tar xzf bundle.tar.gz /policy.wasm
$(RM) bundle.tar.gz $(RM) bundle.tar.gz
touch $@ touch $@
.PHONY: fmt .PHONY: fmt
fmt: fmt:
$(OPA) fmt -w . $(OPA_RW) fmt -w .
.PHONY: test .PHONY: test
test: test:

View File

@@ -2,17 +2,54 @@ package client_registration
import future.keywords.in import future.keywords.in
default allow := false
allow {
count(violation) == 0
}
secure_url(x) { secure_url(x) {
is_string(x) is_string(x)
startswith(x, "https://") startswith(x, "https://")
} }
default allow := false violation[{"msg": "missing client_uri"}] {
not input.client_metadata.client_uri
allow { }
secure_url(input.client_metadata.client_uri)
secure_url(input.client_metadata.tos_uri) violation[{"msg": "invalid client_uri"}] {
secure_url(input.client_metadata.policy_uri) not secure_url(input.client_metadata.client_uri)
some redirect_uri in input.client_metadata.redirect_uris }
secure_url(redirect_uri)
violation[{"msg": "missing tos_uri"}] {
not input.client_metadata.tos_uri
}
violation[{"msg": "invalid tos_uri"}] {
not secure_url(input.client_metadata.tos_uri)
}
violation[{"msg": "missing policy_uri"}] {
not input.client_metadata.policy_uri
}
violation[{"msg": "invalid policy_uri"}] {
not secure_url(input.client_metadata.policy_uri)
}
violation[{"msg": "missing redirect_uris"}] {
not input.client_metadata.redirect_uris
}
violation[{"msg": "invalid redirect_uris"}] {
not is_array(input.client_metadata.redirect_uris)
}
violation[{"msg": "empty redirect_uris"}] {
count(input.client_metadata.redirect_uris) == 0
}
violation[{"msg": "invalid redirect_uri"}] {
some redirect_uri in input.client_metadata.redirect_uris
not secure_url(redirect_uri)
} }

View File

@@ -1,3 +1,3 @@
package login package login
allow := true violation := []

View File

@@ -127,8 +127,21 @@ impl PolicyFactory {
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct EvaluationResult { pub struct Violation {
result: bool, pub msg: String,
pub field: Option<String>,
}
#[derive(Deserialize)]
pub struct EvaluationResult {
#[serde(rename = "result")]
pub violations: Vec<Violation>,
}
impl EvaluationResult {
pub fn valid(&self) -> bool {
self.violations.is_empty()
}
} }
#[derive(Debug)] #[derive(Debug)]
@@ -145,7 +158,7 @@ impl Policy {
pub async fn evaluate_login( pub async fn evaluate_login(
&mut self, &mut self,
user: &mas_data_model::User<()>, user: &mas_data_model::User<()>,
) -> Result<bool, anyhow::Error> { ) -> Result<EvaluationResult, anyhow::Error> {
let user = serde_json::to_value(user)?; let user = serde_json::to_value(user)?;
let input = serde_json::json!({ "user": user }); let input = serde_json::json!({ "user": user });
@@ -154,7 +167,7 @@ impl Policy {
.evaluate(&mut self.store, &self.login_entrypoint, &input) .evaluate(&mut self.store, &self.login_entrypoint, &input)
.await?; .await?;
Ok(res.result) Ok(res)
} }
#[tracing::instrument] #[tracing::instrument]
@@ -162,7 +175,7 @@ impl Policy {
&mut self, &mut self,
username: &str, username: &str,
email: &str, email: &str,
) -> Result<bool, anyhow::Error> { ) -> Result<EvaluationResult, anyhow::Error> {
let input = serde_json::json!({ let input = serde_json::json!({
"user": { "user": {
"username": username, "username": username,
@@ -175,14 +188,14 @@ impl Policy {
.evaluate(&mut self.store, &self.register_entrypoint, &input) .evaluate(&mut self.store, &self.register_entrypoint, &input)
.await?; .await?;
Ok(res.result) Ok(res)
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn evaluate_client_registration( pub async fn evaluate_client_registration(
&mut self, &mut self,
client_metadata: &ClientMetadata, client_metadata: &ClientMetadata,
) -> Result<bool, anyhow::Error> { ) -> Result<EvaluationResult, anyhow::Error> {
let client_metadata = serde_json::to_value(client_metadata)?; let client_metadata = serde_json::to_value(client_metadata)?;
let input = serde_json::json!({ let input = serde_json::json!({
"client_metadata": client_metadata, "client_metadata": client_metadata,
@@ -197,6 +210,47 @@ impl Policy {
) )
.await?; .await?;
Ok(res.result) Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_register() {
let factory = PolicyFactory::load(
default_wasm_policy(),
serde_json::json!({
"allowed_domains": ["element.io", "*.element.io"],
"banned_domains": ["staging.element.io"],
}),
"login/violation".to_string(),
"register/violation".to_string(),
"client_registration/violation".to_string(),
)
.await
.unwrap();
let mut policy = factory.instantiate().await.unwrap();
let res = policy
.evaluate_register("hello", "hello@example.com")
.await
.unwrap();
assert!(!res.valid());
let res = policy
.evaluate_register("hello", "hello@foo.element.io")
.await
.unwrap();
assert!(res.valid());
let res = policy
.evaluate_register("hello", "hello@staging.element.io")
.await
.unwrap();
assert!(!res.valid());
} }
} }

View File

@@ -38,6 +38,12 @@ pub enum FieldError {
/// That value already exists /// That value already exists
Exists, Exists,
/// Denied by the policy
Policy {
/// Message for this policy violation
message: String,
},
} }
/// An error on the whole form /// An error on the whole form
@@ -54,7 +60,10 @@ pub enum FormError {
Internal, Internal,
/// Denied by the policy /// Denied by the policy
Policy, Policy {
/// Message for this policy violation
message: String,
},
} }
#[derive(Debug, Default, Serialize)] #[derive(Debug, Default, Serialize)]

View File

@@ -30,13 +30,14 @@ limitations under the License.
{% endif %} {% endif %}
<label class="flex flex-col block {{ class }}"> <label class="flex flex-col block {{ class }}">
<div <div class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">{{ label }}</div>
class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">
{{ label }}</div>
<input name="{{ name }}" <input name="{{ name }}"
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0" class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0"
type="{{ type }}" inputmode="{{ inputmode }}" {% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %} {% type="{{ type }}"
if state.value %} value="{{ state.value }}" {% endif %} /> inputmode="{{ inputmode }}"
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
{% if state.value %} value="{{ state.value }}" {% endif %}
/>
{% if state.errors is not empty %} {% if state.errors is not empty %}
{% for error in state.errors %} {% for error in state.errors %}
@@ -46,6 +47,8 @@ limitations under the License.
This field is required This field is required
{% elif error.kind == "exists" and name == "username" %} {% elif error.kind == "exists" and name == "username" %}
This username is already taken This username is already taken
{% elif error.kind == "policy" %}
Denied by policy: {{ error.message }}
{% else %} {% else %}
{{ error.kind }} {{ error.kind }}
{% endif %} {% endif %}