You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Switch the policies to a violation list based approach
This allows policies to give proper feedback on form fields
This commit is contained in:
@ -22,15 +22,15 @@ use serde_with::serde_as;
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_client_registration_endpoint() -> String {
|
||||
"client_registration/allow".to_string()
|
||||
"client_registration/violation".to_string()
|
||||
}
|
||||
|
||||
fn default_login_endpoint() -> String {
|
||||
"login/allow".to_string()
|
||||
"login/violation".to_string()
|
||||
}
|
||||
|
||||
fn default_register_endpoint() -> String {
|
||||
"register/allow".to_string()
|
||||
"register/violation".to_string()
|
||||
}
|
||||
|
||||
/// Application secrets
|
||||
|
@ -119,8 +119,8 @@ pub(crate) async fn post(
|
||||
}
|
||||
|
||||
let mut policy = policy_factory.instantiate().await?;
|
||||
let allowed = policy.evaluate_client_registration(&body).await?;
|
||||
if !allowed {
|
||||
let res = policy.evaluate_client_registration(&body).await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::PolicyDenied);
|
||||
}
|
||||
|
||||
|
@ -136,8 +136,24 @@ pub(crate) async fn post(
|
||||
.evaluate_register(&form.username, &form.email)
|
||||
.await?;
|
||||
|
||||
if !res {
|
||||
state.add_error_on_form(FormError::Policy);
|
||||
for violation in res.violations {
|
||||
match violation.field.as_deref() {
|
||||
Some("email") => state.add_error_on_field(
|
||||
RegisterFormField::Email,
|
||||
FieldError::Policy {
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
Some("username") => state.add_error_on_field(
|
||||
RegisterFormField::Username,
|
||||
FieldError::Policy {
|
||||
message: violation.msg,
|
||||
},
|
||||
),
|
||||
_ => state.add_error_on_form(FormError::Policy {
|
||||
message: violation.msg,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
state
|
||||
|
@ -3,19 +3,21 @@ DOCKER := 0
|
||||
|
||||
ifeq ($(DOCKER), 0)
|
||||
OPA := opa
|
||||
OPA_RW := opa
|
||||
else
|
||||
OPA := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
|
||||
OPA := docker run -v $(shell pwd):/policies:ro -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
|
||||
OPA_RW := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
|
||||
endif
|
||||
|
||||
policy.wasm: client_registration.rego login.rego register.rego
|
||||
$(OPA) build -t wasm -e "client_registration/allow" -e "login/allow" -e "register/allow" $^
|
||||
$(OPA_RW) build -t wasm -e "client_registration/violation" -e "login/violation" -e "register/violation" $^
|
||||
tar xzf bundle.tar.gz /policy.wasm
|
||||
$(RM) bundle.tar.gz
|
||||
touch $@
|
||||
|
||||
.PHONY: fmt
|
||||
fmt:
|
||||
$(OPA) fmt -w .
|
||||
$(OPA_RW) fmt -w .
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
|
@ -2,17 +2,54 @@ package client_registration
|
||||
|
||||
import future.keywords.in
|
||||
|
||||
default allow := false
|
||||
|
||||
allow {
|
||||
count(violation) == 0
|
||||
}
|
||||
|
||||
secure_url(x) {
|
||||
is_string(x)
|
||||
startswith(x, "https://")
|
||||
}
|
||||
|
||||
default allow := false
|
||||
|
||||
allow {
|
||||
secure_url(input.client_metadata.client_uri)
|
||||
secure_url(input.client_metadata.tos_uri)
|
||||
secure_url(input.client_metadata.policy_uri)
|
||||
some redirect_uri in input.client_metadata.redirect_uris
|
||||
secure_url(redirect_uri)
|
||||
violation[{"msg": "missing client_uri"}] {
|
||||
not input.client_metadata.client_uri
|
||||
}
|
||||
|
||||
violation[{"msg": "invalid client_uri"}] {
|
||||
not secure_url(input.client_metadata.client_uri)
|
||||
}
|
||||
|
||||
violation[{"msg": "missing tos_uri"}] {
|
||||
not input.client_metadata.tos_uri
|
||||
}
|
||||
|
||||
violation[{"msg": "invalid tos_uri"}] {
|
||||
not secure_url(input.client_metadata.tos_uri)
|
||||
}
|
||||
|
||||
violation[{"msg": "missing policy_uri"}] {
|
||||
not input.client_metadata.policy_uri
|
||||
}
|
||||
|
||||
violation[{"msg": "invalid policy_uri"}] {
|
||||
not secure_url(input.client_metadata.policy_uri)
|
||||
}
|
||||
|
||||
violation[{"msg": "missing redirect_uris"}] {
|
||||
not input.client_metadata.redirect_uris
|
||||
}
|
||||
|
||||
violation[{"msg": "invalid redirect_uris"}] {
|
||||
not is_array(input.client_metadata.redirect_uris)
|
||||
}
|
||||
|
||||
violation[{"msg": "empty redirect_uris"}] {
|
||||
count(input.client_metadata.redirect_uris) == 0
|
||||
}
|
||||
|
||||
violation[{"msg": "invalid redirect_uri"}] {
|
||||
some redirect_uri in input.client_metadata.redirect_uris
|
||||
not secure_url(redirect_uri)
|
||||
}
|
||||
|
@ -1,3 +1,3 @@
|
||||
package login
|
||||
|
||||
allow := true
|
||||
violation := []
|
||||
|
@ -127,8 +127,21 @@ impl PolicyFactory {
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EvaluationResult {
|
||||
result: bool,
|
||||
pub struct Violation {
|
||||
pub msg: String,
|
||||
pub field: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EvaluationResult {
|
||||
#[serde(rename = "result")]
|
||||
pub violations: Vec<Violation>,
|
||||
}
|
||||
|
||||
impl EvaluationResult {
|
||||
pub fn valid(&self) -> bool {
|
||||
self.violations.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -145,7 +158,7 @@ impl Policy {
|
||||
pub async fn evaluate_login(
|
||||
&mut self,
|
||||
user: &mas_data_model::User<()>,
|
||||
) -> Result<bool, anyhow::Error> {
|
||||
) -> Result<EvaluationResult, anyhow::Error> {
|
||||
let user = serde_json::to_value(user)?;
|
||||
let input = serde_json::json!({ "user": user });
|
||||
|
||||
@ -154,7 +167,7 @@ impl Policy {
|
||||
.evaluate(&mut self.store, &self.login_entrypoint, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res.result)
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
@ -162,7 +175,7 @@ impl Policy {
|
||||
&mut self,
|
||||
username: &str,
|
||||
email: &str,
|
||||
) -> Result<bool, anyhow::Error> {
|
||||
) -> Result<EvaluationResult, anyhow::Error> {
|
||||
let input = serde_json::json!({
|
||||
"user": {
|
||||
"username": username,
|
||||
@ -175,14 +188,14 @@ impl Policy {
|
||||
.evaluate(&mut self.store, &self.register_entrypoint, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res.result)
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn evaluate_client_registration(
|
||||
&mut self,
|
||||
client_metadata: &ClientMetadata,
|
||||
) -> Result<bool, anyhow::Error> {
|
||||
) -> Result<EvaluationResult, anyhow::Error> {
|
||||
let client_metadata = serde_json::to_value(client_metadata)?;
|
||||
let input = serde_json::json!({
|
||||
"client_metadata": client_metadata,
|
||||
@ -197,6 +210,47 @@ impl Policy {
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(res.result)
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register() {
|
||||
let factory = PolicyFactory::load(
|
||||
default_wasm_policy(),
|
||||
serde_json::json!({
|
||||
"allowed_domains": ["element.io", "*.element.io"],
|
||||
"banned_domains": ["staging.element.io"],
|
||||
}),
|
||||
"login/violation".to_string(),
|
||||
"register/violation".to_string(),
|
||||
"client_registration/violation".to_string(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
let res = policy
|
||||
.evaluate_register("hello", "hello@example.com")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.valid());
|
||||
|
||||
let res = policy
|
||||
.evaluate_register("hello", "hello@foo.element.io")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(res.valid());
|
||||
|
||||
let res = policy
|
||||
.evaluate_register("hello", "hello@staging.element.io")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!res.valid());
|
||||
}
|
||||
}
|
||||
|
@ -38,6 +38,12 @@ pub enum FieldError {
|
||||
|
||||
/// That value already exists
|
||||
Exists,
|
||||
|
||||
/// Denied by the policy
|
||||
Policy {
|
||||
/// Message for this policy violation
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// An error on the whole form
|
||||
@ -54,7 +60,10 @@ pub enum FormError {
|
||||
Internal,
|
||||
|
||||
/// Denied by the policy
|
||||
Policy,
|
||||
Policy {
|
||||
/// Message for this policy violation
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
|
@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@ -15,28 +15,29 @@ limitations under the License.
|
||||
#}
|
||||
|
||||
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text") %}
|
||||
{% if not form_state %}
|
||||
{% set form_state = dict(errors=[], fields=dict()) %}
|
||||
{% endif %}
|
||||
{% if not form_state %}
|
||||
{% set form_state = dict(errors=[], fields=dict()) %}
|
||||
{% endif %}
|
||||
|
||||
{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %}
|
||||
{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %}
|
||||
|
||||
{% if state.errors is not empty %}
|
||||
{% set border_color = "border-alert" %}
|
||||
{% set text_color = "text-alert" %}
|
||||
{% else %}
|
||||
{% set border_color = "border-grey-50 dark:border-grey-450" %}
|
||||
{% set text_color = "text-black-800 dark:text-grey-300" %}
|
||||
{% endif %}
|
||||
{% if state.errors is not empty %}
|
||||
{% set border_color = "border-alert" %}
|
||||
{% set text_color = "text-alert" %}
|
||||
{% else %}
|
||||
{% set border_color = "border-grey-50 dark:border-grey-450" %}
|
||||
{% set text_color = "text-black-800 dark:text-grey-300" %}
|
||||
{% endif %}
|
||||
|
||||
<label class="flex flex-col block {{ class }}">
|
||||
<div
|
||||
class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">
|
||||
{{ label }}</div>
|
||||
<label class="flex flex-col block {{ class }}">
|
||||
<div class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">{{ label }}</div>
|
||||
<input name="{{ name }}"
|
||||
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0"
|
||||
type="{{ type }}" inputmode="{{ inputmode }}" {% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %} {%
|
||||
if state.value %} value="{{ state.value }}" {% endif %} />
|
||||
type="{{ type }}"
|
||||
inputmode="{{ inputmode }}"
|
||||
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
|
||||
{% if state.value %} value="{{ state.value }}" {% endif %}
|
||||
/>
|
||||
|
||||
{% if state.errors is not empty %}
|
||||
{% for error in state.errors %}
|
||||
@ -46,6 +47,8 @@ limitations under the License.
|
||||
This field is required
|
||||
{% elif error.kind == "exists" and name == "username" %}
|
||||
This username is already taken
|
||||
{% elif error.kind == "policy" %}
|
||||
Denied by policy: {{ error.message }}
|
||||
{% else %}
|
||||
{{ error.kind }}
|
||||
{% endif %}
|
||||
@ -53,5 +56,5 @@ limitations under the License.
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</label>
|
||||
</label>
|
||||
{% endmacro input %}
|
||||
|
Reference in New Issue
Block a user