diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index 78753d46..1480319c 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -21,10 +21,12 @@ use axum::{ use ipnetwork::IpNetwork; use mas_handlers::{ passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, - HttpClientFactory, MatrixHomeserver, MetadataCache, SiteConfig, + HttpClientFactory, MetadataCache, SiteConfig, }; use mas_i18n::Translator; use mas_keystore::{Encrypter, Keystore}; +use mas_matrix::BoxHomeserverConnection; +use mas_matrix_synapse::SynapseConnection; use mas_policy::{Policy, PolicyFactory}; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; @@ -45,7 +47,7 @@ pub struct AppState { pub cookie_manager: CookieManager, pub encrypter: Encrypter, pub url_builder: UrlBuilder, - pub homeserver: MatrixHomeserver, + pub homeserver_connection: SynapseConnection, pub policy_factory: Arc, pub graphql_schema: mas_graphql::Schema, pub http_client_factory: HttpClientFactory, @@ -177,12 +179,6 @@ impl FromRef for UrlBuilder { } } -impl FromRef for MatrixHomeserver { - fn from_ref(input: &AppState) -> Self { - input.homeserver.clone() - } -} - impl FromRef for HttpClientFactory { fn from_ref(input: &AppState) -> Self { input.http_client_factory.clone() @@ -213,6 +209,12 @@ impl FromRef for SiteConfig { } } +impl FromRef for BoxHomeserverConnection { + fn from_ref(input: &AppState) -> Self { + Box::new(input.homeserver_connection.clone()) + } +} + #[async_trait] impl FromRequestParts for BoxClock { type Rejection = Infallible; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 14765ab6..c8b5b091 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -18,9 +18,7 @@ use anyhow::Context; use clap::Parser; use itertools::Itertools; use mas_config::AppConfig; -use mas_handlers::{ - ActivityTracker, CookieManager, HttpClientFactory, MatrixHomeserver, MetadataCache, SiteConfig, -}; +use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache, SiteConfig}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; @@ -123,6 +121,13 @@ impl Options { let http_client_factory = HttpClientFactory::new().await?; + let homeserver_connection = SynapseConnection::new( + config.matrix.homeserver.clone(), + config.matrix.endpoint.clone(), + config.matrix.secret.clone(), + http_client_factory.clone(), + ); + if !self.no_worker { let mailer = mailer_from_config(&config.email, &templates)?; mailer.test_connection().await?; @@ -132,19 +137,13 @@ impl Options { let worker_name = Alphanumeric.sample_string(&mut rng, 10); info!(worker_name, "Starting task worker"); - let conn = SynapseConnection::new( - config.matrix.homeserver.clone(), - config.matrix.endpoint.clone(), - config.matrix.secret.clone(), - http_client_factory.clone(), - ); - let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?; + let monitor = + mas_tasks::init(&worker_name, &pool, &mailer, homeserver_connection.clone()) + .await?; // TODO: grab the handle tokio::spawn(monitor.run()); } - let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone()); - let listeners_config = config.http.listeners.clone(); let password_manager = password_manager_from_config(&config.passwords).await?; @@ -152,13 +151,6 @@ impl Options { // The upstream OIDC metadata cache let metadata_cache = MetadataCache::new(); - let conn = SynapseConnection::new( - config.matrix.homeserver.clone(), - config.matrix.endpoint.clone(), - config.matrix.secret.clone(), - http_client_factory.clone(), - ); - let site_config = SiteConfig { tos_uri: config.branding.tos_uri.clone(), access_token_ttl: config.experimental.access_token_ttl, @@ -176,7 +168,8 @@ impl Options { // Listen for SIGHUP register_sighup(&templates, &activity_tracker)?; - let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn); + let graphql_schema = + mas_handlers::graphql_schema(&pool, &policy_factory, homeserver_connection.clone()); let state = { let mut s = AppState { @@ -187,7 +180,7 @@ impl Options { cookie_manager, encrypter, url_builder, - homeserver, + homeserver_connection, policy_factory, graphql_schema, http_client_factory, diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 7eaf81c9..4da7483f 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -17,6 +17,7 @@ use chrono::Duration; use hyper::StatusCode; use mas_axum_utils::sentry::SentryEventID; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User, UserAgent}; +use mas_matrix::BoxHomeserverConnection; use mas_storage::{ compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, @@ -32,7 +33,7 @@ use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use thiserror::Error; use zeroize::Zeroizing; -use super::{MatrixError, MatrixHomeserver}; +use super::MatrixError; use crate::{ impl_from_error_for_route, passwords::PasswordManager, site_config::SiteConfig, BoundActivityTracker, @@ -215,7 +216,7 @@ pub(crate) async fn post( State(password_manager): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, - State(homeserver): State, + State(homeserver): State, State(site_config): State, user_agent: Option>, Json(input): Json, @@ -254,7 +255,7 @@ pub(crate) async fn post( .await?; } - let user_id = format!("@{username}:{homeserver}", username = user.username); + let user_id = homeserver.mxid(&user.username); // If the client asked for a refreshable token, make it expire let expires_in = if input.refresh_token { diff --git a/crates/handlers/src/compat/mod.rs b/crates/handlers/src/compat/mod.rs index df26cc49..3ae2030c 100644 --- a/crates/handlers/src/compat/mod.rs +++ b/crates/handlers/src/compat/mod.rs @@ -22,22 +22,6 @@ pub(crate) mod login_sso_redirect; pub(crate) mod logout; pub(crate) mod refresh; -#[derive(Debug, Clone)] -pub struct MatrixHomeserver(String); - -impl MatrixHomeserver { - #[must_use] - pub const fn new(hs: String) -> Self { - Self(hs) - } -} - -impl std::fmt::Display for MatrixHomeserver { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - #[derive(Debug, Serialize)] struct MatrixError { errcode: &'static str, diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 8c8dbe9f..cf1548d6 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -43,6 +43,7 @@ use hyper::{ use mas_axum_utils::{cookies::CookieJar, FancyError}; use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; +use mas_matrix::BoxHomeserverConnection; use mas_policy::Policy; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRepository, BoxRng}; @@ -88,7 +89,6 @@ pub use mas_axum_utils::{ pub use self::{ activity_tracker::{ActivityTracker, Bound as BoundActivityTracker}, - compat::MatrixHomeserver, graphql::schema as graphql_schema, preferred_language::PreferredLanguage, site_config::SiteConfig, @@ -253,7 +253,7 @@ where S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, SiteConfig: FromRef, - MatrixHomeserver: FromRef, + BoxHomeserverConnection: FromRef, PasswordManager: FromRef, BoundActivityTracker: FromRequestParts, BoxRepository: FromRequestParts, diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index f441ccb2..2c78d30d 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -38,7 +38,7 @@ use mas_axum_utils::{ }; use mas_i18n::Translator; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; -use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; +use mas_matrix::{BoxHomeserverConnection, HomeserverConnection, MockHomeserverConnection}; use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; @@ -55,7 +55,7 @@ use crate::{ passwords::{Hasher, PasswordManager}, site_config::SiteConfig, upstream_oauth2::cache::MetadataCache, - ActivityTracker, BoundActivityTracker, MatrixHomeserver, + ActivityTracker, BoundActivityTracker, }; // This might fail if it's not the first time it's being called, which is fine, @@ -99,7 +99,7 @@ pub(crate) struct TestState { pub metadata_cache: MetadataCache, pub encrypter: Encrypter, pub url_builder: UrlBuilder, - pub homeserver: MatrixHomeserver, + pub homeserver_connection: Arc, pub policy_factory: Arc, pub graphql_schema: mas_graphql::Schema, pub http_client_factory: HttpClientFactory, @@ -148,11 +148,9 @@ impl TestState { let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; - let homeserver = MatrixHomeserver::new("example.com".to_owned()); - let policy_factory = policy_factory(serde_json::json!({})).await?; - let homeserver_connection = MockHomeserverConnection::new("example.com"); + let homeserver_connection = Arc::new(MockHomeserverConnection::new("example.com")); let http_client_factory = HttpClientFactory::new().await?; @@ -167,7 +165,7 @@ impl TestState { let graphql_state = TestGraphQLState { pool: pool.clone(), policy_factory: Arc::clone(&policy_factory), - homeserver_connection, + homeserver_connection: Arc::clone(&homeserver_connection), rng: Arc::clone(&rng), clock: Arc::clone(&clock), }; @@ -186,7 +184,7 @@ impl TestState { metadata_cache, encrypter, url_builder, - homeserver, + homeserver_connection, policy_factory, graphql_schema, http_client_factory, @@ -281,7 +279,7 @@ impl TestState { struct TestGraphQLState { pool: PgPool, - homeserver_connection: MockHomeserverConnection, + homeserver_connection: Arc, policy_factory: Arc, clock: Arc, rng: Arc>, @@ -360,12 +358,6 @@ impl FromRef for UrlBuilder { } } -impl FromRef for MatrixHomeserver { - fn from_ref(input: &TestState) -> Self { - input.homeserver.clone() - } -} - impl FromRef for HttpClientFactory { fn from_ref(input: &TestState) -> Self { input.http_client_factory.clone() @@ -396,6 +388,12 @@ impl FromRef for SiteConfig { } } +impl FromRef for BoxHomeserverConnection { + fn from_ref(input: &TestState) -> Self { + Box::new(input.homeserver_connection.clone()) + } +} + #[async_trait] impl FromRequestParts for ActivityTracker { type Rejection = Infallible; diff --git a/crates/matrix-synapse/src/lib.rs b/crates/matrix-synapse/src/lib.rs index 54575700..f58c1714 100644 --- a/crates/matrix-synapse/src/lib.rs +++ b/crates/matrix-synapse/src/lib.rs @@ -22,6 +22,7 @@ use url::Url; static SYNAPSE_AUTH_PROVIDER: &str = "oauth-delegated"; +#[derive(Clone)] pub struct SynapseConnection { homeserver: String, endpoint: Url, diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index 5e1589e3..26b921fa 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -14,8 +14,14 @@ mod mock; +use std::sync::Arc; + pub use self::mock::HomeserverConnection as MockHomeserverConnection; +// TODO: this should probably be another error type by default +pub type BoxHomeserverConnection = + Box>; + #[derive(Debug)] pub struct MatrixUser { pub displayname: Option, @@ -351,3 +357,49 @@ impl HomeserverConnection for &T (**self).allow_cross_signing_reset(mxid).await } } + +// Implement for Arc where T: HomeserverConnection +#[async_trait::async_trait] +impl HomeserverConnection for Arc { + type Error = T::Error; + + fn homeserver(&self) -> &str { + (**self).homeserver() + } + + async fn query_user(&self, mxid: &str) -> Result { + (**self).query_user(mxid).await + } + + async fn provision_user(&self, request: &ProvisionRequest) -> Result { + (**self).provision_user(request).await + } + + async fn is_localpart_available(&self, localpart: &str) -> Result { + (**self).is_localpart_available(localpart).await + } + + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + (**self).create_device(mxid, device_id).await + } + + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + (**self).delete_device(mxid, device_id).await + } + + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + (**self).delete_user(mxid, erase).await + } + + async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> { + (**self).set_displayname(mxid, displayname).await + } + + async fn unset_displayname(&self, mxid: &str) -> Result<(), Self::Error> { + (**self).unset_displayname(mxid).await + } + + async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), Self::Error> { + (**self).allow_cross_signing_reset(mxid).await + } +}