1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Switch the policies to a violation list based approach

This allows policies to give proper feedback on form fields
This commit is contained in:
Quentin Gliech
2022-06-03 11:59:56 +02:00
parent 88c2625dc0
commit 7c8893e596
9 changed files with 185 additions and 64 deletions

View File

@ -22,15 +22,15 @@ use serde_with::serde_as;
use super::ConfigurationSection; use super::ConfigurationSection;
fn default_client_registration_endpoint() -> String { fn default_client_registration_endpoint() -> String {
"client_registration/allow".to_string() "client_registration/violation".to_string()
} }
fn default_login_endpoint() -> String { fn default_login_endpoint() -> String {
"login/allow".to_string() "login/violation".to_string()
} }
fn default_register_endpoint() -> String { fn default_register_endpoint() -> String {
"register/allow".to_string() "register/violation".to_string()
} }
/// Application secrets /// Application secrets

View File

@ -119,8 +119,8 @@ pub(crate) async fn post(
} }
let mut policy = policy_factory.instantiate().await?; let mut policy = policy_factory.instantiate().await?;
let allowed = policy.evaluate_client_registration(&body).await?; let res = policy.evaluate_client_registration(&body).await?;
if !allowed { if !res.valid() {
return Err(RouteError::PolicyDenied); return Err(RouteError::PolicyDenied);
} }

View File

@ -136,8 +136,24 @@ pub(crate) async fn post(
.evaluate_register(&form.username, &form.email) .evaluate_register(&form.username, &form.email)
.await?; .await?;
if !res { for violation in res.violations {
state.add_error_on_form(FormError::Policy); match violation.field.as_deref() {
Some("email") => state.add_error_on_field(
RegisterFormField::Email,
FieldError::Policy {
message: violation.msg,
},
),
Some("username") => state.add_error_on_field(
RegisterFormField::Username,
FieldError::Policy {
message: violation.msg,
},
),
_ => state.add_error_on_form(FormError::Policy {
message: violation.msg,
}),
}
} }
state state

View File

@ -3,19 +3,21 @@ DOCKER := 0
ifeq ($(DOCKER), 0) ifeq ($(DOCKER), 0)
OPA := opa OPA := opa
OPA_RW := opa
else else
OPA := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0 OPA := docker run -v $(shell pwd):/policies:ro -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
OPA_RW := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
endif endif
policy.wasm: client_registration.rego login.rego register.rego policy.wasm: client_registration.rego login.rego register.rego
$(OPA) build -t wasm -e "client_registration/allow" -e "login/allow" -e "register/allow" $^ $(OPA_RW) build -t wasm -e "client_registration/violation" -e "login/violation" -e "register/violation" $^
tar xzf bundle.tar.gz /policy.wasm tar xzf bundle.tar.gz /policy.wasm
$(RM) bundle.tar.gz $(RM) bundle.tar.gz
touch $@ touch $@
.PHONY: fmt .PHONY: fmt
fmt: fmt:
$(OPA) fmt -w . $(OPA_RW) fmt -w .
.PHONY: test .PHONY: test
test: test:

View File

@ -2,17 +2,54 @@ package client_registration
import future.keywords.in import future.keywords.in
default allow := false
allow {
count(violation) == 0
}
secure_url(x) { secure_url(x) {
is_string(x) is_string(x)
startswith(x, "https://") startswith(x, "https://")
} }
default allow := false violation[{"msg": "missing client_uri"}] {
not input.client_metadata.client_uri
allow { }
secure_url(input.client_metadata.client_uri)
secure_url(input.client_metadata.tos_uri) violation[{"msg": "invalid client_uri"}] {
secure_url(input.client_metadata.policy_uri) not secure_url(input.client_metadata.client_uri)
some redirect_uri in input.client_metadata.redirect_uris }
secure_url(redirect_uri)
violation[{"msg": "missing tos_uri"}] {
not input.client_metadata.tos_uri
}
violation[{"msg": "invalid tos_uri"}] {
not secure_url(input.client_metadata.tos_uri)
}
violation[{"msg": "missing policy_uri"}] {
not input.client_metadata.policy_uri
}
violation[{"msg": "invalid policy_uri"}] {
not secure_url(input.client_metadata.policy_uri)
}
violation[{"msg": "missing redirect_uris"}] {
not input.client_metadata.redirect_uris
}
violation[{"msg": "invalid redirect_uris"}] {
not is_array(input.client_metadata.redirect_uris)
}
violation[{"msg": "empty redirect_uris"}] {
count(input.client_metadata.redirect_uris) == 0
}
violation[{"msg": "invalid redirect_uri"}] {
some redirect_uri in input.client_metadata.redirect_uris
not secure_url(redirect_uri)
} }

View File

@ -1,3 +1,3 @@
package login package login
allow := true violation := []

View File

@ -127,8 +127,21 @@ impl PolicyFactory {
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct EvaluationResult { pub struct Violation {
result: bool, pub msg: String,
pub field: Option<String>,
}
#[derive(Deserialize)]
pub struct EvaluationResult {
#[serde(rename = "result")]
pub violations: Vec<Violation>,
}
impl EvaluationResult {
pub fn valid(&self) -> bool {
self.violations.is_empty()
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -145,7 +158,7 @@ impl Policy {
pub async fn evaluate_login( pub async fn evaluate_login(
&mut self, &mut self,
user: &mas_data_model::User<()>, user: &mas_data_model::User<()>,
) -> Result<bool, anyhow::Error> { ) -> Result<EvaluationResult, anyhow::Error> {
let user = serde_json::to_value(user)?; let user = serde_json::to_value(user)?;
let input = serde_json::json!({ "user": user }); let input = serde_json::json!({ "user": user });
@ -154,7 +167,7 @@ impl Policy {
.evaluate(&mut self.store, &self.login_entrypoint, &input) .evaluate(&mut self.store, &self.login_entrypoint, &input)
.await?; .await?;
Ok(res.result) Ok(res)
} }
#[tracing::instrument] #[tracing::instrument]
@ -162,7 +175,7 @@ impl Policy {
&mut self, &mut self,
username: &str, username: &str,
email: &str, email: &str,
) -> Result<bool, anyhow::Error> { ) -> Result<EvaluationResult, anyhow::Error> {
let input = serde_json::json!({ let input = serde_json::json!({
"user": { "user": {
"username": username, "username": username,
@ -175,14 +188,14 @@ impl Policy {
.evaluate(&mut self.store, &self.register_entrypoint, &input) .evaluate(&mut self.store, &self.register_entrypoint, &input)
.await?; .await?;
Ok(res.result) Ok(res)
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn evaluate_client_registration( pub async fn evaluate_client_registration(
&mut self, &mut self,
client_metadata: &ClientMetadata, client_metadata: &ClientMetadata,
) -> Result<bool, anyhow::Error> { ) -> Result<EvaluationResult, anyhow::Error> {
let client_metadata = serde_json::to_value(client_metadata)?; let client_metadata = serde_json::to_value(client_metadata)?;
let input = serde_json::json!({ let input = serde_json::json!({
"client_metadata": client_metadata, "client_metadata": client_metadata,
@ -197,6 +210,47 @@ impl Policy {
) )
.await?; .await?;
Ok(res.result) Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_register() {
let factory = PolicyFactory::load(
default_wasm_policy(),
serde_json::json!({
"allowed_domains": ["element.io", "*.element.io"],
"banned_domains": ["staging.element.io"],
}),
"login/violation".to_string(),
"register/violation".to_string(),
"client_registration/violation".to_string(),
)
.await
.unwrap();
let mut policy = factory.instantiate().await.unwrap();
let res = policy
.evaluate_register("hello", "hello@example.com")
.await
.unwrap();
assert!(!res.valid());
let res = policy
.evaluate_register("hello", "hello@foo.element.io")
.await
.unwrap();
assert!(res.valid());
let res = policy
.evaluate_register("hello", "hello@staging.element.io")
.await
.unwrap();
assert!(!res.valid());
} }
} }

View File

@ -38,6 +38,12 @@ pub enum FieldError {
/// That value already exists /// That value already exists
Exists, Exists,
/// Denied by the policy
Policy {
/// Message for this policy violation
message: String,
},
} }
/// An error on the whole form /// An error on the whole form
@ -54,7 +60,10 @@ pub enum FormError {
Internal, Internal,
/// Denied by the policy /// Denied by the policy
Policy, Policy {
/// Message for this policy violation
message: String,
},
} }
#[derive(Debug, Default, Serialize)] #[derive(Debug, Default, Serialize)]

View File

@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
@ -15,43 +15,46 @@ limitations under the License.
#} #}
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text") %} {% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text") %}
{% if not form_state %} {% if not form_state %}
{% set form_state = dict(errors=[], fields=dict()) %} {% set form_state = dict(errors=[], fields=dict()) %}
{% endif %} {% endif %}
{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %} {% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %}
{% if state.errors is not empty %}
{% set border_color = "border-alert" %}
{% set text_color = "text-alert" %}
{% else %}
{% set border_color = "border-grey-50 dark:border-grey-450" %}
{% set text_color = "text-black-800 dark:text-grey-300" %}
{% endif %}
<label class="flex flex-col block {{ class }}">
<div
class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">
{{ label }}</div>
<input name="{{ name }}"
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0"
type="{{ type }}" inputmode="{{ inputmode }}" {% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %} {%
if state.value %} value="{{ state.value }}" {% endif %} />
{% if state.errors is not empty %} {% if state.errors is not empty %}
{% for error in state.errors %} {% set border_color = "border-alert" %}
{% if error.kind != "unspecified" %} {% set text_color = "text-alert" %}
<div class="mx-4 text-sm text-alert"> {% else %}
{% if error.kind == "required" %} {% set border_color = "border-grey-50 dark:border-grey-450" %}
This field is required {% set text_color = "text-black-800 dark:text-grey-300" %}
{% elif error.kind == "exists" and name == "username" %} {% endif %}
This username is already taken
{% else %} <label class="flex flex-col block {{ class }}">
{{ error.kind }} <div class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">{{ label }}</div>
<input name="{{ name }}"
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0"
type="{{ type }}"
inputmode="{{ inputmode }}"
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
{% if state.value %} value="{{ state.value }}" {% endif %}
/>
{% if state.errors is not empty %}
{% for error in state.errors %}
{% if error.kind != "unspecified" %}
<div class="mx-4 text-sm text-alert">
{% if error.kind == "required" %}
This field is required
{% elif error.kind == "exists" and name == "username" %}
This username is already taken
{% elif error.kind == "policy" %}
Denied by policy: {{ error.message }}
{% else %}
{{ error.kind }}
{% endif %}
</div>
{% endif %}
{% endfor %}
{% endif %} {% endif %}
</div> </label>
{% endif %}
{% endfor %}
{% endif %}
</label>
{% endmacro input %} {% endmacro input %}