diff --git a/Cargo.lock b/Cargo.lock index 579b9c8b..fd42f90a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3362,9 +3362,11 @@ dependencies = [ name = "mas-matrix" version = "0.1.0" dependencies = [ + "anyhow", "async-trait", "http", "serde", + "tokio", "url", ] diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index edcdaebb..720d2dee 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -23,7 +23,7 @@ use headers::{Authorization, ContentType, HeaderMapExt, HeaderName}; use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; -use mas_matrix::{HomeserverConnection, MatrixUser, ProvisionRequest}; +use mas_matrix::{HomeserverConnection, MatrixUser, MockHomeserverConnection, ProvisionRequest}; use mas_policy::PolicyFactory; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; @@ -69,40 +69,6 @@ pub(crate) struct TestState { pub rng: Arc>, } -/// A Mock implementation of a [`HomeserverConnection`], which never fails and -/// doesn't do anything. -struct MockHomeserverConnection { - homeserver: String, -} - -#[async_trait] -impl HomeserverConnection for MockHomeserverConnection { - type Error = anyhow::Error; - - fn homeserver(&self) -> &str { - &self.homeserver - } - - async fn query_user(&self, _mxid: &str) -> Result { - Ok(MatrixUser { - displayname: None, - avatar_url: None, - }) - } - - async fn provision_user(&self, _request: &ProvisionRequest) -> Result { - Ok(false) - } - - async fn create_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> { - Ok(()) - } - - async fn delete_device(&self, _mxid: &str, _device_id: &str) -> Result<(), Self::Error> { - Ok(()) - } -} - impl TestState { /// Create a new test state from the given database pool pub async fn from_pool(pool: PgPool) -> Result { @@ -145,9 +111,7 @@ impl TestState { ) .await?; - let homeserver_connection = MockHomeserverConnection { - homeserver: "example.com".to_owned(), - }; + let homeserver_connection = MockHomeserverConnection::new("example.com"); let policy_factory = Arc::new(policy_factory); diff --git a/crates/matrix/Cargo.toml b/crates/matrix/Cargo.toml index ef9777b6..bb162caa 100644 --- a/crates/matrix/Cargo.toml +++ b/crates/matrix/Cargo.toml @@ -6,7 +6,9 @@ edition = "2021" license = "Apache-2.0" [dependencies] +anyhow = "1.0.71" serde = { version = "1.0.177", features = ["derive"] } async-trait = "0.1.72" http = "0.2.9" +tokio = { version = "1.28.2", features = ["sync", "macros", "rt"] } url = "2.4.0" diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index df814452..b1e3e4e5 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -16,6 +16,10 @@ #![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] #![warn(clippy::pedantic)] +mod mock; + +pub use self::mock::MockHomeserverConnection; + #[derive(Debug)] pub struct MatrixUser { pub displayname: Option, @@ -40,10 +44,10 @@ pub struct ProvisionRequest { impl ProvisionRequest { #[must_use] - pub fn new(mxid: String, sub: String) -> Self { + pub fn new(mxid: impl Into, sub: impl Into) -> Self { Self { - mxid, - sub, + mxid: mxid.into(), + sub: sub.into(), displayname: FieldAction::DoNothing, avatar_url: FieldAction::DoNothing, emails: FieldAction::DoNothing, diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs new file mode 100644 index 00000000..b13aa0be --- /dev/null +++ b/crates/matrix/src/mock.rs @@ -0,0 +1,155 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{HashMap, HashSet}; + +use anyhow::Context; +use async_trait::async_trait; +use tokio::sync::RwLock; + +use crate::{HomeserverConnection, MatrixUser, ProvisionRequest}; + +struct MockUser { + sub: String, + avatar_url: Option, + displayname: Option, + devices: HashSet, + emails: Option>, +} + +/// A Mock implementation of a [`HomeserverConnection`], which never fails and +/// doesn't do anything. +pub struct MockHomeserverConnection { + homeserver: String, + users: RwLock>, +} + +impl MockHomeserverConnection { + /// Create a new [`MockHomeserverConnection`]. + pub fn new(homeserver: H) -> Self + where + H: Into, + { + Self { + homeserver: homeserver.into(), + users: RwLock::new(HashMap::new()), + } + } +} + +#[async_trait] +impl HomeserverConnection for MockHomeserverConnection { + type Error = anyhow::Error; + + fn homeserver(&self) -> &str { + &self.homeserver + } + + async fn query_user(&self, mxid: &str) -> Result { + let users = self.users.read().await; + let user = users.get(mxid).context("User not found")?; + Ok(MatrixUser { + displayname: user.displayname.clone(), + avatar_url: user.avatar_url.clone(), + }) + } + + async fn provision_user(&self, request: &ProvisionRequest) -> Result { + let mut users = self.users.write().await; + let inserted = !users.contains_key(request.mxid()); + let user = users.entry(request.mxid().to_owned()).or_insert(MockUser { + sub: request.sub().to_owned(), + avatar_url: None, + displayname: None, + devices: HashSet::new(), + emails: None, + }); + + anyhow::ensure!( + user.sub == request.sub(), + "User already provisioned with different sub" + ); + + request.on_emails(|emails| { + user.emails = emails.map(ToOwned::to_owned); + }); + + request.on_displayname(|displayname| { + user.displayname = displayname.map(ToOwned::to_owned); + }); + + request.on_avatar_url(|avatar_url| { + user.avatar_url = avatar_url.map(ToOwned::to_owned); + }); + + Ok(inserted) + } + + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + let mut users = self.users.write().await; + let user = users.get_mut(mxid).context("User not found")?; + user.devices.insert(device_id.to_owned()); + Ok(()) + } + + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + let mut users = self.users.write().await; + let user = users.get_mut(mxid).context("User not found")?; + user.devices.remove(device_id); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_connection() { + let conn = MockHomeserverConnection::new("example.org"); + + let mxid = "@test:example.org"; + let device = "test"; + assert_eq!(conn.homeserver(), "example.org"); + assert_eq!(conn.mxid("test"), mxid); + + assert!(conn.query_user(mxid).await.is_err()); + assert!(conn.create_device(mxid, device).await.is_err()); + assert!(conn.delete_device(mxid, device).await.is_err()); + + let request = ProvisionRequest::new("@test:example.org", "test") + .set_displayname("Test User".into()) + .set_avatar_url("mxc://example.org/1234567890".into()) + .set_emails(vec!["test@example.org".to_owned()]); + + let inserted = conn.provision_user(&request).await.unwrap(); + assert!(inserted); + + let user = conn.query_user("@test:example.org").await.unwrap(); + assert_eq!(user.displayname, Some("Test User".into())); + assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into())); + + // Deleting a non-existent device should not fail + assert!(conn.delete_device(mxid, device).await.is_ok()); + + // Create the device + assert!(conn.create_device(mxid, device).await.is_ok()); + // Create the same device again + assert!(conn.create_device(mxid, device).await.is_ok()); + + // XXX: there is no API to query devices yet in the trait + // Delete the device + assert!(conn.delete_device(mxid, device).await.is_ok()); + } +}