You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Get rid of warp
This commit is contained in:
@ -16,7 +16,6 @@ tower = { version = "0.4.12", features = ["full"] }
|
||||
hyper = { version = "0.14.17", features = ["full"] }
|
||||
serde_yaml = "0.8.23"
|
||||
serde_json = "1.0.79"
|
||||
warp = "0.3.2"
|
||||
url = "2.2.2"
|
||||
argon2 = { version = "0.3.4", features = ["password-hash"] }
|
||||
reqwest = { version = "0.11.10", features = ["rustls-tls"], default-features = false, optional = true }
|
||||
@ -42,7 +41,6 @@ mas-http = { path = "../http" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-tasks = { path = "../tasks" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-warp-utils = { path = "../warp-utils" }
|
||||
mas-axum-utils = { path = "../axum-utils" }
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -40,7 +40,7 @@ pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
|
||||
|
||||
// The CORS filter needs to know what headers it should whitelist for
|
||||
// CORS-protected requests.
|
||||
mas_warp_utils::filters::cors::set_propagator(&propagator);
|
||||
// TODO mas_warp_utils::filters::cors::set_propagator(&propagator);
|
||||
global::set_text_map_propagator(propagator);
|
||||
|
||||
let tracer = tracer(&config.tracing.exporter)?;
|
||||
|
@ -20,7 +20,6 @@ thiserror = "1.0.30"
|
||||
anyhow = "1.0.56"
|
||||
|
||||
# Web server
|
||||
warp = "0.3.2"
|
||||
hyper = { version = "0.14.17", features = ["full"] }
|
||||
tower = "0.4.12"
|
||||
axum = "0.4.8"
|
||||
@ -67,7 +66,6 @@ mas-jose = { path = "../jose" }
|
||||
mas-static-files = { path = "../static-files" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-warp-utils = { path = "../warp-utils" }
|
||||
|
||||
[dev-dependencies]
|
||||
indoc = "1.0.4"
|
||||
|
@ -16,7 +16,7 @@
|
||||
#![deny(clippy::all, rustdoc::broken_intra_doc_links)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::unused_async // Some warp filters need that
|
||||
clippy::unused_async // Some axum handlers need that
|
||||
)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
@ -14,7 +14,6 @@ serde_json = "1.0.79"
|
||||
thiserror = "1.0.30"
|
||||
anyhow = "1.0.56"
|
||||
tracing = "0.1.32"
|
||||
warp = "0.3.2"
|
||||
|
||||
# Password hashing
|
||||
argon2 = { version = "0.3.4", features = ["password-hash"] }
|
||||
|
@ -21,7 +21,6 @@ use oauth2_types::requests::GrantType;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
use warp::reject::Reject;
|
||||
|
||||
use crate::PostgresqlBackend;
|
||||
|
||||
@ -79,8 +78,6 @@ impl ClientFetchError {
|
||||
}
|
||||
}
|
||||
|
||||
impl Reject for ClientFetchError {}
|
||||
|
||||
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
|
||||
type Error = ClientFetchError;
|
||||
|
||||
|
@ -19,7 +19,6 @@ use mas_data_model::{
|
||||
};
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use warp::reject::Reject;
|
||||
|
||||
use super::client::{lookup_client_by_client_id, ClientFetchError};
|
||||
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
|
||||
@ -87,8 +86,6 @@ pub enum RefreshTokenLookupError {
|
||||
Conversion(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl Reject for RefreshTokenLookupError {}
|
||||
|
||||
impl RefreshTokenLookupError {
|
||||
#[must_use]
|
||||
pub fn not_found(&self) -> bool {
|
||||
|
@ -27,7 +27,6 @@ use sqlx::{postgres::types::PgInterval, Acquire, PgExecutor, Postgres, Transacti
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
use tracing::{info_span, Instrument};
|
||||
use warp::reject::Reject;
|
||||
|
||||
use super::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
use crate::IdAndCreationTime;
|
||||
@ -117,8 +116,6 @@ pub enum ActiveSessionLookupError {
|
||||
Conversion(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl Reject for ActiveSessionLookupError {}
|
||||
|
||||
impl ActiveSessionLookupError {
|
||||
#[must_use]
|
||||
pub fn not_found(&self) -> bool {
|
||||
|
@ -21,7 +21,6 @@ serde_json = "1.0.79"
|
||||
serde_urlencoded = "0.7.1"
|
||||
|
||||
url = "2.2.2"
|
||||
warp = "0.3.2"
|
||||
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
|
@ -271,8 +271,6 @@ pub enum TemplateError {
|
||||
},
|
||||
}
|
||||
|
||||
impl warp::reject::Reject for TemplateError {}
|
||||
|
||||
register_templates! {
|
||||
extra = {
|
||||
"components/button.html",
|
||||
|
@ -1,42 +0,0 @@
|
||||
[package]
|
||||
name = "mas-warp-utils"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.17.0", features = ["macros"] }
|
||||
headers = "0.3.7"
|
||||
cookie = "0.16.0"
|
||||
warp = "0.3.2"
|
||||
hyper = { version = "0.14.17", features = ["full"] }
|
||||
thiserror = "1.0.30"
|
||||
anyhow = "1.0.56"
|
||||
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres"] }
|
||||
chrono = { version = "0.4.19", features = ["serde"] }
|
||||
serde = { version = "1.0.136", features = ["derive"] }
|
||||
serde_with = { version = "1.12.0", features = ["hex", "chrono"] }
|
||||
serde_json = "1.0.79"
|
||||
serde_urlencoded = "0.7.1"
|
||||
data-encoding = "2.3.2"
|
||||
once_cell = "1.10.0"
|
||||
tracing = "0.1.32"
|
||||
opentelemetry = "0.17.0"
|
||||
rand = "0.8.5"
|
||||
mime = "0.3.16"
|
||||
bincode = "1.3.3"
|
||||
crc = "2.1.0"
|
||||
url = "2.2.2"
|
||||
http = "0.2.6"
|
||||
http-body = "0.4.4"
|
||||
tower = { version = "0.4.12", features = ["util"] }
|
||||
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
mas-config = { path = "../config" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-jose = { path = "../jose" }
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-http = { path = "../http" }
|
@ -1,42 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Helper to deal with various unstructured errors in application code
|
||||
|
||||
use warp::{reject::Reject, Rejection};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct WrappedError(anyhow::Error);
|
||||
|
||||
impl warp::reject::Reject for WrappedError {}
|
||||
|
||||
/// Wrap any error in a [`Rejection`]
|
||||
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
|
||||
WrappedError(e.into())
|
||||
}
|
||||
|
||||
/// Extension trait that wraps errors in [`Rejection`]s
|
||||
pub trait WrapError<T> {
|
||||
/// Wrap transform the [`Result`] error type to a [`Rejection`]
|
||||
fn wrap_error(self) -> Result<T, Rejection>;
|
||||
}
|
||||
|
||||
impl<T, E> WrapError<T> for Result<T, E>
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn wrap_error(self) -> Result<T, Rejection> {
|
||||
self.map_err(|e| warp::reject::custom(WrappedError(e.into())))
|
||||
}
|
||||
}
|
@ -1,156 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Authenticate an endpoint with an access token as bearer authorization token
|
||||
|
||||
use headers::{authorization::Bearer, Authorization};
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::{AccessToken, Session, TokenFormatError, TokenType};
|
||||
use mas_storage::{
|
||||
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use thiserror::Error;
|
||||
use warp::{
|
||||
reject::{MissingHeader, Reject},
|
||||
reply::{with_header, with_status},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use super::{
|
||||
database::connection,
|
||||
headers::{typed_header, InvalidTypedHeader},
|
||||
};
|
||||
use crate::errors::wrapped_error;
|
||||
|
||||
/// Bearer token authentication failed
|
||||
///
|
||||
/// This is recoverable with [`recover_unauthorized`]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthenticationError {
|
||||
/// The bearer token has an invalid format
|
||||
#[error("invalid token format")]
|
||||
TokenFormat(#[from] TokenFormatError),
|
||||
|
||||
/// The bearer token is not an access token
|
||||
#[error("invalid token type {0:?}, expected an access token")]
|
||||
WrongTokenType(TokenType),
|
||||
|
||||
/// The access token was not found in the database
|
||||
#[error("unknown token")]
|
||||
TokenNotFound(#[source] AccessTokenLookupError),
|
||||
|
||||
/// The `Authorization` header is missing
|
||||
#[error("missing authorization header")]
|
||||
MissingAuthorizationHeader,
|
||||
|
||||
/// The `Authorization` header is invalid
|
||||
#[error("invalid authorization header")]
|
||||
InvalidAuthorizationHeader,
|
||||
}
|
||||
|
||||
impl Reject for AuthenticationError {}
|
||||
|
||||
/// Authenticate a request using an access token as a bearer authorization
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`AuthenticationError`] or with a generic
|
||||
/// wrapped sqlx error.
|
||||
#[must_use]
|
||||
pub fn authentication(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<
|
||||
Extract = (AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>),
|
||||
Error = Rejection,
|
||||
> + Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
connection(pool)
|
||||
.and(typed_header())
|
||||
.and_then(authenticate)
|
||||
.recover(recover)
|
||||
.unify()
|
||||
.untuple_one()
|
||||
}
|
||||
|
||||
fn ensure<T: Clone + Send + Sync + 'static>(t: T) -> T {
|
||||
t
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
auth: Authorization<Bearer>,
|
||||
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
|
||||
let token = auth.0.token();
|
||||
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
|
||||
|
||||
if token_type != TokenType::AccessToken {
|
||||
return Err(AuthenticationError::WrongTokenType(token_type).into());
|
||||
}
|
||||
|
||||
let (token, session) = lookup_active_access_token(&mut conn, token)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.not_found() {
|
||||
// This error happens if the token was not found and should be recovered
|
||||
warp::reject::custom(AuthenticationError::TokenNotFound(e))
|
||||
} else {
|
||||
// This is a generic database error that we want to propagate
|
||||
warp::reject::custom(wrapped_error(e))
|
||||
}
|
||||
})?;
|
||||
|
||||
let session = ensure(session);
|
||||
let token = ensure(token);
|
||||
|
||||
Ok((token, session))
|
||||
}
|
||||
|
||||
/// Transform the rejections from the [`with_typed_header`] filter
|
||||
async fn recover(
|
||||
rejection: Rejection,
|
||||
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
|
||||
if rejection.find::<MissingHeader>().is_some() {
|
||||
return Err(warp::reject::custom(
|
||||
AuthenticationError::MissingAuthorizationHeader,
|
||||
));
|
||||
}
|
||||
|
||||
if rejection.find::<InvalidTypedHeader>().is_some() {
|
||||
return Err(warp::reject::custom(
|
||||
AuthenticationError::InvalidAuthorizationHeader,
|
||||
));
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
||||
|
||||
/// Recover from an [`AuthenticationError`] with a `WWW-Authenticate` header, as
|
||||
/// per [RFC6750]. This is not intended for user-facing endpoints.
|
||||
///
|
||||
/// [RFC6750]: https://www.rfc-editor.org/rfc/rfc6750.html
|
||||
pub async fn recover_unauthorized(rejection: Rejection) -> Result<Box<dyn Reply>, Rejection> {
|
||||
if rejection.find::<AuthenticationError>().is_some() {
|
||||
// TODO: have the issuer/realm here
|
||||
let reply = "invalid token";
|
||||
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
|
||||
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
|
||||
return Ok(Box::new(reply));
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
@ -1,772 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Handle client authentication
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use data_encoding::BASE64;
|
||||
use headers::{authorization::Basic, Authorization};
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
|
||||
use mas_http::HttpServiceExt;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::{
|
||||
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
|
||||
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebKeySet, JsonWebTokenParts, SharedSecret,
|
||||
StaticJwksStore, VerifyingKeystore,
|
||||
};
|
||||
use mas_storage::{
|
||||
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use thiserror::Error;
|
||||
use tower::{BoxError, ServiceExt};
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
use super::{database::connection, headers::typed_header};
|
||||
use crate::errors::WrapError;
|
||||
|
||||
/// Protect an enpoint with client authentication
|
||||
#[must_use]
|
||||
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
audience: String,
|
||||
) -> impl Filter<
|
||||
Extract = (
|
||||
OAuthClientAuthenticationMethod,
|
||||
Client<PostgresqlBackend>,
|
||||
T,
|
||||
),
|
||||
Error = Rejection,
|
||||
> + Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
let encrypter = encrypter.clone();
|
||||
|
||||
// First, extract the client credentials
|
||||
let credentials = typed_header()
|
||||
.and(warp::body::form())
|
||||
// Either from the "Authorization" header
|
||||
.map(|auth: Authorization<Basic>, body: T| {
|
||||
let client_id = auth.0.username().to_string();
|
||||
let client_secret = Some(auth.0.password().to_string());
|
||||
(
|
||||
ClientCredentials::Pair {
|
||||
via: CredentialsVia::AuthorizationHeader,
|
||||
client_id,
|
||||
client_secret,
|
||||
},
|
||||
body,
|
||||
)
|
||||
})
|
||||
// Or from the form body
|
||||
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
|
||||
let ClientAuthForm { credentials, body } = form;
|
||||
|
||||
(credentials, body)
|
||||
}))
|
||||
.unify()
|
||||
.untuple_one();
|
||||
|
||||
warp::any()
|
||||
.and(connection(pool))
|
||||
.and(warp::any().map(move || encrypter.clone()))
|
||||
.and(warp::any().map(move || audience.clone()))
|
||||
.and(credentials)
|
||||
.and_then(authenticate_client)
|
||||
.untuple_one()
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
enum ClientAuthenticationError {
|
||||
#[error("wrong client secret for client {client_id:?}")]
|
||||
ClientSecretMismatch { client_id: String },
|
||||
|
||||
#[error("could not fetch client {client_id:?}")]
|
||||
ClientFetch {
|
||||
client_id: String,
|
||||
source: ClientFetchError,
|
||||
},
|
||||
|
||||
#[error("client {client_id:?} has an invalid client secret")]
|
||||
InvalidClientSecret {
|
||||
client_id: String,
|
||||
source: anyhow::Error,
|
||||
},
|
||||
|
||||
#[error("client {client_id:?} has an invalid JWKS")]
|
||||
InvalidJwks { client_id: String },
|
||||
|
||||
#[error("wrong client authentication method for client {client_id:?}")]
|
||||
WrongAuthenticationMethod { client_id: String },
|
||||
|
||||
#[error("wrong audience in client assertion: expected {expected:?}")]
|
||||
MissingAudience { expected: String },
|
||||
|
||||
#[error("invalid client assertion")]
|
||||
InvalidAssertion,
|
||||
}
|
||||
|
||||
impl Reject for ClientAuthenticationError {}
|
||||
|
||||
fn decrypt_client_secret<T: StorageBackend>(
|
||||
client: &Client<T>,
|
||||
encrypter: &Encrypter,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let encrypted_client_secret = client
|
||||
.encrypted_client_secret
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing encrypted_client_secret field"))?;
|
||||
|
||||
let encrypted_client_secret = BASE64.decode(encrypted_client_secret.as_bytes())?;
|
||||
|
||||
let nonce: &[u8; 12] = encrypted_client_secret
|
||||
.get(0..12)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
|
||||
.try_into()?;
|
||||
|
||||
let payload = encrypted_client_secret
|
||||
.get(12..)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
|
||||
|
||||
let decrypted_client_secret = encrypter.decrypt(nonce, payload)?;
|
||||
|
||||
Ok(decrypted_client_secret)
|
||||
}
|
||||
|
||||
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||
// Assert that the output is both a VerifyingKeystore and Send
|
||||
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
|
||||
t
|
||||
}
|
||||
|
||||
let inner = match jwks {
|
||||
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
|
||||
JwksOrJwksUri::JwksUri(uri) => {
|
||||
let uri = uri.clone();
|
||||
|
||||
// TODO: get the client from somewhere else?
|
||||
let exporter = mas_http::client("fetch-jwks")
|
||||
.json::<JsonWebKeySet>()
|
||||
.map_request(move |_: ()| {
|
||||
http::Request::builder()
|
||||
.method("GET")
|
||||
// TODO: change the Uri type in config to avoid reparsing here
|
||||
.uri(uri.to_string())
|
||||
.body(http_body::Empty::new())
|
||||
.unwrap()
|
||||
})
|
||||
.map_response(http::Response::into_body)
|
||||
.map_err(BoxError::from)
|
||||
.boxed_clone();
|
||||
|
||||
Either::Right(DynamicJwksStore::new(exporter))
|
||||
}
|
||||
};
|
||||
|
||||
assert(inner)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
|
||||
async fn authenticate_client<T>(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
encrypter: Encrypter,
|
||||
audience: String,
|
||||
credentials: ClientCredentials,
|
||||
body: T,
|
||||
) -> Result<
|
||||
(
|
||||
OAuthClientAuthenticationMethod,
|
||||
Client<PostgresqlBackend>,
|
||||
T,
|
||||
),
|
||||
Rejection,
|
||||
> {
|
||||
let (auth_method, client) = match credentials {
|
||||
ClientCredentials::Pair {
|
||||
client_id,
|
||||
client_secret,
|
||||
via,
|
||||
} => {
|
||||
let client = lookup_client_by_client_id(&mut *conn, &client_id)
|
||||
.await
|
||||
.map_err(|source| ClientAuthenticationError::ClientFetch {
|
||||
client_id: client_id.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
let auth_method = client.token_endpoint_auth_method.ok_or(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id.clone(),
|
||||
},
|
||||
)?;
|
||||
|
||||
// Let's match the authentication method
|
||||
match (auth_method, client_secret, via) {
|
||||
(OAuthClientAuthenticationMethod::None, None, _) => {}
|
||||
(
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic,
|
||||
Some(client_secret),
|
||||
CredentialsVia::AuthorizationHeader,
|
||||
)
|
||||
| (
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost,
|
||||
Some(client_secret),
|
||||
CredentialsVia::FormBody,
|
||||
) => {
|
||||
let decrypted =
|
||||
decrypt_client_secret(&client, &encrypter).map_err(|source| {
|
||||
ClientAuthenticationError::InvalidClientSecret {
|
||||
client_id: client.client_id.clone(),
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
|
||||
if client_secret.as_bytes() != decrypted {
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::ClientSecretMismatch {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
(auth_method, client)
|
||||
}
|
||||
ClientCredentials::Assertion {
|
||||
client_id,
|
||||
client_assertion_type: ClientAssertionType::JwtBearer,
|
||||
client_assertion,
|
||||
} => {
|
||||
let token: JsonWebTokenParts = client_assertion.parse().wrap_error()?;
|
||||
let decoded: DecodedJsonWebToken<HashMap<String, serde_json::Value>> =
|
||||
token.decode().wrap_error()?;
|
||||
|
||||
let time_options = TimeOptions::default()
|
||||
.freeze()
|
||||
.leeway(chrono::Duration::minutes(1));
|
||||
|
||||
let mut claims = decoded.claims().clone();
|
||||
let iss = ISS.extract_required(&mut claims).wrap_error()?;
|
||||
let sub = SUB.extract_required(&mut claims).wrap_error()?;
|
||||
let aud = AUD.extract_required(&mut claims).wrap_error()?;
|
||||
|
||||
// Validate the times
|
||||
let _exp = EXP
|
||||
.extract_required_with_options(&mut claims, &time_options)
|
||||
.wrap_error()?;
|
||||
let _nbf = NBF
|
||||
.extract_optional_with_options(&mut claims, &time_options)
|
||||
.wrap_error()?;
|
||||
let _iat = IAT
|
||||
.extract_optional_with_options(&mut claims, &time_options)
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: validate the JTI
|
||||
let _jti = JTI.extract_optional(&mut claims).wrap_error()?;
|
||||
|
||||
// client_id might have been passed as parameter. If not, it should be inferred
|
||||
// from the token, as per rfc7521 sec. 4.2
|
||||
let client_id = client_id.as_ref().unwrap_or(&sub);
|
||||
|
||||
let client = lookup_client_by_client_id(&mut *conn, client_id)
|
||||
.await
|
||||
.map_err(|source| ClientAuthenticationError::ClientFetch {
|
||||
client_id: client_id.to_string(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
let auth_method = client.token_endpoint_auth_method.ok_or(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id.clone(),
|
||||
},
|
||||
)?;
|
||||
|
||||
match auth_method {
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt => {
|
||||
let client_secret =
|
||||
decrypt_client_secret(&client, &encrypter).map_err(|source| {
|
||||
ClientAuthenticationError::InvalidClientSecret {
|
||||
client_id: client.client_id.clone(),
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
|
||||
let store = SharedSecret::new(&client_secret);
|
||||
let fut = token.verify(decoded.header(), &store);
|
||||
fut.await.wrap_error()?;
|
||||
}
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
|
||||
let jwks = client.jwks.as_ref().ok_or_else(|| {
|
||||
ClientAuthenticationError::InvalidJwks {
|
||||
client_id: client.client_id.clone(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let store = jwks_key_store(jwks);
|
||||
let fut = token.verify(decoded.header(), &store);
|
||||
fut.await.wrap_error()?;
|
||||
}
|
||||
_ => {
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// rfc7523 sec. 3.3: the audience is the URL being called
|
||||
if !aud.contains(&audience) {
|
||||
return Err(
|
||||
ClientAuthenticationError::MissingAudience { expected: audience }.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
|
||||
// match the client_id
|
||||
if iss != sub || &iss != client_id {
|
||||
return Err(ClientAuthenticationError::InvalidAssertion.into());
|
||||
}
|
||||
|
||||
(auth_method, client)
|
||||
}
|
||||
};
|
||||
|
||||
tracing::Span::current().record("enduser.id", &client.client_id.as_str());
|
||||
|
||||
Ok((auth_method, client, body))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum ClientAssertionType {
|
||||
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
|
||||
JwtBearer,
|
||||
}
|
||||
|
||||
enum CredentialsVia {
|
||||
FormBody,
|
||||
AuthorizationHeader,
|
||||
}
|
||||
|
||||
impl Default for CredentialsVia {
|
||||
fn default() -> Self {
|
||||
Self::FormBody
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ClientCredentials {
|
||||
// Order here is important: serde tries to deserialize enum variants in order, so if "Pair"
|
||||
// was before "Assertion", a client_assertion with a client_id would match the "Pair"
|
||||
// variant first
|
||||
Assertion {
|
||||
client_id: Option<String>,
|
||||
client_assertion_type: ClientAssertionType,
|
||||
client_assertion: String,
|
||||
},
|
||||
Pair {
|
||||
#[serde(skip)]
|
||||
via: CredentialsVia,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClientAuthForm<T> {
|
||||
#[serde(flatten)]
|
||||
credentials: ClientCredentials,
|
||||
|
||||
#[serde(flatten)]
|
||||
body: T,
|
||||
}
|
||||
|
||||
/* TODO: all secrets are broken because there is no way to mock the DB yet
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use headers::authorization::Credentials;
|
||||
use mas_config::{ClientAuthMethodConfig, ConfigurationSection};
|
||||
use mas_jose::{SigningKeystore, StaticKeystore};
|
||||
use serde_json::json;
|
||||
use tower::{Service, ServiceExt};
|
||||
|
||||
use super::*;
|
||||
|
||||
// Long client_secret to support it as a HS512 key
|
||||
const CLIENT_SECRET: &str = "leek2zaeyeb8thai7piehea3vah6ool9oanin9aeraThuci9EeghaekaiD1upe4Quoh7xeMae2meitohj0Waaveiwaorah1yazohr6Vae7iebeiRaWene5IeWeeciezu";
|
||||
|
||||
fn client_private_keystore() -> StaticKeystore {
|
||||
let mut store = StaticKeystore::new();
|
||||
store.add_test_rsa_key().unwrap();
|
||||
store.add_test_ecdsa_key().unwrap();
|
||||
store
|
||||
}
|
||||
|
||||
async fn oauth2_config() -> ClientsConfig {
|
||||
let mut config = ClientsConfig::test();
|
||||
config.push(ClientConfig {
|
||||
client_id: "public".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::None,
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-basic".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretBasic {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-post".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretPost {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-jwt".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-jwt-2".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
|
||||
let store = client_private_keystore();
|
||||
let jwks = (&store).ready().await.unwrap().call(()).await.unwrap();
|
||||
//let jwks = store.export_jwks().await.unwrap();
|
||||
config.push(ClientConfig {
|
||||
client_id: "private-key-jwt".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.clone().into()),
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "private-key-jwt-2".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.into()),
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Form {
|
||||
foo: String,
|
||||
bar: String,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_hs256() {
|
||||
client_secret_jwt("HS256").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_hs384() {
|
||||
client_secret_jwt("HS384").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_hs512() {
|
||||
client_secret_jwt("HS512").await;
|
||||
}
|
||||
|
||||
fn client_claims(
|
||||
client_id: &str,
|
||||
audience: &str,
|
||||
iat: chrono::DateTime<chrono::Utc>,
|
||||
) -> HashMap<String, serde_json::Value> {
|
||||
let mut claims = HashMap::new();
|
||||
let exp = iat + chrono::Duration::minutes(1);
|
||||
|
||||
ISS.insert(&mut claims, client_id).unwrap();
|
||||
SUB.insert(&mut claims, client_id).unwrap();
|
||||
AUD.insert(&mut claims, vec![audience.to_string()]).unwrap();
|
||||
IAT.insert(&mut claims, iat).unwrap();
|
||||
NBF.insert(&mut claims, iat).unwrap();
|
||||
EXP.insert(&mut claims, exp).unwrap();
|
||||
|
||||
claims
|
||||
}
|
||||
|
||||
async fn client_secret_jwt(alg: &str) {
|
||||
let alg = alg.parse().unwrap();
|
||||
let audience = "https://example.com/token";
|
||||
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
|
||||
|
||||
let store = SharedSecret::new(&CLIENT_SECRET);
|
||||
let claims = client_claims("secret-jwt", audience, chrono::Utc::now());
|
||||
let header = store.prepare_header(alg).await.expect("JWT header");
|
||||
let jwt = DecodedJsonWebToken::new(header, claims);
|
||||
let jwt = jwt.sign(&store).await.expect("signed token");
|
||||
let jwt = jwt.serialize();
|
||||
|
||||
// TODO: test failing cases
|
||||
// - expired token
|
||||
// - "not before" in the future
|
||||
// - subject/issuer mismatch
|
||||
// - audience mismatch
|
||||
// - wrong secret/signature
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "secret-jwt",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretJwt);
|
||||
assert_eq!(client.client_id, "secret-jwt");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
|
||||
// Without client_id
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_ok());
|
||||
|
||||
// client_id mismatch
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "secret-jwt-2",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_rs256() {
|
||||
private_key_jwt("RS256").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_rs384() {
|
||||
private_key_jwt("RS384").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_rs512() {
|
||||
private_key_jwt("RS512").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_es256() {
|
||||
private_key_jwt("ES256").await;
|
||||
}
|
||||
|
||||
async fn private_key_jwt(alg: &str) {
|
||||
let alg = alg.parse().unwrap();
|
||||
let audience = "https://example.com/token";
|
||||
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
|
||||
|
||||
let store = client_private_keystore();
|
||||
let claims = client_claims("private-key-jwt", audience, chrono::Utc::now());
|
||||
let header = store.prepare_header(alg).await.expect("JWT header");
|
||||
let jwt = DecodedJsonWebToken::new(header, claims);
|
||||
let jwt = jwt.sign(&store).await.expect("signed token");
|
||||
let jwt = jwt.serialize();
|
||||
|
||||
// TODO: test failing cases
|
||||
// - expired token
|
||||
// - "not before" in the future
|
||||
// - subject/issuer mismatch
|
||||
// - audience mismatch
|
||||
// - wrong secret/signature
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "private-key-jwt",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::PrivateKeyJwt);
|
||||
assert_eq!(client.client_id, "private-key-jwt");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
|
||||
// Without client_id
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_ok());
|
||||
|
||||
// client_id mismatch
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "private-key-jwt-2",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_post() {
|
||||
let filter = client_authentication::<Form>(
|
||||
&oauth2_config().await,
|
||||
"https://example.com/token".to_string(),
|
||||
);
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header(
|
||||
"Content-Type",
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||
)
|
||||
.body(
|
||||
serde_urlencoded::to_string(json!({
|
||||
"client_id": "secret-post",
|
||||
"client_secret": CLIENT_SECRET,
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretPost);
|
||||
assert_eq!(client.client_id, "secret-post");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_basic() {
|
||||
let filter = client_authentication::<Form>(
|
||||
&oauth2_config().await,
|
||||
"https://example.com/token".to_string(),
|
||||
);
|
||||
|
||||
let auth = Authorization::basic("secret-basic", CLIENT_SECRET);
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header(
|
||||
"Content-Type",
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||
)
|
||||
.header("Authorization", auth.0.encode())
|
||||
.body(
|
||||
serde_urlencoded::to_string(json!({
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretBasic);
|
||||
assert_eq!(client.client_id, "secret-basic");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn none() {
|
||||
let filter = client_authentication::<Form>(
|
||||
&oauth2_config().await,
|
||||
"https://example.com/token".to_string(),
|
||||
);
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header(
|
||||
"Content-Type",
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||
)
|
||||
.body(
|
||||
serde_urlencoded::to_string(json!({
|
||||
"client_id": "public",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::None);
|
||||
assert_eq!(client.client_id, "public");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
}
|
||||
*/
|
@ -1,193 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Deal with encrypted cookies
|
||||
|
||||
use std::{convert::Infallible, marker::PhantomData};
|
||||
|
||||
use cookie::{Cookie, SameSite};
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{Header, HeaderValue, SetCookie};
|
||||
use mas_config::Encrypter;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use warp::{
|
||||
reject::{InvalidHeader, MissingCookie, Reject},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use super::none_on_error;
|
||||
use crate::{
|
||||
errors::WrapError,
|
||||
reply::{with_typed_header, WithTypedHeader},
|
||||
};
|
||||
|
||||
/// Unable to decrypt the cookie
|
||||
#[derive(Debug, Error)]
|
||||
pub struct CookieDecryptionError<T: EncryptableCookieValue>(
|
||||
#[source] anyhow::Error,
|
||||
// This [`std::marker::PhantomData`] records what kind of cookie it was trying to save.
|
||||
// This then use when displaying the error.
|
||||
PhantomData<T>,
|
||||
);
|
||||
|
||||
impl<T> Reject for CookieDecryptionError<T> where T: EncryptableCookieValue + 'static {}
|
||||
|
||||
impl<T: EncryptableCookieValue> From<anyhow::Error> for CookieDecryptionError<T> {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self(e, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: EncryptableCookieValue> std::fmt::Display for CookieDecryptionError<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to decrypt cookie {}", T::cookie_key())
|
||||
}
|
||||
}
|
||||
|
||||
fn decryption_error<T>(e: anyhow::Error) -> Rejection
|
||||
where
|
||||
T: EncryptableCookieValue + 'static,
|
||||
{
|
||||
let e: CookieDecryptionError<T> = e.into();
|
||||
warp::reject::custom(e)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct EncryptedCookie {
|
||||
nonce: [u8; 12],
|
||||
ciphertext: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncryptedCookie {
|
||||
/// Encrypt from a given key
|
||||
fn encrypt<T: Serialize>(payload: T, encrypter: &Encrypter) -> anyhow::Result<Self> {
|
||||
let message = bincode::serialize(&payload)?;
|
||||
let nonce: [u8; 12] = rand::random();
|
||||
let ciphertext = encrypter.encrypt(&nonce, &message)?;
|
||||
Ok(Self { nonce, ciphertext })
|
||||
}
|
||||
|
||||
/// Decrypt the content of the cookie from a given key
|
||||
fn decrypt<T: DeserializeOwned>(&self, encrypter: &Encrypter) -> anyhow::Result<T> {
|
||||
let message = encrypter.decrypt(&self.nonce, &self.ciphertext)?;
|
||||
let token = bincode::deserialize(&message)?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Encode the encrypted cookie to be then saved as a cookie
|
||||
fn to_cookie_value(&self) -> anyhow::Result<String> {
|
||||
let raw = bincode::serialize(self)?;
|
||||
Ok(BASE64URL_NOPAD.encode(&raw))
|
||||
}
|
||||
|
||||
fn from_cookie_value(value: &str) -> anyhow::Result<Self> {
|
||||
let raw = BASE64URL_NOPAD.decode(value.as_bytes())?;
|
||||
let content = bincode::deserialize(&raw)?;
|
||||
Ok(content)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract an optional encrypted cookie
|
||||
#[must_use]
|
||||
pub fn maybe_encrypted<T>(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (Option<T>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + EncryptableCookieValue + 'static,
|
||||
{
|
||||
encrypted(encrypter)
|
||||
.map(Some)
|
||||
.recover(none_on_error::<T, InvalidHeader>)
|
||||
.unify()
|
||||
.recover(none_on_error::<T, MissingCookie>)
|
||||
.unify()
|
||||
.recover(none_on_error::<T, CookieDecryptionError<T>>)
|
||||
.unify()
|
||||
}
|
||||
|
||||
/// Extract an encrypted cookie
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`warp::reject::MissingCookie`] or a
|
||||
/// [`CookieDecryptionError`]
|
||||
#[must_use]
|
||||
pub fn encrypted<T>(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + EncryptableCookieValue + 'static,
|
||||
{
|
||||
let encrypter = encrypter.clone();
|
||||
warp::cookie::cookie(T::cookie_key()).and_then(move |value: String| {
|
||||
let encrypter = encrypter.clone();
|
||||
async move {
|
||||
let encrypted_payload =
|
||||
EncryptedCookie::from_cookie_value(&value).map_err(decryption_error::<T>)?;
|
||||
let decrypted_payload = encrypted_payload
|
||||
.decrypt(&encrypter)
|
||||
.map_err(decryption_error::<T>)?;
|
||||
Ok::<_, Rejection>(decrypted_payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Get an [`EncryptedCookieSaver`] to help saving an [`EncryptableCookieValue`]
|
||||
#[must_use]
|
||||
pub fn encrypted_cookie_saver(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (EncryptedCookieSaver,), Error = Infallible> + Clone + Send + Sync + 'static
|
||||
{
|
||||
let encrypter = encrypter.clone();
|
||||
warp::any().map(move || EncryptedCookieSaver {
|
||||
encrypter: encrypter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// A cookie that can be encrypted with a well-known cookie key
|
||||
pub trait EncryptableCookieValue: Serialize + Send + Sync + std::fmt::Debug {
|
||||
/// What key should be used for this cookie
|
||||
fn cookie_key() -> &'static str;
|
||||
}
|
||||
|
||||
/// An opaque structure which helps encrypting a cookie and attach it to a reply
|
||||
pub struct EncryptedCookieSaver {
|
||||
encrypter: Encrypter,
|
||||
}
|
||||
|
||||
impl EncryptedCookieSaver {
|
||||
/// Save an [`EncryptableCookieValue`]
|
||||
pub fn save_encrypted<T: EncryptableCookieValue, R: Reply>(
|
||||
&self,
|
||||
cookie: &T,
|
||||
reply: R,
|
||||
) -> Result<WithTypedHeader<R, SetCookie>, Rejection> {
|
||||
let encrypted = EncryptedCookie::encrypt(cookie, &self.encrypter)
|
||||
.wrap_error()?
|
||||
.to_cookie_value()
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: make those options customizable
|
||||
let value = Cookie::build(T::cookie_key(), encrypted)
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Lax)
|
||||
.finish()
|
||||
.to_string();
|
||||
|
||||
let header = SetCookie::decode(&mut [HeaderValue::from_str(&value).wrap_error()?].iter())
|
||||
.wrap_error()?;
|
||||
Ok(with_typed_header(header, reply))
|
||||
}
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Wrapper around [`warp::filters::cors`]
|
||||
|
||||
use std::string::ToString;
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
static PROPAGATOR_HEADERS: OnceCell<Vec<String>> = OnceCell::new();
|
||||
|
||||
/// Notify the CORS filter what opentelemetry propagators are being used. This
|
||||
/// helps whitelisting headers in CORS requests.
|
||||
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
|
||||
let headers = propagator.fields().map(ToString::to_string).collect();
|
||||
tracing::debug!(
|
||||
?headers,
|
||||
"Headers allowed in CORS requests for trace propagators set"
|
||||
);
|
||||
PROPAGATOR_HEADERS
|
||||
.set(headers)
|
||||
.expect(concat!(module_path!(), "::set_propagator was called twice"));
|
||||
}
|
||||
|
||||
/// Create a wrapping filter that exposes CORS behavior for a wrapped filter.
|
||||
#[must_use]
|
||||
pub fn cors() -> warp::filters::cors::Builder {
|
||||
warp::filters::cors::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new()))
|
||||
}
|
@ -1,185 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Stateless CSRF protection middleware based on a chacha20-poly1305 encrypted
|
||||
//! and signed token
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use data_encoding::{DecodeError, BASE64URL_NOPAD};
|
||||
use mas_config::{CsrfConfig, Encrypter};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
use super::cookies::EncryptableCookieValue;
|
||||
|
||||
/// Failed to validate CSRF token
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CsrfError {
|
||||
/// The token in the form did not match the token in the cookie
|
||||
#[error("CSRF token mismatch")]
|
||||
Mismatch,
|
||||
|
||||
/// The token expired
|
||||
#[error("CSRF token expired")]
|
||||
Expired,
|
||||
|
||||
/// Failed to decode the token
|
||||
#[error("could not decode CSRF token")]
|
||||
Decode(#[from] DecodeError),
|
||||
}
|
||||
|
||||
impl Reject for CsrfError {}
|
||||
|
||||
/// A CSRF token
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CsrfToken {
|
||||
#[serde_as(as = "TimestampSeconds<i64>")]
|
||||
expiration: DateTime<Utc>,
|
||||
token: [u8; 32],
|
||||
}
|
||||
|
||||
impl CsrfToken {
|
||||
/// Create a new token from a defined value valid for a specified duration
|
||||
fn new(token: [u8; 32], ttl: Duration) -> Self {
|
||||
let expiration = Utc::now() + ttl;
|
||||
Self { expiration, token }
|
||||
}
|
||||
|
||||
/// Generate a new random token valid for a specified duration
|
||||
fn generate(ttl: Duration) -> Self {
|
||||
let token = rand::random();
|
||||
Self::new(token, ttl)
|
||||
}
|
||||
|
||||
/// Generate a new token with the same value but an up to date expiration
|
||||
fn refresh(self, ttl: Duration) -> Self {
|
||||
Self::new(self.token, ttl)
|
||||
}
|
||||
|
||||
/// Get the value to include in HTML forms
|
||||
#[must_use]
|
||||
pub fn form_value(&self) -> String {
|
||||
BASE64URL_NOPAD.encode(&self.token[..])
|
||||
}
|
||||
|
||||
/// Verifies that the value got from an HTML form matches this token
|
||||
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
|
||||
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
|
||||
if self.token[..] == form_value {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CsrfError::Mismatch)
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_expiration(self) -> Result<Self, CsrfError> {
|
||||
if Utc::now() < self.expiration {
|
||||
Ok(self)
|
||||
} else {
|
||||
Err(CsrfError::Expired)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptableCookieValue for CsrfToken {
|
||||
fn cookie_key() -> &'static str {
|
||||
"csrf"
|
||||
}
|
||||
}
|
||||
|
||||
/// A CSRF-protected form
|
||||
#[derive(Deserialize)]
|
||||
struct CsrfForm<T> {
|
||||
csrf: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T> CsrfForm<T> {
|
||||
fn verify_csrf(self, token: &CsrfToken) -> Result<T, CsrfError> {
|
||||
// Verify CSRF from request
|
||||
token.verify_form_value(&self.csrf)?;
|
||||
Ok(self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
fn csrf_token(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
super::cookies::encrypted(encrypter).and_then(move |token: CsrfToken| async move {
|
||||
let verified = token.verify_expiration()?;
|
||||
Ok::<_, Rejection>(verified)
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract an up-to-date CSRF token to include in forms
|
||||
///
|
||||
/// Routes using this should not forget to reply the updated CSRF cookie using
|
||||
/// an [`EncryptedCookieSaver`][`super::cookies::EncryptedCookieSaver`] obtained
|
||||
/// with [`encrypted_cookie_saver`][`super::cookies::encrypted_cookie_saver`]
|
||||
#[must_use]
|
||||
pub fn updated_csrf_token(
|
||||
encrypter: &Encrypter,
|
||||
csrf_config: &CsrfConfig,
|
||||
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let ttl = csrf_config.ttl;
|
||||
super::cookies::maybe_encrypted(encrypter).and_then(
|
||||
move |maybe_token: Option<CsrfToken>| async move {
|
||||
// Explicitely specify the "Error" type here to have the `?` operation working
|
||||
Ok::<_, Rejection>(
|
||||
maybe_token
|
||||
// Verify its TTL (but do not hard-error if it expired)
|
||||
.and_then(|token| token.verify_expiration().ok())
|
||||
.map_or_else(
|
||||
// Generate a new token if no valid one were found
|
||||
|| CsrfToken::generate(ttl),
|
||||
// Else, refresh the expiration of the token
|
||||
|token| token.refresh(ttl),
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Extract values from a CSRF-protected form
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with:
|
||||
///
|
||||
/// - [`warp::filters::body::BodyDeserializeError`] if the overall form failed
|
||||
/// to decode
|
||||
/// - [`CsrfError`] if the CSRF token was invalid or expired
|
||||
/// - [`warp::reject::MissingCookie`] if the CSRF cookie was missing
|
||||
/// - [`super::cookies::CookieDecryptionError`] if the cookie failed to decrypt
|
||||
///
|
||||
/// TODO: we might want to unify the last three rejections in one
|
||||
#[must_use]
|
||||
pub fn protected_form<T>(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + Send + 'static,
|
||||
{
|
||||
csrf_token(encrypter).and(warp::body::form()).and_then(
|
||||
|csrf_token: CsrfToken, protected_form: CsrfForm<T>| async move {
|
||||
let form = protected_form.verify_csrf(&csrf_token)?;
|
||||
Ok::<_, Rejection>(form)
|
||||
},
|
||||
)
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Database-related filters to grab connections and start transactions from the
|
||||
//! connection pool
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use sqlx::{
|
||||
pool::{Pool, PoolConnection},
|
||||
Database, Transaction,
|
||||
};
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
use crate::errors::WrapError;
|
||||
|
||||
fn with_pool<T: Database>(
|
||||
pool: &Pool<T>,
|
||||
) -> impl Filter<Extract = (Pool<T>,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let pool = pool.clone();
|
||||
warp::any().map(move || pool.clone())
|
||||
}
|
||||
|
||||
/// Acquire a connection to the database
|
||||
pub fn connection<T: Database>(
|
||||
pool: &Pool<T>,
|
||||
) -> impl Filter<Extract = (PoolConnection<T>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
with_pool(pool).and_then(acquire_connection)
|
||||
}
|
||||
|
||||
async fn acquire_connection<T: Database>(pool: Pool<T>) -> Result<PoolConnection<T>, Rejection> {
|
||||
let conn = pool.acquire().await.wrap_error()?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
/// Start a database transaction
|
||||
pub fn transaction<T: Database>(
|
||||
pool: &Pool<T>,
|
||||
) -> impl Filter<Extract = (Transaction<'static, T>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
with_pool(pool).and_then(acquire_transaction)
|
||||
}
|
||||
|
||||
async fn acquire_transaction<T: Database>(
|
||||
pool: Pool<T>,
|
||||
) -> Result<Transaction<'static, T>, Rejection> {
|
||||
let txn = pool.begin().await.wrap_error()?;
|
||||
Ok(txn)
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Deal with typed headers from the [`headers`] crate
|
||||
|
||||
use headers::{Header, HeaderValue};
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
/// Failed to decode typed header
|
||||
#[derive(Debug, Error)]
|
||||
#[error("could not decode header {1}")]
|
||||
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
|
||||
|
||||
impl Reject for InvalidTypedHeader {}
|
||||
|
||||
/// Extract a typed header from the request
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`warp::reject::MissingHeader`] or a
|
||||
/// [`InvalidTypedHeader`].
|
||||
pub fn typed_header<T: Header + Send + 'static>(
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
|
||||
}
|
||||
|
||||
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
|
||||
let mut it = std::iter::once(&header);
|
||||
let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
|
||||
Ok(decoded)
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Set of [`warp`] filters
|
||||
|
||||
#![allow(clippy::unused_async)] // Some warp filters need that
|
||||
#![deny(missing_docs)]
|
||||
|
||||
pub mod authenticate;
|
||||
pub mod client;
|
||||
pub mod cookies;
|
||||
pub mod cors;
|
||||
pub mod csrf;
|
||||
pub mod database;
|
||||
pub mod headers;
|
||||
pub mod session;
|
||||
pub mod trace;
|
||||
pub mod url_builder;
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use mas_templates::Templates;
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
pub use self::csrf::CsrfToken;
|
||||
|
||||
/// Get the [`Templates`]
|
||||
#[must_use]
|
||||
pub fn with_templates(
|
||||
templates: &Templates,
|
||||
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let templates = templates.clone();
|
||||
warp::any().map(move || templates.clone())
|
||||
}
|
||||
|
||||
/// Recover a particular rejection type with a `None` option variant
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// extern crate warp;
|
||||
///
|
||||
/// use warp::{filters::header::header, reject::MissingHeader, Filter};
|
||||
///
|
||||
/// use mas_warp_utils::filters::none_on_error;
|
||||
///
|
||||
/// header("Content-Length")
|
||||
/// .map(Some)
|
||||
/// .recover(none_on_error::<_, MissingHeader>)
|
||||
/// .unify()
|
||||
/// .map(|length: Option<u64>| {
|
||||
/// format!("header: {:?}", length)
|
||||
/// });
|
||||
/// ```
|
||||
pub async fn none_on_error<T, E: 'static>(rejection: Rejection) -> Result<Option<T>, Rejection> {
|
||||
if rejection.find::<E>().is_some() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(rejection)
|
||||
}
|
||||
}
|
@ -1,162 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Load user sessions from the database
|
||||
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_storage::{
|
||||
user::{lookup_active_session, ActiveSessionLookupError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
use warp::{
|
||||
reject::{InvalidHeader, MissingCookie, Reject},
|
||||
Filter, Rejection,
|
||||
};
|
||||
|
||||
use super::{
|
||||
cookies::{encrypted, CookieDecryptionError, EncryptableCookieValue},
|
||||
database::connection,
|
||||
none_on_error,
|
||||
};
|
||||
|
||||
/// The session is missing or failed to load
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SessionLoadError {
|
||||
/// No session cookie was found
|
||||
#[error("missing session cookie")]
|
||||
MissingCookie,
|
||||
|
||||
/// The session cookie is invalid
|
||||
#[error("unable to parse or decrypt session cookie")]
|
||||
InvalidCookie,
|
||||
|
||||
/// The session is unknown or inactive
|
||||
#[error("unknown or inactive session")]
|
||||
UnknownSession,
|
||||
}
|
||||
|
||||
impl Reject for SessionLoadError {}
|
||||
|
||||
/// An encrypted cookie to save the session ID
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SessionCookie {
|
||||
current: i64,
|
||||
}
|
||||
|
||||
impl SessionCookie {
|
||||
/// Forge the cookie from a [`BrowserSession`]
|
||||
#[must_use]
|
||||
pub fn from_session(session: &BrowserSession<PostgresqlBackend>) -> Self {
|
||||
Self {
|
||||
current: session.data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the [`BrowserSession`] from database
|
||||
pub async fn load_session(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> {
|
||||
let res = lookup_active_session(executor, self.current).await?;
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptableCookieValue for SessionCookie {
|
||||
fn cookie_key() -> &'static str {
|
||||
"session"
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract a user session information if logged in
|
||||
#[must_use]
|
||||
pub fn optional_session(
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (Option<BrowserSession<PostgresqlBackend>>,), Error = Rejection>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
session(pool, encrypter)
|
||||
.map(Some)
|
||||
.recover(none_on_error::<_, SessionLoadError>)
|
||||
.unify()
|
||||
}
|
||||
|
||||
/// Extract a user session information, rejecting if not logged in
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This filter will reject with a [`SessionLoadError`] when the session is
|
||||
/// inactive or missing. It will reject with a wrapped error on other database
|
||||
/// failures.
|
||||
#[must_use]
|
||||
pub fn session(
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (BrowserSession<PostgresqlBackend>,), Error = Rejection>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
encrypted(encrypter)
|
||||
.and(connection(pool))
|
||||
.and_then(load_session)
|
||||
.recover(recover)
|
||||
.unify()
|
||||
}
|
||||
|
||||
async fn load_session(
|
||||
session: SessionCookie,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
) -> Result<BrowserSession<PostgresqlBackend>, Rejection> {
|
||||
let session_info = session.load_session(&mut conn).await?;
|
||||
Ok(session_info)
|
||||
}
|
||||
|
||||
/// Recover from expected rejections, to transform them into a
|
||||
/// [`SessionLoadError`]
|
||||
async fn recover<T>(rejection: Rejection) -> Result<T, Rejection> {
|
||||
if let Some(e) = rejection.find::<ActiveSessionLookupError>() {
|
||||
if e.not_found() {
|
||||
return Err(warp::reject::custom(SessionLoadError::UnknownSession));
|
||||
}
|
||||
|
||||
// If we're here, there is a real database error that should be
|
||||
// propagated
|
||||
}
|
||||
|
||||
if let Some(e) = rejection.find::<InvalidHeader>() {
|
||||
if e.name() == "cookie" {
|
||||
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(_e) = rejection.find::<MissingCookie>() {
|
||||
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
|
||||
}
|
||||
|
||||
if let Some(error) = rejection.find::<CookieDecryptionError<SessionCookie>>() {
|
||||
warn!(?error, "could not decrypt session cookie");
|
||||
return Err(warp::reject::custom(SessionLoadError::InvalidCookie));
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Route tracing utility
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use warp::Filter;
|
||||
|
||||
/// Set the name of that route
|
||||
#[must_use]
|
||||
pub fn name(
|
||||
name: &'static str,
|
||||
) -> impl Filter<Extract = (), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
warp::any()
|
||||
.map(move || {
|
||||
// TODO: update_name has a weird signature, which is already fixed in
|
||||
// opentelemetry-rust, just not released yet
|
||||
// TODO: we should find another way to classify requests. Span::update_name has
|
||||
// impacts on sampling and should not be used
|
||||
opentelemetry::trace::get_active_span(|s| s.update_name::<String>(name.to_string()));
|
||||
})
|
||||
.untuple_one()
|
||||
}
|
@ -1,121 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Utility to build URLs
|
||||
|
||||
// TODO: move this somewhere else
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use mas_config::HttpConfig;
|
||||
use url::Url;
|
||||
use warp::Filter;
|
||||
|
||||
impl From<&HttpConfig> for UrlBuilder {
|
||||
fn from(config: &HttpConfig) -> Self {
|
||||
Self::new(config.public_base.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helps building absolute URLs
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct UrlBuilder {
|
||||
base: Url,
|
||||
}
|
||||
|
||||
impl UrlBuilder {
|
||||
/// Create a new [`UrlBuilder`] from a base URL
|
||||
#[must_use]
|
||||
pub fn new(base: Url) -> Self {
|
||||
Self { base }
|
||||
}
|
||||
|
||||
/// OIDC issuer
|
||||
#[must_use]
|
||||
pub fn oidc_issuer(&self) -> Url {
|
||||
self.base.clone()
|
||||
}
|
||||
|
||||
/// OIDC dicovery document URL
|
||||
#[must_use]
|
||||
pub fn oidc_discovery(&self) -> Url {
|
||||
self.base
|
||||
.join(".well-known/openid-configuration")
|
||||
.expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 authorization endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_authorization_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/authorize").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 token endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_token_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/token").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 introspection endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_introspection_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/introspect").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 introspection endpoint
|
||||
#[must_use]
|
||||
pub fn oidc_userinfo_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/userinfo").expect("build URL")
|
||||
}
|
||||
|
||||
/// JWKS URI
|
||||
#[must_use]
|
||||
pub fn jwks_uri(&self) -> Url {
|
||||
self.base.join("oauth2/keys.json").expect("build URL")
|
||||
}
|
||||
|
||||
/// Email verification URL
|
||||
#[must_use]
|
||||
pub fn email_verification(&self, code: &str) -> Url {
|
||||
self.base
|
||||
.join("verify/")
|
||||
.expect("build URL")
|
||||
.join(code)
|
||||
.expect("build URL")
|
||||
}
|
||||
}
|
||||
|
||||
/// Injects an [`UrlBuilder`] to help building absolute URLs
|
||||
#[must_use]
|
||||
pub fn url_builder(
|
||||
config: &HttpConfig,
|
||||
) -> impl Filter<Extract = (UrlBuilder,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let builder: UrlBuilder = config.into();
|
||||
warp::any().map(move || builder.clone())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn build_email_verification_url() {
|
||||
let base = Url::parse("https://example.com/").unwrap();
|
||||
let builder = UrlBuilder::new(base);
|
||||
assert_eq!(
|
||||
builder.email_verification("123456abcdef").as_str(),
|
||||
"https://example.com/verify/123456abcdef"
|
||||
);
|
||||
}
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Various warp filters and replies
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(clippy::all, missing_docs, rustdoc::broken_intra_doc_links)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)]
|
||||
|
||||
pub mod errors;
|
||||
pub mod filters;
|
||||
pub mod reply;
|
@ -1,54 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Reply with a typed header from the [`headers`] crate.
|
||||
//!
|
||||
//! ```rust
|
||||
//! extern crate headers;
|
||||
//! extern crate warp;
|
||||
//!
|
||||
//! use warp::Reply;
|
||||
//! use mas_warp_utils::reply::with_typed_header;
|
||||
//!
|
||||
//! let reply = r#"{"hello": "world"}"#;
|
||||
//! let reply = with_typed_header(headers::ContentType::json(), reply);;
|
||||
//! let response = reply.into_response();
|
||||
//! assert_eq!(response.headers().get("Content-Type").unwrap().to_str().unwrap(), "application/json");
|
||||
//! ```
|
||||
|
||||
use headers::{Header, HeaderMapExt};
|
||||
use warp::Reply;
|
||||
|
||||
/// Add a typed header to a reply
|
||||
pub fn with_typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
|
||||
WithTypedHeader { reply, header }
|
||||
}
|
||||
|
||||
/// A reply with a typed header set
|
||||
pub struct WithTypedHeader<R, H> {
|
||||
reply: R,
|
||||
header: H,
|
||||
}
|
||||
|
||||
impl<R, H> Reply for WithTypedHeader<R, H>
|
||||
where
|
||||
R: Reply,
|
||||
H: Header + Send,
|
||||
{
|
||||
fn into_response(self) -> warp::reply::Response {
|
||||
let mut res = self.reply.into_response();
|
||||
res.headers_mut().typed_insert(self.header);
|
||||
res
|
||||
}
|
||||
}
|
@ -1,21 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Set of wrappers for [`warp::Reply`]
|
||||
|
||||
#![deny(missing_docs)]
|
||||
|
||||
pub mod headers;
|
||||
|
||||
pub use self::headers::{with_typed_header, WithTypedHeader};
|
Reference in New Issue
Block a user