From 8a2be43fe794582b64e1caa9d87cdbcf202af2ad Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 4 Apr 2023 17:06:10 +0200 Subject: [PATCH] Proactively provision users on registration & sync threepids --- Cargo.lock | 3 + crates/axum-utils/src/http_client_factory.rs | 2 +- crates/axum-utils/src/lib.rs | 2 + crates/cli/src/commands/server.rs | 9 +- crates/cli/src/commands/worker.rs | 12 +- crates/config/src/sections/matrix.rs | 39 ++-- crates/config/src/sections/mod.rs | 1 - crates/handlers/src/upstream_oauth2/link.rs | 8 +- .../handlers/src/views/account/emails/mod.rs | 8 +- .../src/views/account/emails/verify.rs | 10 +- crates/handlers/src/views/register.rs | 6 +- crates/http/src/client.rs | 2 +- crates/storage/src/job.rs | 96 ++++++++- crates/tasks/Cargo.toml | 3 + crates/tasks/src/lib.rs | 51 ++++- crates/tasks/src/matrix.rs | 184 ++++++++++++++++++ 16 files changed, 411 insertions(+), 25 deletions(-) create mode 100644 crates/tasks/src/matrix.rs diff --git a/Cargo.lock b/Cargo.lock index 057c68b9..9f6c5e96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3502,8 +3502,10 @@ dependencies = [ "apalis-sql", "async-trait", "chrono", + "mas-axum-utils", "mas-data-model", "mas-email", + "mas-http", "mas-storage", "mas-storage-pg", "rand 0.8.5", @@ -3515,6 +3517,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "ulid", + "url", ] [[package]] diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs index 6eb5f940..a143616a 100644 --- a/crates/axum-utils/src/http_client_factory.rs +++ b/crates/axum-utils/src/http_client_factory.rs @@ -48,7 +48,7 @@ impl HttpClientFactory { operation: &'static str, ) -> Result>, ClientInitError> where - B: axum::body::HttpBody + Send + Sync + 'static, + B: axum::body::HttpBody + Send, B::Data: Send, { let client = mas_http::make_traced_client::().await?; diff --git a/crates/axum-utils/src/lib.rs b/crates/axum-utils/src/lib.rs index d03fe4db..90fb6aba 100644 --- a/crates/axum-utils/src/lib.rs +++ b/crates/axum-utils/src/lib.rs @@ -31,6 +31,8 @@ pub mod jwt; pub mod session; pub mod user_authorization; +pub use axum; + pub use self::{ cookies::CookieExt, fancy_error::FancyError, diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 312c9f76..90c08407 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -22,6 +22,7 @@ use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_router::UrlBuilder; use mas_storage_pg::MIGRATOR; +use mas_tasks::HomeserverConnection; use rand::{ distributions::{Alphanumeric, DistString}, thread_rng, @@ -96,7 +97,13 @@ impl Options { let worker_name = Alphanumeric.sample_string(&mut rng, 10); info!(worker_name, "Starting task worker"); - let monitor = mas_tasks::init(&worker_name, &pool, &mailer); + let http_client_factory = HttpClientFactory::new(50); + let conn = HomeserverConnection::new( + config.matrix.homeserver.clone(), + config.matrix.endpoint.clone(), + config.matrix.secret.clone(), + ); + let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn, &http_client_factory); // TODO: grab the handle tokio::spawn(monitor.run()); } diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index fa7ebdb2..a90aaf59 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -14,7 +14,9 @@ use clap::Parser; use mas_config::RootConfig; +use mas_handlers::HttpClientFactory; use mas_router::UrlBuilder; +use mas_tasks::HomeserverConnection; use rand::{ distributions::{Alphanumeric, DistString}, thread_rng, @@ -42,6 +44,14 @@ impl Options { let mailer = mailer_from_config(&config.email, &templates).await?; mailer.test_connection().await?; + + let http_client_factory = HttpClientFactory::new(50); + let conn = HomeserverConnection::new( + config.matrix.homeserver.clone(), + config.matrix.endpoint.clone(), + config.matrix.secret.clone(), + ); + drop(config); #[allow(clippy::disallowed_methods)] @@ -49,7 +59,7 @@ impl Options { let worker_name = Alphanumeric.sample_string(&mut rng, 10); info!(worker_name, "Starting task scheduler"); - let monitor = mas_tasks::init(&worker_name, &pool, &mailer); + let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn, &http_client_factory); span.exit(); diff --git a/crates/config/src/sections/matrix.rs b/crates/config/src/sections/matrix.rs index 0a3866c6..3647c9ef 100644 --- a/crates/config/src/sections/matrix.rs +++ b/crates/config/src/sections/matrix.rs @@ -13,10 +13,14 @@ // limitations under the License. use async_trait::async_trait; -use rand::Rng; +use rand::{ + distributions::{Alphanumeric, DistString}, + Rng, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_with::serde_as; +use url::Url; use super::ConfigurationSection; @@ -24,6 +28,10 @@ fn default_homeserver() -> String { "localhost:8008".to_owned() } +fn default_endpoint() -> Url { + Url::parse("http://localhost:8008/").unwrap() +} + /// Configuration related to the Matrix homeserver #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] @@ -31,14 +39,13 @@ pub struct MatrixConfig { /// Time-to-live of a CSRF token in seconds #[serde(default = "default_homeserver")] pub homeserver: String, -} -impl Default for MatrixConfig { - fn default() -> Self { - Self { - homeserver: default_homeserver(), - } - } + /// Shared secret to use for calls to the admin API + pub secret: String, + + /// The base URL of the homeserver's client API + #[serde(default = "default_endpoint")] + pub endpoint: Url, } #[async_trait] @@ -47,15 +54,23 @@ impl ConfigurationSection<'_> for MatrixConfig { "matrix" } - async fn generate(_rng: R) -> anyhow::Result + async fn generate(mut rng: R) -> anyhow::Result where R: Rng + Send, { - Ok(Self::default()) + Ok(Self { + homeserver: default_homeserver(), + secret: Alphanumeric.sample_string(&mut rng, 32), + endpoint: default_endpoint(), + }) } fn test() -> Self { - Self::default() + Self { + homeserver: default_homeserver(), + secret: "test".to_owned(), + endpoint: default_endpoint(), + } } } @@ -73,12 +88,14 @@ mod tests { r#" matrix: homeserver: matrix.org + secret: test "#, )?; let config = MatrixConfig::load_from_file("config.yaml")?; assert_eq!(config.homeserver, "matrix.org".to_owned()); + assert_eq!(config.secret, "test".to_owned()); Ok(()) }); diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index 017f3e3d..96e8cbc9 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -89,7 +89,6 @@ pub struct RootConfig { pub passwords: PasswordsConfig, /// Configuration related to the homeserver - #[serde(default)] pub matrix: MatrixConfig, /// Configuration related to the OPA policies diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 567b1746..8f8488d2 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,9 +25,10 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ + job::{JobRepositoryExt, ProvisionUserJob}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - BoxClock, BoxRepository, BoxRng, + BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, @@ -285,6 +286,11 @@ pub(crate) async fn post( (None, None, FormData::Register { username }) => { let user = repo.user().add(&mut rng, &clock, username).await?; + + repo.job() + .schedule_job(ProvisionUserJob::new(&user)) + .await?; + repo.upstream_oauth_link() .associate_to_user(&link, &user) .await?; diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 4aa38ed9..a1ee661f 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -26,7 +26,7 @@ use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ - job::{JobRepositoryExt, VerifyEmailJob}, + job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; @@ -179,6 +179,12 @@ pub(crate) async fn post( } }; + // XXX: It shouldn't hurt to do this even if the user didn't change their emails + // in a meaningful way + repo.job() + .schedule_job(ProvisionUserJob::new(&session.user)) + .await?; + let reply = render( &mut rng, &clock, diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 5dcb3aa2..51765e92 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,7 +24,11 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; +use mas_storage::{ + job::{JobRepositoryExt, ProvisionUserJob}, + user::UserEmailRepository, + BoxClock, BoxRepository, BoxRng, RepositoryAccess, +}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use ulid::Ulid; @@ -133,6 +137,10 @@ pub(crate) async fn post( .mark_as_verified(&clock, user_email) .await?; + repo.job() + .schedule_job(ProvisionUserJob::new(&session.user)) + .await?; + repo.save().await?; let destination = query.go_next_or_default(&mas_router::AccountEmails); diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 9f1484c5..61b4c8bb 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -28,7 +28,7 @@ use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ - job::{JobRepositoryExt, VerifyEmailJob}, + job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; @@ -205,6 +205,10 @@ pub(crate) async fn post( .schedule_job(VerifyEmailJob::new(&user_email)) .await?; + repo.job() + .schedule_job(ProvisionUserJob::new(&user)) + .await?; + repo.save().await?; let cookie_jar = cookie_jar.set_session(&session); diff --git a/crates/http/src/client.rs b/crates/http/src/client.rs index 19d9d726..f23db552 100644 --- a/crates/http/src/client.rs +++ b/crates/http/src/client.rs @@ -158,7 +158,7 @@ where /// Returns an error if it failed to load the TLS certificates pub async fn make_traced_client() -> Result, ClientInitError> where - B: http_body::Body + Send + 'static, + B: http_body::Body + Send, B::Data: Send, { let https = make_traced_connector().await?; diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs index 78db872d..3eb246a4 100644 --- a/crates/storage/src/job.rs +++ b/crates/storage/src/job.rs @@ -216,7 +216,7 @@ where mod jobs { // XXX: Move this somewhere else? use apalis_core::job::Job; - use mas_data_model::UserEmail; + use mas_data_model::{User, UserEmail}; use serde::{Deserialize, Serialize}; use ulid::Ulid; @@ -245,6 +245,98 @@ mod jobs { impl Job for VerifyEmailJob { const NAME: &'static str = "verify-email"; } + + /// A job to provision the user on the homeserver. + #[derive(Serialize, Deserialize, Debug, Clone)] + pub struct ProvisionUserJob { + user_id: Ulid, + } + + impl ProvisionUserJob { + /// Create a new job to provision the user on the homeserver. + #[must_use] + pub fn new(user: &User) -> Self { + Self { user_id: user.id } + } + + /// The ID of the user to provision. + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + } + + impl Job for ProvisionUserJob { + const NAME: &'static str = "provision-user"; + } + + /// A job to provision a device for a user on the homeserver. + #[derive(Serialize, Deserialize, Debug, Clone)] + pub struct ProvisionDeviceJob { + user_id: Ulid, + device_id: String, + } + + impl ProvisionDeviceJob { + /// Create a new job to provision a device for a user on the homeserver. + #[must_use] + pub fn new(user: &User, device_id: &str) -> Self { + Self { + user_id: user.id, + device_id: device_id.to_owned(), + } + } + + /// The ID of the user to provision the device for. + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + + /// The ID of the device to provision. + #[must_use] + pub fn device_id(&self) -> &str { + &self.device_id + } + } + + impl Job for ProvisionDeviceJob { + const NAME: &'static str = "provision-device"; + } + + /// A job to delete a device for a user on the homeserver. + #[derive(Serialize, Deserialize, Debug, Clone)] + pub struct DeleteDeviceJob { + user_id: Ulid, + device_id: String, + } + + impl DeleteDeviceJob { + /// Create a new job to delete a device for a user on the homeserver. + #[must_use] + pub fn new(user: &User, device_id: &str) -> Self { + Self { + user_id: user.id, + device_id: device_id.to_owned(), + } + } + + /// The ID of the user to delete the device for. + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + + /// The ID of the device to delete. + #[must_use] + pub fn device_id(&self) -> &str { + &self.device_id + } + } + + impl Job for DeleteDeviceJob { + const NAME: &'static str = "delete-device"; + } } -pub use self::jobs::VerifyEmailJob; +pub use self::jobs::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, VerifyEmailJob}; diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 328a4672..33cdbe7e 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -20,9 +20,12 @@ tower = "0.4.13" tracing = "0.1.37" tracing-opentelemetry = "0.18.0" ulid = "1.0.0" +url = "2.3.1" serde = { version = "1.0.159", features = ["derive"] } +mas-axum-utils = { path = "../axum-utils" } mas-storage = { path = "../storage" } mas-storage-pg = { path = "../storage-pg" } mas-email = { path = "../email" } +mas-http = { path = "../http" } mas-data-model = { path = "../data-model" } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 4bea7f93..7474c745 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -16,9 +16,13 @@ #![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] #![warn(clippy::pedantic)] +use std::sync::Arc; + use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor}; use apalis_sql::postgres::PostgresStorage; +use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::Mailer; +use mas_http::{ClientInitError, ClientService, TracedClient}; use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock}; use mas_storage_pg::{DatabaseError, PgRepository}; use rand::SeedableRng; @@ -28,20 +32,33 @@ use tracing::debug; mod database; mod email; mod layers; +mod matrix; + +pub use self::matrix::HomeserverConnection; #[derive(Clone)] struct State { pool: Pool, mailer: Mailer, clock: SystemClock, + homeserver: Arc, + http_client_factory: HttpClientFactory, } impl State { - pub fn new(pool: Pool, clock: SystemClock, mailer: Mailer) -> Self { + pub fn new( + pool: Pool, + clock: SystemClock, + mailer: Mailer, + homeserver: HomeserverConnection, + http_client_factory: HttpClientFactory, + ) -> Self { Self { pool, mailer, clock, + homeserver: Arc::new(homeserver), + http_client_factory, } } @@ -81,6 +98,21 @@ impl State { Ok(repo) } + + pub fn matrix_connection(&self) -> &HomeserverConnection { + &self.homeserver + } + + pub async fn http_client( + &self, + operation: &'static str, + ) -> Result>, ClientInitError> + where + B: mas_axum_utils::axum::body::HttpBody + Send, + B::Data: Send, + { + self.http_client_factory.client(operation).await + } } trait JobContextExt { @@ -96,11 +128,24 @@ impl JobContextExt for apalis_core::context::JobContext { } #[must_use] -pub fn init(name: &str, pool: &Pool, mailer: &Mailer) -> Monitor { - let state = State::new(pool.clone(), SystemClock::default(), mailer.clone()); +pub fn init( + name: &str, + pool: &Pool, + mailer: &Mailer, + homeserver: HomeserverConnection, + http_client_factory: &HttpClientFactory, +) -> Monitor { + let state = State::new( + pool.clone(), + SystemClock::default(), + mailer.clone(), + homeserver, + http_client_factory.clone(), + ); let monitor = Monitor::new(); let monitor = self::database::register(name, monitor, &state); let monitor = self::email::register(name, monitor, &state); + let monitor = self::matrix::register(name, monitor, &state); debug!(?monitor, "workers registered"); monitor } diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs new file mode 100644 index 00000000..d3c91665 --- /dev/null +++ b/crates/tasks/src/matrix.rs @@ -0,0 +1,184 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Context; +use apalis_core::{ + builder::{WorkerBuilder, WorkerFactory}, + context::JobContext, + executor::TokioExecutor, + job::Job, + job_fn::job_fn, + monitor::Monitor, + storage::builder::WithStorage, +}; +use mas_axum_utils::axum::{ + headers::{Authorization, HeaderMapExt}, + http::{Request, StatusCode}, +}; +use mas_http::HttpServiceExt; +use mas_storage::{ + job::{JobWithSpanContext, ProvisionUserJob}, + user::{UserEmailRepository, UserRepository}, + RepositoryAccess, +}; +use serde::{Deserialize, Serialize}; +use tower::{Service, ServiceExt}; +use tracing::{info, info_span, Instrument}; +use url::Url; + +use crate::{layers::TracingLayer, JobContextExt, State}; + +pub struct HomeserverConnection { + homeserver: String, + endpoint: Url, + access_token: String, +} + +impl HomeserverConnection { + pub fn new(homeserver: String, endpoint: Url, access_token: String) -> Self { + Self { + homeserver, + endpoint, + access_token, + } + } +} + +#[derive(Serialize, Deserialize)] +struct ExternalID { + pub auth_provider: String, + pub external_id: String, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +enum ThreePIDMedium { + Email, + MSISDN, +} + +#[derive(Serialize, Deserialize)] +struct ThreePID { + pub medium: ThreePIDMedium, + pub address: String, +} + +#[derive(Serialize, Deserialize)] +struct UserRequest { + #[serde(rename = "displayname")] + pub display_name: String, + + #[serde(rename = "threepids")] + pub three_pids: Vec, + + pub external_ids: Vec, +} + +#[tracing::instrument( + name = "job.provision_user" + fields(user.id = %job.user_id()), + skip_all, + err(Debug), +)] +async fn provision_user( + job: JobWithSpanContext, + ctx: JobContext, +) -> Result<(), anyhow::Error> { + let state = ctx.state(); + let matrix = state.matrix_connection(); + let mut client = state + .http_client("provision-matrix-user") + .await? + .request_bytes_to_body() + .json_request(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(job.user_id()) + .await? + .context("User not found")?; + + let mxid = format!("@{}:{}", user.username, matrix.homeserver); + + let three_pids = repo + .user_email() + .all(&user) + .await? + .into_iter() + .filter_map(|email| { + if email.confirmed_at.is_some() { + Some(ThreePID { + medium: ThreePIDMedium::Email, + address: email.email, + }) + } else { + None + } + }) + .collect(); + + let display_name = user.username.clone(); + + let body = UserRequest { + display_name, + three_pids, + external_ids: vec![ExternalID { + auth_provider: "oauth-delegated".to_string(), + external_id: user.sub, + }], + }; + + repo.cancel().await?; + + let mut req = Request::put( + matrix + .endpoint + .join("_synapse/admin/v2/users/")? + .join(&mxid)? + .as_str(), + ); + req.headers_mut() + .context("Failed to get headers")? + .typed_insert(Authorization::bearer(&matrix.access_token)?); + + let req = req.body(body).context("Failed to build request")?; + + let span = info_span!("matrix.provision_user", %mxid); + let response = client.ready().await?.call(req).instrument(span).await?; + + match response.status() { + StatusCode::CREATED => info!(%user.id, %mxid, "User created"), + StatusCode::OK => info!(%user.id, %mxid, "User updated"), + // TODO: Better error handling + code => anyhow::bail!("Failed to provision user. Status code: {code}"), + } + + Ok(()) +} + +pub(crate) fn register( + suffix: &str, + monitor: Monitor, + state: &State, +) -> Monitor { + let storage = state.store(); + let worker_name = format!("{job}-{suffix}", job = ProvisionUserJob::NAME); + let worker = WorkerBuilder::new(worker_name) + .layer(state.inject()) + .layer(TracingLayer::new()) + .with_storage(storage) + .build(job_fn(provision_user)); + monitor.register(worker) +}