1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Support for applying OPA policies during client registration

This commit is contained in:
Quentin Gliech
2022-06-02 16:30:34 +02:00
parent 959466a5ba
commit aab1f49374
16 changed files with 1153 additions and 28 deletions

View File

@ -36,6 +36,7 @@ mas-config = { path = "../config" }
mas-email = { path = "../email" }
mas-handlers = { path = "../handlers" }
mas-http = { path = "../http" }
mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-static-files = { path = "../static-files" }
mas-storage = { path = "../storage" }

View File

@ -25,6 +25,7 @@ use hyper::Server;
use mas_config::RootConfig;
use mas_email::{MailTransport, Mailer};
use mas_http::ServerLayer;
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR;
use mas_tasks::TaskQueue;
@ -176,6 +177,20 @@ impl Options {
let encrypter = config.secrets.encrypter();
// Load and compile the WASM policies
let mut policy = tokio::fs::File::open(&config.policy.wasm_module)
.await
.context("failed to open OPA WASM policy file")?;
let policy_factory = PolicyFactory::load(
&mut policy,
config.policy.data.clone().unwrap_or_default(),
config.policy.login_entrypoint.clone(),
config.policy.register_entrypoint.clone(),
config.policy.client_registration_entrypoint.clone(),
)
.await?;
let policy_factory = Arc::new(policy_factory);
// Load and compile the templates
let templates = Templates::load_from_config(&config.templates)
.await
@ -217,6 +232,7 @@ impl Options {
&mailer,
&url_builder,
&matrix_config,
&policy_factory,
)
.fallback(static_files)
.layer(ServerLayer::default());

View File

@ -22,6 +22,7 @@ mod database;
mod email;
mod http;
mod matrix;
mod policy;
mod secrets;
mod telemetry;
mod templates;
@ -33,6 +34,7 @@ pub use self::{
email::{EmailConfig, EmailSmtpMode, EmailTransportConfig},
http::HttpConfig,
matrix::MatrixConfig,
policy::PolicyConfig,
secrets::{Encrypter, SecretsConfig},
telemetry::{
MetricsConfig, MetricsExporterConfig, Propagator, TelemetryConfig, TracingConfig,
@ -79,6 +81,10 @@ pub struct RootConfig {
/// Configuration related to the homeserver
#[serde(default)]
pub matrix: MatrixConfig,
/// Configuration related to the OPA policies
#[serde(default)]
pub policy: PolicyConfig,
}
#[async_trait]
@ -98,6 +104,7 @@ impl ConfigurationSection<'_> for RootConfig {
email: EmailConfig::generate().await?,
secrets: SecretsConfig::generate().await?,
matrix: MatrixConfig::generate().await?,
policy: PolicyConfig::generate().await?,
})
}
@ -112,6 +119,7 @@ impl ConfigurationSection<'_> for RootConfig {
email: EmailConfig::test(),
secrets: SecretsConfig::test(),
matrix: MatrixConfig::test(),
policy: PolicyConfig::test(),
}
}
}

View File

@ -0,0 +1,90 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::PathBuf;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use super::ConfigurationSection;
fn default_wasm_module() -> PathBuf {
"./policies/policy.wasm".into()
}
fn default_client_registration_endpoint() -> String {
"client_registration/allow".to_string()
}
fn default_login_endpoint() -> String {
"login/allow".to_string()
}
fn default_register_endpoint() -> String {
"register/allow".to_string()
}
/// Application secrets
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct PolicyConfig {
/// Path to the WASM module
#[serde(default = "default_wasm_module")]
pub wasm_module: PathBuf,
/// Entrypoint to use when evaluating client registrations
#[serde(default = "default_client_registration_endpoint")]
pub client_registration_entrypoint: String,
/// Entrypoint to use when evaluating user logins
#[serde(default = "default_login_endpoint")]
pub login_entrypoint: String,
/// Entrypoint to use when evaluating user registrations
#[serde(default = "default_register_endpoint")]
pub register_entrypoint: String,
/// Arbitrary data to pass to the policy
#[serde(default)]
pub data: Option<serde_json::Value>,
}
impl Default for PolicyConfig {
fn default() -> Self {
Self {
wasm_module: default_wasm_module(),
client_registration_entrypoint: default_client_registration_endpoint(),
login_entrypoint: default_login_endpoint(),
register_entrypoint: default_register_endpoint(),
data: None,
}
}
}
#[async_trait]
impl ConfigurationSection<'_> for PolicyConfig {
fn path() -> &'static str {
"policy"
}
async fn generate() -> anyhow::Result<Self> {
Ok(Self::default())
}
fn test() -> Self {
Self::default()
}
}

View File

@ -62,6 +62,7 @@ mas-email = { path = "../email" }
mas-http = { path = "../http" }
mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }
mas-policy = { path = "../policy" }
mas-storage = { path = "../storage" }
mas-templates = { path = "../templates" }
mas-router = { path = "../router" }

View File

@ -34,6 +34,7 @@ use mas_config::{Encrypter, MatrixConfig};
use mas_email::Mailer;
use mas_http::CorsLayerExt;
use mas_jose::StaticKeystore;
use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder};
use mas_templates::{ErrorContext, Templates};
use sqlx::PgPool;
@ -46,7 +47,11 @@ mod oauth2;
mod views;
#[must_use]
#[allow(clippy::too_many_lines, clippy::missing_panics_doc)]
#[allow(
clippy::too_many_lines,
clippy::missing_panics_doc,
clippy::too_many_arguments
)]
pub fn router<B>(
pool: &PgPool,
templates: &Templates,
@ -55,6 +60,7 @@ pub fn router<B>(
mailer: &Mailer,
url_builder: &UrlBuilder,
matrix_config: &MatrixConfig,
policy_factory: &Arc<PolicyFactory>,
) -> Router<B>
where
B: HttpBody + Send + 'static,
@ -233,4 +239,5 @@ where
.layer(Extension(url_builder.clone()))
.layer(Extension(mailer.clone()))
.layer(Extension(matrix_config.clone()))
.layer(Extension(policy_factory.clone()))
}

View File

@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::{response::IntoResponse, Extension, Json};
use hyper::StatusCode;
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod};
use mas_policy::PolicyFactory;
use mas_storage::oauth2::client::insert_client;
use oauth2_types::{
errors::{INVALID_CLIENT_METADATA, INVALID_REDIRECT_URI, SERVER_ERROR},
@ -31,11 +34,17 @@ pub(crate) enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
#[error("invalid redirect uri")]
InvalidRedirectUri,
#[error("invalid client metadata")]
InvalidClientMetadata,
#[error("denied by the policy")]
PolicyDenied,
}
impl From<sqlx::Error> for RouteError {
@ -47,9 +56,12 @@ impl From<sqlx::Error> for RouteError {
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
Self::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR)),
Self::Internal(_) | Self::Anyhow(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR))
}
Self::InvalidRedirectUri => (StatusCode::BAD_REQUEST, Json(INVALID_REDIRECT_URI)),
Self::InvalidClientMetadata => (StatusCode::BAD_REQUEST, Json(INVALID_CLIENT_METADATA)),
Self::PolicyDenied => (StatusCode::UNAUTHORIZED, Json(INVALID_CLIENT_METADATA)),
}
.into_response()
}
@ -58,6 +70,7 @@ impl IntoResponse for RouteError {
#[tracing::instrument(skip_all, err)]
pub(crate) async fn post(
Extension(pool): Extension<PgPool>,
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
Json(body): Json<ClientMetadata>,
) -> Result<impl IntoResponse, RouteError> {
info!(?body, "Client registration");
@ -105,6 +118,12 @@ pub(crate) async fn post(
return Err(RouteError::InvalidClientMetadata);
}
let mut policy = policy_factory.instanciate().await?;
let allowed = policy.evaluate_client_registration(&body).await?;
if !allowed {
return Err(RouteError::PolicyDenied);
}
// Grab a txn
let mut txn = pool.begin().await?;

19
crates/policy/Cargo.toml Normal file
View File

@ -0,0 +1,19 @@
[package]
name = "mas-policy"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
anyhow = "1.0.31"
opa-wasm = { git = "https://github.com/matrix-org/rust-opa-wasm.git" }
serde = { version = "1.0.31", features = ["derive"] }
serde_json = "1.0.31"
thiserror = "1.0.31"
tokio = { version = "1.18.2", features = ["io-util", "rt"] }
tracing = "0.1.34"
wasmtime = "0.37.0"
mas-data-model = { path = "../data-model" }
oauth2-types = { path = "../oauth2-types" }

175
crates/policy/src/lib.rs Normal file
View File

@ -0,0 +1,175 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::bail;
use oauth2_types::registration::ClientMetadata;
use opa_wasm::Runtime;
use serde::Deserialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
use wasmtime::{Config, Engine, Module, Store};
#[derive(Debug, Error)]
pub enum LoadError {
#[error("failed to read module")]
Read(#[from] tokio::io::Error),
#[error("failed to create WASM engine")]
Engine(#[source] anyhow::Error),
#[error("module compilation task crashed")]
CompilationTask(#[from] tokio::task::JoinError),
#[error("failed to compile WASM module")]
Compilation(#[source] anyhow::Error),
}
pub struct PolicyFactory {
engine: Engine,
module: Module,
data: serde_json::Value,
login_entrypoint: String,
register_entrypoint: String,
client_registration_entrypoint: String,
}
impl PolicyFactory {
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: serde_json::Value,
login_entrypoint: String,
register_entrypoint: String,
client_registration_entrypoint: String,
) -> Result<Self, LoadError> {
let mut config = Config::default();
config.async_support(true);
config.cranelift_opt_level(wasmtime::OptLevel::Speed);
let engine = Engine::new(&config).map_err(LoadError::Engine)?;
let mut buf = Vec::new();
source.read_to_end(&mut buf).await?;
let (engine, module) = tokio::task::spawn_blocking(move || {
let module = Module::new(&engine, buf);
(engine, module)
})
.await?;
let module = module.map_err(LoadError::Compilation)?;
Ok(Self {
engine,
module,
data,
login_entrypoint,
register_entrypoint,
client_registration_entrypoint,
})
}
pub async fn instanciate(&self) -> Result<Policy, anyhow::Error> {
let mut store = Store::new(&self.engine, ());
let runtime = Runtime::new(&mut store, &self.module).await?;
// Check that we have the required entrypoints
let entrypoints = runtime.entrypoints();
for e in [
self.login_entrypoint.as_str(),
self.register_entrypoint.as_str(),
] {
if !entrypoints.contains(e) {
bail!("missing entrypoint {e}")
}
}
let instance = runtime.with_data(&mut store, &self.data).await?;
Ok(Policy {
store,
instance,
login_entrypoint: self.login_entrypoint.clone(),
register_entrypoint: self.register_entrypoint.clone(),
client_registration_entrypoint: self.client_registration_entrypoint.clone(),
})
}
}
#[derive(Deserialize)]
struct EvaluationResult {
result: bool,
}
pub struct Policy {
store: Store<()>,
instance: opa_wasm::Policy,
login_entrypoint: String,
register_entrypoint: String,
client_registration_entrypoint: String,
}
impl Policy {
pub async fn evaluate_login(
&mut self,
user: &mas_data_model::User<()>,
) -> Result<bool, anyhow::Error> {
let user = serde_json::to_value(user)?;
let input = serde_json::json!({ "user": user });
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.login_entrypoint, &input)
.await?;
Ok(res.result)
}
pub async fn evaluate_register(
&mut self,
username: &str,
email: &str,
) -> Result<bool, anyhow::Error> {
let input = serde_json::json!({
"user": {
"username": username,
"email": email
}
});
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.register_entrypoint, &input)
.await?;
Ok(res.result)
}
pub async fn evaluate_client_registration(
&mut self,
client_metadata: &ClientMetadata,
) -> Result<bool, anyhow::Error> {
let client_metadata = serde_json::to_value(client_metadata)?;
let input = serde_json::json!({
"client_metadata": client_metadata,
});
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.client_registration_entrypoint,
&input,
)
.await?;
Ok(res.result)
}
}