You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Switch the policies to a violation list based approach
This allows policies to give proper feedback on form fields
This commit is contained in:
@ -22,15 +22,15 @@ use serde_with::serde_as;
|
|||||||
use super::ConfigurationSection;
|
use super::ConfigurationSection;
|
||||||
|
|
||||||
fn default_client_registration_endpoint() -> String {
|
fn default_client_registration_endpoint() -> String {
|
||||||
"client_registration/allow".to_string()
|
"client_registration/violation".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_login_endpoint() -> String {
|
fn default_login_endpoint() -> String {
|
||||||
"login/allow".to_string()
|
"login/violation".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_register_endpoint() -> String {
|
fn default_register_endpoint() -> String {
|
||||||
"register/allow".to_string()
|
"register/violation".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Application secrets
|
/// Application secrets
|
||||||
|
@ -119,8 +119,8 @@ pub(crate) async fn post(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut policy = policy_factory.instantiate().await?;
|
let mut policy = policy_factory.instantiate().await?;
|
||||||
let allowed = policy.evaluate_client_registration(&body).await?;
|
let res = policy.evaluate_client_registration(&body).await?;
|
||||||
if !allowed {
|
if !res.valid() {
|
||||||
return Err(RouteError::PolicyDenied);
|
return Err(RouteError::PolicyDenied);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,8 +136,24 @@ pub(crate) async fn post(
|
|||||||
.evaluate_register(&form.username, &form.email)
|
.evaluate_register(&form.username, &form.email)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if !res {
|
for violation in res.violations {
|
||||||
state.add_error_on_form(FormError::Policy);
|
match violation.field.as_deref() {
|
||||||
|
Some("email") => state.add_error_on_field(
|
||||||
|
RegisterFormField::Email,
|
||||||
|
FieldError::Policy {
|
||||||
|
message: violation.msg,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Some("username") => state.add_error_on_field(
|
||||||
|
RegisterFormField::Username,
|
||||||
|
FieldError::Policy {
|
||||||
|
message: violation.msg,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
_ => state.add_error_on_form(FormError::Policy {
|
||||||
|
message: violation.msg,
|
||||||
|
}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
state
|
state
|
||||||
|
@ -3,19 +3,21 @@ DOCKER := 0
|
|||||||
|
|
||||||
ifeq ($(DOCKER), 0)
|
ifeq ($(DOCKER), 0)
|
||||||
OPA := opa
|
OPA := opa
|
||||||
|
OPA_RW := opa
|
||||||
else
|
else
|
||||||
OPA := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
|
OPA := docker run -v $(shell pwd):/policies:ro -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
|
||||||
|
OPA_RW := docker run -v $(shell pwd):/policies -w /policies --rm docker.io/openpolicyagent/opa:0.40.0
|
||||||
endif
|
endif
|
||||||
|
|
||||||
policy.wasm: client_registration.rego login.rego register.rego
|
policy.wasm: client_registration.rego login.rego register.rego
|
||||||
$(OPA) build -t wasm -e "client_registration/allow" -e "login/allow" -e "register/allow" $^
|
$(OPA_RW) build -t wasm -e "client_registration/violation" -e "login/violation" -e "register/violation" $^
|
||||||
tar xzf bundle.tar.gz /policy.wasm
|
tar xzf bundle.tar.gz /policy.wasm
|
||||||
$(RM) bundle.tar.gz
|
$(RM) bundle.tar.gz
|
||||||
touch $@
|
touch $@
|
||||||
|
|
||||||
.PHONY: fmt
|
.PHONY: fmt
|
||||||
fmt:
|
fmt:
|
||||||
$(OPA) fmt -w .
|
$(OPA_RW) fmt -w .
|
||||||
|
|
||||||
.PHONY: test
|
.PHONY: test
|
||||||
test:
|
test:
|
||||||
|
@ -2,17 +2,54 @@ package client_registration
|
|||||||
|
|
||||||
import future.keywords.in
|
import future.keywords.in
|
||||||
|
|
||||||
|
default allow := false
|
||||||
|
|
||||||
|
allow {
|
||||||
|
count(violation) == 0
|
||||||
|
}
|
||||||
|
|
||||||
secure_url(x) {
|
secure_url(x) {
|
||||||
is_string(x)
|
is_string(x)
|
||||||
startswith(x, "https://")
|
startswith(x, "https://")
|
||||||
}
|
}
|
||||||
|
|
||||||
default allow := false
|
violation[{"msg": "missing client_uri"}] {
|
||||||
|
not input.client_metadata.client_uri
|
||||||
allow {
|
}
|
||||||
secure_url(input.client_metadata.client_uri)
|
|
||||||
secure_url(input.client_metadata.tos_uri)
|
violation[{"msg": "invalid client_uri"}] {
|
||||||
secure_url(input.client_metadata.policy_uri)
|
not secure_url(input.client_metadata.client_uri)
|
||||||
some redirect_uri in input.client_metadata.redirect_uris
|
}
|
||||||
secure_url(redirect_uri)
|
|
||||||
|
violation[{"msg": "missing tos_uri"}] {
|
||||||
|
not input.client_metadata.tos_uri
|
||||||
|
}
|
||||||
|
|
||||||
|
violation[{"msg": "invalid tos_uri"}] {
|
||||||
|
not secure_url(input.client_metadata.tos_uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
violation[{"msg": "missing policy_uri"}] {
|
||||||
|
not input.client_metadata.policy_uri
|
||||||
|
}
|
||||||
|
|
||||||
|
violation[{"msg": "invalid policy_uri"}] {
|
||||||
|
not secure_url(input.client_metadata.policy_uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
violation[{"msg": "missing redirect_uris"}] {
|
||||||
|
not input.client_metadata.redirect_uris
|
||||||
|
}
|
||||||
|
|
||||||
|
violation[{"msg": "invalid redirect_uris"}] {
|
||||||
|
not is_array(input.client_metadata.redirect_uris)
|
||||||
|
}
|
||||||
|
|
||||||
|
violation[{"msg": "empty redirect_uris"}] {
|
||||||
|
count(input.client_metadata.redirect_uris) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
violation[{"msg": "invalid redirect_uri"}] {
|
||||||
|
some redirect_uri in input.client_metadata.redirect_uris
|
||||||
|
not secure_url(redirect_uri)
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
package login
|
package login
|
||||||
|
|
||||||
allow := true
|
violation := []
|
||||||
|
@ -127,8 +127,21 @@ impl PolicyFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct EvaluationResult {
|
pub struct Violation {
|
||||||
result: bool,
|
pub msg: String,
|
||||||
|
pub field: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct EvaluationResult {
|
||||||
|
#[serde(rename = "result")]
|
||||||
|
pub violations: Vec<Violation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EvaluationResult {
|
||||||
|
pub fn valid(&self) -> bool {
|
||||||
|
self.violations.is_empty()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -145,7 +158,7 @@ impl Policy {
|
|||||||
pub async fn evaluate_login(
|
pub async fn evaluate_login(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &mas_data_model::User<()>,
|
user: &mas_data_model::User<()>,
|
||||||
) -> Result<bool, anyhow::Error> {
|
) -> Result<EvaluationResult, anyhow::Error> {
|
||||||
let user = serde_json::to_value(user)?;
|
let user = serde_json::to_value(user)?;
|
||||||
let input = serde_json::json!({ "user": user });
|
let input = serde_json::json!({ "user": user });
|
||||||
|
|
||||||
@ -154,7 +167,7 @@ impl Policy {
|
|||||||
.evaluate(&mut self.store, &self.login_entrypoint, &input)
|
.evaluate(&mut self.store, &self.login_entrypoint, &input)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(res.result)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
@ -162,7 +175,7 @@ impl Policy {
|
|||||||
&mut self,
|
&mut self,
|
||||||
username: &str,
|
username: &str,
|
||||||
email: &str,
|
email: &str,
|
||||||
) -> Result<bool, anyhow::Error> {
|
) -> Result<EvaluationResult, anyhow::Error> {
|
||||||
let input = serde_json::json!({
|
let input = serde_json::json!({
|
||||||
"user": {
|
"user": {
|
||||||
"username": username,
|
"username": username,
|
||||||
@ -175,14 +188,14 @@ impl Policy {
|
|||||||
.evaluate(&mut self.store, &self.register_entrypoint, &input)
|
.evaluate(&mut self.store, &self.register_entrypoint, &input)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(res.result)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn evaluate_client_registration(
|
pub async fn evaluate_client_registration(
|
||||||
&mut self,
|
&mut self,
|
||||||
client_metadata: &ClientMetadata,
|
client_metadata: &ClientMetadata,
|
||||||
) -> Result<bool, anyhow::Error> {
|
) -> Result<EvaluationResult, anyhow::Error> {
|
||||||
let client_metadata = serde_json::to_value(client_metadata)?;
|
let client_metadata = serde_json::to_value(client_metadata)?;
|
||||||
let input = serde_json::json!({
|
let input = serde_json::json!({
|
||||||
"client_metadata": client_metadata,
|
"client_metadata": client_metadata,
|
||||||
@ -197,6 +210,47 @@ impl Policy {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(res.result)
|
Ok(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_register() {
|
||||||
|
let factory = PolicyFactory::load(
|
||||||
|
default_wasm_policy(),
|
||||||
|
serde_json::json!({
|
||||||
|
"allowed_domains": ["element.io", "*.element.io"],
|
||||||
|
"banned_domains": ["staging.element.io"],
|
||||||
|
}),
|
||||||
|
"login/violation".to_string(),
|
||||||
|
"register/violation".to_string(),
|
||||||
|
"client_registration/violation".to_string(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut policy = factory.instantiate().await.unwrap();
|
||||||
|
|
||||||
|
let res = policy
|
||||||
|
.evaluate_register("hello", "hello@example.com")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(!res.valid());
|
||||||
|
|
||||||
|
let res = policy
|
||||||
|
.evaluate_register("hello", "hello@foo.element.io")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(res.valid());
|
||||||
|
|
||||||
|
let res = policy
|
||||||
|
.evaluate_register("hello", "hello@staging.element.io")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(!res.valid());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,12 @@ pub enum FieldError {
|
|||||||
|
|
||||||
/// That value already exists
|
/// That value already exists
|
||||||
Exists,
|
Exists,
|
||||||
|
|
||||||
|
/// Denied by the policy
|
||||||
|
Policy {
|
||||||
|
/// Message for this policy violation
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An error on the whole form
|
/// An error on the whole form
|
||||||
@ -54,7 +60,10 @@ pub enum FormError {
|
|||||||
Internal,
|
Internal,
|
||||||
|
|
||||||
/// Denied by the policy
|
/// Denied by the policy
|
||||||
Policy,
|
Policy {
|
||||||
|
/// Message for this policy violation
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize)]
|
#[derive(Debug, Default, Serialize)]
|
||||||
|
@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
|
|||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
@ -15,43 +15,46 @@ limitations under the License.
|
|||||||
#}
|
#}
|
||||||
|
|
||||||
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text") %}
|
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text") %}
|
||||||
{% if not form_state %}
|
{% if not form_state %}
|
||||||
{% set form_state = dict(errors=[], fields=dict()) %}
|
{% set form_state = dict(errors=[], fields=dict()) %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %}
|
{% set state = form_state.fields[name] | default(value=dict(errors=[], value="")) %}
|
||||||
|
|
||||||
{% if state.errors is not empty %}
|
|
||||||
{% set border_color = "border-alert" %}
|
|
||||||
{% set text_color = "text-alert" %}
|
|
||||||
{% else %}
|
|
||||||
{% set border_color = "border-grey-50 dark:border-grey-450" %}
|
|
||||||
{% set text_color = "text-black-800 dark:text-grey-300" %}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
<label class="flex flex-col block {{ class }}">
|
|
||||||
<div
|
|
||||||
class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">
|
|
||||||
{{ label }}</div>
|
|
||||||
<input name="{{ name }}"
|
|
||||||
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0"
|
|
||||||
type="{{ type }}" inputmode="{{ inputmode }}" {% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %} {%
|
|
||||||
if state.value %} value="{{ state.value }}" {% endif %} />
|
|
||||||
|
|
||||||
{% if state.errors is not empty %}
|
{% if state.errors is not empty %}
|
||||||
{% for error in state.errors %}
|
{% set border_color = "border-alert" %}
|
||||||
{% if error.kind != "unspecified" %}
|
{% set text_color = "text-alert" %}
|
||||||
<div class="mx-4 text-sm text-alert">
|
{% else %}
|
||||||
{% if error.kind == "required" %}
|
{% set border_color = "border-grey-50 dark:border-grey-450" %}
|
||||||
This field is required
|
{% set text_color = "text-black-800 dark:text-grey-300" %}
|
||||||
{% elif error.kind == "exists" and name == "username" %}
|
{% endif %}
|
||||||
This username is already taken
|
|
||||||
{% else %}
|
<label class="flex flex-col block {{ class }}">
|
||||||
{{ error.kind }}
|
<div class="mx-2 -mb-3 -mt-2 leading-5 px-1 z-10 self-start bg-white dark:bg-black-900 border-white border-1 dark:border-2 dark:border-black-900 rounded-full text-sm {{ text_color }}">{{ label }}</div>
|
||||||
|
<input name="{{ name }}"
|
||||||
|
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-1 dark:border-2 focus:border-accent focus:ring-0 focus:outline-0"
|
||||||
|
type="{{ type }}"
|
||||||
|
inputmode="{{ inputmode }}"
|
||||||
|
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
|
||||||
|
{% if state.value %} value="{{ state.value }}" {% endif %}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{% if state.errors is not empty %}
|
||||||
|
{% for error in state.errors %}
|
||||||
|
{% if error.kind != "unspecified" %}
|
||||||
|
<div class="mx-4 text-sm text-alert">
|
||||||
|
{% if error.kind == "required" %}
|
||||||
|
This field is required
|
||||||
|
{% elif error.kind == "exists" and name == "username" %}
|
||||||
|
This username is already taken
|
||||||
|
{% elif error.kind == "policy" %}
|
||||||
|
Denied by policy: {{ error.message }}
|
||||||
|
{% else %}
|
||||||
|
{{ error.kind }}
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</div>
|
</label>
|
||||||
{% endif %}
|
|
||||||
{% endfor %}
|
|
||||||
{% endif %}
|
|
||||||
</label>
|
|
||||||
{% endmacro input %}
|
{% endmacro input %}
|
||||||
|
Reference in New Issue
Block a user