1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Disallow OAuth 2.0 use of the GraphQL API by default

This commit is contained in:
Quentin Gliech
2024-08-07 17:36:54 +02:00
parent eb4072f3c3
commit 1bdad262cd
6 changed files with 70 additions and 14 deletions

View File

@@ -201,9 +201,13 @@ pub fn build_router(
mas_config::HttpResource::Human => { mas_config::HttpResource::Human => {
router.merge(mas_handlers::human_router::<AppState>(templates.clone())) router.merge(mas_handlers::human_router::<AppState>(templates.clone()))
} }
mas_config::HttpResource::GraphQL { playground } => { mas_config::HttpResource::GraphQL {
router.merge(mas_handlers::graphql_router::<AppState>(*playground)) playground,
} undocumented_oauth2_access,
} => router.merge(mas_handlers::graphql_router::<AppState>(
*playground,
*undocumented_oauth2_access,
)),
mas_config::HttpResource::Assets { path } => { mas_config::HttpResource::Assets { path } => {
let static_service = ServeDir::new(path) let static_service = ServeDir::new(path)
.append_index_html_on_directories(false) .append_index_html_on_directories(false)

View File

@@ -291,8 +291,12 @@ pub enum Resource {
/// GraphQL endpoint /// GraphQL endpoint
GraphQL { GraphQL {
/// Enabled the GraphQL playground /// Enabled the GraphQL playground
#[serde(default)] #[serde(default, skip_serializing_if = "std::ops::Not::not")]
playground: bool, playground: bool,
/// Allow access for OAuth 2.0 clients (undocumented)
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
undocumented_oauth2_access: bool,
}, },
/// OAuth-related APIs /// OAuth-related APIs
@@ -379,7 +383,10 @@ impl Default for HttpConfig {
Resource::Human, Resource::Human,
Resource::OAuth, Resource::OAuth,
Resource::Compat, Resource::Compat,
Resource::GraphQL { playground: true }, Resource::GraphQL {
playground: false,
undocumented_oauth2_access: false,
},
Resource::Assets { Resource::Assets {
path: http_listener_assets_path_default(), path: http_listener_assets_path_default(),
}, },

View File

@@ -27,7 +27,7 @@ use axum::{
extract::{RawQuery, State as AxumState}, extract::{RawQuery, State as AxumState},
http::StatusCode, http::StatusCode,
response::{Html, IntoResponse, Response}, response::{Html, IntoResponse, Response},
Json, Extension, Json,
}; };
use axum_extra::typed_header::TypedHeader; use axum_extra::typed_header::TypedHeader;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
@@ -65,6 +65,13 @@ use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivity
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
/// Extra parameters we get from the listener configuration, because they are
/// per-listener options. We pass them through request extensions.
#[derive(Debug, Clone)]
pub struct ExtraRouterParameters {
pub undocumented_oauth2_access: bool,
}
struct GraphQLState { struct GraphQLState {
pool: PgPool, pool: PgPool,
homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>, homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
@@ -217,6 +224,7 @@ impl IntoResponse for RouteError {
} }
async fn get_requester( async fn get_requester(
undocumented_oauth2_access: bool,
clock: &impl Clock, clock: &impl Clock,
activity_tracker: &BoundActivityTracker, activity_tracker: &BoundActivityTracker,
mut repo: BoxRepository, mut repo: BoxRepository,
@@ -224,6 +232,11 @@ async fn get_requester(
token: Option<&str>, token: Option<&str>,
) -> Result<Requester, RouteError> { ) -> Result<Requester, RouteError> {
let requester = if let Some(token) = token { let requester = if let Some(token) = token {
// If we haven't enabled undocumented_oauth2_access on the listener, we bail out
if !undocumented_oauth2_access {
return Err(RouteError::InvalidToken);
}
let token = repo let token = repo
.oauth2_access_token() .oauth2_access_token()
.find_by_token(token) .find_by_token(token)
@@ -281,6 +294,9 @@ async fn get_requester(
pub async fn post( pub async fn post(
AxumState(schema): AxumState<Schema>, AxumState(schema): AxumState<Schema>,
Extension(ExtraRouterParameters {
undocumented_oauth2_access,
}): Extension<ExtraRouterParameters>,
clock: BoxClock, clock: BoxClock,
repo: BoxRepository, repo: BoxRepository,
activity_tracker: BoundActivityTracker, activity_tracker: BoundActivityTracker,
@@ -294,7 +310,15 @@ pub async fn post(
.as_ref() .as_ref()
.map(|TypedHeader(Authorization(bearer))| bearer.token()); .map(|TypedHeader(Authorization(bearer))| bearer.token());
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let requester = get_requester(&clock, &activity_tracker, repo, session_info, token).await?; let requester = get_requester(
undocumented_oauth2_access,
&clock,
&activity_tracker,
repo,
session_info,
token,
)
.await?;
let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let content_type = content_type.map(|TypedHeader(h)| h.to_string());
@@ -323,6 +347,9 @@ pub async fn post(
pub async fn get( pub async fn get(
AxumState(schema): AxumState<Schema>, AxumState(schema): AxumState<Schema>,
Extension(ExtraRouterParameters {
undocumented_oauth2_access,
}): Extension<ExtraRouterParameters>,
clock: BoxClock, clock: BoxClock,
repo: BoxRepository, repo: BoxRepository,
activity_tracker: BoundActivityTracker, activity_tracker: BoundActivityTracker,
@@ -334,7 +361,15 @@ pub async fn get(
.as_ref() .as_ref()
.map(|TypedHeader(Authorization(bearer))| bearer.token()); .map(|TypedHeader(Authorization(bearer))| bearer.token());
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let requester = get_requester(&clock, &activity_tracker, repo, session_info, token).await?; let requester = get_requester(
undocumented_oauth2_access,
&clock,
&activity_tracker,
repo,
session_info,
token,
)
.await?;
let request = let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester); async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);

View File

@@ -30,8 +30,9 @@ use axum::{
http::Method, http::Method,
response::{Html, IntoResponse}, response::{Html, IntoResponse},
routing::{get, post}, routing::{get, post},
Router, Extension, Router,
}; };
use graphql::ExtraRouterParameters;
use headers::HeaderName; use headers::HeaderName;
use hyper::{ use hyper::{
header::{ header::{
@@ -108,7 +109,7 @@ where
Router::new().route(mas_router::Healthcheck::route(), get(self::health::get)) Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
} }
pub fn graphql_router<S>(playground: bool) -> Router<S> pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
where where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
graphql::Schema: FromRef<S>, graphql::Schema: FromRef<S>,
@@ -123,6 +124,11 @@ where
mas_router::GraphQL::route(), mas_router::GraphQL::route(),
get(self::graphql::get).post(self::graphql::post), get(self::graphql::get).post(self::graphql::post),
) )
// Pass the undocumented_oauth2_access parameter through the request extension, as it is
// per-listener
.layer(Extension(ExtraRouterParameters {
undocumented_oauth2_access,
}))
.layer( .layer(
CorsLayer::new() CorsLayer::new()
.allow_origin(Any) .allow_origin(Any)

View File

@@ -249,7 +249,9 @@ impl TestState {
.merge(crate::api_router()) .merge(crate::api_router())
.merge(crate::compat_router()) .merge(crate::compat_router())
.merge(crate::human_router(self.templates.clone())) .merge(crate::human_router(self.templates.clone()))
.merge(crate::graphql_router(false)) // We enable undocumented_oauth2_access for the tests, as it is easier to query the API
// with it
.merge(crate::graphql_router(false, true))
.merge(crate::admin_api_router().1) .merge(crate::admin_api_router().1)
.with_state(self.clone()) .with_state(self.clone())
.into_service(); .into_service();

View File

@@ -35,8 +35,7 @@
"name": "compat" "name": "compat"
}, },
{ {
"name": "graphql", "name": "graphql"
"playground": true
}, },
{ {
"name": "assets" "name": "assets"
@@ -742,7 +741,10 @@
}, },
"playground": { "playground": {
"description": "Enabled the GraphQL playground", "description": "Enabled the GraphQL playground",
"default": false, "type": "boolean"
},
"undocumented_oauth2_access": {
"description": "Allow access for OAuth 2.0 clients (undocumented)",
"type": "boolean" "type": "boolean"
} }
} }