1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Make the HomeserverConnection available in handlers

This commit is contained in:
Quentin Gliech
2024-02-28 09:58:27 +01:00
parent 20dd5ca311
commit 4aeb446061
8 changed files with 96 additions and 65 deletions

View File

@@ -21,10 +21,12 @@ use axum::{
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use mas_handlers::{ use mas_handlers::{
passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper,
HttpClientFactory, MatrixHomeserver, MetadataCache, SiteConfig, HttpClientFactory, MetadataCache, SiteConfig,
}; };
use mas_i18n::Translator; use mas_i18n::Translator;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_matrix::BoxHomeserverConnection;
use mas_matrix_synapse::SynapseConnection;
use mas_policy::{Policy, PolicyFactory}; use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
@@ -45,7 +47,7 @@ pub struct AppState {
pub cookie_manager: CookieManager, pub cookie_manager: CookieManager,
pub encrypter: Encrypter, pub encrypter: Encrypter,
pub url_builder: UrlBuilder, pub url_builder: UrlBuilder,
pub homeserver: MatrixHomeserver, pub homeserver_connection: SynapseConnection,
pub policy_factory: Arc<PolicyFactory>, pub policy_factory: Arc<PolicyFactory>,
pub graphql_schema: mas_graphql::Schema, pub graphql_schema: mas_graphql::Schema,
pub http_client_factory: HttpClientFactory, pub http_client_factory: HttpClientFactory,
@@ -177,12 +179,6 @@ impl FromRef<AppState> for UrlBuilder {
} }
} }
impl FromRef<AppState> for MatrixHomeserver {
fn from_ref(input: &AppState) -> Self {
input.homeserver.clone()
}
}
impl FromRef<AppState> for HttpClientFactory { impl FromRef<AppState> for HttpClientFactory {
fn from_ref(input: &AppState) -> Self { fn from_ref(input: &AppState) -> Self {
input.http_client_factory.clone() input.http_client_factory.clone()
@@ -213,6 +209,12 @@ impl FromRef<AppState> for SiteConfig {
} }
} }
impl FromRef<AppState> for BoxHomeserverConnection {
fn from_ref(input: &AppState) -> Self {
Box::new(input.homeserver_connection.clone())
}
}
#[async_trait] #[async_trait]
impl FromRequestParts<AppState> for BoxClock { impl FromRequestParts<AppState> for BoxClock {
type Rejection = Infallible; type Rejection = Infallible;

View File

@@ -18,9 +18,7 @@ use anyhow::Context;
use clap::Parser; use clap::Parser;
use itertools::Itertools; use itertools::Itertools;
use mas_config::AppConfig; use mas_config::AppConfig;
use mas_handlers::{ use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache, SiteConfig};
ActivityTracker, CookieManager, HttpClientFactory, MatrixHomeserver, MetadataCache, SiteConfig,
};
use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_matrix_synapse::SynapseConnection; use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
@@ -123,6 +121,13 @@ impl Options {
let http_client_factory = HttpClientFactory::new().await?; let http_client_factory = HttpClientFactory::new().await?;
let homeserver_connection = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client_factory.clone(),
);
if !self.no_worker { if !self.no_worker {
let mailer = mailer_from_config(&config.email, &templates)?; let mailer = mailer_from_config(&config.email, &templates)?;
mailer.test_connection().await?; mailer.test_connection().await?;
@@ -132,19 +137,13 @@ impl Options {
let worker_name = Alphanumeric.sample_string(&mut rng, 10); let worker_name = Alphanumeric.sample_string(&mut rng, 10);
info!(worker_name, "Starting task worker"); info!(worker_name, "Starting task worker");
let conn = SynapseConnection::new( let monitor =
config.matrix.homeserver.clone(), mas_tasks::init(&worker_name, &pool, &mailer, homeserver_connection.clone())
config.matrix.endpoint.clone(), .await?;
config.matrix.secret.clone(),
http_client_factory.clone(),
);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
// TODO: grab the handle // TODO: grab the handle
tokio::spawn(monitor.run()); tokio::spawn(monitor.run());
} }
let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone());
let listeners_config = config.http.listeners.clone(); let listeners_config = config.http.listeners.clone();
let password_manager = password_manager_from_config(&config.passwords).await?; let password_manager = password_manager_from_config(&config.passwords).await?;
@@ -152,13 +151,6 @@ impl Options {
// The upstream OIDC metadata cache // The upstream OIDC metadata cache
let metadata_cache = MetadataCache::new(); let metadata_cache = MetadataCache::new();
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client_factory.clone(),
);
let site_config = SiteConfig { let site_config = SiteConfig {
tos_uri: config.branding.tos_uri.clone(), tos_uri: config.branding.tos_uri.clone(),
access_token_ttl: config.experimental.access_token_ttl, access_token_ttl: config.experimental.access_token_ttl,
@@ -176,7 +168,8 @@ impl Options {
// Listen for SIGHUP // Listen for SIGHUP
register_sighup(&templates, &activity_tracker)?; register_sighup(&templates, &activity_tracker)?;
let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn); let graphql_schema =
mas_handlers::graphql_schema(&pool, &policy_factory, homeserver_connection.clone());
let state = { let state = {
let mut s = AppState { let mut s = AppState {
@@ -187,7 +180,7 @@ impl Options {
cookie_manager, cookie_manager,
encrypter, encrypter,
url_builder, url_builder,
homeserver, homeserver_connection,
policy_factory, policy_factory,
graphql_schema, graphql_schema,
http_client_factory, http_client_factory,

View File

@@ -17,6 +17,7 @@ use chrono::Duration;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID; use mas_axum_utils::sentry::SentryEventID;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User, UserAgent}; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User, UserAgent};
use mas_matrix::BoxHomeserverConnection;
use mas_storage::{ use mas_storage::{
compat::{ compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
@@ -32,7 +33,7 @@ use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use thiserror::Error; use thiserror::Error;
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::{MatrixError, MatrixHomeserver}; use super::MatrixError;
use crate::{ use crate::{
impl_from_error_for_route, passwords::PasswordManager, site_config::SiteConfig, impl_from_error_for_route, passwords::PasswordManager, site_config::SiteConfig,
BoundActivityTracker, BoundActivityTracker,
@@ -215,7 +216,7 @@ pub(crate) async fn post(
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
mut repo: BoxRepository, mut repo: BoxRepository,
activity_tracker: BoundActivityTracker, activity_tracker: BoundActivityTracker,
State(homeserver): State<MatrixHomeserver>, State(homeserver): State<BoxHomeserverConnection>,
State(site_config): State<SiteConfig>, State(site_config): State<SiteConfig>,
user_agent: Option<TypedHeader<headers::UserAgent>>, user_agent: Option<TypedHeader<headers::UserAgent>>,
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
@@ -254,7 +255,7 @@ pub(crate) async fn post(
.await?; .await?;
} }
let user_id = format!("@{username}:{homeserver}", username = user.username); let user_id = homeserver.mxid(&user.username);
// If the client asked for a refreshable token, make it expire // If the client asked for a refreshable token, make it expire
let expires_in = if input.refresh_token { let expires_in = if input.refresh_token {

View File

@@ -22,22 +22,6 @@ pub(crate) mod login_sso_redirect;
pub(crate) mod logout; pub(crate) mod logout;
pub(crate) mod refresh; pub(crate) mod refresh;
#[derive(Debug, Clone)]
pub struct MatrixHomeserver(String);
impl MatrixHomeserver {
#[must_use]
pub const fn new(hs: String) -> Self {
Self(hs)
}
}
impl std::fmt::Display for MatrixHomeserver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct MatrixError { struct MatrixError {
errcode: &'static str, errcode: &'static str,

View File

@@ -43,6 +43,7 @@ use hyper::{
use mas_axum_utils::{cookies::CookieJar, FancyError}; use mas_axum_utils::{cookies::CookieJar, FancyError};
use mas_http::CorsLayerExt; use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_matrix::BoxHomeserverConnection;
use mas_policy::Policy; use mas_policy::Policy;
use mas_router::{Route, UrlBuilder}; use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_storage::{BoxClock, BoxRepository, BoxRng};
@@ -88,7 +89,6 @@ pub use mas_axum_utils::{
pub use self::{ pub use self::{
activity_tracker::{ActivityTracker, Bound as BoundActivityTracker}, activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
compat::MatrixHomeserver,
graphql::schema as graphql_schema, graphql::schema as graphql_schema,
preferred_language::PreferredLanguage, preferred_language::PreferredLanguage,
site_config::SiteConfig, site_config::SiteConfig,
@@ -253,7 +253,7 @@ where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
SiteConfig: FromRef<S>, SiteConfig: FromRef<S>,
MatrixHomeserver: FromRef<S>, BoxHomeserverConnection: FromRef<S>,
PasswordManager: FromRef<S>, PasswordManager: FromRef<S>,
BoundActivityTracker: FromRequestParts<S>, BoundActivityTracker: FromRequestParts<S>,
BoxRepository: FromRequestParts<S>, BoxRepository: FromRequestParts<S>,

View File

@@ -38,7 +38,7 @@ use mas_axum_utils::{
}; };
use mas_i18n::Translator; use mas_i18n::Translator;
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; use mas_matrix::{BoxHomeserverConnection, HomeserverConnection, MockHomeserverConnection};
use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::{SimpleRoute, UrlBuilder}; use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
@@ -55,7 +55,7 @@ use crate::{
passwords::{Hasher, PasswordManager}, passwords::{Hasher, PasswordManager},
site_config::SiteConfig, site_config::SiteConfig,
upstream_oauth2::cache::MetadataCache, upstream_oauth2::cache::MetadataCache,
ActivityTracker, BoundActivityTracker, MatrixHomeserver, ActivityTracker, BoundActivityTracker,
}; };
// This might fail if it's not the first time it's being called, which is fine, // This might fail if it's not the first time it's being called, which is fine,
@@ -99,7 +99,7 @@ pub(crate) struct TestState {
pub metadata_cache: MetadataCache, pub metadata_cache: MetadataCache,
pub encrypter: Encrypter, pub encrypter: Encrypter,
pub url_builder: UrlBuilder, pub url_builder: UrlBuilder,
pub homeserver: MatrixHomeserver, pub homeserver_connection: Arc<MockHomeserverConnection>,
pub policy_factory: Arc<PolicyFactory>, pub policy_factory: Arc<PolicyFactory>,
pub graphql_schema: mas_graphql::Schema, pub graphql_schema: mas_graphql::Schema,
pub http_client_factory: HttpClientFactory, pub http_client_factory: HttpClientFactory,
@@ -148,11 +148,9 @@ impl TestState {
let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?;
let homeserver = MatrixHomeserver::new("example.com".to_owned());
let policy_factory = policy_factory(serde_json::json!({})).await?; let policy_factory = policy_factory(serde_json::json!({})).await?;
let homeserver_connection = MockHomeserverConnection::new("example.com"); let homeserver_connection = Arc::new(MockHomeserverConnection::new("example.com"));
let http_client_factory = HttpClientFactory::new().await?; let http_client_factory = HttpClientFactory::new().await?;
@@ -167,7 +165,7 @@ impl TestState {
let graphql_state = TestGraphQLState { let graphql_state = TestGraphQLState {
pool: pool.clone(), pool: pool.clone(),
policy_factory: Arc::clone(&policy_factory), policy_factory: Arc::clone(&policy_factory),
homeserver_connection, homeserver_connection: Arc::clone(&homeserver_connection),
rng: Arc::clone(&rng), rng: Arc::clone(&rng),
clock: Arc::clone(&clock), clock: Arc::clone(&clock),
}; };
@@ -186,7 +184,7 @@ impl TestState {
metadata_cache, metadata_cache,
encrypter, encrypter,
url_builder, url_builder,
homeserver, homeserver_connection,
policy_factory, policy_factory,
graphql_schema, graphql_schema,
http_client_factory, http_client_factory,
@@ -281,7 +279,7 @@ impl TestState {
struct TestGraphQLState { struct TestGraphQLState {
pool: PgPool, pool: PgPool,
homeserver_connection: MockHomeserverConnection, homeserver_connection: Arc<MockHomeserverConnection>,
policy_factory: Arc<PolicyFactory>, policy_factory: Arc<PolicyFactory>,
clock: Arc<MockClock>, clock: Arc<MockClock>,
rng: Arc<Mutex<ChaChaRng>>, rng: Arc<Mutex<ChaChaRng>>,
@@ -360,12 +358,6 @@ impl FromRef<TestState> for UrlBuilder {
} }
} }
impl FromRef<TestState> for MatrixHomeserver {
fn from_ref(input: &TestState) -> Self {
input.homeserver.clone()
}
}
impl FromRef<TestState> for HttpClientFactory { impl FromRef<TestState> for HttpClientFactory {
fn from_ref(input: &TestState) -> Self { fn from_ref(input: &TestState) -> Self {
input.http_client_factory.clone() input.http_client_factory.clone()
@@ -396,6 +388,12 @@ impl FromRef<TestState> for SiteConfig {
} }
} }
impl FromRef<TestState> for BoxHomeserverConnection {
fn from_ref(input: &TestState) -> Self {
Box::new(input.homeserver_connection.clone())
}
}
#[async_trait] #[async_trait]
impl FromRequestParts<TestState> for ActivityTracker { impl FromRequestParts<TestState> for ActivityTracker {
type Rejection = Infallible; type Rejection = Infallible;

View File

@@ -22,6 +22,7 @@ use url::Url;
static SYNAPSE_AUTH_PROVIDER: &str = "oauth-delegated"; static SYNAPSE_AUTH_PROVIDER: &str = "oauth-delegated";
#[derive(Clone)]
pub struct SynapseConnection { pub struct SynapseConnection {
homeserver: String, homeserver: String,
endpoint: Url, endpoint: Url,

View File

@@ -14,8 +14,14 @@
mod mock; mod mock;
use std::sync::Arc;
pub use self::mock::HomeserverConnection as MockHomeserverConnection; pub use self::mock::HomeserverConnection as MockHomeserverConnection;
// TODO: this should probably be another error type by default
pub type BoxHomeserverConnection<Error = anyhow::Error> =
Box<dyn HomeserverConnection<Error = Error>>;
#[derive(Debug)] #[derive(Debug)]
pub struct MatrixUser { pub struct MatrixUser {
pub displayname: Option<String>, pub displayname: Option<String>,
@@ -351,3 +357,49 @@ impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T
(**self).allow_cross_signing_reset(mxid).await (**self).allow_cross_signing_reset(mxid).await
} }
} }
// Implement for Arc<T> where T: HomeserverConnection
#[async_trait::async_trait]
impl<T: HomeserverConnection + ?Sized> HomeserverConnection for Arc<T> {
type Error = T::Error;
fn homeserver(&self) -> &str {
(**self).homeserver()
}
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
(**self).query_user(mxid).await
}
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error> {
(**self).provision_user(request).await
}
async fn is_localpart_available(&self, localpart: &str) -> Result<bool, Self::Error> {
(**self).is_localpart_available(localpart).await
}
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
(**self).create_device(mxid, device_id).await
}
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
(**self).delete_device(mxid, device_id).await
}
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
(**self).delete_user(mxid, erase).await
}
async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> {
(**self).set_displayname(mxid, displayname).await
}
async fn unset_displayname(&self, mxid: &str) -> Result<(), Self::Error> {
(**self).unset_displayname(mxid).await
}
async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), Self::Error> {
(**self).allow_cross_signing_reset(mxid).await
}
}