1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Refactor the matrix connection logic

Also make the display name available through the graphql api
This commit is contained in:
Quentin Gliech
2023-06-16 11:31:01 +02:00
parent 2a514cf452
commit 4181cbc9d5
25 changed files with 763 additions and 231 deletions

33
Cargo.lock generated
View File

@@ -3145,6 +3145,8 @@ dependencies = [
"mas-http",
"mas-iana",
"mas-listener",
"mas-matrix",
"mas-matrix-synapse",
"mas-policy",
"mas-router",
"mas-spa",
@@ -3259,11 +3261,13 @@ dependencies = [
"chrono",
"lettre",
"mas-data-model",
"mas-matrix",
"mas-storage",
"oauth2-types",
"serde",
"thiserror",
"tokio",
"tower",
"tracing",
"ulid",
"url",
@@ -3295,6 +3299,7 @@ dependencies = [
"mas-iana",
"mas-jose",
"mas-keystore",
"mas-matrix",
"mas-oidc-client",
"mas-policy",
"mas-router",
@@ -3465,6 +3470,31 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "mas-matrix"
version = "0.1.0"
dependencies = [
"async-trait",
"http",
"serde",
"url",
]
[[package]]
name = "mas-matrix-synapse"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"http",
"mas-axum-utils",
"mas-http",
"mas-matrix",
"serde",
"tower",
"url",
]
[[package]]
name = "mas-oidc-client"
version = "0.1.0"
@@ -3605,10 +3635,9 @@ dependencies = [
"apalis-sql",
"async-trait",
"chrono",
"mas-axum-utils",
"mas-data-model",
"mas-email",
"mas-http",
"mas-matrix",
"mas-storage",
"mas-storage-pg",
"mas-tower",

View File

@@ -13,6 +13,7 @@ axum = "0.6.18"
camino = "1.1.4"
clap = { version = "4.3.3", features = ["derive"] }
dotenv = "0.15.0"
httpdate = "1.0.2"
hyper = { version = "0.14.26", features = ["full"] }
itertools = "0.10.5"
listenfd = "1.0.1"
@@ -52,6 +53,8 @@ mas-handlers = { path = "../handlers", default-features = false }
mas-http = { path = "../http", default-features = false, features = ["axum", "client"] }
mas-iana = { path = "../iana" }
mas-listener = { path = "../listener" }
mas-matrix = { path = "../matrix" }
mas-matrix-synapse = { path = "../matrix-synapse" }
mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-spa = { path = "../spa" }
@@ -61,7 +64,6 @@ mas-tasks = { path = "../tasks" }
mas-templates = { path = "../templates" }
mas-tower = { path = "../tower" }
oauth2-types = { path = "../oauth2-types" }
httpdate = "1.0.2"
[dev-dependencies]
indoc = "2.0.1"

View File

@@ -20,9 +20,9 @@ use itertools::Itertools;
use mas_config::RootConfig;
use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
use mas_storage_pg::MIGRATOR;
use mas_tasks::HomeserverConnection;
use rand::{
distributions::{Alphanumeric, DistString},
thread_rng,
@@ -96,14 +96,17 @@ impl Options {
let mut rng = thread_rng();
let worker_name = Alphanumeric.sample_string(&mut rng, 10);
info!(worker_name, "Starting task worker");
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
let conn = HomeserverConnection::new(
info!(worker_name, "Starting task worker");
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client_factory,
);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn, &http_client_factory);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn);
// TODO: grab the handle
tokio::spawn(monitor.run());
}
@@ -114,7 +117,17 @@ impl Options {
let password_manager = password_manager_from_config(&config.passwords).await?;
// Explicitely the config to properly zeroize secret keys
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client_factory.clone(),
);
// Explicitly the config to properly zeroize secret keys
drop(config);
// Watch for changes in templates if the --watch flag is present
@@ -122,10 +135,7 @@ impl Options {
watch_templates(&templates).await?;
}
let graphql_schema = mas_handlers::graphql_schema(&pool);
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
let graphql_schema = mas_handlers::graphql_schema(&pool, conn);
let state = AppState {
pool,

View File

@@ -15,8 +15,8 @@
use clap::Parser;
use mas_config::RootConfig;
use mas_handlers::HttpClientFactory;
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
use mas_tasks::HomeserverConnection;
use rand::{
distributions::{Alphanumeric, DistString},
thread_rng,
@@ -46,10 +46,11 @@ impl Options {
mailer.test_connection().await?;
let http_client_factory = HttpClientFactory::new(50);
let conn = HomeserverConnection::new(
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client_factory,
);
drop(config);
@@ -59,7 +60,7 @@ impl Options {
let worker_name = Alphanumeric.sample_string(&mut rng, 10);
info!(worker_name, "Starting task scheduler");
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn, &http_client_factory);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn);
span.exit();

View File

@@ -15,11 +15,13 @@ serde = { version = "1.0.164", features = ["derive"] }
thiserror = "1.0.40"
tokio = { version = "1.28.2", features = ["sync"] }
tracing = "0.1.37"
tower = { version = "0.4.13", features = ["util"] }
ulid = "1.0.0"
url = "2.4.0"
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-matrix = { path = "../matrix" }
mas-storage = { path = "../storage" }
[[bin]]

View File

@@ -0,0 +1,45 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::SimpleObject;
use mas_matrix::HomeserverConnection;
#[derive(SimpleObject)]
pub struct MatrixUser {
/// The Matrix ID of the user.
mxid: String,
/// The display name of the user, if any.
display_name: Option<String>,
/// The avatar URL of the user, if any.
avatar_url: Option<String>,
}
impl MatrixUser {
pub(crate) async fn load<C: HomeserverConnection + ?Sized>(
conn: &C,
user: &str,
) -> Result<MatrixUser, C::Error> {
let mxid = conn.mxid(user);
let info = conn.query_user(&mxid).await?;
Ok(MatrixUser {
mxid,
display_name: info.displayname,
avatar_url: info.avatar_url,
})
}
}

View File

@@ -18,6 +18,7 @@ use chrono::{DateTime, Utc};
mod browser_sessions;
mod compat_sessions;
mod cursor;
mod matrix;
mod node;
mod oauth;
mod upstream_oauth;

View File

@@ -29,7 +29,7 @@ use super::{
compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session,
UpstreamOAuth2Link,
};
use crate::state::ContextExt;
use crate::{model::matrix::MatrixUser, state::ContextExt};
#[derive(Description)]
/// A user is an individual's account.
@@ -59,6 +59,13 @@ impl User {
&self.0.username
}
/// Access to the user's Matrix account information.
async fn matrix(&self, ctx: &Context<'_>) -> Result<MatrixUser, async_graphql::Error> {
let state = ctx.state();
let conn = state.homeserver_connection();
Ok(MatrixUser::load(conn, &self.0.username).await?)
}
/// Primary email address of the user.
async fn primary_email(
&self,

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use mas_matrix::HomeserverConnection;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
use crate::Requester;
@@ -19,6 +20,7 @@ use crate::Requester;
#[async_trait::async_trait]
pub trait State {
async fn repository(&self) -> Result<BoxRepository, RepositoryError>;
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error>;
fn clock(&self) -> BoxClock;
fn rng(&self) -> BoxRng;
}

View File

@@ -64,6 +64,7 @@ mas-http = { path = "../http", default-features = false }
mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }
mas-keystore = { path = "../keystore" }
mas-matrix = { path = "../matrix" }
mas-oidc-client = { path = "../oidc-client" }
mas-policy = { path = "../policy" }
mas-router = { path = "../router" }

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use async_graphql::{
extensions::{ApolloTracing, Tracing},
http::{playground_source, GraphQLPlaygroundConfig, MultipartOptions},
@@ -29,6 +31,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::{Requester, Schema};
use mas_keystore::Encrypter;
use mas_matrix::HomeserverConnection;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, RepositoryError, SystemClock};
use mas_storage_pg::PgRepository;
use rand::{thread_rng, SeedableRng};
@@ -38,6 +41,7 @@ use tracing::{info_span, Instrument};
struct GraphQLState {
pool: PgPool,
homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
}
#[async_trait]
@@ -50,6 +54,10 @@ impl mas_graphql::State for GraphQLState {
Ok(repo.map_err(RepositoryError::from_error).boxed())
}
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
self.homeserver_connection.as_ref()
}
fn clock(&self) -> BoxClock {
let clock = SystemClock::default();
Box::new(clock)
@@ -65,8 +73,14 @@ impl mas_graphql::State for GraphQLState {
}
#[must_use]
pub fn schema(pool: &PgPool) -> Schema {
let state = GraphQLState { pool: pool.clone() };
pub fn schema(
pool: &PgPool,
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Schema {
let state = GraphQLState {
pool: pool.clone(),
homeserver_connection: Arc::new(homeserver_connection),
};
let state: mas_graphql::BoxState = Box::new(state);
mas_graphql::schema_builder()

View File

@@ -23,6 +23,7 @@ use headers::{Authorization, ContentType, HeaderMapExt, HeaderName};
use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode};
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{HomeserverConnection, MatrixUser, ProvisionRequest};
use mas_policy::PolicyFactory;
use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
@@ -68,6 +69,40 @@ pub(crate) struct TestState {
pub rng: Arc<Mutex<ChaChaRng>>,
}
/// A Mock implementation of a [`HomeserverConnection`], which never fails and
/// doesn't do anything.
struct MockHomeserverConnection {
homeserver: String,
}
#[async_trait]
impl HomeserverConnection for MockHomeserverConnection {
type Error = anyhow::Error;
fn homeserver(&self) -> &str {
&self.homeserver
}
async fn query_user(&self, _mxid: &str) -> Result<MatrixUser, Self::Error> {
Ok(MatrixUser {
displayname: None,
avatar_url: None,
})
}
async fn provision_user(&self, _request: &ProvisionRequest) -> Result<bool, Self::Error> {
Ok(false)
}
async fn create_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> {
Ok(())
}
async fn delete_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> {
Ok(())
}
}
impl TestState {
/// Create a new test state from the given database pool
pub async fn from_pool(pool: PgPool) -> Result<Self, anyhow::Error> {
@@ -106,9 +141,13 @@ impl TestState {
)
.await?;
let homeserver_connection = MockHomeserverConnection {
homeserver: "example.com".to_owned(),
};
let policy_factory = Arc::new(policy_factory);
let graphql_schema = graphql_schema(&pool);
let graphql_schema = graphql_schema(&pool, homeserver_connection);
let http_client_factory = HttpClientFactory::new(10);

View File

@@ -0,0 +1,18 @@
[package]
name = "mas-matrix-synapse"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
anyhow = "1.0.71"
async-trait = "0.1.68"
http = "0.2.9"
url = "2.4.0"
serde = { version = "1.0.164", features = ["derive"] }
tower = { version = "0.4.13", features = ["util"] }
mas-axum-utils = { path = "../axum-utils" }
mas-http = { path = "../http" }
mas-matrix = { path = "../matrix" }

View File

@@ -0,0 +1,256 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)]
use http::{header::AUTHORIZATION, request::Builder, Method, Request, StatusCode};
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_http::{EmptyBody, HttpServiceExt};
use mas_matrix::{HomeserverConnection, MatrixUser, ProvisionRequest};
use serde::{Deserialize, Serialize};
use tower::{Service, ServiceExt};
use url::Url;
static SYNAPSE_AUTH_PROVIDER: &str = "oauth-delegated";
pub struct SynapseConnection {
homeserver: String,
endpoint: Url,
access_token: String,
http_client_factory: HttpClientFactory,
}
impl SynapseConnection {
#[must_use]
pub fn new(
homeserver: String,
endpoint: Url,
access_token: String,
http_client_factory: HttpClientFactory,
) -> Self {
Self {
homeserver,
endpoint,
access_token,
http_client_factory,
}
}
fn builder(&self, url: &str) -> Builder {
Request::builder()
.uri(
self.endpoint
.join(url)
.map(Url::into)
.unwrap_or(String::new()),
)
.header(AUTHORIZATION, format!("Bearer {}", self.access_token))
}
#[must_use]
pub fn post(&self, url: &str) -> Builder {
self.builder(url).method(Method::POST)
}
#[must_use]
pub fn get(&self, url: &str) -> Builder {
self.builder(url).method(Method::GET)
}
#[must_use]
pub fn put(&self, url: &str) -> Builder {
self.builder(url).method(Method::PUT)
}
#[must_use]
pub fn delete(&self, url: &str) -> Builder {
self.builder(url).method(Method::DELETE)
}
}
#[derive(Serialize, Deserialize)]
struct ExternalID {
auth_provider: String,
external_id: String,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum ThreePIDMedium {
Email,
Msisdn,
}
#[derive(Serialize, Deserialize)]
struct ThreePID {
medium: ThreePIDMedium,
address: String,
}
#[derive(Default, Serialize, Deserialize)]
struct SynapseUser {
#[serde(
default,
rename = "displayname",
skip_serializing_if = "Option::is_none"
)]
display_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
avatar_url: Option<String>,
#[serde(default, rename = "threepids", skip_serializing_if = "Option::is_none")]
three_pids: Option<Vec<ThreePID>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
external_ids: Option<Vec<ExternalID>>,
}
#[derive(Serialize, Deserialize)]
struct SynapseDevice {
device_id: String,
}
#[async_trait::async_trait]
impl HomeserverConnection for SynapseConnection {
type Error = anyhow::Error;
fn homeserver(&self) -> &str {
&self.homeserver
}
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
let mut client = self
.http_client_factory
.client()
.await?
.response_body_to_bytes()
.json_response();
let request = self
.get(&format!("_synapse/admin/v2/users/{mxid}"))
.body(EmptyBody::new())?;
let response = client.ready().await?.call(request).await?;
if response.status() != StatusCode::OK {
return Err(anyhow::anyhow!("Failed to query user from Synapse"));
}
let body: SynapseUser = response.into_body();
Ok(MatrixUser {
displayname: body.display_name,
avatar_url: body.avatar_url,
})
}
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error> {
let mut body = SynapseUser {
external_ids: Some(vec![ExternalID {
auth_provider: SYNAPSE_AUTH_PROVIDER.to_owned(),
external_id: request.sub().to_owned(),
}]),
..SynapseUser::default()
};
request
.on_displayname(|displayname| {
body.display_name = Some(displayname.unwrap_or_default().to_owned());
})
.on_avatar_url(|avatar_url| {
body.avatar_url = Some(avatar_url.unwrap_or_default().to_owned());
})
.on_emails(|emails| {
body.three_pids = Some(
emails
.unwrap_or_default()
.iter()
.map(|email| ThreePID {
medium: ThreePIDMedium::Email,
address: email.clone(),
})
.collect(),
);
});
let mut client = self
.http_client_factory
.client()
.await?
.request_bytes_to_body()
.json_request();
let request = self
.put(&format!(
"_synapse/admin/v2/users/{mxid}",
mxid = request.mxid()
))
.body(body)?;
let response = client.ready().await?.call(request).await?;
match response.status() {
StatusCode::CREATED => Ok(true),
StatusCode::OK => Ok(false),
code => Err(anyhow::anyhow!(
"Failed to provision user in Synapse: {}",
code
)),
}
}
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self
.http_client_factory
.client()
.await?
.request_bytes_to_body()
.json_request();
let request = self
.post(&format!("_synapse/admin/v2/users/{mxid}/devices"))
.body(SynapseDevice {
device_id: device_id.to_owned(),
})?;
let response = client.ready().await?.call(request).await?;
if response.status() != StatusCode::CREATED {
return Err(anyhow::anyhow!("Failed to create device in Synapse"));
}
Ok(())
}
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self.http_client_factory.client().await?;
let request = self
.delete(&format!(
"_synapse/admin/v2/users/{mxid}/devices/{device_id}"
))
.body(EmptyBody::new())?;
let response = client.ready().await?.call(request).await?;
if response.status() != StatusCode::OK {
return Err(anyhow::anyhow!("Failed to delete device in Synapse"));
}
Ok(())
}
}

12
crates/matrix/Cargo.toml Normal file
View File

@@ -0,0 +1,12 @@
[package]
name = "mas-matrix"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
serde = { version = "1.0.164", features = ["derive"] }
async-trait = "0.1.68"
http = "0.2.9"
url = "2.4.0"

168
crates/matrix/src/lib.rs Normal file
View File

@@ -0,0 +1,168 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)]
#[derive(Debug)]
pub struct MatrixUser {
pub displayname: Option<String>,
pub avatar_url: Option<String>,
}
#[derive(Debug, Default)]
enum FieldAction<T> {
#[default]
DoNothing,
Set(T),
Unset,
}
pub struct ProvisionRequest {
mxid: String,
sub: String,
displayname: FieldAction<String>,
avatar_url: FieldAction<String>,
emails: FieldAction<Vec<String>>,
}
impl ProvisionRequest {
#[must_use]
pub fn new(mxid: String, sub: String) -> Self {
Self {
mxid,
sub,
displayname: FieldAction::DoNothing,
avatar_url: FieldAction::DoNothing,
emails: FieldAction::DoNothing,
}
}
#[must_use]
pub fn sub(&self) -> &str {
&self.sub
}
#[must_use]
pub fn mxid(&self) -> &str {
&self.mxid
}
#[must_use]
pub fn set_displayname(mut self, displayname: String) -> Self {
self.displayname = FieldAction::Set(displayname);
self
}
#[must_use]
pub fn unset_displayname(mut self) -> Self {
self.displayname = FieldAction::Unset;
self
}
pub fn on_displayname(&self, callback: impl FnOnce(Option<&str>)) -> &Self {
match &self.displayname {
FieldAction::DoNothing => callback(None),
FieldAction::Set(displayname) => callback(Some(displayname)),
FieldAction::Unset => {}
}
self
}
#[must_use]
pub fn set_avatar_url(mut self, avatar_url: String) -> Self {
self.avatar_url = FieldAction::Set(avatar_url);
self
}
#[must_use]
pub fn unset_avatar_url(mut self) -> Self {
self.avatar_url = FieldAction::Unset;
self
}
pub fn on_avatar_url(&self, callback: impl FnOnce(Option<&str>)) -> &Self {
match &self.avatar_url {
FieldAction::DoNothing => callback(None),
FieldAction::Set(avatar_url) => callback(Some(avatar_url)),
FieldAction::Unset => {}
}
self
}
#[must_use]
pub fn set_emails(mut self, emails: Vec<String>) -> Self {
self.emails = FieldAction::Set(emails);
self
}
#[must_use]
pub fn unset_emails(mut self) -> Self {
self.emails = FieldAction::Unset;
self
}
pub fn on_emails(&self, callback: impl FnOnce(Option<&[String]>)) -> &Self {
match &self.emails {
FieldAction::DoNothing => callback(None),
FieldAction::Set(emails) => callback(Some(emails)),
FieldAction::Unset => {}
}
self
}
}
#[async_trait::async_trait]
pub trait HomeserverConnection: Send + Sync {
type Error;
fn homeserver(&self) -> &str;
fn mxid(&self, localpart: &str) -> String {
format!("@{}:{}", localpart, self.homeserver())
}
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error>;
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error>;
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>;
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>;
}
#[async_trait::async_trait]
impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T {
type Error = T::Error;
fn homeserver(&self) -> &str {
(**self).homeserver()
}
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
(**self).query_user(mxid).await
}
async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, Self::Error> {
(**self).provision_user(request).await
}
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
(**self).create_device(mxid, device_id).await
}
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
(**self).delete_device(mxid, device_id).await
}
}

View File

@@ -24,10 +24,9 @@ ulid = "1.0.0"
url = "2.4.0"
serde = { version = "1.0.164", features = ["derive"] }
mas-axum-utils = { path = "../axum-utils" }
mas-data-model = { path = "../data-model" }
mas-email = { path = "../email" }
mas-http = { path = "../http" }
mas-matrix = { path = "../matrix" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
mas-tower = { path = "../tower" }

View File

@@ -20,9 +20,8 @@ use std::sync::Arc;
use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor};
use apalis_sql::postgres::PostgresStorage;
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_email::Mailer;
use mas_http::{ClientInitError, ClientService, TracedClient};
use mas_matrix::HomeserverConnection;
use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock};
use mas_storage_pg::{DatabaseError, PgRepository};
use rand::SeedableRng;
@@ -34,15 +33,12 @@ mod email;
mod matrix;
mod utils;
pub use self::matrix::HomeserverConnection;
#[derive(Clone)]
struct State {
pool: Pool<Postgres>,
mailer: Mailer,
clock: SystemClock,
homeserver: Arc<HomeserverConnection>,
http_client_factory: HttpClientFactory,
homeserver: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
}
impl State {
@@ -50,15 +46,13 @@ impl State {
pool: Pool<Postgres>,
clock: SystemClock,
mailer: Mailer,
homeserver: HomeserverConnection,
http_client_factory: HttpClientFactory,
homeserver: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Self {
Self {
pool,
mailer,
clock,
homeserver: Arc::new(homeserver),
http_client_factory,
}
}
@@ -97,16 +91,8 @@ impl State {
Ok(repo)
}
pub fn matrix_connection(&self) -> &HomeserverConnection {
&self.homeserver
}
pub async fn http_client<B>(&self) -> Result<ClientService<TracedClient<B>>, ClientInitError>
where
B: mas_axum_utils::axum::body::HttpBody + Send,
B::Data: Send,
{
self.http_client_factory.client().await
pub fn matrix_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
self.homeserver.as_ref()
}
}
@@ -127,15 +113,13 @@ pub fn init(
name: &str,
pool: &Pool<Postgres>,
mailer: &Mailer,
homeserver: HomeserverConnection,
http_client_factory: &HttpClientFactory,
homeserver: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Monitor<TokioExecutor> {
let state = State::new(
pool.clone(),
SystemClock::default(),
mailer.clone(),
homeserver,
http_client_factory.clone(),
);
let monitor = Monitor::new().executor(TokioExecutor::new());
let monitor = self::database::register(name, monitor, &state);

View File

@@ -21,73 +21,19 @@ use apalis_core::{
monitor::Monitor,
storage::builder::WithStorage,
};
use mas_axum_utils::axum::{
headers::{Authorization, HeaderMapExt},
http::{Request, StatusCode},
};
use mas_http::{EmptyBody, HttpServiceExt};
use mas_matrix::ProvisionRequest;
use mas_storage::{
job::{DeleteDeviceJob, JobWithSpanContext, ProvisionDeviceJob, ProvisionUserJob},
user::{UserEmailRepository, UserRepository},
RepositoryAccess,
};
use serde::{Deserialize, Serialize};
use tower::{Service, ServiceExt};
use tracing::{info, info_span, Instrument};
use url::Url;
use tracing::info;
use crate::{
utils::{metrics_layer, trace_layer},
JobContextExt, State,
};
pub struct HomeserverConnection {
homeserver: String,
endpoint: Url,
access_token: String,
}
impl HomeserverConnection {
#[must_use]
pub fn new(homeserver: String, endpoint: Url, access_token: String) -> Self {
Self {
homeserver,
endpoint,
access_token,
}
}
}
#[derive(Serialize, Deserialize)]
struct ExternalID {
pub auth_provider: String,
pub external_id: String,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum ThreePIDMedium {
Email,
Msisdn,
}
#[derive(Serialize, Deserialize)]
struct ThreePID {
pub medium: ThreePIDMedium,
pub address: String,
}
#[derive(Serialize, Deserialize)]
struct UserRequest {
#[serde(rename = "displayname")]
pub display_name: String,
#[serde(rename = "threepids")]
pub three_pids: Vec<ThreePID>,
pub external_ids: Vec<ExternalID>,
}
/// Job to provision a user on the Matrix homeserver.
/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id}
/// endpoint.
@@ -103,11 +49,6 @@ async fn provision_user(
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let matrix = state.matrix_connection();
let mut client = state
.http_client()
.await?
.request_bytes_to_body()
.json_request();
let mut repo = state.repository().await?;
let user = repo
@@ -116,73 +57,30 @@ async fn provision_user(
.await?
.context("User not found")?;
// XXX: there is a lot that could go wrong in terms of encoding here
let mxid = format!(
"@{localpart}:{homeserver}",
localpart = user.username,
homeserver = matrix.homeserver
);
let three_pids = repo
let mxid = matrix.mxid(&user.username);
let emails = repo
.user_email()
.all(&user)
.await?
.into_iter()
.filter_map(|email| {
if email.confirmed_at.is_some() {
Some(ThreePID {
medium: ThreePIDMedium::Email,
address: email.email,
})
} else {
None
}
})
.filter(|email| email.confirmed_at.is_some())
.map(|email| email.email)
.collect();
let display_name = user.username.clone();
let body = UserRequest {
display_name,
three_pids,
external_ids: vec![ExternalID {
auth_provider: "oauth-delegated".to_owned(),
external_id: user.sub,
}],
};
repo.cancel().await?;
let path = format!("_synapse/admin/v2/users/{mxid}",);
let mut req = Request::put(matrix.endpoint.join(&path)?.as_str());
req.headers_mut()
.context("Failed to get headers")?
.typed_insert(Authorization::bearer(&matrix.access_token)?);
let request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
let created = matrix.provision_user(&request).await?;
let req = req.body(body).context("Failed to build request")?;
let response = client
.ready()
.await?
.call(req)
.instrument(info_span!("matrix.provision_user"))
.await?;
match response.status() {
StatusCode::CREATED => info!(%user.id, %mxid, "User created"),
StatusCode::OK => info!(%user.id, %mxid, "User updated"),
// TODO: Better error handling
code => anyhow::bail!("Failed to provision user. Status code: {code}"),
if created {
info!(%user.id, %mxid, "User created");
} else {
info!(%user.id, %mxid, "User updated");
}
Ok(())
}
#[derive(Serialize, Deserialize)]
struct DeviceRequest {
device_id: String,
}
/// Job to provision a device on the Matrix homeserver.
/// This works by doing a POST request to the
/// /_synapse/admin/v2/users/{user_id}/devices endpoint.
@@ -201,11 +99,6 @@ async fn provision_device(
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let matrix = state.matrix_connection();
let mut client = state
.http_client()
.await?
.request_bytes_to_body()
.json_request();
let mut repo = state.repository().await?;
let user = repo
@@ -214,38 +107,10 @@ async fn provision_device(
.await?
.context("User not found")?;
// XXX: there is a lot that could go wrong in terms of encoding here
let mxid = format!(
"@{localpart}:{homeserver}",
localpart = user.username,
homeserver = matrix.homeserver
);
let mxid = matrix.mxid(&user.username);
let path = format!("_synapse/admin/v2/users/{mxid}/devices");
let mut req = Request::post(matrix.endpoint.join(&path)?.as_str());
req.headers_mut()
.context("Failed to get headers")?
.typed_insert(Authorization::bearer(&matrix.access_token)?);
let req = req
.body(DeviceRequest {
device_id: job.device_id().to_owned(),
})
.context("Failed to build request")?;
let response = client
.ready()
.await?
.call(req)
.instrument(info_span!("matrix.create_device"))
.await?;
match response.status() {
StatusCode::CREATED => {
info!(%user.id, %mxid, device.id = job.device_id(), "Device created");
}
code => anyhow::bail!("Failed to provision device. Status code: {code}"),
}
matrix.create_device(&mxid, job.device_id()).await?;
info!(%user.id, %mxid, device.id = job.device_id(), "Device created");
Ok(())
}
@@ -268,7 +133,6 @@ async fn delete_device(
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let matrix = state.matrix_connection();
let mut client = state.http_client().await?;
let mut repo = state.repository().await?;
let user = repo
@@ -277,37 +141,10 @@ async fn delete_device(
.await?
.context("User not found")?;
// XXX: there is a lot that could go wrong in terms of encoding here
let mxid = format!(
"@{localpart}:{homeserver}",
localpart = user.username,
homeserver = matrix.homeserver
);
let mxid = matrix.mxid(&user.username);
let path = format!(
"_synapse/admin/v2/users/{mxid}/devices/{device_id}",
device_id = job.device_id()
);
let mut req = Request::delete(matrix.endpoint.join(&path)?.as_str());
req.headers_mut()
.context("Failed to get headers")?
.typed_insert(Authorization::bearer(&matrix.access_token)?);
let req = req
.body(EmptyBody::new())
.context("Failed to build request")?;
let response = client
.ready()
.await?
.call(req)
.instrument(info_span!("matrix.delete_device"))
.await?;
match response.status() {
StatusCode::OK => info!(%user.id, %mxid, "Device deleted"),
code => anyhow::bail!("Failed to delete device. Status code: {code}"),
};
matrix.delete_device(&mxid, job.device_id()).await?;
info!(%user.id, %mxid, device.id = job.device_id(), "Device deleted");
Ok(())
}

View File

@@ -97,7 +97,8 @@ module.exports = {
{
exceptions: {
// The '*Connection', '*Edge', '*Payload' and 'PageInfo' types don't have IDs
types: ["PageInfo"],
// XXX: Maybe the MatrixUser type should have an ID?
types: ["PageInfo", "MatrixUser"],
suffixes: ["Connection", "Edge", "Payload"],
},
},

View File

@@ -293,6 +293,21 @@ enum EndOAuth2SessionStatus {
NOT_FOUND
}
type MatrixUser {
"""
The Matrix ID of the user.
"""
mxid: String!
"""
The display name of the user, if any.
"""
displayName: String
"""
The avatar URL of the user, if any.
"""
avatarUrl: String
}
"""
The mutations root of the GraphQL interface.
"""
@@ -758,6 +773,10 @@ type User implements Node {
"""
username: String!
"""
Access to the user's Matrix account information.
"""
matrix: MatrixUser!
"""
Primary email address of the user.
"""
primaryEmail: UserEmail

View File

@@ -24,6 +24,10 @@ const QUERY = graphql(/* GraphQL */ `
user(id: $userId) {
id
username
matrix {
mxid
displayName
}
}
}
`);
@@ -44,9 +48,9 @@ const UserGreeting: React.FC<{ userId: string }> = ({ userId }) => {
return (
<header className="oidc_Header">
<Heading size="xl" weight="semibold">
John Doe
{result.data.user.matrix.displayName || result.data.user.username}
</Heading>
<Body size="lg">{result.data.user.username}</Body>
<Body size="lg">{result.data.user.matrix.mxid}</Body>
</header>
);
}

View File

@@ -51,7 +51,7 @@ const documents = {
types.UserEmailListQueryDocument,
"\n query UserPrimaryEmail($userId: ID!) {\n user(id: $userId) {\n id\n primaryEmail {\n id\n }\n }\n }\n":
types.UserPrimaryEmailDocument,
"\n query UserGreeting($userId: ID!) {\n user(id: $userId) {\n id\n username\n }\n }\n":
"\n query UserGreeting($userId: ID!) {\n user(id: $userId) {\n id\n username\n matrix {\n mxid\n displayName\n }\n }\n }\n":
types.UserGreetingDocument,
"\n query BrowserSessionQuery($id: ID!) {\n browserSession(id: $id) {\n id\n createdAt\n lastAuthentication {\n id\n createdAt\n }\n user {\n id\n username\n }\n }\n }\n":
types.BrowserSessionQueryDocument,
@@ -191,8 +191,8 @@ export function graphql(
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/
export function graphql(
source: "\n query UserGreeting($userId: ID!) {\n user(id: $userId) {\n id\n username\n }\n }\n"
): typeof documents["\n query UserGreeting($userId: ID!) {\n user(id: $userId) {\n id\n username\n }\n }\n"];
source: "\n query UserGreeting($userId: ID!) {\n user(id: $userId) {\n id\n username\n matrix {\n mxid\n displayName\n }\n }\n }\n"
): typeof documents["\n query UserGreeting($userId: ID!) {\n user(id: $userId) {\n id\n username\n matrix {\n mxid\n displayName\n }\n }\n }\n"];
/**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/

View File

@@ -228,6 +228,16 @@ export enum EndOAuth2SessionStatus {
NotFound = "NOT_FOUND",
}
export type MatrixUser = {
__typename?: "MatrixUser";
/** The avatar URL of the user, if any. */
avatarUrl?: Maybe<Scalars["String"]["output"]>;
/** The display name of the user, if any. */
displayName?: Maybe<Scalars["String"]["output"]>;
/** The Matrix ID of the user. */
mxid: Scalars["String"]["output"];
};
/** The mutations root of the GraphQL interface. */
export type Mutation = {
__typename?: "Mutation";
@@ -591,6 +601,8 @@ export type User = Node & {
emails: UserEmailConnection;
/** ID of the object. */
id: Scalars["ID"]["output"];
/** Access to the user's Matrix account information. */
matrix: MatrixUser;
/** Get the list of OAuth 2.0 sessions, chronologically sorted */
oauth2Sessions: Oauth2SessionConnection;
/** Primary email address of the user. */
@@ -1088,7 +1100,16 @@ export type UserGreetingQueryVariables = Exact<{
export type UserGreetingQuery = {
__typename?: "Query";
user?: { __typename?: "User"; id: string; username: string } | null;
user?: {
__typename?: "User";
id: string;
username: string;
matrix: {
__typename?: "MatrixUser";
mxid: string;
displayName?: string | null;
};
} | null;
};
export type BrowserSessionQueryQueryVariables = Exact<{
@@ -3028,6 +3049,20 @@ export const UserGreetingDocument = {
selections: [
{ kind: "Field", name: { kind: "Name", value: "id" } },
{ kind: "Field", name: { kind: "Name", value: "username" } },
{
kind: "Field",
name: { kind: "Name", value: "matrix" },
selectionSet: {
kind: "SelectionSet",
selections: [
{ kind: "Field", name: { kind: "Name", value: "mxid" } },
{
kind: "Field",
name: { kind: "Name", value: "displayName" },
},
],
},
},
],
},
},

View File

@@ -576,6 +576,40 @@ export default {
],
interfaces: [],
},
{
kind: "OBJECT",
name: "MatrixUser",
fields: [
{
name: "avatarUrl",
type: {
kind: "SCALAR",
name: "Any",
},
args: [],
},
{
name: "displayName",
type: {
kind: "SCALAR",
name: "Any",
},
args: [],
},
{
name: "mxid",
type: {
kind: "NON_NULL",
ofType: {
kind: "SCALAR",
name: "Any",
},
},
args: [],
},
],
interfaces: [],
},
{
kind: "OBJECT",
name: "Mutation",
@@ -1887,6 +1921,18 @@ export default {
},
args: [],
},
{
name: "matrix",
type: {
kind: "NON_NULL",
ofType: {
kind: "OBJECT",
name: "MatrixUser",
ofType: null,
},
},
args: [],
},
{
name: "oauth2Sessions",
type: {