diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql/mod.rs similarity index 99% rename from crates/handlers/src/graphql.rs rename to crates/handlers/src/graphql/mod.rs index 84488574..74b4bd6d 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -44,6 +44,9 @@ use tracing::{info_span, Instrument}; use crate::impl_from_error_for_route; +#[cfg(test)] +mod tests; + struct GraphQLState { pool: PgPool, homeserver_connection: Arc>, diff --git a/crates/handlers/src/graphql/tests.rs b/crates/handlers/src/graphql/tests.rs new file mode 100644 index 00000000..f9eb5cb2 --- /dev/null +++ b/crates/handlers/src/graphql/tests.rs @@ -0,0 +1,203 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::http::Request; +use hyper::StatusCode; +use mas_data_model::AuthorizationCode; +use mas_router::SimpleRoute; +use oauth2_types::{ + registration::ClientRegistrationResponse, + requests::{AccessTokenResponse, ResponseMode}, + scope::{Scope, ScopeToken, OPENID}, +}; +use sqlx::PgPool; + +use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}; + +const GRAPHQL_SCOPE: ScopeToken = ScopeToken::from_static("urn:mas:graphql:*"); + +#[derive(serde::Deserialize)] +struct GraphQLResponse { + data: serde_json::Value, + errors: Option>, +} + +#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] +async fn test_anonymous_viewer(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + let req = Request::post("/graphql").json(serde_json::json!({ + "query": r#" + query { + viewer { + __typename + } + } + "#, + })); + + let response = state.request(req).await; + response.assert_status(StatusCode::OK); + let response: GraphQLResponse = response.json(); + + assert_eq!(response.errors, None); + assert_eq!( + response.data, + serde_json::json!({ + "viewer": { + "__typename": "Anonymous", + }, + }) + ); +} + +#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] +async fn test_oauth2_viewer(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Start by creating a user, a client and a token + // XXX: this is a lot of boilerplate just to get an access token! + + // Provision a client + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/callback"], + "contacts": ["contact@example.com"], + "token_endpoint_auth_method": "none", + "response_types": ["code"], + "grant_types": ["authorization_code"], + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + + let ClientRegistrationResponse { client_id, .. } = response.json(); + + // Let's provision a user and create a session for them. + let mut repo = state.repository().await.unwrap(); + + let user = repo + .user() + .add(&mut state.rng(), &state.clock, "alice".to_owned()) + .await + .unwrap(); + + let browser_session = repo + .browser_session() + .add(&mut state.rng(), &state.clock, &user) + .await + .unwrap(); + + // Lookup the client in the database. + let client = repo + .oauth2_client() + .find_by_client_id(&client_id) + .await + .unwrap() + .unwrap(); + + // Start a grant + let code = "thisisaverysecurecode"; + let grant = repo + .oauth2_authorization_grant() + .add( + &mut state.rng(), + &state.clock, + &client, + "https://example.com/redirect".parse().unwrap(), + Scope::from_iter([OPENID, GRAPHQL_SCOPE]), + Some(AuthorizationCode { + code: code.to_owned(), + pkce: None, + }), + Some("state".to_owned()), + Some("nonce".to_owned()), + None, + ResponseMode::Query, + false, + false, + ) + .await + .unwrap(); + + let session = repo + .oauth2_session() + .add( + &mut state.rng(), + &state.clock, + &client, + &browser_session, + grant.scope.clone(), + ) + .await + .unwrap(); + + // And fulfill it + let grant = repo + .oauth2_authorization_grant() + .fulfill(&state.clock, &session, grant) + .await + .unwrap(); + + repo.save().await.unwrap(); + + // Now call the token endpoint to get an access token. + let request = Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": grant.redirect_uri, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + let AccessTokenResponse { access_token, .. } = response.json(); + + let req = Request::post("/graphql") + .bearer(&access_token) + .json(serde_json::json!({ + "query": r#" + query { + viewer { + __typename + + ... on User { + id + username + } + } + } + "#, + })); + + let response = state.request(req).await; + response.assert_status(StatusCode::OK); + let response: GraphQLResponse = response.json(); + + assert_eq!(response.errors, None); + assert_eq!( + response.data, + serde_json::json!({ + "viewer": { + "__typename": "User", + "id": format!("user:{id}", id = user.id), + "username": "alice", + }, + }) + ); +} diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 8b2a4c51..18510155 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -12,18 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{convert::Infallible, sync::Arc}; +use std::{ + convert::Infallible, + sync::{Arc, Mutex}, +}; use axum::{ async_trait, - body::HttpBody, + body::{Bytes, HttpBody}, extract::{FromRef, FromRequestParts}, }; use headers::{Authorization, ContentType, HeaderMapExt, HeaderName}; use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; -use mas_matrix::MockHomeserverConnection; +use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; use mas_policy::PolicyFactory; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; @@ -33,12 +36,10 @@ use rand::SeedableRng; use rand_chacha::ChaChaRng; use serde::{de::DeserializeOwned, Serialize}; use sqlx::PgPool; -use tokio::sync::Mutex; use tower::{Service, ServiceExt}; use crate::{ app_state::RepositoryError, - graphql_schema, passwords::{Hasher, PasswordManager}, MatrixHomeserver, }; @@ -115,13 +116,21 @@ impl TestState { let policy_factory = Arc::new(policy_factory); - let graphql_schema = graphql_schema(&pool, homeserver_connection); - let http_client_factory = HttpClientFactory::new(10); let clock = Arc::new(MockClock::default()); let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42))); + let graphql_state = TestGraphQLState { + pool: pool.clone(), + homeserver_connection, + rng: Arc::clone(&rng), + clock: Arc::clone(&clock), + }; + let state: mas_graphql::BoxState = Box::new(graphql_state); + + let graphql_schema = mas_graphql::schema_builder().data(state).finish(); + Ok(Self { pool, templates, @@ -141,6 +150,8 @@ impl TestState { pub async fn request(&self, request: Request) -> Response where B: HttpBody + Send + 'static, + ::Data: Into, + ::Error: std::error::Error + Send + Sync, B::Error: std::error::Error + Send + Sync, B::Data: Send, { @@ -149,6 +160,7 @@ impl TestState { .merge(crate::api_router()) .merge(crate::compat_router()) .merge(crate::human_router(self.templates.clone())) + .merge(crate::graphql_router(false)) .with_state(self.clone()); // Both unwrap are on Infallible, so this is safe @@ -211,6 +223,40 @@ impl TestState { } } +struct TestGraphQLState { + pool: PgPool, + homeserver_connection: MockHomeserverConnection, + clock: Arc, + rng: Arc>, +} + +#[async_trait] +impl mas_graphql::State for TestGraphQLState { + async fn repository(&self) -> Result { + let repo = PgRepository::from_pool(&self.pool) + .await + .map_err(mas_storage::RepositoryError::from_error)?; + + Ok(repo + .map_err(mas_storage::RepositoryError::from_error) + .boxed()) + } + + fn homeserver_connection(&self) -> &dyn HomeserverConnection { + &self.homeserver_connection + } + + fn clock(&self) -> BoxClock { + Box::new(self.clock.clone()) + } + + fn rng(&self) -> BoxRng { + let mut parent_rng = self.rng.lock().expect("Failed to lock RNG"); + let rng = ChaChaRng::from_rng(&mut *parent_rng).expect("Failed to seed RNG"); + Box::new(rng) + } +} + impl FromRef for PgPool { fn from_ref(input: &TestState) -> Self { input.pool.clone() @@ -291,7 +337,7 @@ impl FromRequestParts for BoxRng { _parts: &mut axum::http::request::Parts, state: &TestState, ) -> Result { - let mut parent_rng = state.rng.lock().await; + let mut parent_rng = state.rng.lock().expect("Failed to lock RNG"); let rng = ChaChaRng::from_rng(&mut *parent_rng).expect("Failed to seed RNG"); Ok(Box::new(rng)) }