diff --git a/Cargo.lock b/Cargo.lock index 347eb3be..a6533dce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3047,6 +3047,7 @@ dependencies = [ "mas-data-model", "oauth2-types", "opa-wasm", + "schemars", "serde", "serde_json", "thiserror", diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 5eb934ba..2ebe2efe 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -97,12 +97,18 @@ pub async fn policy_factory_from_config( .await .context("failed to open OPA WASM policy file")?; + let entrypoints = mas_policy::Entrypoints { + register: config.register_entrypoint.clone(), + client_registration: config.client_registration_entrypoint.clone(), + authorization_grant: config.authorization_grant_entrypoint.clone(), + email: config.email_entrypoint.clone(), + password: config.password_entrypoint.clone(), + }; + PolicyFactory::load( policy_file, config.data.clone().unwrap_or_default(), - config.register_entrypoint.clone(), - config.client_registration_entrypoint.clone(), - config.authorization_grant_entrypoint.clone(), + entrypoints, ) .await .context("failed to load the policy") diff --git a/crates/config/src/sections/policy.rs b/crates/config/src/sections/policy.rs index 9317cfb9..b3e14954 100644 --- a/crates/config/src/sections/policy.rs +++ b/crates/config/src/sections/policy.rs @@ -48,6 +48,14 @@ fn default_authorization_grant_endpoint() -> String { "authorization_grant/violation".to_owned() } +fn default_password_endpoint() -> String { + "password/violation".to_owned() +} + +fn default_email_endpoint() -> String { + "email/violation".to_owned() +} + /// Application secrets #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] @@ -69,6 +77,14 @@ pub struct PolicyConfig { #[serde(default = "default_authorization_grant_endpoint")] pub authorization_grant_entrypoint: String, + /// Entrypoint to use when changing password + #[serde(default = "default_password_endpoint")] + pub password_entrypoint: String, + + /// Entrypoint to use when adding an email address + #[serde(default = "default_email_endpoint")] + pub email_entrypoint: String, + /// Arbitrary data to pass to the policy #[serde(default)] pub data: Option, @@ -81,6 +97,8 @@ impl Default for PolicyConfig { client_registration_entrypoint: default_client_registration_endpoint(), register_entrypoint: default_register_endpoint(), authorization_grant_entrypoint: default_authorization_grant_endpoint(), + password_entrypoint: default_password_endpoint(), + email_entrypoint: default_email_endpoint(), data: None, } } diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 687c1c64..6c5b7b7f 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -76,7 +76,7 @@ impl IntoResponse for RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(super::callback::CallbackDestinationError); @@ -187,7 +187,7 @@ pub enum GrantCompletionError { impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); -impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); +impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError); diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index ecc3af3d..8fec59fe 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -94,7 +94,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); #[derive(Deserialize)] diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 916c3681..85acb82f 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -61,7 +61,7 @@ pub enum RouteError { impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl IntoResponse for RouteError { diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index e7859fb1..d2f2c5e3 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -49,7 +49,7 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_keystore::aead::Error); diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index c06a8ba5..ca3a581f 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -117,14 +117,15 @@ impl TestState { let file = tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?; - let policy_factory = PolicyFactory::load( - file, - serde_json::json!({}), - "register/violation".to_owned(), - "client_registration/violation".to_owned(), - "authorization_grant/violation".to_owned(), - ) - .await?; + let entrypoints = mas_policy::Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + password: "password/violation".to_owned(), + }; + + let policy_factory = PolicyFactory::load(file, serde_json::json!({}), entrypoints).await?; let homeserver_connection = MockHomeserverConnection::new("example.com"); diff --git a/crates/policy/Cargo.toml b/crates/policy/Cargo.toml index 3c779e3c..5f25991f 100644 --- a/crates/policy/Cargo.toml +++ b/crates/policy/Cargo.toml @@ -10,8 +10,9 @@ anyhow.workspace = true opa-wasm = { git = "https://github.com/matrix-org/rust-opa-wasm.git" } serde.workspace = true serde_json.workspace = true +schemars = {version = "0.8.1", optional = true } thiserror.workspace = true -tokio = { version = "1.32.0", features = ["io-util"] } +tokio = { version = "1.32.0", features = ["io-util", "rt"] } tracing.workspace = true wasmtime = { version = "12.0.1", default-features = false, features = ["async", "cranelift"] } @@ -23,3 +24,8 @@ tokio = { version = "1.32.0", features = ["fs", "rt", "macros"] } [features] cache = ["wasmtime/cache"] +jsonschema = ["dep:schemars"] + +[[bin]] +name = "schema" +required-features = ["jsonschema"] \ No newline at end of file diff --git a/crates/policy/src/bin/schema.rs b/crates/policy/src/bin/schema.rs new file mode 100644 index 00000000..7742a28a --- /dev/null +++ b/crates/policy/src/bin/schema.rs @@ -0,0 +1,55 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::{Path, PathBuf}; + +use mas_policy::model::{ + AuthorizationGrantInput, ClientRegistrationInput, EmailInput, PasswordInput, RegisterInput, +}; +use schemars::{gen::SchemaSettings, JsonSchema}; + +fn write_schema(out_dir: Option<&Path>, file: &str) { + let mut writer: Box = match out_dir { + Some(out_dir) => { + let path = out_dir.join(file); + eprintln!("Writing to {path:?}"); + let file = std::fs::File::create(path).expect("Failed to create file"); + Box::new(std::io::BufWriter::new(file)) + } + None => { + eprintln!("--- {file} ---"); + Box::new(std::io::stdout()) + } + }; + + let settings = SchemaSettings::draft07().with(|s| { + s.option_nullable = false; + s.option_add_null_type = false; + }); + let generator = settings.into_generator(); + let schema = generator.into_root_schema_for::(); + serde_json::to_writer_pretty(&mut writer, &schema).expect("Failed to serialize schema"); + writer.flush().expect("Failed to flush writer"); +} + +fn main() { + let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok(); + let output_root = output_root.as_deref(); + + write_schema::(output_root, "register_input.json"); + write_schema::(output_root, "client_registration_input.json"); + write_schema::(output_root, "authorization_grant_input.json"); + write_schema::(output_root, "email_input.json"); + write_schema::(output_root, "password_input.json"); +} diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 665afe5f..a9d44c48 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,14 +17,20 @@ #![warn(clippy::pedantic)] #![allow(clippy::missing_errors_doc)] +pub mod model; + use mas_data_model::{AuthorizationGrant, Client, User}; use oauth2_types::registration::VerifiedClientMetadata; use opa_wasm::Runtime; -use serde::Deserialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; use wasmtime::{Config, Engine, Module, Store}; +use self::model::{ + AuthorizationGrantInput, ClientRegistrationInput, EmailInput, PasswordInput, RegisterInput, +}; +pub use self::model::{EvaluationResult, Violation}; + #[derive(Debug, Error)] pub enum LoadError { #[error("failed to read module")] @@ -40,7 +46,7 @@ pub enum LoadError { Compilation(#[source] anyhow::Error), #[error("failed to instantiate a test instance")] - Instantiate(#[source] InstanciateError), + Instantiate(#[source] InstantiateError), #[cfg(feature = "cache")] #[error("could not load wasmtime cache configuration")] @@ -48,7 +54,7 @@ pub enum LoadError { } #[derive(Debug, Error)] -pub enum InstanciateError { +pub enum InstantiateError { #[error("failed to create WASM runtime")] Runtime(#[source] anyhow::Error), @@ -59,13 +65,33 @@ pub enum InstanciateError { LoadData(#[source] anyhow::Error), } +/// Holds the entrypoint of each policy +#[derive(Debug, Clone)] +pub struct Entrypoints { + pub register: String, + pub client_registration: String, + pub authorization_grant: String, + pub email: String, + pub password: String, +} + +impl Entrypoints { + fn all(&self) -> [&str; 5] { + [ + self.register.as_str(), + self.client_registration.as_str(), + self.authorization_grant.as_str(), + self.email.as_str(), + self.password.as_str(), + ] + } +} + pub struct PolicyFactory { engine: Engine, module: Module, data: serde_json::Value, - register_entrypoint: String, - client_registration_entrypoint: String, - authorization_grant_endpoint: String, + entrypoints: Entrypoints, } impl PolicyFactory { @@ -73,9 +99,7 @@ impl PolicyFactory { pub async fn load( mut source: impl AsyncRead + std::marker::Unpin, data: serde_json::Value, - register_entrypoint: String, - client_registration_entrypoint: String, - authorization_grant_endpoint: String, + entrypoints: Entrypoints, ) -> Result { let mut config = Config::default(); config.async_support(true); @@ -103,9 +127,7 @@ impl PolicyFactory { engine, module, data, - register_entrypoint, - client_registration_entrypoint, - authorization_grant_endpoint, + entrypoints, }; // Try to instantiate @@ -118,22 +140,18 @@ impl PolicyFactory { } #[tracing::instrument(name = "policy.instantiate", skip_all, err)] - pub async fn instantiate(&self) -> Result { + pub async fn instantiate(&self) -> Result { let mut store = Store::new(&self.engine, ()); let runtime = Runtime::new(&mut store, &self.module) .await - .map_err(InstanciateError::Runtime)?; + .map_err(InstantiateError::Runtime)?; // Check that we have the required entrypoints - let entrypoints = runtime.entrypoints(); + let policy_entrypoints = runtime.entrypoints(); - for e in [ - self.register_entrypoint.as_str(), - self.client_registration_entrypoint.as_str(), - self.authorization_grant_endpoint.as_str(), - ] { - if !entrypoints.contains(e) { - return Err(InstanciateError::MissingEntrypoint { + for e in self.entrypoints.all() { + if !policy_entrypoints.contains(e) { + return Err(InstantiateError::MissingEntrypoint { entrypoint: e.to_owned(), }); } @@ -142,43 +160,20 @@ impl PolicyFactory { let instance = runtime .with_data(&mut store, &self.data) .await - .map_err(InstanciateError::LoadData)?; + .map_err(InstantiateError::LoadData)?; Ok(Policy { store, instance, - register_entrypoint: self.register_entrypoint.clone(), - client_registration_entrypoint: self.client_registration_entrypoint.clone(), - authorization_grant_endpoint: self.authorization_grant_endpoint.clone(), + entrypoints: self.entrypoints.clone(), }) } } -#[derive(Deserialize, Debug)] -pub struct Violation { - pub msg: String, - pub field: Option, -} - -#[derive(Deserialize, Debug)] -pub struct EvaluationResult { - #[serde(rename = "result")] - pub violations: Vec, -} - -impl EvaluationResult { - #[must_use] - pub fn valid(&self) -> bool { - self.violations.is_empty() - } -} - pub struct Policy { store: Store<()>, instance: opa_wasm::Policy, - register_entrypoint: String, - client_registration_entrypoint: String, - authorization_grant_endpoint: String, + entrypoints: Entrypoints, } #[derive(Debug, Error)] @@ -189,11 +184,50 @@ pub enum EvaluationError { } impl Policy { + #[tracing::instrument( + name = "policy.evaluate_email", + skip_all, + fields( + input.email = email, + ), + err, + )] + pub async fn evaluate_email( + &mut self, + email: &str, + ) -> Result { + let input = EmailInput { email }; + + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.email, &input) + .await?; + + Ok(res) + } + + #[tracing::instrument(name = "policy.evaluate_password", skip_all, err)] + pub async fn evaluate_password( + &mut self, + password: &str, + ) -> Result { + let input = PasswordInput { password }; + + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.password, &input) + .await?; + + Ok(res) + } + #[tracing::instrument( name = "policy.evaluate.register", skip_all, fields( - data.username = username, + input.registration_method = "password", + input.user.username = username, + input.user.email = email, ), err, )] @@ -203,17 +237,15 @@ impl Policy { password: &str, email: &str, ) -> Result { - let input = serde_json::json!({ - "user": { - "username": username, - "password": password, - "email": email - } - }); + let input = RegisterInput::Password { + username, + password, + email, + }; let [res]: [EvaluationResult; 1] = self .instance - .evaluate(&mut self.store, &self.register_entrypoint, &input) + .evaluate(&mut self.store, &self.entrypoints.register, &input) .await?; Ok(res) @@ -224,16 +256,13 @@ impl Policy { &mut self, client_metadata: &VerifiedClientMetadata, ) -> Result { - let client_metadata = serde_json::to_value(client_metadata)?; - let input = serde_json::json!({ - "client_metadata": client_metadata, - }); + let input = ClientRegistrationInput { client_metadata }; let [res]: [EvaluationResult; 1] = self .instance .evaluate( &mut self.store, - &self.client_registration_entrypoint, + &self.entrypoints.client_registration, &input, ) .await?; @@ -245,9 +274,9 @@ impl Policy { name = "policy.evaluate.authorization_grant", skip_all, fields( - data.authorization_grant.id = %authorization_grant.id, - data.client.id = %client.id, - data.user.id = %user.id, + input.authorization_grant.id = %authorization_grant.id, + input.client.id = %client.id, + input.user.id = %user.id, ), err, )] @@ -257,17 +286,19 @@ impl Policy { client: &Client, user: &User, ) -> Result { - let authorization_grant = serde_json::to_value(authorization_grant)?; - let user = serde_json::to_value(user)?; - let input = serde_json::json!({ - "authorization_grant": authorization_grant, - "client": client, - "user": user, - }); + let input = AuthorizationGrantInput { + user, + client, + authorization_grant, + }; let [res]: [EvaluationResult; 1] = self .instance - .evaluate(&mut self.store, &self.authorization_grant_endpoint, &input) + .evaluate( + &mut self.store, + &self.entrypoints.authorization_grant, + &input, + ) .await?; Ok(res) @@ -294,15 +325,15 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let factory = PolicyFactory::load( - file, - data, - "register/violation".to_owned(), - "client_registration/violation".to_owned(), - "authorization_grant/violation".to_owned(), - ) - .await - .unwrap(); + let entrypoints = Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + password: "password/violation".to_owned(), + }; + + let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); let mut policy = factory.instantiate().await.unwrap(); diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs new file mode 100644 index 00000000..3cc9ff1f --- /dev/null +++ b/crates/policy/src/model.rs @@ -0,0 +1,96 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use mas_data_model::{AuthorizationGrant, Client, User}; +use oauth2_types::registration::VerifiedClientMetadata; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Debug)] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct Violation { + pub msg: String, + pub field: Option, +} + +#[derive(Deserialize, Debug)] +pub struct EvaluationResult { + #[serde(rename = "result")] + pub violations: Vec, +} + +impl EvaluationResult { + #[must_use] + pub fn valid(&self) -> bool { + self.violations.is_empty() + } +} + +#[derive(Serialize, Debug)] +#[serde(tag = "registration_method", rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub enum RegisterInput<'a> { + Password { + username: &'a str, + password: &'a str, + email: &'a str, + }, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct ClientRegistrationInput<'a> { + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub client_metadata: &'a VerifiedClientMetadata, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct AuthorizationGrantInput<'a> { + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub user: &'a User, + + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub client: &'a Client, + + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub authorization_grant: &'a AuthorizationGrant, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct EmailInput<'a> { + pub email: &'a str, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct PasswordInput<'a> { + pub password: &'a str, +} diff --git a/misc/update.sh b/misc/update.sh index 52d7ac36..91e275df 100644 --- a/misc/update.sh +++ b/misc/update.sh @@ -6,10 +6,12 @@ export SQLX_OFFLINE=1 BASE_DIR="$(dirname "$0")/.." CONFIG_SCHEMA="${BASE_DIR}/docs/config.schema.json" GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql" +POLICIES_SCHEMA="${BASE_DIR}/policies/schema/" set -x cargo run -p mas-config > "${CONFIG_SCHEMA}" cargo run -p mas-graphql > "${GRAPHQL_SCHEMA}" +OUT_DIR="${POLICIES_SCHEMA}" cargo run -p mas-policy --features jsonschema cd "${BASE_DIR}/frontend" npm run generate diff --git a/policies/Makefile b/policies/Makefile index d110c21d..2c9ff7c9 100644 --- a/policies/Makefile +++ b/policies/Makefile @@ -1,6 +1,13 @@ # Set to 1 to run OPA through Docker DOCKER := 0 -OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:0.55.0 +OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:0.55.0-debug + +INPUTS := \ + client_registration.rego \ + register.rego \ + authorization_grant.rego \ + password.rego \ + email.rego ifeq ($(DOCKER), 0) OPA := opa @@ -10,11 +17,13 @@ else OPA_RW := docker run -i -v $(shell pwd):/policies -w /policies --rm $(OPA_DOCKER_IMAGE) endif -policy.wasm: client_registration.rego register.rego authorization_grant.rego +policy.wasm: $(INPUTS) $(OPA_RW) build -t wasm \ -e "client_registration/violation" \ -e "register/violation" \ -e "authorization_grant/violation" \ + -e "password/violation" \ + -e "email/violation" \ $^ tar xzf bundle.tar.gz /policy.wasm $(RM) bundle.tar.gz @@ -26,7 +35,7 @@ fmt: .PHONY: test test: - $(OPA) test -v ./*.rego + $(OPA) test --schema ./schema/ -v ./*.rego .PHONY: coverage coverage: diff --git a/policies/authorization_grant.rego b/policies/authorization_grant.rego index 2fd0c717..d59c6c57 100644 --- a/policies/authorization_grant.rego +++ b/policies/authorization_grant.rego @@ -1,3 +1,6 @@ +# METADATA +# schemas: +# - input: schema["authorization_grant_input"] package authorization_grant import future.keywords.in diff --git a/policies/client_registration.rego b/policies/client_registration.rego index 7ea671f0..a41375cf 100644 --- a/policies/client_registration.rego +++ b/policies/client_registration.rego @@ -1,3 +1,6 @@ +# METADATA +# schemas: +# - input: schema["client_registration_input"] package client_registration import future.keywords.in diff --git a/policies/email.rego b/policies/email.rego new file mode 100644 index 00000000..fecad108 --- /dev/null +++ b/policies/email.rego @@ -0,0 +1,35 @@ +# METADATA +# schemas: +# - input: schema["email_input"] +package email + +import future.keywords.in + +default allow := false + +allow { + count(violation) == 0 +} + +# Allow any domains if the data.allowed_domains array is not set +email_domain_allowed { + not data.allowed_domains +} + +# Allow an email only if its domain is in the list of allowed domains +email_domain_allowed { + [_, domain] := split(input.email, "@") + some allowed_domain in data.allowed_domains + glob.match(allowed_domain, ["."], domain) +} + +violation[{"msg": "email domain is not allowed"}] { + not email_domain_allowed +} + +# Deny emails with their domain in the domains banlist +violation[{"msg": "email domain is banned"}] { + [_, domain] := split(input.email, "@") + some banned_domain in data.banned_domains + glob.match(banned_domain, ["."], domain) +} diff --git a/policies/password.rego b/policies/password.rego new file mode 100644 index 00000000..bae1c215 --- /dev/null +++ b/policies/password.rego @@ -0,0 +1,30 @@ +# METADATA +# schemas: +# - input: schema["password_input"] +package password + +default allow := false + +allow { + count(violation) == 0 +} + +violation[{"msg": msg}] { + count(input.password) < data.passwords.min_length + msg := sprintf("needs to be at least %d characters", [data.passwords.min_length]) +} + +violation[{"msg": "requires at least one number"}] { + data.passwords.require_number + not regex.match("[0-9]", input.password) +} + +violation[{"msg": "requires at least one lowercase letter"}] { + data.passwords.require_lowercase + not regex.match("[a-z]", input.password) +} + +violation[{"msg": "requires at least one uppercase letter"}] { + data.passwords.require_uppercase + not regex.match("[A-Z]", input.password) +} diff --git a/policies/register.rego b/policies/register.rego index 391fc37b..b15e0fdc 100644 --- a/policies/register.rego +++ b/policies/register.rego @@ -1,5 +1,11 @@ +# METADATA +# schemas: +# - input: schema["register_input"] package register +import data.email as email_policy +import data.password as password_policy + import future.keywords.in default allow := false @@ -9,52 +15,24 @@ allow { } violation[{"field": "username", "msg": "username too short"}] { - count(input.user.username) <= 2 + count(input.username) <= 2 } violation[{"field": "username", "msg": "username too long"}] { - count(input.user.username) >= 15 + count(input.username) >= 15 } -violation[{"field": "password", "msg": msg}] { - count(input.user.password) < data.passwords.min_length - msg := sprintf("needs to be at least %d characters", [data.passwords.min_length]) +violation[object.union({"field": "password"}, v)] { + # Check if the registration method is password + input.registration_method == "password" + + # Get the violation object from the password policy + some v in password_policy.violation } -violation[{"field": "password", "msg": "requires at least one number"}] { - data.passwords.require_number - not regex.match("[0-9]", input.user.password) -} - -violation[{"field": "password", "msg": "requires at least one lowercase letter"}] { - data.passwords.require_lowercase - not regex.match("[a-z]", input.user.password) -} - -violation[{"field": "password", "msg": "requires at least one uppercase letter"}] { - data.passwords.require_uppercase - not regex.match("[A-Z]", input.user.password) -} - -# Allow any domains if the data.allowed_domains array is not set -email_domain_allowed { - not data.allowed_domains -} - -# Allow an email only if its domain is in the list of allowed domains -email_domain_allowed { - [_, domain] := split(input.user.email, "@") - some allowed_domain in data.allowed_domains - glob.match(allowed_domain, ["."], domain) -} - -violation[{"field": "email", "msg": "email domain not allowed"}] { - not email_domain_allowed -} - -# Deny emails with their domain in the domains banlist -violation[{"field": "email", "msg": "email domain not allowed"}] { - [_, domain] := split(input.user.email, "@") - some banned_domain in data.banned_domains - glob.match(banned_domain, ["."], domain) +# Check if the email is valid using the email policy +# and add the email field to the violation object +violation[object.union({"field": "email"}, v)] { + # Get the violation object from the email policy + some v in email_policy.violation } diff --git a/policies/register_test.rego b/policies/register_test.rego index d2b042fe..70acea87 100644 --- a/policies/register_test.rego +++ b/policies/register_test.rego @@ -1,72 +1,85 @@ package register -mock_user := {"username": "hello", "password": "Hunter2", "email": "hello@staging.element.io"} +mock_registration := { + "registration_method": "password", + "username": "hello", + "password": "Hunter2", + "email": "hello@staging.element.io", +} test_allow_all_domains { - allow with input.user as mock_user + allow with input as mock_registration } test_allowed_domain { - allow with input.user as mock_user + allow with input as mock_registration with data.allowed_domains as ["*.element.io"] } test_not_allowed_domain { - not allow with input.user as mock_user + not allow with input as mock_registration with data.allowed_domains as ["example.com"] } test_banned_domain { - not allow with input.user as mock_user + not allow with input as mock_registration with data.banned_domains as ["*.element.io"] } test_banned_subdomain { - not allow with input.user as mock_user + not allow with input as mock_registration with data.allowed_domains as ["*.element.io"] with data.banned_domains as ["staging.element.io"] } test_short_username { - not allow with input.user as {"username": "a", "email": "hello@element.io"} + not allow with input as {"username": "a", "email": "hello@element.io"} } test_long_username { - not allow with input.user as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"} + not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"} } test_password_require_number { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.require_number as true - not allow with input.user as mock_user - with input.user.password as "hunter" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "hunter" with data.passwords.require_number as true } test_password_require_lowercase { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.require_lowercase as true - not allow with input.user as mock_user - with input.user.password as "HUNTER2" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "HUNTER2" with data.passwords.require_lowercase as true } test_password_require_uppercase { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.require_uppercase as true - not allow with input.user as mock_user - with input.user.password as "hunter2" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "hunter2" with data.passwords.require_uppercase as true } test_password_min_length { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.min_length as 6 - not allow with input.user as mock_user - with input.user.password as "short" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "short" with data.passwords.min_length as 6 } diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json new file mode 100644 index 00000000..a1a49a8d --- /dev/null +++ b/policies/schema/authorization_grant_input.json @@ -0,0 +1,24 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "AuthorizationGrantInput", + "type": "object", + "required": [ + "authorization_grant", + "client", + "user" + ], + "properties": { + "authorization_grant": { + "type": "object", + "additionalProperties": true + }, + "client": { + "type": "object", + "additionalProperties": true + }, + "user": { + "type": "object", + "additionalProperties": true + } + } +} \ No newline at end of file diff --git a/policies/schema/client_registration_input.json b/policies/schema/client_registration_input.json new file mode 100644 index 00000000..7261068e --- /dev/null +++ b/policies/schema/client_registration_input.json @@ -0,0 +1,14 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "ClientRegistrationInput", + "type": "object", + "required": [ + "client_metadata" + ], + "properties": { + "client_metadata": { + "type": "object", + "additionalProperties": true + } + } +} \ No newline at end of file diff --git a/policies/schema/email_input.json b/policies/schema/email_input.json new file mode 100644 index 00000000..487eb4b9 --- /dev/null +++ b/policies/schema/email_input.json @@ -0,0 +1,13 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "EmailInput", + "type": "object", + "required": [ + "email" + ], + "properties": { + "email": { + "type": "string" + } + } +} \ No newline at end of file diff --git a/policies/schema/password_input.json b/policies/schema/password_input.json new file mode 100644 index 00000000..d85b2862 --- /dev/null +++ b/policies/schema/password_input.json @@ -0,0 +1,13 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "PasswordInput", + "type": "object", + "required": [ + "password" + ], + "properties": { + "password": { + "type": "string" + } + } +} \ No newline at end of file diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json new file mode 100644 index 00000000..d77ce66e --- /dev/null +++ b/policies/schema/register_input.json @@ -0,0 +1,32 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "RegisterInput", + "oneOf": [ + { + "type": "object", + "required": [ + "email", + "password", + "registration_method", + "username" + ], + "properties": { + "email": { + "type": "string" + }, + "password": { + "type": "string" + }, + "registration_method": { + "type": "string", + "enum": [ + "password" + ] + }, + "username": { + "type": "string" + } + } + } + ] +} \ No newline at end of file