You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Better CORS filter to allow OTEL propagator headers
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1464,6 +1464,8 @@ dependencies = [
|
|||||||
"mas-data-model",
|
"mas-data-model",
|
||||||
"mime",
|
"mime",
|
||||||
"oauth2-types",
|
"oauth2-types",
|
||||||
|
"once_cell",
|
||||||
|
"opentelemetry",
|
||||||
"password-hash",
|
"password-hash",
|
||||||
"pkcs8",
|
"pkcs8",
|
||||||
"rand 0.8.4",
|
"rand 0.8.4",
|
||||||
|
@ -30,7 +30,12 @@ use opentelemetry_semantic_conventions as semcov;
|
|||||||
|
|
||||||
pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
|
pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
|
||||||
global::set_error_handler(|e| tracing::error!("{}", e))?;
|
global::set_error_handler(|e| tracing::error!("{}", e))?;
|
||||||
global::set_text_map_propagator(propagator());
|
let propagator = propagator();
|
||||||
|
|
||||||
|
// The CORS filter needs to know what headers it should whitelist for
|
||||||
|
// CORS-protected requests.
|
||||||
|
mas_core::filters::cors::set_propagator(&propagator);
|
||||||
|
global::set_text_map_propagator(propagator);
|
||||||
|
|
||||||
let tracer = tracer(&config.tracing)?;
|
let tracer = tracer(&config.tracing)?;
|
||||||
meter(&config.metrics)?;
|
meter(&config.metrics)?;
|
||||||
|
@ -67,6 +67,8 @@ cookie = "0.15.1"
|
|||||||
oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] }
|
oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] }
|
||||||
mas-config = { path = "../config" }
|
mas-config = { path = "../config" }
|
||||||
mas-data-model = { path = "../data-model" }
|
mas-data-model = { path = "../data-model" }
|
||||||
|
opentelemetry = "0.16.0"
|
||||||
|
once_cell = "1.8.0"
|
||||||
|
|
||||||
[dependencies.jwt-compact]
|
[dependencies.jwt-compact]
|
||||||
# Waiting on the next release because of the bump of the `rsa` dependency
|
# Waiting on the next release because of the bump of the `rsa` dependency
|
||||||
|
37
crates/core/src/filters/cors.rs
Normal file
37
crates/core/src/filters/cors.rs
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! Wrapper around [`warp::filters::cors`]
|
||||||
|
|
||||||
|
use std::string::ToString;
|
||||||
|
|
||||||
|
use once_cell::sync::OnceCell;
|
||||||
|
|
||||||
|
static PROPAGATOR_HEADERS: OnceCell<Vec<String>> = OnceCell::new();
|
||||||
|
|
||||||
|
/// Notify the CORS filter what opentelemetry propagators are being used. This
|
||||||
|
/// helps whitelisting headers in CORS requests.
|
||||||
|
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
|
||||||
|
PROPAGATOR_HEADERS
|
||||||
|
.set(propagator.fields().map(ToString::to_string).collect())
|
||||||
|
.expect(concat!(module_path!(), "::set_propagator was called twice"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a wrapping filter that exposes CORS behavior for a wrapped filter.
|
||||||
|
#[must_use]
|
||||||
|
pub fn cors() -> warp::filters::cors::Builder {
|
||||||
|
warp::filters::cors::cors()
|
||||||
|
.allow_any_origin()
|
||||||
|
.allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new()))
|
||||||
|
}
|
@ -1,200 +0,0 @@
|
|||||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
use std::{cmp::Reverse, future::Future, pin::Pin};
|
|
||||||
|
|
||||||
use mime::{Mime, STAR};
|
|
||||||
use serde::Serialize;
|
|
||||||
use tera::Context;
|
|
||||||
use tide::{
|
|
||||||
http::headers::{ACCEPT, LOCATION},
|
|
||||||
Body, Request, StatusCode,
|
|
||||||
};
|
|
||||||
use tracing::debug;
|
|
||||||
|
|
||||||
use crate::{state::State, templates::common_context};
|
|
||||||
|
|
||||||
/// Get the weight parameter for a mime type from 0 to 1000
|
|
||||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
|
||||||
fn get_weight(mime: &Mime) -> usize {
|
|
||||||
let q = mime
|
|
||||||
.get_param("q")
|
|
||||||
.map_or(1.0_f64, |q| q.as_str().parse().unwrap_or(0.0))
|
|
||||||
.min(1.0)
|
|
||||||
.max(0.0);
|
|
||||||
|
|
||||||
// Weight have a 3 digit precision so we can multiply by 1000 and cast to
|
|
||||||
// int. Sign loss should not happen here because of the min/max up there and
|
|
||||||
// truncation does not matter here.
|
|
||||||
(q * 1000.0) as _
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find what content type should be used for a given request
|
|
||||||
fn preferred_mime_type<'a>(
|
|
||||||
request: &Request<State>,
|
|
||||||
supported_types: &'a [Mime],
|
|
||||||
) -> Option<&'a Mime> {
|
|
||||||
let accept = request.header(ACCEPT)?;
|
|
||||||
// Parse the Accept header as a list of mime types with their associated
|
|
||||||
// weight
|
|
||||||
let accepted_types: Vec<(Mime, usize)> = {
|
|
||||||
let v: Option<Vec<_>> = accept
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|value| value.as_str().split(','))
|
|
||||||
.map(|mime| {
|
|
||||||
mime.trim().parse().ok().map(|mime| {
|
|
||||||
let q = get_weight(&mime);
|
|
||||||
(mime, q)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let mut v = v?;
|
|
||||||
v.sort_by_key(|(_, weight)| Reverse(*weight));
|
|
||||||
v
|
|
||||||
};
|
|
||||||
|
|
||||||
// For each supported content type, find out if it is accepted with what
|
|
||||||
// weight and specificity
|
|
||||||
let mut types: Vec<_> = supported_types
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.filter_map(|(index, supported)| {
|
|
||||||
accepted_types.iter().find_map(|(accepted, weight)| {
|
|
||||||
if accepted.type_() == supported.type_()
|
|
||||||
&& accepted.subtype() == supported.subtype()
|
|
||||||
{
|
|
||||||
// Accept: text/html
|
|
||||||
Some((supported, *weight, 2_usize, index))
|
|
||||||
} else if accepted.type_() == supported.type_() && accepted.subtype() == STAR {
|
|
||||||
// Accept: text/*
|
|
||||||
Some((supported, *weight, 1, index))
|
|
||||||
} else if accepted.type_() == STAR && accepted.subtype() == STAR {
|
|
||||||
// Accept: */*
|
|
||||||
Some((supported, *weight, 0, index))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
types.sort_by_key(|(_, weight, specificity, index)| {
|
|
||||||
(Reverse(*weight), Reverse(*specificity), *index)
|
|
||||||
});
|
|
||||||
|
|
||||||
types.first().map(|(mime, _, _, _)| *mime)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ErrorContext {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
code: Option<String>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
description: Option<String>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
details: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ErrorContext {
|
|
||||||
fn should_render(&self) -> bool {
|
|
||||||
self.code.is_some() || self.description.is_some() || self.details.is_some()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn middleware<'a>(
|
|
||||||
request: tide::Request<State>,
|
|
||||||
next: tide::Next<'a, State>,
|
|
||||||
) -> Pin<Box<dyn Future<Output = tide::Result> + Send + 'a>> {
|
|
||||||
Box::pin(async {
|
|
||||||
let content_type = preferred_mime_type(
|
|
||||||
&request,
|
|
||||||
&[mime::TEXT_PLAIN, mime::TEXT_HTML, mime::APPLICATION_JSON],
|
|
||||||
);
|
|
||||||
debug!("Content-Type from Accept: {:?}", content_type);
|
|
||||||
|
|
||||||
// TODO: We should not clone here
|
|
||||||
let templates = request.state().templates().clone();
|
|
||||||
|
|
||||||
// TODO: This context should probably be comptuted somewhere else
|
|
||||||
let pctx = common_context(&request).await?.clone();
|
|
||||||
|
|
||||||
let mut response = next.run(request).await;
|
|
||||||
|
|
||||||
// Find out what message should be displayed from the response status
|
|
||||||
// code
|
|
||||||
let (code, description) = match response.status() {
|
|
||||||
StatusCode::NotFound => (Some("Not found".to_string()), None),
|
|
||||||
StatusCode::MethodNotAllowed => (Some("Method not allowed".to_string()), None),
|
|
||||||
StatusCode::Found
|
|
||||||
| StatusCode::PermanentRedirect
|
|
||||||
| StatusCode::TemporaryRedirect
|
|
||||||
| StatusCode::SeeOther => {
|
|
||||||
let description = response.header(LOCATION).map(|loc| format!("To {}", loc));
|
|
||||||
(Some("Redirecting".to_string()), description)
|
|
||||||
}
|
|
||||||
StatusCode::InternalServerError => (Some("Internal server error".to_string()), None),
|
|
||||||
_ => (None, None),
|
|
||||||
};
|
|
||||||
|
|
||||||
// If there is an error associated to the response, format it in a nice
|
|
||||||
// way with a backtrace if we have one
|
|
||||||
let details = response.take_error().map(|err| {
|
|
||||||
format!(
|
|
||||||
"{:?}{}",
|
|
||||||
err,
|
|
||||||
err.backtrace()
|
|
||||||
.map(|bt| format!("\nBacktrace:\n{}", bt.to_string()))
|
|
||||||
.unwrap_or_default()
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
let error_context = ErrorContext {
|
|
||||||
code,
|
|
||||||
description,
|
|
||||||
details,
|
|
||||||
};
|
|
||||||
|
|
||||||
// This is the case if one of the code, description or details is not
|
|
||||||
// None
|
|
||||||
if error_context.should_render() {
|
|
||||||
match content_type {
|
|
||||||
Some(c) if c == &mime::APPLICATION_JSON => {
|
|
||||||
response.set_body(Body::from_json(&error_context)?);
|
|
||||||
response.set_content_type("application/json");
|
|
||||||
}
|
|
||||||
Some(c) if c == &mime::TEXT_HTML => {
|
|
||||||
let mut ctx = Context::from_serialize(&error_context)?;
|
|
||||||
ctx.extend(pctx);
|
|
||||||
response.set_body(templates.render("error.html", &ctx)?);
|
|
||||||
response.set_content_type("text/html");
|
|
||||||
}
|
|
||||||
Some(c) if c == &mime::TEXT_PLAIN => {
|
|
||||||
let mut ctx = Context::from_serialize(&error_context)?;
|
|
||||||
ctx.extend(pctx);
|
|
||||||
response.set_body(templates.render("error.txt", &ctx)?);
|
|
||||||
response.set_content_type("text/plain");
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
response.set_body("Unsupported Content-Type in Accept header");
|
|
||||||
response.set_content_type("text/plain");
|
|
||||||
response.set_status(StatusCode::NotAcceptable);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
})
|
|
||||||
}
|
|
@ -17,11 +17,11 @@
|
|||||||
#![allow(clippy::unused_async)] // Some warp filters need that
|
#![allow(clippy::unused_async)] // Some warp filters need that
|
||||||
#![deny(missing_docs)]
|
#![deny(missing_docs)]
|
||||||
|
|
||||||
pub mod csrf;
|
|
||||||
// mod errors;
|
|
||||||
pub mod authenticate;
|
pub mod authenticate;
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod cookies;
|
pub mod cookies;
|
||||||
|
pub mod cors;
|
||||||
|
pub mod csrf;
|
||||||
pub mod database;
|
pub mod database;
|
||||||
pub mod headers;
|
pub mod headers;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use hyper::Method;
|
use hyper::Method;
|
||||||
|
use mas_config::OAuth2Config;
|
||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
oidc::Metadata,
|
oidc::Metadata,
|
||||||
pkce::CodeChallengeMethod,
|
pkce::CodeChallengeMethod,
|
||||||
@ -22,7 +23,7 @@ use oauth2_types::{
|
|||||||
};
|
};
|
||||||
use warp::{Filter, Rejection, Reply};
|
use warp::{Filter, Rejection, Reply};
|
||||||
|
|
||||||
use crate::config::OAuth2Config;
|
use crate::filters::cors::cors;
|
||||||
|
|
||||||
pub(super) fn filter(
|
pub(super) fn filter(
|
||||||
config: &OAuth2Config,
|
config: &OAuth2Config,
|
||||||
@ -87,15 +88,9 @@ pub(super) fn filter(
|
|||||||
code_challenge_methods_supported,
|
code_challenge_methods_supported,
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: get the headers list from the global opentelemetry propagators
|
|
||||||
let cors = warp::cors()
|
|
||||||
.allow_method(Method::GET)
|
|
||||||
.allow_any_origin()
|
|
||||||
.allow_headers(["traceparent"]);
|
|
||||||
|
|
||||||
warp::path!(".well-known" / "openid-configuration").and(
|
warp::path!(".well-known" / "openid-configuration").and(
|
||||||
warp::get()
|
warp::get()
|
||||||
.map(move || warp::reply::json(&metadata))
|
.map(move || warp::reply::json(&metadata))
|
||||||
.with(cors),
|
.with(cors().allow_method(Method::GET)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
use hyper::Method;
|
||||||
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
|
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
|
||||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
@ -23,6 +24,7 @@ use crate::{
|
|||||||
errors::WrapError,
|
errors::WrapError,
|
||||||
filters::{
|
filters::{
|
||||||
client::{client_authentication, ClientAuthentication},
|
client::{client_authentication, ClientAuthentication},
|
||||||
|
cors::cors,
|
||||||
database::connection,
|
database::connection,
|
||||||
},
|
},
|
||||||
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
|
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
|
||||||
@ -33,12 +35,14 @@ pub fn filter(
|
|||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
oauth2_config: &OAuth2Config,
|
oauth2_config: &OAuth2Config,
|
||||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||||
warp::path!("oauth2" / "introspect")
|
warp::path!("oauth2" / "introspect").and(
|
||||||
.and(warp::post())
|
warp::post()
|
||||||
.and(connection(pool))
|
.and(connection(pool))
|
||||||
.and(client_authentication(oauth2_config))
|
.and(client_authentication(oauth2_config))
|
||||||
.and_then(introspect)
|
.and_then(introspect)
|
||||||
.recover(recover)
|
.recover(recover)
|
||||||
|
.with(cors().allow_method(Method::POST)),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
||||||
|
@ -12,19 +12,20 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use hyper::Method;
|
||||||
|
use mas_config::OAuth2Config;
|
||||||
use warp::{Filter, Rejection, Reply};
|
use warp::{Filter, Rejection, Reply};
|
||||||
|
|
||||||
use crate::config::OAuth2Config;
|
use crate::filters::cors::cors;
|
||||||
|
|
||||||
pub(super) fn filter(
|
pub(super) fn filter(
|
||||||
config: &OAuth2Config,
|
config: &OAuth2Config,
|
||||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||||
let jwks = config.keys.to_public_jwks();
|
let jwks = config.keys.to_public_jwks();
|
||||||
|
|
||||||
let cors = warp::cors().allow_any_origin();
|
warp::path!("oauth2" / "keys.json").and(
|
||||||
|
warp::get()
|
||||||
warp::path!("oauth2" / "keys.json")
|
.map(move || warp::reply::json(&jwks))
|
||||||
.and(warp::get())
|
.with(cors().allow_method(Method::GET)),
|
||||||
.map(move || warp::reply::json(&jwks))
|
)
|
||||||
.with(cors)
|
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ use anyhow::Context;
|
|||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use data_encoding::BASE64URL_NOPAD;
|
use data_encoding::BASE64URL_NOPAD;
|
||||||
use headers::{CacheControl, Pragma};
|
use headers::{CacheControl, Pragma};
|
||||||
use hyper::StatusCode;
|
use hyper::{Method, StatusCode};
|
||||||
use jwt_compact::{Claims, Header, TimeOptions};
|
use jwt_compact::{Claims, Header, TimeOptions};
|
||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
errors::{
|
errors::{
|
||||||
@ -44,6 +44,7 @@ use crate::{
|
|||||||
errors::WrapError,
|
errors::WrapError,
|
||||||
filters::{
|
filters::{
|
||||||
client::{client_authentication, ClientAuthentication},
|
client::{client_authentication, ClientAuthentication},
|
||||||
|
cors::cors,
|
||||||
database::connection,
|
database::connection,
|
||||||
with_keys,
|
with_keys,
|
||||||
},
|
},
|
||||||
@ -92,14 +93,16 @@ pub fn filter(
|
|||||||
oauth2_config: &OAuth2Config,
|
oauth2_config: &OAuth2Config,
|
||||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||||
let issuer = oauth2_config.issuer.clone();
|
let issuer = oauth2_config.issuer.clone();
|
||||||
warp::path!("oauth2" / "token")
|
warp::path!("oauth2" / "token").and(
|
||||||
.and(warp::post())
|
warp::post()
|
||||||
.and(client_authentication(oauth2_config))
|
.and(client_authentication(oauth2_config))
|
||||||
.and(with_keys(oauth2_config))
|
.and(with_keys(oauth2_config))
|
||||||
.and(warp::any().map(move || issuer.clone()))
|
.and(warp::any().map(move || issuer.clone()))
|
||||||
.and(connection(pool))
|
.and(connection(pool))
|
||||||
.and_then(token)
|
.and_then(token)
|
||||||
.recover(recover)
|
.recover(recover)
|
||||||
|
.with(cors().allow_method(Method::POST)),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
||||||
|
@ -12,13 +12,17 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use hyper::Method;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use warp::{Filter, Rejection, Reply};
|
use warp::{Filter, Rejection, Reply};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::OAuth2Config,
|
config::OAuth2Config,
|
||||||
filters::authenticate::{authentication, recover_unauthorized},
|
filters::{
|
||||||
|
authenticate::{authentication, recover_unauthorized},
|
||||||
|
cors::cors,
|
||||||
|
},
|
||||||
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -31,11 +35,15 @@ pub(super) fn filter(
|
|||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
_config: &OAuth2Config,
|
_config: &OAuth2Config,
|
||||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||||
warp::path!("oauth2" / "userinfo")
|
warp::path!("oauth2" / "userinfo").and(
|
||||||
.and(warp::get().or(warp::post()).unify())
|
warp::get()
|
||||||
.and(authentication(pool))
|
.or(warp::post())
|
||||||
.and_then(userinfo)
|
.unify()
|
||||||
.recover(recover_unauthorized)
|
.and(authentication(pool))
|
||||||
|
.and_then(userinfo)
|
||||||
|
.recover(recover_unauthorized)
|
||||||
|
.with(cors().allow_methods([Method::GET, Method::POST])),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
|
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
|
||||||
|
Reference in New Issue
Block a user