You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Basic tests of the GraphQL API
This commit is contained in:
302
crates/handlers/src/graphql/mod.rs
Normal file
302
crates/handlers/src/graphql/mod.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_graphql::{
|
||||
extensions::{ApolloTracing, Tracing},
|
||||
http::{playground_source, GraphQLPlaygroundConfig, MultipartOptions},
|
||||
};
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{BodyStream, RawQuery, State},
|
||||
http::StatusCode,
|
||||
response::{Html, IntoResponse, Response},
|
||||
Json, TypedHeader,
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use futures_util::TryStreamExt;
|
||||
use headers::{authorization::Bearer, Authorization, ContentType, HeaderValue};
|
||||
use hyper::header::CACHE_CONTROL;
|
||||
use mas_axum_utils::{FancyError, SessionInfo, SessionInfoExt};
|
||||
use mas_graphql::{Requester, Schema};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_storage::{
|
||||
BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use rand::{thread_rng, SeedableRng};
|
||||
use rand_chacha::ChaChaRng;
|
||||
use sqlx::PgPool;
|
||||
use tracing::{info_span, Instrument};
|
||||
|
||||
use crate::impl_from_error_for_route;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
struct GraphQLState {
|
||||
pool: PgPool,
|
||||
homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl mas_graphql::State for GraphQLState {
|
||||
async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
|
||||
let repo = PgRepository::from_pool(&self.pool)
|
||||
.await
|
||||
.map_err(RepositoryError::from_error)?;
|
||||
|
||||
Ok(repo.map_err(RepositoryError::from_error).boxed())
|
||||
}
|
||||
|
||||
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
|
||||
self.homeserver_connection.as_ref()
|
||||
}
|
||||
|
||||
fn clock(&self) -> BoxClock {
|
||||
let clock = SystemClock::default();
|
||||
Box::new(clock)
|
||||
}
|
||||
|
||||
fn rng(&self) -> BoxRng {
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let rng = thread_rng();
|
||||
|
||||
let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
|
||||
Box::new(rng)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn schema(
|
||||
pool: &PgPool,
|
||||
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
|
||||
) -> Schema {
|
||||
let state = GraphQLState {
|
||||
pool: pool.clone(),
|
||||
homeserver_connection: Arc::new(homeserver_connection),
|
||||
};
|
||||
let state: mas_graphql::BoxState = Box::new(state);
|
||||
|
||||
mas_graphql::schema_builder()
|
||||
.extension(Tracing)
|
||||
.extension(ApolloTracing)
|
||||
.data(state)
|
||||
.finish()
|
||||
}
|
||||
|
||||
fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
|
||||
let span = info_span!(
|
||||
"GraphQL operation",
|
||||
"otel.name" = tracing::field::Empty,
|
||||
"otel.kind" = "server",
|
||||
"graphql.document" = request.query,
|
||||
"graphql.operation.name" = tracing::field::Empty,
|
||||
);
|
||||
|
||||
if let Some(name) = &request.operation_name {
|
||||
span.record("otel.name", name);
|
||||
span.record("graphql.operation.name", name);
|
||||
}
|
||||
|
||||
span
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
#[error("Loading of some database objects failed")]
|
||||
LoadFailed,
|
||||
|
||||
#[error("Invalid access token")]
|
||||
InvalidToken,
|
||||
|
||||
#[error("Missing scope")]
|
||||
MissingScope,
|
||||
|
||||
#[error(transparent)]
|
||||
ParseRequest(#[from] async_graphql::ParseRequestError),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> Response {
|
||||
sentry::capture_error(&self);
|
||||
|
||||
match self {
|
||||
e @ (Self::Internal(_) | Self::LoadFailed) => {
|
||||
let error = async_graphql::Error::new_with_source(e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"errors": [error]})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
Self::InvalidToken => {
|
||||
let error = async_graphql::Error::new("Invalid token");
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({"errors": [error]})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
Self::MissingScope => {
|
||||
let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({"errors": [error]})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
Self::ParseRequest(e) => {
|
||||
let error = async_graphql::Error::new_with_source(e);
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"errors": [error]})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_requester(
|
||||
clock: &impl Clock,
|
||||
mut repo: BoxRepository,
|
||||
session_info: SessionInfo,
|
||||
token: Option<&str>,
|
||||
) -> Result<Requester, RouteError> {
|
||||
let requester = if let Some(token) = token {
|
||||
let token = repo
|
||||
.oauth2_access_token()
|
||||
.find_by_token(token)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidToken)?;
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::LoadFailed)?;
|
||||
|
||||
// XXX: The user_id should really be directly on the OAuth session
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.lookup(session.user_session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::LoadFailed)?;
|
||||
|
||||
let user = browser_session.user;
|
||||
|
||||
if !token.is_valid(clock.now()) || !session.is_valid() || !user.is_valid() {
|
||||
return Err(RouteError::InvalidToken);
|
||||
}
|
||||
|
||||
if !session.scope.contains("urn:mas:graphql:*") {
|
||||
return Err(RouteError::MissingScope);
|
||||
}
|
||||
|
||||
Requester::OAuth2Session(session, user)
|
||||
} else {
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
Requester::from(maybe_session)
|
||||
};
|
||||
repo.cancel().await?;
|
||||
Ok(requester)
|
||||
}
|
||||
|
||||
pub async fn post(
|
||||
State(schema): State<Schema>,
|
||||
clock: BoxClock,
|
||||
repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
content_type: Option<TypedHeader<ContentType>>,
|
||||
authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
body: BodyStream,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let token = authorization
|
||||
.as_ref()
|
||||
.map(|TypedHeader(Authorization(bearer))| bearer.token());
|
||||
let (session_info, _cookie_jar) = cookie_jar.session_info();
|
||||
let requester = get_requester(&clock, repo, session_info, token).await?;
|
||||
|
||||
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
|
||||
|
||||
let request = async_graphql::http::receive_body(
|
||||
content_type,
|
||||
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
.into_async_read(),
|
||||
MultipartOptions::default(),
|
||||
)
|
||||
.await?
|
||||
.data(requester); // XXX: this should probably return another error response?
|
||||
|
||||
let span = span_for_graphql_request(&request);
|
||||
let response = schema.execute(request).instrument(span).await;
|
||||
|
||||
let cache_control = response
|
||||
.cache_control
|
||||
.value()
|
||||
.and_then(|v| HeaderValue::from_str(&v).ok())
|
||||
.map(|h| [(CACHE_CONTROL, h)]);
|
||||
|
||||
let headers = response.http_headers.clone();
|
||||
|
||||
Ok((headers, cache_control, Json(response)))
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
State(schema): State<Schema>,
|
||||
clock: BoxClock,
|
||||
repo: BoxRepository,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let token = authorization
|
||||
.as_ref()
|
||||
.map(|TypedHeader(Authorization(bearer))| bearer.token());
|
||||
let (session_info, _cookie_jar) = cookie_jar.session_info();
|
||||
let requester = get_requester(&clock, repo, session_info, token).await?;
|
||||
|
||||
let request =
|
||||
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
|
||||
|
||||
let span = span_for_graphql_request(&request);
|
||||
let response = schema.execute(request).instrument(span).await;
|
||||
|
||||
let cache_control = response
|
||||
.cache_control
|
||||
.value()
|
||||
.and_then(|v| HeaderValue::from_str(&v).ok())
|
||||
.map(|h| [(CACHE_CONTROL, h)]);
|
||||
|
||||
let headers = response.http_headers.clone();
|
||||
|
||||
Ok((headers, cache_control, Json(response)))
|
||||
}
|
||||
|
||||
pub async fn playground() -> impl IntoResponse {
|
||||
Html(playground_source(
|
||||
GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
|
||||
))
|
||||
}
|
||||
203
crates/handlers/src/graphql/tests.rs
Normal file
203
crates/handlers/src/graphql/tests.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::http::Request;
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::AuthorizationCode;
|
||||
use mas_router::SimpleRoute;
|
||||
use oauth2_types::{
|
||||
registration::ClientRegistrationResponse,
|
||||
requests::{AccessTokenResponse, ResponseMode},
|
||||
scope::{Scope, ScopeToken, OPENID},
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
|
||||
|
||||
const GRAPHQL_SCOPE: ScopeToken = ScopeToken::from_static("urn:mas:graphql:*");
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct GraphQLResponse {
|
||||
data: serde_json::Value,
|
||||
errors: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_anonymous_viewer(pool: PgPool) {
|
||||
init_tracing();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
|
||||
let req = Request::post("/graphql").json(serde_json::json!({
|
||||
"query": r#"
|
||||
query {
|
||||
viewer {
|
||||
__typename
|
||||
}
|
||||
}
|
||||
"#,
|
||||
}));
|
||||
|
||||
let response = state.request(req).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let response: GraphQLResponse = response.json();
|
||||
|
||||
assert_eq!(response.errors, None);
|
||||
assert_eq!(
|
||||
response.data,
|
||||
serde_json::json!({
|
||||
"viewer": {
|
||||
"__typename": "Anonymous",
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_oauth2_viewer(pool: PgPool) {
|
||||
init_tracing();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
|
||||
// Start by creating a user, a client and a token
|
||||
// XXX: this is a lot of boilerplate just to get an access token!
|
||||
|
||||
// Provision a client
|
||||
let request =
|
||||
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||
"client_uri": "https://example.com/",
|
||||
"redirect_uris": ["https://example.com/callback"],
|
||||
"contacts": ["contact@example.com"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"response_types": ["code"],
|
||||
"grant_types": ["authorization_code"],
|
||||
}));
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::CREATED);
|
||||
|
||||
let ClientRegistrationResponse { client_id, .. } = response.json();
|
||||
|
||||
// Let's provision a user and create a session for them.
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut state.rng(), &state.clock, "alice".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let browser_session = repo
|
||||
.browser_session()
|
||||
.add(&mut state.rng(), &state.clock, &user)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Lookup the client in the database.
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.find_by_client_id(&client_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
// Start a grant
|
||||
let code = "thisisaverysecurecode";
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.add(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
"https://example.com/redirect".parse().unwrap(),
|
||||
Scope::from_iter([OPENID, GRAPHQL_SCOPE]),
|
||||
Some(AuthorizationCode {
|
||||
code: code.to_owned(),
|
||||
pkce: None,
|
||||
}),
|
||||
Some("state".to_owned()),
|
||||
Some("nonce".to_owned()),
|
||||
None,
|
||||
ResponseMode::Query,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
&browser_session,
|
||||
grant.scope.clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// And fulfill it
|
||||
let grant = repo
|
||||
.oauth2_authorization_grant()
|
||||
.fulfill(&state.clock, &session, grant)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
repo.save().await.unwrap();
|
||||
|
||||
// Now call the token endpoint to get an access token.
|
||||
let request = Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": grant.redirect_uri,
|
||||
"client_id": client.client_id,
|
||||
}));
|
||||
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
|
||||
let AccessTokenResponse { access_token, .. } = response.json();
|
||||
|
||||
let req = Request::post("/graphql")
|
||||
.bearer(&access_token)
|
||||
.json(serde_json::json!({
|
||||
"query": r#"
|
||||
query {
|
||||
viewer {
|
||||
__typename
|
||||
|
||||
... on User {
|
||||
id
|
||||
username
|
||||
}
|
||||
}
|
||||
}
|
||||
"#,
|
||||
}));
|
||||
|
||||
let response = state.request(req).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let response: GraphQLResponse = response.json();
|
||||
|
||||
assert_eq!(response.errors, None);
|
||||
assert_eq!(
|
||||
response.data,
|
||||
serde_json::json!({
|
||||
"viewer": {
|
||||
"__typename": "User",
|
||||
"id": format!("user:{id}", id = user.id),
|
||||
"username": "alice",
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user