1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Initial GraphQL API

This commit is contained in:
Quentin Gliech
2022-11-04 18:59:25 +01:00
parent 35e5a5a7a7
commit c13b0478e6
10 changed files with 518 additions and 9 deletions

201
Cargo.lock generated
View File

@ -2,6 +2,16 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "Inflector"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3"
dependencies = [
"lazy_static",
"regex",
]
[[package]] [[package]]
name = "addr2line" name = "addr2line"
version = "0.17.0" version = "0.17.0"
@ -119,6 +129,12 @@ version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "ascii_utils"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71938f30533e4d95a6d17aa530939da3842c2ab6f4f84b9dae68447e4129f74a"
[[package]] [[package]]
name = "assert_matches" name = "assert_matches"
version = "1.5.0" version = "1.5.0"
@ -139,6 +155,81 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "async-graphql"
version = "4.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9ed522678d412d77effe47b3c82314ac36952a35e6e852093dd48287c421f80"
dependencies = [
"async-graphql-derive",
"async-graphql-parser",
"async-graphql-value",
"async-stream",
"async-trait",
"base64",
"bytes 1.2.1",
"chrono",
"fast_chemail",
"fnv",
"futures-util",
"http",
"indexmap",
"mime",
"multer",
"num-traits",
"once_cell",
"pin-project-lite",
"regex",
"serde",
"serde_json",
"serde_urlencoded",
"static_assertions",
"tempfile",
"thiserror",
"tracing",
"tracing-futures",
]
[[package]]
name = "async-graphql-derive"
version = "4.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c121a894495d7d3fc3d4e15e0a9843e422e4d1d9e3c514d8062a1c94b35b005d"
dependencies = [
"Inflector",
"async-graphql-parser",
"darling",
"proc-macro-crate",
"proc-macro2",
"quote",
"syn",
"thiserror",
]
[[package]]
name = "async-graphql-parser"
version = "4.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b6c386f398145c6180206c1869c2279f5a3d45db5be4e0266148c6ac5c6ad68"
dependencies = [
"async-graphql-value",
"pest",
"serde",
"serde_json",
]
[[package]]
name = "async-graphql-value"
version = "4.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a941b499fead4a3fb5392cabf42446566d18c86313f69f2deab69560394d65f"
dependencies = [
"bytes 1.2.1",
"indexmap",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "async-stream" name = "async-stream"
version = "0.3.3" version = "0.3.3"
@ -524,6 +615,7 @@ checksum = "d2628a243073c55aef15a1c1fe45c87f21b84f9e89ca9e7b262a180d3d03543d"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core 0.3.0-rc.2", "axum-core 0.3.0-rc.2",
"base64",
"bitflags", "bitflags",
"bytes 1.2.1", "bytes 1.2.1",
"futures-util", "futures-util",
@ -540,8 +632,10 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"sha-1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-http", "tower-http",
"tower-layer", "tower-layer",
@ -1498,6 +1592,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fast_chemail"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "495a39d30d624c2caabe6312bfead73e7717692b44e0b32df168c275a2e8e9e4"
dependencies = [
"ascii_utils",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "1.8.0" version = "1.8.0"
@ -2231,7 +2334,7 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
dependencies = [ dependencies = [
"spin", "spin 0.5.2",
] ]
[[package]] [[package]]
@ -2495,10 +2598,12 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"argon2", "argon2",
"async-graphql",
"axum 0.6.0-rc.2", "axum 0.6.0-rc.2",
"axum-extra", "axum-extra",
"axum-macros", "axum-macros",
"chrono", "chrono",
"futures-util",
"headers", "headers",
"hyper", "hyper",
"indoc", "indoc",
@ -2525,6 +2630,7 @@ dependencies = [
"sqlx", "sqlx",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream",
"tower", "tower",
"tower-http", "tower-http",
"tracing", "tracing",
@ -2877,6 +2983,24 @@ dependencies = [
"windows-sys 0.42.0", "windows-sys 0.42.0",
] ]
[[package]]
name = "multer"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ed4198ce7a4cbd2a57af78d28c6fbb57d81ac5f1d6ad79ac6c5587419cbdf22"
dependencies = [
"bytes 1.2.1",
"encoding_rs",
"futures-util",
"http",
"httparse",
"log",
"memchr",
"mime",
"spin 0.9.4",
"version_check",
]
[[package]] [[package]]
name = "multimap" name = "multimap"
version = "0.8.3" version = "0.8.3"
@ -3578,6 +3702,17 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "proc-macro-crate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9"
dependencies = [
"once_cell",
"thiserror",
"toml",
]
[[package]] [[package]]
name = "proc-macro-error" name = "proc-macro-error"
version = "1.0.4" version = "1.0.4"
@ -3878,7 +4013,7 @@ dependencies = [
"cc", "cc",
"libc", "libc",
"once_cell", "once_cell",
"spin", "spin 0.5.2",
"untrusted", "untrusted",
"web-sys", "web-sys",
"winapi", "winapi",
@ -4246,6 +4381,17 @@ dependencies = [
"unsafe-libyaml", "unsafe-libyaml",
] ]
[[package]]
name = "sha-1"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f"
dependencies = [
"cfg-if",
"cpufeatures",
"digest 0.10.5",
]
[[package]] [[package]]
name = "sha1" name = "sha1"
version = "0.10.5" version = "0.10.5"
@ -4367,6 +4513,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f6002a767bff9e83f8eeecf883ecb8011875a21ae8da43bffb817a57e78cc09"
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.6.0" version = "0.6.0"
@ -4498,6 +4650,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]] [[package]]
name = "stringprep" name = "stringprep"
version = "0.1.2" version = "0.1.2"
@ -4803,6 +4961,18 @@ dependencies = [
"tokio-stream", "tokio-stream",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.6.10" version = "0.6.10"
@ -4998,6 +5168,8 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2"
dependencies = [ dependencies = [
"futures 0.3.25",
"futures-task",
"pin-project", "pin-project",
"tracing", "tracing",
] ]
@ -5060,6 +5232,25 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
dependencies = [
"base64",
"byteorder",
"bytes 1.2.1",
"http",
"httparse",
"log",
"rand",
"sha-1",
"thiserror",
"url",
"utf-8",
]
[[package]] [[package]]
name = "typed-builder" name = "typed-builder"
version = "0.9.1" version = "0.9.1"
@ -5241,6 +5432,12 @@ version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.2.1" version = "1.2.1"

View File

@ -25,7 +25,7 @@ use ulid::Ulid;
use crate::CookieExt; use crate::CookieExt;
/// An encrypted cookie to save the session ID /// An encrypted cookie to save the session ID
#[derive(Serialize, Deserialize, Debug, Default)] #[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct SessionInfo { pub struct SessionInfo {
current: Option<Ulid>, current: Option<Ulid>,
} }

View File

@ -203,6 +203,8 @@ impl Options {
.context("could not watch for templates changes")?; .context("could not watch for templates changes")?;
} }
let graphql_schema = mas_handlers::graphql_schema(&pool);
let state = Arc::new(AppState { let state = Arc::new(AppState {
pool, pool,
templates, templates,
@ -212,6 +214,7 @@ impl Options {
mailer, mailer,
homeserver, homeserver,
policy_factory, policy_factory,
graphql_schema,
}); });
let mut fd_manager = listenfd::ListenFd::from_env(); let mut fd_manager = listenfd::ListenFd::from_env();

View File

@ -31,7 +31,7 @@ use rustls::ServerConfig;
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B> pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send, <B as HttpBody>::Data: Into<axum::body::Bytes> + Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync, <B as HttpBody>::Error: std::error::Error + Send + Sync,
{ {
let mut router = Router::with_state_arc(state.clone()); let mut router = Router::with_state_arc(state.clone());
@ -50,6 +50,9 @@ where
mas_config::HttpResource::Human => { mas_config::HttpResource::Human => {
router.merge(mas_handlers::human_router(state.clone())) router.merge(mas_handlers::human_router(state.clone()))
} }
mas_config::HttpResource::GraphQL { playground } => {
router.merge(mas_handlers::graphql_router(state.clone(), *playground))
}
mas_config::HttpResource::Static { web_root } => { mas_config::HttpResource::Static { web_root } => {
let handler = mas_static_files::service(web_root); let handler = mas_static_files::service(web_root);
router.nest(mas_router::StaticAsset::route(), handler) router.nest(mas_router::StaticAsset::route(), handler)

View File

@ -234,6 +234,13 @@ pub enum Resource {
/// Pages destined to be viewed by humans /// Pages destined to be viewed by humans
Human, Human,
/// GraphQL endpoint
GraphQL {
/// Enabled the GraphQL playground
#[serde(default)]
playground: bool,
},
/// OAuth-related APIs /// OAuth-related APIs
OAuth, OAuth,
@ -300,6 +307,7 @@ impl Default for HttpConfig {
Resource::Human, Resource::Human,
Resource::OAuth, Resource::OAuth,
Resource::Compat, Resource::Compat,
Resource::GraphQL { playground: true },
Resource::Static { web_root: None }, Resource::Static { web_root: None },
], ],
tls: None, tls: None,

View File

@ -8,6 +8,8 @@ license = "Apache-2.0"
[dependencies] [dependencies]
# Async runtime # Async runtime
tokio = { version = "1.21.2", features = ["macros"] } tokio = { version = "1.21.2", features = ["macros"] }
tokio-stream = "0.1.11"
futures-util = "0.3.25"
# Logging and tracing # Logging and tracing
tracing = "0.1.37" tracing = "0.1.37"
@ -20,10 +22,12 @@ anyhow = "1.0.66"
hyper = { version = "0.14.22", features = ["full"] } hyper = { version = "0.14.22", features = ["full"] }
tower = "0.4.13" tower = "0.4.13"
tower-http = { version = "0.3.4", features = ["cors"] } tower-http = { version = "0.3.4", features = ["cors"] }
axum = "0.6.0-rc.2" axum = { version = "0.6.0-rc.2", features = ["ws"] }
axum-macros = "0.3.0-rc.1" axum-macros = "0.3.0-rc.1"
axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] } axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] }
async-graphql = { version = "4.0.16", features = ["tracing", "apollo_tracing"] }
# Emails # Emails
lettre = { version = "0.10.1", default-features = false, features = ["builder"] } lettre = { version = "0.10.1", default-features = false, features = ["builder"] }

View File

@ -22,7 +22,7 @@ use mas_router::UrlBuilder;
use mas_templates::Templates; use mas_templates::Templates;
use sqlx::PgPool; use sqlx::PgPool;
use crate::MatrixHomeserver; use crate::{GraphQLSchema, MatrixHomeserver};
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
@ -34,6 +34,7 @@ pub struct AppState {
pub mailer: Mailer, pub mailer: Mailer,
pub homeserver: MatrixHomeserver, pub homeserver: MatrixHomeserver,
pub policy_factory: Arc<PolicyFactory>, pub policy_factory: Arc<PolicyFactory>,
pub graphql_schema: GraphQLSchema,
} }
impl FromRef<AppState> for PgPool { impl FromRef<AppState> for PgPool {
@ -42,6 +43,12 @@ impl FromRef<AppState> for PgPool {
} }
} }
impl FromRef<AppState> for GraphQLSchema {
fn from_ref(input: &AppState) -> Self {
input.graphql_schema.clone()
}
}
impl FromRef<AppState> for Templates { impl FromRef<AppState> for Templates {
fn from_ref(input: &AppState) -> Self { fn from_ref(input: &AppState) -> Self {
input.templates.clone() input.templates.clone()

View File

@ -0,0 +1,251 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{borrow::Cow, str::FromStr, time::Duration};
use async_graphql::{
extensions::{ApolloTracing, Tracing},
futures_util::TryStreamExt,
http::{
playground_source, GraphQLPlaygroundConfig, MultipartOptions, WebSocketProtocols,
WsMessage, ALL_WEBSOCKET_PROTOCOLS,
},
Context, Data, EmptyMutation,
};
use axum::{
extract::{
ws::{CloseFrame, Message},
BodyStream, RawQuery, State, WebSocketUpgrade,
},
response::{Html, IntoResponse, Response},
Json, TypedHeader,
};
use axum_extra::extract::PrivateCookieJar;
use futures_util::{SinkExt, Stream, StreamExt};
use headers::{ContentType, Header, HeaderValue};
use hyper::header::{CACHE_CONTROL, SEC_WEBSOCKET_PROTOCOL};
use mas_axum_utils::{FancyError, SessionInfo, SessionInfoExt};
use mas_keystore::Encrypter;
use sqlx::PgPool;
use tracing::{info_span, Instrument};
pub type Schema = async_graphql::Schema<Query, EmptyMutation, Subscription>;
#[must_use]
pub fn schema(pool: &PgPool) -> Schema {
async_graphql::Schema::build(Query::new(pool), EmptyMutation, Subscription)
.extension(Tracing)
.extension(ApolloTracing)
.finish()
}
fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
let span = info_span!(
"GraphQL operation",
"otel.name" = tracing::field::Empty,
"otel.kind" = "server",
"graphql.document" = request.query,
"graphql.operation.name" = tracing::field::Empty,
);
if let Some(name) = &request.operation_name {
span.record("otel.name", name);
span.record("graphql.operation.name", name);
}
span
}
pub async fn post(
State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>,
content_type: Option<TypedHeader<ContentType>>,
body: BodyStream,
) -> Result<impl IntoResponse, FancyError> {
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let request = async_graphql::http::receive_batch_body(
content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(),
MultipartOptions::default(),
)
.await? // XXX: this should probably return another error response?
.data(session_info);
let response = match request {
async_graphql::BatchRequest::Single(request) => {
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
async_graphql::BatchResponse::Single(response)
}
async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch(
futures_util::stream::iter(requests.into_iter())
.then(|request| {
let span = span_for_graphql_request(&request);
schema.execute(request).instrument(span)
})
.collect()
.await,
),
};
let cache_control = response
.cache_control()
.value()
.and_then(|v| HeaderValue::from_str(&v).ok())
.map(|h| [(CACHE_CONTROL, h)]);
let headers = response.http_headers();
Ok((headers, cache_control, Json(response)))
}
pub async fn get(
State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info();
let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(session_info);
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
let cache_control = response
.cache_control
.value()
.and_then(|v| HeaderValue::from_str(&v).ok())
.map(|h| [(CACHE_CONTROL, h)]);
let headers = response.http_headers.clone();
Ok((headers, cache_control, Json(response)))
}
pub struct SecWebsocketProtocol(WebSocketProtocols);
impl Header for SecWebsocketProtocol {
fn name() -> &'static headers::HeaderName {
&SEC_WEBSOCKET_PROTOCOL
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
values
.filter_map(|value| value.to_str().ok())
.flat_map(|value| value.split(','))
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
.map(Self)
.ok_or_else(headers::Error::invalid)
}
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
if let Ok(v) = HeaderValue::from_str(self.0.sec_websocket_protocol()) {
values.extend(std::iter::once(v));
}
}
}
pub async fn ws(
State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>,
TypedHeader(SecWebsocketProtocol(protocol)): TypedHeader<SecWebsocketProtocol>,
websocket: WebSocketUpgrade,
) -> Response {
let (session_info, _cookie_jar) = cookie_jar.session_info();
websocket
.protocols(ALL_WEBSOCKET_PROTOCOLS)
.on_upgrade(move |ws| async move {
let (mut sink, stream) = ws.split();
let stream = stream
.take_while(|res| std::future::ready(res.is_ok()))
.map(Result::unwrap)
.filter_map(|msg| {
if let Message::Text(_) | Message::Binary(_) = msg {
std::future::ready(Some(msg.into_data()))
} else {
std::future::ready(None)
}
});
let mut data = Data::default();
data.insert(session_info);
let mut stream = async_graphql::http::WebSocket::new(schema.clone(), stream, protocol)
.connection_data(data)
.map(|msg| match msg {
WsMessage::Text(text) => Message::Text(text),
WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
code,
reason: Cow::from(status),
})),
});
while let Some(item) = stream.next().await {
let _res = sink.send(item).await;
}
})
}
pub async fn playground() -> impl IntoResponse {
Html(playground_source(
GraphQLPlaygroundConfig::new("/graphql")
.subscription_endpoint("/graphql/ws")
.with_setting("request.credentials", "include"),
))
}
pub struct Query {
database: PgPool,
}
impl Query {
fn new(pool: &PgPool) -> Self {
Self {
database: pool.clone(),
}
}
}
#[async_graphql::Object]
impl Query {
async fn username(&self, ctx: &Context<'_>) -> Result<Option<String>, async_graphql::Error> {
let mut conn = self.database.acquire().await?;
let session_info = ctx.data::<SessionInfo>()?;
let session = session_info.load_session(&mut conn).await?;
Ok(session.map(|s| s.user.username))
}
}
pub struct Subscription;
#[async_graphql::Subscription]
impl Subscription {
async fn integers(&self, #[graphql(default = 1)] step: i32) -> impl Stream<Item = i32> {
let mut value = 0;
tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(Duration::from_secs(1)))
.map(move |_| {
value += step;
value
})
}
}

View File

@ -28,7 +28,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration};
use anyhow::Context; use anyhow::Context;
use axum::{ use axum::{
body::HttpBody, body::{Bytes, HttpBody},
extract::FromRef, extract::FromRef,
response::{Html, IntoResponse}, response::{Html, IntoResponse},
routing::{get, on, post, MethodFilter}, routing::{get, on, post, MethodFilter},
@ -49,13 +49,17 @@ use tower_http::cors::{Any, CorsLayer};
mod app_state; mod app_state;
mod compat; mod compat;
mod graphql;
mod health; mod health;
mod oauth2; mod oauth2;
mod views; mod views;
pub use compat::MatrixHomeserver; pub use compat::MatrixHomeserver;
pub use self::app_state::AppState; pub use self::{
app_state::AppState,
graphql::{schema as graphql_schema, Schema as GraphQLSchema},
};
#[must_use] #[must_use]
pub fn empty_router<S, B>(state: Arc<S>) -> Router<S, B> pub fn empty_router<S, B>(state: Arc<S>) -> Router<S, B>
@ -76,6 +80,30 @@ where
Router::with_state_arc(state).route(mas_router::Healthcheck::route(), get(self::health::get)) Router::with_state_arc(state).route(mas_router::Healthcheck::route(), get(self::health::get))
} }
#[must_use]
pub fn graphql_router<S, B>(state: Arc<S>, playground: bool) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Into<Bytes>,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
GraphQLSchema: FromRef<S>,
Encrypter: FromRef<S>,
{
let mut router = Router::with_state_arc(state)
.route(
"/graphql",
get(self::graphql::get).post(self::graphql::post),
)
.route("/graphql/ws", get(self::graphql::ws));
if playground {
router = router.route("/graphql/playground", get(self::graphql::playground));
}
router
}
#[must_use] #[must_use]
pub fn discovery_router<S, B>(state: Arc<S>) -> Router<S, B> pub fn discovery_router<S, B>(state: Arc<S>) -> Router<S, B>
where where
@ -305,7 +333,7 @@ where
pub fn router<S, B>(state: Arc<S>) -> Router<S, B> pub fn router<S, B>(state: Arc<S>) -> Router<S, B>
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send, <B as HttpBody>::Data: Into<Bytes> + Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync, <B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static, S: Send + Sync + 'static,
Keystore: FromRef<S>, Keystore: FromRef<S>,
@ -316,10 +344,12 @@ where
Templates: FromRef<S>, Templates: FromRef<S>,
Mailer: FromRef<S>, Mailer: FromRef<S>,
MatrixHomeserver: FromRef<S>, MatrixHomeserver: FromRef<S>,
GraphQLSchema: FromRef<S>,
{ {
let healthcheck_router = healthcheck_router(state.clone()); let healthcheck_router = healthcheck_router(state.clone());
let discovery_router = discovery_router(state.clone()); let discovery_router = discovery_router(state.clone());
let api_router = api_router(state.clone()); let api_router = api_router(state.clone());
let graphql_router = graphql_router(state.clone(), true);
let compat_router = compat_router(state.clone()); let compat_router = compat_router(state.clone());
let human_router = human_router(state.clone()); let human_router = human_router(state.clone());
@ -328,6 +358,7 @@ where
.merge(discovery_router) .merge(discovery_router)
.merge(human_router) .merge(human_router)
.merge(api_router) .merge(api_router)
.merge(graphql_router)
.merge(compat_router) .merge(compat_router)
} }
@ -352,6 +383,8 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?; let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?;
let policy_factory = Arc::new(policy_factory); let policy_factory = Arc::new(policy_factory);
let graphql_schema = graphql_schema(&pool);
Ok(Arc::new(AppState { Ok(Arc::new(AppState {
pool, pool,
templates, templates,
@ -361,6 +394,7 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
mailer, mailer,
homeserver, homeserver,
policy_factory, policy_factory,
graphql_schema,
})) }))
} }

View File

@ -171,12 +171,14 @@ where
hyper::server::conn::Http::new() hyper::server::conn::Http::new()
.http2_only(true) .http2_only(true)
.serve_connection(stream, service) .serve_connection(stream, service)
.with_upgrades()
.await?; .await?;
} else { } else {
hyper::server::conn::Http::new() hyper::server::conn::Http::new()
.http1_only(true) .http1_only(true)
.http1_keep_alive(false) .http1_keep_alive(false)
.serve_connection(stream, service) .serve_connection(stream, service)
.with_upgrades()
.await?; .await?;
}; };