You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Initial GraphQL API
This commit is contained in:
201
Cargo.lock
generated
201
Cargo.lock
generated
@ -2,6 +2,16 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "Inflector"
|
||||
version = "0.11.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.17.0"
|
||||
@ -119,6 +129,12 @@ version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
|
||||
|
||||
[[package]]
|
||||
name = "ascii_utils"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71938f30533e4d95a6d17aa530939da3842c2ab6f4f84b9dae68447e4129f74a"
|
||||
|
||||
[[package]]
|
||||
name = "assert_matches"
|
||||
version = "1.5.0"
|
||||
@ -139,6 +155,81 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-graphql"
|
||||
version = "4.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9ed522678d412d77effe47b3c82314ac36952a35e6e852093dd48287c421f80"
|
||||
dependencies = [
|
||||
"async-graphql-derive",
|
||||
"async-graphql-parser",
|
||||
"async-graphql-value",
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bytes 1.2.1",
|
||||
"chrono",
|
||||
"fast_chemail",
|
||||
"fnv",
|
||||
"futures-util",
|
||||
"http",
|
||||
"indexmap",
|
||||
"mime",
|
||||
"multer",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"static_assertions",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
"tracing-futures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-graphql-derive"
|
||||
version = "4.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c121a894495d7d3fc3d4e15e0a9843e422e4d1d9e3c514d8062a1c94b35b005d"
|
||||
dependencies = [
|
||||
"Inflector",
|
||||
"async-graphql-parser",
|
||||
"darling",
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-graphql-parser"
|
||||
version = "4.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b6c386f398145c6180206c1869c2279f5a3d45db5be4e0266148c6ac5c6ad68"
|
||||
dependencies = [
|
||||
"async-graphql-value",
|
||||
"pest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-graphql-value"
|
||||
version = "4.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a941b499fead4a3fb5392cabf42446566d18c86313f69f2deab69560394d65f"
|
||||
dependencies = [
|
||||
"bytes 1.2.1",
|
||||
"indexmap",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.3"
|
||||
@ -524,6 +615,7 @@ checksum = "d2628a243073c55aef15a1c1fe45c87f21b84f9e89ca9e7b262a180d3d03543d"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core 0.3.0-rc.2",
|
||||
"base64",
|
||||
"bitflags",
|
||||
"bytes 1.2.1",
|
||||
"futures-util",
|
||||
@ -540,8 +632,10 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sha-1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-layer",
|
||||
@ -1498,6 +1592,15 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fast_chemail"
|
||||
version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "495a39d30d624c2caabe6312bfead73e7717692b44e0b32df168c275a2e8e9e4"
|
||||
dependencies = [
|
||||
"ascii_utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "1.8.0"
|
||||
@ -2231,7 +2334,7 @@ version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
dependencies = [
|
||||
"spin",
|
||||
"spin 0.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2495,10 +2598,12 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"argon2",
|
||||
"async-graphql",
|
||||
"axum 0.6.0-rc.2",
|
||||
"axum-extra",
|
||||
"axum-macros",
|
||||
"chrono",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"hyper",
|
||||
"indoc",
|
||||
@ -2525,6 +2630,7 @@ dependencies = [
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
@ -2877,6 +2983,24 @@ dependencies = [
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multer"
|
||||
version = "2.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ed4198ce7a4cbd2a57af78d28c6fbb57d81ac5f1d6ad79ac6c5587419cbdf22"
|
||||
dependencies = [
|
||||
"bytes 1.2.1",
|
||||
"encoding_rs",
|
||||
"futures-util",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"memchr",
|
||||
"mime",
|
||||
"spin 0.9.4",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.8.3"
|
||||
@ -3578,6 +3702,17 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-crate"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"thiserror",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
@ -3878,7 +4013,7 @@ dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"spin",
|
||||
"spin 0.5.2",
|
||||
"untrusted",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
@ -4246,6 +4381,17 @@ dependencies = [
|
||||
"unsafe-libyaml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
@ -4367,6 +4513,12 @@ version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f6002a767bff9e83f8eeecf883ecb8011875a21ae8da43bffb817a57e78cc09"
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
version = "0.6.0"
|
||||
@ -4498,6 +4650,12 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.2"
|
||||
@ -4803,6 +4961,18 @@ dependencies = [
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.17.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.6.10"
|
||||
@ -4998,6 +5168,8 @@ version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2"
|
||||
dependencies = [
|
||||
"futures 0.3.25",
|
||||
"futures-task",
|
||||
"pin-project",
|
||||
"tracing",
|
||||
]
|
||||
@ -5060,6 +5232,25 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.17.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"byteorder",
|
||||
"bytes 1.2.1",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha-1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typed-builder"
|
||||
version = "0.9.1"
|
||||
@ -5241,6 +5432,12 @@ version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.2.1"
|
||||
|
@ -25,7 +25,7 @@ use ulid::Ulid;
|
||||
use crate::CookieExt;
|
||||
|
||||
/// An encrypted cookie to save the session ID
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
|
||||
pub struct SessionInfo {
|
||||
current: Option<Ulid>,
|
||||
}
|
||||
|
@ -203,6 +203,8 @@ impl Options {
|
||||
.context("could not watch for templates changes")?;
|
||||
}
|
||||
|
||||
let graphql_schema = mas_handlers::graphql_schema(&pool);
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
pool,
|
||||
templates,
|
||||
@ -212,6 +214,7 @@ impl Options {
|
||||
mailer,
|
||||
homeserver,
|
||||
policy_factory,
|
||||
graphql_schema,
|
||||
});
|
||||
|
||||
let mut fd_manager = listenfd::ListenFd::from_env();
|
||||
|
@ -31,7 +31,7 @@ use rustls::ServerConfig;
|
||||
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Send,
|
||||
<B as HttpBody>::Data: Into<axum::body::Bytes> + Send,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
{
|
||||
let mut router = Router::with_state_arc(state.clone());
|
||||
@ -50,6 +50,9 @@ where
|
||||
mas_config::HttpResource::Human => {
|
||||
router.merge(mas_handlers::human_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::GraphQL { playground } => {
|
||||
router.merge(mas_handlers::graphql_router(state.clone(), *playground))
|
||||
}
|
||||
mas_config::HttpResource::Static { web_root } => {
|
||||
let handler = mas_static_files::service(web_root);
|
||||
router.nest(mas_router::StaticAsset::route(), handler)
|
||||
|
@ -234,6 +234,13 @@ pub enum Resource {
|
||||
/// Pages destined to be viewed by humans
|
||||
Human,
|
||||
|
||||
/// GraphQL endpoint
|
||||
GraphQL {
|
||||
/// Enabled the GraphQL playground
|
||||
#[serde(default)]
|
||||
playground: bool,
|
||||
},
|
||||
|
||||
/// OAuth-related APIs
|
||||
OAuth,
|
||||
|
||||
@ -300,6 +307,7 @@ impl Default for HttpConfig {
|
||||
Resource::Human,
|
||||
Resource::OAuth,
|
||||
Resource::Compat,
|
||||
Resource::GraphQL { playground: true },
|
||||
Resource::Static { web_root: None },
|
||||
],
|
||||
tls: None,
|
||||
|
@ -8,6 +8,8 @@ license = "Apache-2.0"
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1.21.2", features = ["macros"] }
|
||||
tokio-stream = "0.1.11"
|
||||
futures-util = "0.3.25"
|
||||
|
||||
# Logging and tracing
|
||||
tracing = "0.1.37"
|
||||
@ -20,10 +22,12 @@ anyhow = "1.0.66"
|
||||
hyper = { version = "0.14.22", features = ["full"] }
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.3.4", features = ["cors"] }
|
||||
axum = "0.6.0-rc.2"
|
||||
axum = { version = "0.6.0-rc.2", features = ["ws"] }
|
||||
axum-macros = "0.3.0-rc.1"
|
||||
axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] }
|
||||
|
||||
async-graphql = { version = "4.0.16", features = ["tracing", "apollo_tracing"] }
|
||||
|
||||
# Emails
|
||||
lettre = { version = "0.10.1", default-features = false, features = ["builder"] }
|
||||
|
||||
|
@ -22,7 +22,7 @@ use mas_router::UrlBuilder;
|
||||
use mas_templates::Templates;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::MatrixHomeserver;
|
||||
use crate::{GraphQLSchema, MatrixHomeserver};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@ -34,6 +34,7 @@ pub struct AppState {
|
||||
pub mailer: Mailer,
|
||||
pub homeserver: MatrixHomeserver,
|
||||
pub policy_factory: Arc<PolicyFactory>,
|
||||
pub graphql_schema: GraphQLSchema,
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for PgPool {
|
||||
@ -42,6 +43,12 @@ impl FromRef<AppState> for PgPool {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for GraphQLSchema {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.graphql_schema.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for Templates {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.templates.clone()
|
||||
|
251
crates/handlers/src/graphql/mod.rs
Normal file
251
crates/handlers/src/graphql/mod.rs
Normal file
@ -0,0 +1,251 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{borrow::Cow, str::FromStr, time::Duration};
|
||||
|
||||
use async_graphql::{
|
||||
extensions::{ApolloTracing, Tracing},
|
||||
futures_util::TryStreamExt,
|
||||
http::{
|
||||
playground_source, GraphQLPlaygroundConfig, MultipartOptions, WebSocketProtocols,
|
||||
WsMessage, ALL_WEBSOCKET_PROTOCOLS,
|
||||
},
|
||||
Context, Data, EmptyMutation,
|
||||
};
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{CloseFrame, Message},
|
||||
BodyStream, RawQuery, State, WebSocketUpgrade,
|
||||
},
|
||||
response::{Html, IntoResponse, Response},
|
||||
Json, TypedHeader,
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use futures_util::{SinkExt, Stream, StreamExt};
|
||||
use headers::{ContentType, Header, HeaderValue};
|
||||
use hyper::header::{CACHE_CONTROL, SEC_WEBSOCKET_PROTOCOL};
|
||||
use mas_axum_utils::{FancyError, SessionInfo, SessionInfoExt};
|
||||
use mas_keystore::Encrypter;
|
||||
use sqlx::PgPool;
|
||||
use tracing::{info_span, Instrument};
|
||||
|
||||
pub type Schema = async_graphql::Schema<Query, EmptyMutation, Subscription>;
|
||||
|
||||
#[must_use]
|
||||
pub fn schema(pool: &PgPool) -> Schema {
|
||||
async_graphql::Schema::build(Query::new(pool), EmptyMutation, Subscription)
|
||||
.extension(Tracing)
|
||||
.extension(ApolloTracing)
|
||||
.finish()
|
||||
}
|
||||
|
||||
fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
|
||||
let span = info_span!(
|
||||
"GraphQL operation",
|
||||
"otel.name" = tracing::field::Empty,
|
||||
"otel.kind" = "server",
|
||||
"graphql.document" = request.query,
|
||||
"graphql.operation.name" = tracing::field::Empty,
|
||||
);
|
||||
|
||||
if let Some(name) = &request.operation_name {
|
||||
span.record("otel.name", name);
|
||||
span.record("graphql.operation.name", name);
|
||||
}
|
||||
|
||||
span
|
||||
}
|
||||
|
||||
pub async fn post(
|
||||
State(schema): State<Schema>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
content_type: Option<TypedHeader<ContentType>>,
|
||||
body: BodyStream,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
|
||||
|
||||
let (session_info, _cookie_jar) = cookie_jar.session_info();
|
||||
|
||||
let request = async_graphql::http::receive_batch_body(
|
||||
content_type,
|
||||
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
.into_async_read(),
|
||||
MultipartOptions::default(),
|
||||
)
|
||||
.await? // XXX: this should probably return another error response?
|
||||
.data(session_info);
|
||||
|
||||
let response = match request {
|
||||
async_graphql::BatchRequest::Single(request) => {
|
||||
let span = span_for_graphql_request(&request);
|
||||
let response = schema.execute(request).instrument(span).await;
|
||||
async_graphql::BatchResponse::Single(response)
|
||||
}
|
||||
async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch(
|
||||
futures_util::stream::iter(requests.into_iter())
|
||||
.then(|request| {
|
||||
let span = span_for_graphql_request(&request);
|
||||
schema.execute(request).instrument(span)
|
||||
})
|
||||
.collect()
|
||||
.await,
|
||||
),
|
||||
};
|
||||
|
||||
let cache_control = response
|
||||
.cache_control()
|
||||
.value()
|
||||
.and_then(|v| HeaderValue::from_str(&v).ok())
|
||||
.map(|h| [(CACHE_CONTROL, h)]);
|
||||
|
||||
let headers = response.http_headers();
|
||||
|
||||
Ok((headers, cache_control, Json(response)))
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
State(schema): State<Schema>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let (session_info, _cookie_jar) = cookie_jar.session_info();
|
||||
let request =
|
||||
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(session_info);
|
||||
|
||||
let span = span_for_graphql_request(&request);
|
||||
let response = schema.execute(request).instrument(span).await;
|
||||
|
||||
let cache_control = response
|
||||
.cache_control
|
||||
.value()
|
||||
.and_then(|v| HeaderValue::from_str(&v).ok())
|
||||
.map(|h| [(CACHE_CONTROL, h)]);
|
||||
|
||||
let headers = response.http_headers.clone();
|
||||
|
||||
Ok((headers, cache_control, Json(response)))
|
||||
}
|
||||
|
||||
pub struct SecWebsocketProtocol(WebSocketProtocols);
|
||||
|
||||
impl Header for SecWebsocketProtocol {
|
||||
fn name() -> &'static headers::HeaderName {
|
||||
&SEC_WEBSOCKET_PROTOCOL
|
||||
}
|
||||
|
||||
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
|
||||
where
|
||||
Self: Sized,
|
||||
I: Iterator<Item = &'i HeaderValue>,
|
||||
{
|
||||
values
|
||||
.filter_map(|value| value.to_str().ok())
|
||||
.flat_map(|value| value.split(','))
|
||||
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
|
||||
.map(Self)
|
||||
.ok_or_else(headers::Error::invalid)
|
||||
}
|
||||
|
||||
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
|
||||
if let Ok(v) = HeaderValue::from_str(self.0.sec_websocket_protocol()) {
|
||||
values.extend(std::iter::once(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws(
|
||||
State(schema): State<Schema>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
TypedHeader(SecWebsocketProtocol(protocol)): TypedHeader<SecWebsocketProtocol>,
|
||||
websocket: WebSocketUpgrade,
|
||||
) -> Response {
|
||||
let (session_info, _cookie_jar) = cookie_jar.session_info();
|
||||
websocket
|
||||
.protocols(ALL_WEBSOCKET_PROTOCOLS)
|
||||
.on_upgrade(move |ws| async move {
|
||||
let (mut sink, stream) = ws.split();
|
||||
let stream = stream
|
||||
.take_while(|res| std::future::ready(res.is_ok()))
|
||||
.map(Result::unwrap)
|
||||
.filter_map(|msg| {
|
||||
if let Message::Text(_) | Message::Binary(_) = msg {
|
||||
std::future::ready(Some(msg.into_data()))
|
||||
} else {
|
||||
std::future::ready(None)
|
||||
}
|
||||
});
|
||||
|
||||
let mut data = Data::default();
|
||||
data.insert(session_info);
|
||||
|
||||
let mut stream = async_graphql::http::WebSocket::new(schema.clone(), stream, protocol)
|
||||
.connection_data(data)
|
||||
.map(|msg| match msg {
|
||||
WsMessage::Text(text) => Message::Text(text),
|
||||
WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
|
||||
code,
|
||||
reason: Cow::from(status),
|
||||
})),
|
||||
});
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
let _res = sink.send(item).await;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn playground() -> impl IntoResponse {
|
||||
Html(playground_source(
|
||||
GraphQLPlaygroundConfig::new("/graphql")
|
||||
.subscription_endpoint("/graphql/ws")
|
||||
.with_setting("request.credentials", "include"),
|
||||
))
|
||||
}
|
||||
|
||||
pub struct Query {
|
||||
database: PgPool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
fn new(pool: &PgPool) -> Self {
|
||||
Self {
|
||||
database: pool.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_graphql::Object]
|
||||
impl Query {
|
||||
async fn username(&self, ctx: &Context<'_>) -> Result<Option<String>, async_graphql::Error> {
|
||||
let mut conn = self.database.acquire().await?;
|
||||
let session_info = ctx.data::<SessionInfo>()?;
|
||||
let session = session_info.load_session(&mut conn).await?;
|
||||
|
||||
Ok(session.map(|s| s.user.username))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Subscription;
|
||||
|
||||
#[async_graphql::Subscription]
|
||||
impl Subscription {
|
||||
async fn integers(&self, #[graphql(default = 1)] step: i32) -> impl Stream<Item = i32> {
|
||||
let mut value = 0;
|
||||
tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(Duration::from_secs(1)))
|
||||
.map(move |_| {
|
||||
value += step;
|
||||
value
|
||||
})
|
||||
}
|
||||
}
|
@ -28,7 +28,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
body::HttpBody,
|
||||
body::{Bytes, HttpBody},
|
||||
extract::FromRef,
|
||||
response::{Html, IntoResponse},
|
||||
routing::{get, on, post, MethodFilter},
|
||||
@ -49,13 +49,17 @@ use tower_http::cors::{Any, CorsLayer};
|
||||
|
||||
mod app_state;
|
||||
mod compat;
|
||||
mod graphql;
|
||||
mod health;
|
||||
mod oauth2;
|
||||
mod views;
|
||||
|
||||
pub use compat::MatrixHomeserver;
|
||||
|
||||
pub use self::app_state::AppState;
|
||||
pub use self::{
|
||||
app_state::AppState,
|
||||
graphql::{schema as graphql_schema, Schema as GraphQLSchema},
|
||||
};
|
||||
|
||||
#[must_use]
|
||||
pub fn empty_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||
@ -76,6 +80,30 @@ where
|
||||
Router::with_state_arc(state).route(mas_router::Healthcheck::route(), get(self::health::get))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn graphql_router<S, B>(state: Arc<S>, playground: bool) -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Into<Bytes>,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Send + Sync + 'static,
|
||||
GraphQLSchema: FromRef<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
{
|
||||
let mut router = Router::with_state_arc(state)
|
||||
.route(
|
||||
"/graphql",
|
||||
get(self::graphql::get).post(self::graphql::post),
|
||||
)
|
||||
.route("/graphql/ws", get(self::graphql::ws));
|
||||
|
||||
if playground {
|
||||
router = router.route("/graphql/playground", get(self::graphql::playground));
|
||||
}
|
||||
|
||||
router
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn discovery_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||
where
|
||||
@ -305,7 +333,7 @@ where
|
||||
pub fn router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Send,
|
||||
<B as HttpBody>::Data: Into<Bytes> + Send,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Send + Sync + 'static,
|
||||
Keystore: FromRef<S>,
|
||||
@ -316,10 +344,12 @@ where
|
||||
Templates: FromRef<S>,
|
||||
Mailer: FromRef<S>,
|
||||
MatrixHomeserver: FromRef<S>,
|
||||
GraphQLSchema: FromRef<S>,
|
||||
{
|
||||
let healthcheck_router = healthcheck_router(state.clone());
|
||||
let discovery_router = discovery_router(state.clone());
|
||||
let api_router = api_router(state.clone());
|
||||
let graphql_router = graphql_router(state.clone(), true);
|
||||
let compat_router = compat_router(state.clone());
|
||||
let human_router = human_router(state.clone());
|
||||
|
||||
@ -328,6 +358,7 @@ where
|
||||
.merge(discovery_router)
|
||||
.merge(human_router)
|
||||
.merge(api_router)
|
||||
.merge(graphql_router)
|
||||
.merge(compat_router)
|
||||
}
|
||||
|
||||
@ -352,6 +383,8 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
|
||||
let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?;
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
|
||||
let graphql_schema = graphql_schema(&pool);
|
||||
|
||||
Ok(Arc::new(AppState {
|
||||
pool,
|
||||
templates,
|
||||
@ -361,6 +394,7 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
|
||||
mailer,
|
||||
homeserver,
|
||||
policy_factory,
|
||||
graphql_schema,
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -171,12 +171,14 @@ where
|
||||
hyper::server::conn::Http::new()
|
||||
.http2_only(true)
|
||||
.serve_connection(stream, service)
|
||||
.with_upgrades()
|
||||
.await?;
|
||||
} else {
|
||||
hyper::server::conn::Http::new()
|
||||
.http1_only(true)
|
||||
.http1_keep_alive(false)
|
||||
.serve_connection(stream, service)
|
||||
.with_upgrades()
|
||||
.await?;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user