You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Support for applying OPA policies during client registration
This commit is contained in:
@ -36,6 +36,7 @@ mas-config = { path = "../config" }
|
||||
mas-email = { path = "../email" }
|
||||
mas-handlers = { path = "../handlers" }
|
||||
mas-http = { path = "../http" }
|
||||
mas-policy = { path = "../policy" }
|
||||
mas-router = { path = "../router" }
|
||||
mas-static-files = { path = "../static-files" }
|
||||
mas-storage = { path = "../storage" }
|
||||
|
@ -25,6 +25,7 @@ use hyper::Server;
|
||||
use mas_config::RootConfig;
|
||||
use mas_email::{MailTransport, Mailer};
|
||||
use mas_http::ServerLayer;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::MIGRATOR;
|
||||
use mas_tasks::TaskQueue;
|
||||
@ -176,6 +177,20 @@ impl Options {
|
||||
|
||||
let encrypter = config.secrets.encrypter();
|
||||
|
||||
// Load and compile the WASM policies
|
||||
let mut policy = tokio::fs::File::open(&config.policy.wasm_module)
|
||||
.await
|
||||
.context("failed to open OPA WASM policy file")?;
|
||||
let policy_factory = PolicyFactory::load(
|
||||
&mut policy,
|
||||
config.policy.data.clone().unwrap_or_default(),
|
||||
config.policy.login_entrypoint.clone(),
|
||||
config.policy.register_entrypoint.clone(),
|
||||
config.policy.client_registration_entrypoint.clone(),
|
||||
)
|
||||
.await?;
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
|
||||
// Load and compile the templates
|
||||
let templates = Templates::load_from_config(&config.templates)
|
||||
.await
|
||||
@ -217,6 +232,7 @@ impl Options {
|
||||
&mailer,
|
||||
&url_builder,
|
||||
&matrix_config,
|
||||
&policy_factory,
|
||||
)
|
||||
.fallback(static_files)
|
||||
.layer(ServerLayer::default());
|
||||
|
@ -22,6 +22,7 @@ mod database;
|
||||
mod email;
|
||||
mod http;
|
||||
mod matrix;
|
||||
mod policy;
|
||||
mod secrets;
|
||||
mod telemetry;
|
||||
mod templates;
|
||||
@ -33,6 +34,7 @@ pub use self::{
|
||||
email::{EmailConfig, EmailSmtpMode, EmailTransportConfig},
|
||||
http::HttpConfig,
|
||||
matrix::MatrixConfig,
|
||||
policy::PolicyConfig,
|
||||
secrets::{Encrypter, SecretsConfig},
|
||||
telemetry::{
|
||||
MetricsConfig, MetricsExporterConfig, Propagator, TelemetryConfig, TracingConfig,
|
||||
@ -79,6 +81,10 @@ pub struct RootConfig {
|
||||
/// Configuration related to the homeserver
|
||||
#[serde(default)]
|
||||
pub matrix: MatrixConfig,
|
||||
|
||||
/// Configuration related to the OPA policies
|
||||
#[serde(default)]
|
||||
pub policy: PolicyConfig,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@ -98,6 +104,7 @@ impl ConfigurationSection<'_> for RootConfig {
|
||||
email: EmailConfig::generate().await?,
|
||||
secrets: SecretsConfig::generate().await?,
|
||||
matrix: MatrixConfig::generate().await?,
|
||||
policy: PolicyConfig::generate().await?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -112,6 +119,7 @@ impl ConfigurationSection<'_> for RootConfig {
|
||||
email: EmailConfig::test(),
|
||||
secrets: SecretsConfig::test(),
|
||||
matrix: MatrixConfig::test(),
|
||||
policy: PolicyConfig::test(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
90
crates/config/src/sections/policy.rs
Normal file
90
crates/config/src/sections/policy.rs
Normal file
@ -0,0 +1,90 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_wasm_module() -> PathBuf {
|
||||
"./policies/policy.wasm".into()
|
||||
}
|
||||
|
||||
fn default_client_registration_endpoint() -> String {
|
||||
"client_registration/allow".to_string()
|
||||
}
|
||||
|
||||
fn default_login_endpoint() -> String {
|
||||
"login/allow".to_string()
|
||||
}
|
||||
|
||||
fn default_register_endpoint() -> String {
|
||||
"register/allow".to_string()
|
||||
}
|
||||
|
||||
/// Application secrets
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct PolicyConfig {
|
||||
/// Path to the WASM module
|
||||
#[serde(default = "default_wasm_module")]
|
||||
pub wasm_module: PathBuf,
|
||||
|
||||
/// Entrypoint to use when evaluating client registrations
|
||||
#[serde(default = "default_client_registration_endpoint")]
|
||||
pub client_registration_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when evaluating user logins
|
||||
#[serde(default = "default_login_endpoint")]
|
||||
pub login_entrypoint: String,
|
||||
|
||||
/// Entrypoint to use when evaluating user registrations
|
||||
#[serde(default = "default_register_endpoint")]
|
||||
pub register_entrypoint: String,
|
||||
|
||||
/// Arbitrary data to pass to the policy
|
||||
#[serde(default)]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl Default for PolicyConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
wasm_module: default_wasm_module(),
|
||||
client_registration_entrypoint: default_client_registration_endpoint(),
|
||||
login_entrypoint: default_login_endpoint(),
|
||||
register_entrypoint: default_register_endpoint(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for PolicyConfig {
|
||||
fn path() -> &'static str {
|
||||
"policy"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
@ -62,6 +62,7 @@ mas-email = { path = "../email" }
|
||||
mas-http = { path = "../http" }
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-jose = { path = "../jose" }
|
||||
mas-policy = { path = "../policy" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-router = { path = "../router" }
|
||||
|
@ -34,6 +34,7 @@ use mas_config::{Encrypter, MatrixConfig};
|
||||
use mas_email::Mailer;
|
||||
use mas_http::CorsLayerExt;
|
||||
use mas_jose::StaticKeystore;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{Route, UrlBuilder};
|
||||
use mas_templates::{ErrorContext, Templates};
|
||||
use sqlx::PgPool;
|
||||
@ -46,7 +47,11 @@ mod oauth2;
|
||||
mod views;
|
||||
|
||||
#[must_use]
|
||||
#[allow(clippy::too_many_lines, clippy::missing_panics_doc)]
|
||||
#[allow(
|
||||
clippy::too_many_lines,
|
||||
clippy::missing_panics_doc,
|
||||
clippy::too_many_arguments
|
||||
)]
|
||||
pub fn router<B>(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
@ -55,6 +60,7 @@ pub fn router<B>(
|
||||
mailer: &Mailer,
|
||||
url_builder: &UrlBuilder,
|
||||
matrix_config: &MatrixConfig,
|
||||
policy_factory: &Arc<PolicyFactory>,
|
||||
) -> Router<B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
@ -233,4 +239,5 @@ where
|
||||
.layer(Extension(url_builder.clone()))
|
||||
.layer(Extension(mailer.clone()))
|
||||
.layer(Extension(matrix_config.clone()))
|
||||
.layer(Extension(policy_factory.clone()))
|
||||
}
|
||||
|
@ -12,9 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{response::IntoResponse, Extension, Json};
|
||||
use hyper::StatusCode;
|
||||
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_storage::oauth2::client::insert_client;
|
||||
use oauth2_types::{
|
||||
errors::{INVALID_CLIENT_METADATA, INVALID_REDIRECT_URI, SERVER_ERROR},
|
||||
@ -31,11 +34,17 @@ pub(crate) enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("invalid redirect uri")]
|
||||
InvalidRedirectUri,
|
||||
|
||||
#[error("invalid client metadata")]
|
||||
InvalidClientMetadata,
|
||||
|
||||
#[error("denied by the policy")]
|
||||
PolicyDenied,
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for RouteError {
|
||||
@ -47,9 +56,12 @@ impl From<sqlx::Error> for RouteError {
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR)),
|
||||
Self::Internal(_) | Self::Anyhow(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR))
|
||||
}
|
||||
Self::InvalidRedirectUri => (StatusCode::BAD_REQUEST, Json(INVALID_REDIRECT_URI)),
|
||||
Self::InvalidClientMetadata => (StatusCode::BAD_REQUEST, Json(INVALID_CLIENT_METADATA)),
|
||||
Self::PolicyDenied => (StatusCode::UNAUTHORIZED, Json(INVALID_CLIENT_METADATA)),
|
||||
}
|
||||
.into_response()
|
||||
}
|
||||
@ -58,6 +70,7 @@ impl IntoResponse for RouteError {
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub(crate) async fn post(
|
||||
Extension(pool): Extension<PgPool>,
|
||||
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
|
||||
Json(body): Json<ClientMetadata>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
info!(?body, "Client registration");
|
||||
@ -105,6 +118,12 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::InvalidClientMetadata);
|
||||
}
|
||||
|
||||
let mut policy = policy_factory.instanciate().await?;
|
||||
let allowed = policy.evaluate_client_registration(&body).await?;
|
||||
if !allowed {
|
||||
return Err(RouteError::PolicyDenied);
|
||||
}
|
||||
|
||||
// Grab a txn
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
|
19
crates/policy/Cargo.toml
Normal file
19
crates/policy/Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "mas-policy"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.31"
|
||||
opa-wasm = { git = "https://github.com/matrix-org/rust-opa-wasm.git" }
|
||||
serde = { version = "1.0.31", features = ["derive"] }
|
||||
serde_json = "1.0.31"
|
||||
thiserror = "1.0.31"
|
||||
tokio = { version = "1.18.2", features = ["io-util", "rt"] }
|
||||
tracing = "0.1.34"
|
||||
wasmtime = "0.37.0"
|
||||
|
||||
mas-data-model = { path = "../data-model" }
|
||||
oauth2-types = { path = "../oauth2-types" }
|
175
crates/policy/src/lib.rs
Normal file
175
crates/policy/src/lib.rs
Normal file
@ -0,0 +1,175 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::bail;
|
||||
use oauth2_types::registration::ClientMetadata;
|
||||
use opa_wasm::Runtime;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use wasmtime::{Config, Engine, Module, Store};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LoadError {
|
||||
#[error("failed to read module")]
|
||||
Read(#[from] tokio::io::Error),
|
||||
|
||||
#[error("failed to create WASM engine")]
|
||||
Engine(#[source] anyhow::Error),
|
||||
|
||||
#[error("module compilation task crashed")]
|
||||
CompilationTask(#[from] tokio::task::JoinError),
|
||||
|
||||
#[error("failed to compile WASM module")]
|
||||
Compilation(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
pub struct PolicyFactory {
|
||||
engine: Engine,
|
||||
module: Module,
|
||||
data: serde_json::Value,
|
||||
login_entrypoint: String,
|
||||
register_entrypoint: String,
|
||||
client_registration_entrypoint: String,
|
||||
}
|
||||
|
||||
impl PolicyFactory {
|
||||
pub async fn load(
|
||||
mut source: impl AsyncRead + std::marker::Unpin,
|
||||
data: serde_json::Value,
|
||||
login_entrypoint: String,
|
||||
register_entrypoint: String,
|
||||
client_registration_entrypoint: String,
|
||||
) -> Result<Self, LoadError> {
|
||||
let mut config = Config::default();
|
||||
config.async_support(true);
|
||||
config.cranelift_opt_level(wasmtime::OptLevel::Speed);
|
||||
let engine = Engine::new(&config).map_err(LoadError::Engine)?;
|
||||
let mut buf = Vec::new();
|
||||
source.read_to_end(&mut buf).await?;
|
||||
let (engine, module) = tokio::task::spawn_blocking(move || {
|
||||
let module = Module::new(&engine, buf);
|
||||
(engine, module)
|
||||
})
|
||||
.await?;
|
||||
let module = module.map_err(LoadError::Compilation)?;
|
||||
|
||||
Ok(Self {
|
||||
engine,
|
||||
module,
|
||||
data,
|
||||
login_entrypoint,
|
||||
register_entrypoint,
|
||||
client_registration_entrypoint,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn instanciate(&self) -> Result<Policy, anyhow::Error> {
|
||||
let mut store = Store::new(&self.engine, ());
|
||||
let runtime = Runtime::new(&mut store, &self.module).await?;
|
||||
|
||||
// Check that we have the required entrypoints
|
||||
let entrypoints = runtime.entrypoints();
|
||||
|
||||
for e in [
|
||||
self.login_entrypoint.as_str(),
|
||||
self.register_entrypoint.as_str(),
|
||||
] {
|
||||
if !entrypoints.contains(e) {
|
||||
bail!("missing entrypoint {e}")
|
||||
}
|
||||
}
|
||||
|
||||
let instance = runtime.with_data(&mut store, &self.data).await?;
|
||||
|
||||
Ok(Policy {
|
||||
store,
|
||||
instance,
|
||||
login_entrypoint: self.login_entrypoint.clone(),
|
||||
register_entrypoint: self.register_entrypoint.clone(),
|
||||
client_registration_entrypoint: self.client_registration_entrypoint.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EvaluationResult {
|
||||
result: bool,
|
||||
}
|
||||
|
||||
pub struct Policy {
|
||||
store: Store<()>,
|
||||
instance: opa_wasm::Policy,
|
||||
login_entrypoint: String,
|
||||
register_entrypoint: String,
|
||||
client_registration_entrypoint: String,
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
pub async fn evaluate_login(
|
||||
&mut self,
|
||||
user: &mas_data_model::User<()>,
|
||||
) -> Result<bool, anyhow::Error> {
|
||||
let user = serde_json::to_value(user)?;
|
||||
let input = serde_json::json!({ "user": user });
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.login_entrypoint, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res.result)
|
||||
}
|
||||
|
||||
pub async fn evaluate_register(
|
||||
&mut self,
|
||||
username: &str,
|
||||
email: &str,
|
||||
) -> Result<bool, anyhow::Error> {
|
||||
let input = serde_json::json!({
|
||||
"user": {
|
||||
"username": username,
|
||||
"email": email
|
||||
}
|
||||
});
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.register_entrypoint, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res.result)
|
||||
}
|
||||
|
||||
pub async fn evaluate_client_registration(
|
||||
&mut self,
|
||||
client_metadata: &ClientMetadata,
|
||||
) -> Result<bool, anyhow::Error> {
|
||||
let client_metadata = serde_json::to_value(client_metadata)?;
|
||||
let input = serde_json::json!({
|
||||
"client_metadata": client_metadata,
|
||||
});
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(
|
||||
&mut self.store,
|
||||
&self.client_registration_entrypoint,
|
||||
&input,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(res.result)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user