You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
policies: split the email & password policies and add jsonschema validation of the input
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3047,6 +3047,7 @@ dependencies = [
|
||||
"mas-data-model",
|
||||
"oauth2-types",
|
||||
"opa-wasm",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
|
@@ -97,12 +97,18 @@ pub async fn policy_factory_from_config(
|
||||
.await
|
||||
.context("failed to open OPA WASM policy file")?;
|
||||
|
||||
let entrypoints = mas_policy::Entrypoints {
|
||||
register: config.register_entrypoint.clone(),
|
||||
client_registration: config.client_registration_entrypoint.clone(),
|
||||
authorization_grant: config.authorization_grant_entrypoint.clone(),
|
||||
email: config.email_entrypoint.clone(),
|
||||
password: config.password_entrypoint.clone(),
|
||||
};
|
||||
|
||||
PolicyFactory::load(
|
||||
policy_file,
|
||||
config.data.clone().unwrap_or_default(),
|
||||
config.register_entrypoint.clone(),
|
||||
config.client_registration_entrypoint.clone(),
|
||||
config.authorization_grant_entrypoint.clone(),
|
||||
entrypoints,
|
||||
)
|
||||
.await
|
||||
.context("failed to load the policy")
|
||||
|
@@ -48,6 +48,14 @@ fn default_authorization_grant_endpoint() -> String {
|
||||
"authorization_grant/violation".to_owned()
|
||||
}
|
||||
|
||||
fn default_password_endpoint() -> String {
|
||||
"password/violation".to_owned()
|
||||
}
|
||||
|
||||
fn default_email_endpoint() -> String {
|
||||
"email/violation".to_owned()
|
||||
}
|
||||
|
||||
/// Application secrets
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -69,6 +77,14 @@ pub struct PolicyConfig {
|
||||
#[serde(default = "default_authorization_grant_endpoint")]
|
||||
pub authorization_grant_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when changing password
|
||||
#[serde(default = "default_password_endpoint")]
|
||||
pub password_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when adding an email address
|
||||
#[serde(default = "default_email_endpoint")]
|
||||
pub email_entrypoint: String,
|
||||
|
||||
/// Arbitrary data to pass to the policy
|
||||
#[serde(default)]
|
||||
pub data: Option<serde_json::Value>,
|
||||
@@ -81,6 +97,8 @@ impl Default for PolicyConfig {
|
||||
client_registration_entrypoint: default_client_registration_endpoint(),
|
||||
register_entrypoint: default_register_endpoint(),
|
||||
authorization_grant_entrypoint: default_authorization_grant_endpoint(),
|
||||
password_entrypoint: default_password_endpoint(),
|
||||
email_entrypoint: default_email_endpoint(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
@@ -76,7 +76,7 @@ impl IntoResponse for RouteError {
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::InstantiateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
|
||||
impl_from_error_for_route!(super::callback::CallbackDestinationError);
|
||||
@@ -187,7 +187,7 @@ pub enum GrantCompletionError {
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError);
|
||||
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError);
|
||||
|
||||
|
@@ -94,7 +94,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(self::callback::CallbackDestinationError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::InstantiateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
@@ -61,7 +61,7 @@ pub enum RouteError {
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::InstantiateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
|
@@ -49,7 +49,7 @@ pub(crate) enum RouteError {
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::LoadError);
|
||||
impl_from_error_for_route!(mas_policy::InstanciateError);
|
||||
impl_from_error_for_route!(mas_policy::InstantiateError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(mas_keystore::aead::Error);
|
||||
|
||||
|
@@ -117,14 +117,15 @@ impl TestState {
|
||||
let file =
|
||||
tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?;
|
||||
|
||||
let policy_factory = PolicyFactory::load(
|
||||
file,
|
||||
serde_json::json!({}),
|
||||
"register/violation".to_owned(),
|
||||
"client_registration/violation".to_owned(),
|
||||
"authorization_grant/violation".to_owned(),
|
||||
)
|
||||
.await?;
|
||||
let entrypoints = mas_policy::Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
password: "password/violation".to_owned(),
|
||||
};
|
||||
|
||||
let policy_factory = PolicyFactory::load(file, serde_json::json!({}), entrypoints).await?;
|
||||
|
||||
let homeserver_connection = MockHomeserverConnection::new("example.com");
|
||||
|
||||
|
@@ -10,8 +10,9 @@ anyhow.workspace = true
|
||||
opa-wasm = { git = "https://github.com/matrix-org/rust-opa-wasm.git" }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
schemars = {version = "0.8.1", optional = true }
|
||||
thiserror.workspace = true
|
||||
tokio = { version = "1.32.0", features = ["io-util"] }
|
||||
tokio = { version = "1.32.0", features = ["io-util", "rt"] }
|
||||
tracing.workspace = true
|
||||
wasmtime = { version = "12.0.1", default-features = false, features = ["async", "cranelift"] }
|
||||
|
||||
@@ -23,3 +24,8 @@ tokio = { version = "1.32.0", features = ["fs", "rt", "macros"] }
|
||||
|
||||
[features]
|
||||
cache = ["wasmtime/cache"]
|
||||
jsonschema = ["dep:schemars"]
|
||||
|
||||
[[bin]]
|
||||
name = "schema"
|
||||
required-features = ["jsonschema"]
|
55
crates/policy/src/bin/schema.rs
Normal file
55
crates/policy/src/bin/schema.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use mas_policy::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, EmailInput, PasswordInput, RegisterInput,
|
||||
};
|
||||
use schemars::{gen::SchemaSettings, JsonSchema};
|
||||
|
||||
fn write_schema<T: JsonSchema>(out_dir: Option<&Path>, file: &str) {
|
||||
let mut writer: Box<dyn std::io::Write> = match out_dir {
|
||||
Some(out_dir) => {
|
||||
let path = out_dir.join(file);
|
||||
eprintln!("Writing to {path:?}");
|
||||
let file = std::fs::File::create(path).expect("Failed to create file");
|
||||
Box::new(std::io::BufWriter::new(file))
|
||||
}
|
||||
None => {
|
||||
eprintln!("--- {file} ---");
|
||||
Box::new(std::io::stdout())
|
||||
}
|
||||
};
|
||||
|
||||
let settings = SchemaSettings::draft07().with(|s| {
|
||||
s.option_nullable = false;
|
||||
s.option_add_null_type = false;
|
||||
});
|
||||
let generator = settings.into_generator();
|
||||
let schema = generator.into_root_schema_for::<T>();
|
||||
serde_json::to_writer_pretty(&mut writer, &schema).expect("Failed to serialize schema");
|
||||
writer.flush().expect("Failed to flush writer");
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok();
|
||||
let output_root = output_root.as_deref();
|
||||
|
||||
write_schema::<RegisterInput>(output_root, "register_input.json");
|
||||
write_schema::<ClientRegistrationInput>(output_root, "client_registration_input.json");
|
||||
write_schema::<AuthorizationGrantInput>(output_root, "authorization_grant_input.json");
|
||||
write_schema::<EmailInput>(output_root, "email_input.json");
|
||||
write_schema::<PasswordInput>(output_root, "password_input.json");
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -17,14 +17,20 @@
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::missing_errors_doc)]
|
||||
|
||||
pub mod model;
|
||||
|
||||
use mas_data_model::{AuthorizationGrant, Client, User};
|
||||
use oauth2_types::registration::VerifiedClientMetadata;
|
||||
use opa_wasm::Runtime;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use wasmtime::{Config, Engine, Module, Store};
|
||||
|
||||
use self::model::{
|
||||
AuthorizationGrantInput, ClientRegistrationInput, EmailInput, PasswordInput, RegisterInput,
|
||||
};
|
||||
pub use self::model::{EvaluationResult, Violation};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LoadError {
|
||||
#[error("failed to read module")]
|
||||
@@ -40,7 +46,7 @@ pub enum LoadError {
|
||||
Compilation(#[source] anyhow::Error),
|
||||
|
||||
#[error("failed to instantiate a test instance")]
|
||||
Instantiate(#[source] InstanciateError),
|
||||
Instantiate(#[source] InstantiateError),
|
||||
|
||||
#[cfg(feature = "cache")]
|
||||
#[error("could not load wasmtime cache configuration")]
|
||||
@@ -48,7 +54,7 @@ pub enum LoadError {
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InstanciateError {
|
||||
pub enum InstantiateError {
|
||||
#[error("failed to create WASM runtime")]
|
||||
Runtime(#[source] anyhow::Error),
|
||||
|
||||
@@ -59,13 +65,33 @@ pub enum InstanciateError {
|
||||
LoadData(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Holds the entrypoint of each policy
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Entrypoints {
|
||||
pub register: String,
|
||||
pub client_registration: String,
|
||||
pub authorization_grant: String,
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl Entrypoints {
|
||||
fn all(&self) -> [&str; 5] {
|
||||
[
|
||||
self.register.as_str(),
|
||||
self.client_registration.as_str(),
|
||||
self.authorization_grant.as_str(),
|
||||
self.email.as_str(),
|
||||
self.password.as_str(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PolicyFactory {
|
||||
engine: Engine,
|
||||
module: Module,
|
||||
data: serde_json::Value,
|
||||
register_entrypoint: String,
|
||||
client_registration_entrypoint: String,
|
||||
authorization_grant_endpoint: String,
|
||||
entrypoints: Entrypoints,
|
||||
}
|
||||
|
||||
impl PolicyFactory {
|
||||
@@ -73,9 +99,7 @@ impl PolicyFactory {
|
||||
pub async fn load(
|
||||
mut source: impl AsyncRead + std::marker::Unpin,
|
||||
data: serde_json::Value,
|
||||
register_entrypoint: String,
|
||||
client_registration_entrypoint: String,
|
||||
authorization_grant_endpoint: String,
|
||||
entrypoints: Entrypoints,
|
||||
) -> Result<Self, LoadError> {
|
||||
let mut config = Config::default();
|
||||
config.async_support(true);
|
||||
@@ -103,9 +127,7 @@ impl PolicyFactory {
|
||||
engine,
|
||||
module,
|
||||
data,
|
||||
register_entrypoint,
|
||||
client_registration_entrypoint,
|
||||
authorization_grant_endpoint,
|
||||
entrypoints,
|
||||
};
|
||||
|
||||
// Try to instantiate
|
||||
@@ -118,22 +140,18 @@ impl PolicyFactory {
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "policy.instantiate", skip_all, err)]
|
||||
pub async fn instantiate(&self) -> Result<Policy, InstanciateError> {
|
||||
pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
|
||||
let mut store = Store::new(&self.engine, ());
|
||||
let runtime = Runtime::new(&mut store, &self.module)
|
||||
.await
|
||||
.map_err(InstanciateError::Runtime)?;
|
||||
.map_err(InstantiateError::Runtime)?;
|
||||
|
||||
// Check that we have the required entrypoints
|
||||
let entrypoints = runtime.entrypoints();
|
||||
let policy_entrypoints = runtime.entrypoints();
|
||||
|
||||
for e in [
|
||||
self.register_entrypoint.as_str(),
|
||||
self.client_registration_entrypoint.as_str(),
|
||||
self.authorization_grant_endpoint.as_str(),
|
||||
] {
|
||||
if !entrypoints.contains(e) {
|
||||
return Err(InstanciateError::MissingEntrypoint {
|
||||
for e in self.entrypoints.all() {
|
||||
if !policy_entrypoints.contains(e) {
|
||||
return Err(InstantiateError::MissingEntrypoint {
|
||||
entrypoint: e.to_owned(),
|
||||
});
|
||||
}
|
||||
@@ -142,43 +160,20 @@ impl PolicyFactory {
|
||||
let instance = runtime
|
||||
.with_data(&mut store, &self.data)
|
||||
.await
|
||||
.map_err(InstanciateError::LoadData)?;
|
||||
.map_err(InstantiateError::LoadData)?;
|
||||
|
||||
Ok(Policy {
|
||||
store,
|
||||
instance,
|
||||
register_entrypoint: self.register_entrypoint.clone(),
|
||||
client_registration_entrypoint: self.client_registration_entrypoint.clone(),
|
||||
authorization_grant_endpoint: self.authorization_grant_endpoint.clone(),
|
||||
entrypoints: self.entrypoints.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Violation {
|
||||
pub msg: String,
|
||||
pub field: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct EvaluationResult {
|
||||
#[serde(rename = "result")]
|
||||
pub violations: Vec<Violation>,
|
||||
}
|
||||
|
||||
impl EvaluationResult {
|
||||
#[must_use]
|
||||
pub fn valid(&self) -> bool {
|
||||
self.violations.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Policy {
|
||||
store: Store<()>,
|
||||
instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
|
||||
register_entrypoint: String,
|
||||
client_registration_entrypoint: String,
|
||||
authorization_grant_endpoint: String,
|
||||
entrypoints: Entrypoints,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -189,11 +184,50 @@ pub enum EvaluationError {
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate_email",
|
||||
skip_all,
|
||||
fields(
|
||||
input.email = email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_email(
|
||||
&mut self,
|
||||
email: &str,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = EmailInput { email };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.email, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "policy.evaluate_password", skip_all, err)]
|
||||
pub async fn evaluate_password(
|
||||
&mut self,
|
||||
password: &str,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = PasswordInput { password };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.password, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.register",
|
||||
skip_all,
|
||||
fields(
|
||||
data.username = username,
|
||||
input.registration_method = "password",
|
||||
input.user.username = username,
|
||||
input.user.email = email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
@@ -203,17 +237,15 @@ impl Policy {
|
||||
password: &str,
|
||||
email: &str,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = serde_json::json!({
|
||||
"user": {
|
||||
"username": username,
|
||||
"password": password,
|
||||
"email": email
|
||||
}
|
||||
});
|
||||
let input = RegisterInput::Password {
|
||||
username,
|
||||
password,
|
||||
email,
|
||||
};
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.register_entrypoint, &input)
|
||||
.evaluate(&mut self.store, &self.entrypoints.register, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
@@ -224,16 +256,13 @@ impl Policy {
|
||||
&mut self,
|
||||
client_metadata: &VerifiedClientMetadata,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let client_metadata = serde_json::to_value(client_metadata)?;
|
||||
let input = serde_json::json!({
|
||||
"client_metadata": client_metadata,
|
||||
});
|
||||
let input = ClientRegistrationInput { client_metadata };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(
|
||||
&mut self.store,
|
||||
&self.client_registration_entrypoint,
|
||||
&self.entrypoints.client_registration,
|
||||
&input,
|
||||
)
|
||||
.await?;
|
||||
@@ -245,9 +274,9 @@ impl Policy {
|
||||
name = "policy.evaluate.authorization_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
data.authorization_grant.id = %authorization_grant.id,
|
||||
data.client.id = %client.id,
|
||||
data.user.id = %user.id,
|
||||
input.authorization_grant.id = %authorization_grant.id,
|
||||
input.client.id = %client.id,
|
||||
input.user.id = %user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
@@ -257,17 +286,19 @@ impl Policy {
|
||||
client: &Client,
|
||||
user: &User,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let authorization_grant = serde_json::to_value(authorization_grant)?;
|
||||
let user = serde_json::to_value(user)?;
|
||||
let input = serde_json::json!({
|
||||
"authorization_grant": authorization_grant,
|
||||
"client": client,
|
||||
"user": user,
|
||||
});
|
||||
let input = AuthorizationGrantInput {
|
||||
user,
|
||||
client,
|
||||
authorization_grant,
|
||||
};
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.authorization_grant_endpoint, &input)
|
||||
.evaluate(
|
||||
&mut self.store,
|
||||
&self.entrypoints.authorization_grant,
|
||||
&input,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
@@ -294,15 +325,15 @@ mod tests {
|
||||
|
||||
let file = tokio::fs::File::open(path).await.unwrap();
|
||||
|
||||
let factory = PolicyFactory::load(
|
||||
file,
|
||||
data,
|
||||
"register/violation".to_owned(),
|
||||
"client_registration/violation".to_owned(),
|
||||
"authorization_grant/violation".to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let entrypoints = Entrypoints {
|
||||
register: "register/violation".to_owned(),
|
||||
client_registration: "client_registration/violation".to_owned(),
|
||||
authorization_grant: "authorization_grant/violation".to_owned(),
|
||||
email: "email/violation".to_owned(),
|
||||
password: "password/violation".to_owned(),
|
||||
};
|
||||
|
||||
let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
|
||||
|
||||
let mut policy = factory.instantiate().await.unwrap();
|
||||
|
||||
|
96
crates/policy/src/model.rs
Normal file
96
crates/policy/src/model.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use mas_data_model::{AuthorizationGrant, Client, User};
|
||||
use oauth2_types::registration::VerifiedClientMetadata;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub struct Violation {
|
||||
pub msg: String,
|
||||
pub field: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct EvaluationResult {
|
||||
#[serde(rename = "result")]
|
||||
pub violations: Vec<Violation>,
|
||||
}
|
||||
|
||||
impl EvaluationResult {
|
||||
#[must_use]
|
||||
pub fn valid(&self) -> bool {
|
||||
self.violations.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(tag = "registration_method", rename_all = "snake_case")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub enum RegisterInput<'a> {
|
||||
Password {
|
||||
username: &'a str,
|
||||
password: &'a str,
|
||||
email: &'a str,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub struct ClientRegistrationInput<'a> {
|
||||
#[cfg_attr(
|
||||
feature = "jsonschema",
|
||||
schemars(with = "std::collections::HashMap<String, serde_json::Value>")
|
||||
)]
|
||||
pub client_metadata: &'a VerifiedClientMetadata,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub struct AuthorizationGrantInput<'a> {
|
||||
#[cfg_attr(
|
||||
feature = "jsonschema",
|
||||
schemars(with = "std::collections::HashMap<String, serde_json::Value>")
|
||||
)]
|
||||
pub user: &'a User,
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "jsonschema",
|
||||
schemars(with = "std::collections::HashMap<String, serde_json::Value>")
|
||||
)]
|
||||
pub client: &'a Client,
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "jsonschema",
|
||||
schemars(with = "std::collections::HashMap<String, serde_json::Value>")
|
||||
)]
|
||||
pub authorization_grant: &'a AuthorizationGrant,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub struct EmailInput<'a> {
|
||||
pub email: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub struct PasswordInput<'a> {
|
||||
pub password: &'a str,
|
||||
}
|
@@ -6,10 +6,12 @@ export SQLX_OFFLINE=1
|
||||
BASE_DIR="$(dirname "$0")/.."
|
||||
CONFIG_SCHEMA="${BASE_DIR}/docs/config.schema.json"
|
||||
GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql"
|
||||
POLICIES_SCHEMA="${BASE_DIR}/policies/schema/"
|
||||
|
||||
set -x
|
||||
cargo run -p mas-config > "${CONFIG_SCHEMA}"
|
||||
cargo run -p mas-graphql > "${GRAPHQL_SCHEMA}"
|
||||
OUT_DIR="${POLICIES_SCHEMA}" cargo run -p mas-policy --features jsonschema
|
||||
|
||||
cd "${BASE_DIR}/frontend"
|
||||
npm run generate
|
||||
|
@@ -1,6 +1,13 @@
|
||||
# Set to 1 to run OPA through Docker
|
||||
DOCKER := 0
|
||||
OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:0.55.0
|
||||
OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:0.55.0-debug
|
||||
|
||||
INPUTS := \
|
||||
client_registration.rego \
|
||||
register.rego \
|
||||
authorization_grant.rego \
|
||||
password.rego \
|
||||
email.rego
|
||||
|
||||
ifeq ($(DOCKER), 0)
|
||||
OPA := opa
|
||||
@@ -10,11 +17,13 @@ else
|
||||
OPA_RW := docker run -i -v $(shell pwd):/policies -w /policies --rm $(OPA_DOCKER_IMAGE)
|
||||
endif
|
||||
|
||||
policy.wasm: client_registration.rego register.rego authorization_grant.rego
|
||||
policy.wasm: $(INPUTS)
|
||||
$(OPA_RW) build -t wasm \
|
||||
-e "client_registration/violation" \
|
||||
-e "register/violation" \
|
||||
-e "authorization_grant/violation" \
|
||||
-e "password/violation" \
|
||||
-e "email/violation" \
|
||||
$^
|
||||
tar xzf bundle.tar.gz /policy.wasm
|
||||
$(RM) bundle.tar.gz
|
||||
@@ -26,7 +35,7 @@ fmt:
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
$(OPA) test -v ./*.rego
|
||||
$(OPA) test --schema ./schema/ -v ./*.rego
|
||||
|
||||
.PHONY: coverage
|
||||
coverage:
|
||||
|
@@ -1,3 +1,6 @@
|
||||
# METADATA
|
||||
# schemas:
|
||||
# - input: schema["authorization_grant_input"]
|
||||
package authorization_grant
|
||||
|
||||
import future.keywords.in
|
||||
|
@@ -1,3 +1,6 @@
|
||||
# METADATA
|
||||
# schemas:
|
||||
# - input: schema["client_registration_input"]
|
||||
package client_registration
|
||||
|
||||
import future.keywords.in
|
||||
|
35
policies/email.rego
Normal file
35
policies/email.rego
Normal file
@@ -0,0 +1,35 @@
|
||||
# METADATA
|
||||
# schemas:
|
||||
# - input: schema["email_input"]
|
||||
package email
|
||||
|
||||
import future.keywords.in
|
||||
|
||||
default allow := false
|
||||
|
||||
allow {
|
||||
count(violation) == 0
|
||||
}
|
||||
|
||||
# Allow any domains if the data.allowed_domains array is not set
|
||||
email_domain_allowed {
|
||||
not data.allowed_domains
|
||||
}
|
||||
|
||||
# Allow an email only if its domain is in the list of allowed domains
|
||||
email_domain_allowed {
|
||||
[_, domain] := split(input.email, "@")
|
||||
some allowed_domain in data.allowed_domains
|
||||
glob.match(allowed_domain, ["."], domain)
|
||||
}
|
||||
|
||||
violation[{"msg": "email domain is not allowed"}] {
|
||||
not email_domain_allowed
|
||||
}
|
||||
|
||||
# Deny emails with their domain in the domains banlist
|
||||
violation[{"msg": "email domain is banned"}] {
|
||||
[_, domain] := split(input.email, "@")
|
||||
some banned_domain in data.banned_domains
|
||||
glob.match(banned_domain, ["."], domain)
|
||||
}
|
30
policies/password.rego
Normal file
30
policies/password.rego
Normal file
@@ -0,0 +1,30 @@
|
||||
# METADATA
|
||||
# schemas:
|
||||
# - input: schema["password_input"]
|
||||
package password
|
||||
|
||||
default allow := false
|
||||
|
||||
allow {
|
||||
count(violation) == 0
|
||||
}
|
||||
|
||||
violation[{"msg": msg}] {
|
||||
count(input.password) < data.passwords.min_length
|
||||
msg := sprintf("needs to be at least %d characters", [data.passwords.min_length])
|
||||
}
|
||||
|
||||
violation[{"msg": "requires at least one number"}] {
|
||||
data.passwords.require_number
|
||||
not regex.match("[0-9]", input.password)
|
||||
}
|
||||
|
||||
violation[{"msg": "requires at least one lowercase letter"}] {
|
||||
data.passwords.require_lowercase
|
||||
not regex.match("[a-z]", input.password)
|
||||
}
|
||||
|
||||
violation[{"msg": "requires at least one uppercase letter"}] {
|
||||
data.passwords.require_uppercase
|
||||
not regex.match("[A-Z]", input.password)
|
||||
}
|
@@ -1,5 +1,11 @@
|
||||
# METADATA
|
||||
# schemas:
|
||||
# - input: schema["register_input"]
|
||||
package register
|
||||
|
||||
import data.email as email_policy
|
||||
import data.password as password_policy
|
||||
|
||||
import future.keywords.in
|
||||
|
||||
default allow := false
|
||||
@@ -9,52 +15,24 @@ allow {
|
||||
}
|
||||
|
||||
violation[{"field": "username", "msg": "username too short"}] {
|
||||
count(input.user.username) <= 2
|
||||
count(input.username) <= 2
|
||||
}
|
||||
|
||||
violation[{"field": "username", "msg": "username too long"}] {
|
||||
count(input.user.username) >= 15
|
||||
count(input.username) >= 15
|
||||
}
|
||||
|
||||
violation[{"field": "password", "msg": msg}] {
|
||||
count(input.user.password) < data.passwords.min_length
|
||||
msg := sprintf("needs to be at least %d characters", [data.passwords.min_length])
|
||||
violation[object.union({"field": "password"}, v)] {
|
||||
# Check if the registration method is password
|
||||
input.registration_method == "password"
|
||||
|
||||
# Get the violation object from the password policy
|
||||
some v in password_policy.violation
|
||||
}
|
||||
|
||||
violation[{"field": "password", "msg": "requires at least one number"}] {
|
||||
data.passwords.require_number
|
||||
not regex.match("[0-9]", input.user.password)
|
||||
}
|
||||
|
||||
violation[{"field": "password", "msg": "requires at least one lowercase letter"}] {
|
||||
data.passwords.require_lowercase
|
||||
not regex.match("[a-z]", input.user.password)
|
||||
}
|
||||
|
||||
violation[{"field": "password", "msg": "requires at least one uppercase letter"}] {
|
||||
data.passwords.require_uppercase
|
||||
not regex.match("[A-Z]", input.user.password)
|
||||
}
|
||||
|
||||
# Allow any domains if the data.allowed_domains array is not set
|
||||
email_domain_allowed {
|
||||
not data.allowed_domains
|
||||
}
|
||||
|
||||
# Allow an email only if its domain is in the list of allowed domains
|
||||
email_domain_allowed {
|
||||
[_, domain] := split(input.user.email, "@")
|
||||
some allowed_domain in data.allowed_domains
|
||||
glob.match(allowed_domain, ["."], domain)
|
||||
}
|
||||
|
||||
violation[{"field": "email", "msg": "email domain not allowed"}] {
|
||||
not email_domain_allowed
|
||||
}
|
||||
|
||||
# Deny emails with their domain in the domains banlist
|
||||
violation[{"field": "email", "msg": "email domain not allowed"}] {
|
||||
[_, domain] := split(input.user.email, "@")
|
||||
some banned_domain in data.banned_domains
|
||||
glob.match(banned_domain, ["."], domain)
|
||||
# Check if the email is valid using the email policy
|
||||
# and add the email field to the violation object
|
||||
violation[object.union({"field": "email"}, v)] {
|
||||
# Get the violation object from the email policy
|
||||
some v in email_policy.violation
|
||||
}
|
||||
|
@@ -1,72 +1,85 @@
|
||||
package register
|
||||
|
||||
mock_user := {"username": "hello", "password": "Hunter2", "email": "hello@staging.element.io"}
|
||||
mock_registration := {
|
||||
"registration_method": "password",
|
||||
"username": "hello",
|
||||
"password": "Hunter2",
|
||||
"email": "hello@staging.element.io",
|
||||
}
|
||||
|
||||
test_allow_all_domains {
|
||||
allow with input.user as mock_user
|
||||
allow with input as mock_registration
|
||||
}
|
||||
|
||||
test_allowed_domain {
|
||||
allow with input.user as mock_user
|
||||
allow with input as mock_registration
|
||||
with data.allowed_domains as ["*.element.io"]
|
||||
}
|
||||
|
||||
test_not_allowed_domain {
|
||||
not allow with input.user as mock_user
|
||||
not allow with input as mock_registration
|
||||
with data.allowed_domains as ["example.com"]
|
||||
}
|
||||
|
||||
test_banned_domain {
|
||||
not allow with input.user as mock_user
|
||||
not allow with input as mock_registration
|
||||
with data.banned_domains as ["*.element.io"]
|
||||
}
|
||||
|
||||
test_banned_subdomain {
|
||||
not allow with input.user as mock_user
|
||||
not allow with input as mock_registration
|
||||
with data.allowed_domains as ["*.element.io"]
|
||||
with data.banned_domains as ["staging.element.io"]
|
||||
}
|
||||
|
||||
test_short_username {
|
||||
not allow with input.user as {"username": "a", "email": "hello@element.io"}
|
||||
not allow with input as {"username": "a", "email": "hello@element.io"}
|
||||
}
|
||||
|
||||
test_long_username {
|
||||
not allow with input.user as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"}
|
||||
not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"}
|
||||
}
|
||||
|
||||
test_password_require_number {
|
||||
allow with input.user as mock_user
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.require_number as true
|
||||
|
||||
not allow with input.user as mock_user
|
||||
with input.user.password as "hunter"
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "hunter"
|
||||
with data.passwords.require_number as true
|
||||
}
|
||||
|
||||
test_password_require_lowercase {
|
||||
allow with input.user as mock_user
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.require_lowercase as true
|
||||
|
||||
not allow with input.user as mock_user
|
||||
with input.user.password as "HUNTER2"
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "HUNTER2"
|
||||
with data.passwords.require_lowercase as true
|
||||
}
|
||||
|
||||
test_password_require_uppercase {
|
||||
allow with input.user as mock_user
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.require_uppercase as true
|
||||
|
||||
not allow with input.user as mock_user
|
||||
with input.user.password as "hunter2"
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "hunter2"
|
||||
with data.passwords.require_uppercase as true
|
||||
}
|
||||
|
||||
test_password_min_length {
|
||||
allow with input.user as mock_user
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.min_length as 6
|
||||
|
||||
not allow with input.user as mock_user
|
||||
with input.user.password as "short"
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "short"
|
||||
with data.passwords.min_length as 6
|
||||
}
|
||||
|
24
policies/schema/authorization_grant_input.json
Normal file
24
policies/schema/authorization_grant_input.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "AuthorizationGrantInput",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"authorization_grant",
|
||||
"client",
|
||||
"user"
|
||||
],
|
||||
"properties": {
|
||||
"authorization_grant": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"client": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"user": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}
|
||||
}
|
||||
}
|
14
policies/schema/client_registration_input.json
Normal file
14
policies/schema/client_registration_input.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "ClientRegistrationInput",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"client_metadata"
|
||||
],
|
||||
"properties": {
|
||||
"client_metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}
|
||||
}
|
||||
}
|
13
policies/schema/email_input.json
Normal file
13
policies/schema/email_input.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "EmailInput",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"email"
|
||||
],
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
13
policies/schema/password_input.json
Normal file
13
policies/schema/password_input.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "PasswordInput",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"password"
|
||||
],
|
||||
"properties": {
|
||||
"password": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
32
policies/schema/register_input.json
Normal file
32
policies/schema/register_input.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "RegisterInput",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"email",
|
||||
"password",
|
||||
"registration_method",
|
||||
"username"
|
||||
],
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string"
|
||||
},
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
"registration_method": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"password"
|
||||
]
|
||||
},
|
||||
"username": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
Reference in New Issue
Block a user