1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Get rid of warp

This commit is contained in:
Quentin Gliech
2022-04-06 15:40:16 +02:00
parent 9cd63f6cf1
commit 4e31fc6c84
30 changed files with 3 additions and 2312 deletions

View File

@ -16,7 +16,6 @@ tower = { version = "0.4.12", features = ["full"] }
hyper = { version = "0.14.17", features = ["full"] }
serde_yaml = "0.8.23"
serde_json = "1.0.79"
warp = "0.3.2"
url = "2.2.2"
argon2 = { version = "0.3.4", features = ["password-hash"] }
reqwest = { version = "0.11.10", features = ["rustls-tls"], default-features = false, optional = true }
@ -42,7 +41,6 @@ mas-http = { path = "../http" }
mas-storage = { path = "../storage" }
mas-tasks = { path = "../tasks" }
mas-templates = { path = "../templates" }
mas-warp-utils = { path = "../warp-utils" }
mas-axum-utils = { path = "../axum-utils" }
[dev-dependencies]

View File

@ -40,7 +40,7 @@ pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
// The CORS filter needs to know what headers it should whitelist for
// CORS-protected requests.
mas_warp_utils::filters::cors::set_propagator(&propagator);
// TODO mas_warp_utils::filters::cors::set_propagator(&propagator);
global::set_text_map_propagator(propagator);
let tracer = tracer(&config.tracing.exporter)?;

View File

@ -20,7 +20,6 @@ thiserror = "1.0.30"
anyhow = "1.0.56"
# Web server
warp = "0.3.2"
hyper = { version = "0.14.17", features = ["full"] }
tower = "0.4.12"
axum = "0.4.8"
@ -67,7 +66,6 @@ mas-jose = { path = "../jose" }
mas-static-files = { path = "../static-files" }
mas-storage = { path = "../storage" }
mas-templates = { path = "../templates" }
mas-warp-utils = { path = "../warp-utils" }
[dev-dependencies]
indoc = "1.0.4"

View File

@ -16,7 +16,7 @@
#![deny(clippy::all, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)]
#![allow(
clippy::unused_async // Some warp filters need that
clippy::unused_async // Some axum handlers need that
)]
use std::sync::Arc;

View File

@ -14,7 +14,6 @@ serde_json = "1.0.79"
thiserror = "1.0.30"
anyhow = "1.0.56"
tracing = "0.1.32"
warp = "0.3.2"
# Password hashing
argon2 = { version = "0.3.4", features = ["password-hash"] }

View File

@ -21,7 +21,6 @@ use oauth2_types::requests::GrantType;
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use url::Url;
use warp::reject::Reject;
use crate::PostgresqlBackend;
@ -79,8 +78,6 @@ impl ClientFetchError {
}
}
impl Reject for ClientFetchError {}
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
type Error = ClientFetchError;

View File

@ -19,7 +19,6 @@ use mas_data_model::{
};
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use warp::reject::Reject;
use super::client::{lookup_client_by_client_id, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
@ -87,8 +86,6 @@ pub enum RefreshTokenLookupError {
Conversion(#[from] DatabaseInconsistencyError),
}
impl Reject for RefreshTokenLookupError {}
impl RefreshTokenLookupError {
#[must_use]
pub fn not_found(&self) -> bool {

View File

@ -27,7 +27,6 @@ use sqlx::{postgres::types::PgInterval, Acquire, PgExecutor, Postgres, Transacti
use thiserror::Error;
use tokio::task;
use tracing::{info_span, Instrument};
use warp::reject::Reject;
use super::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::IdAndCreationTime;
@ -117,8 +116,6 @@ pub enum ActiveSessionLookupError {
Conversion(#[from] DatabaseInconsistencyError),
}
impl Reject for ActiveSessionLookupError {}
impl ActiveSessionLookupError {
#[must_use]
pub fn not_found(&self) -> bool {

View File

@ -21,7 +21,6 @@ serde_json = "1.0.79"
serde_urlencoded = "0.7.1"
url = "2.2.2"
warp = "0.3.2"
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }

View File

@ -271,8 +271,6 @@ pub enum TemplateError {
},
}
impl warp::reject::Reject for TemplateError {}
register_templates! {
extra = {
"components/button.html",

View File

@ -1,42 +0,0 @@
[package]
name = "mas-warp-utils"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
tokio = { version = "1.17.0", features = ["macros"] }
headers = "0.3.7"
cookie = "0.16.0"
warp = "0.3.2"
hyper = { version = "0.14.17", features = ["full"] }
thiserror = "1.0.30"
anyhow = "1.0.56"
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres"] }
chrono = { version = "0.4.19", features = ["serde"] }
serde = { version = "1.0.136", features = ["derive"] }
serde_with = { version = "1.12.0", features = ["hex", "chrono"] }
serde_json = "1.0.79"
serde_urlencoded = "0.7.1"
data-encoding = "2.3.2"
once_cell = "1.10.0"
tracing = "0.1.32"
opentelemetry = "0.17.0"
rand = "0.8.5"
mime = "0.3.16"
bincode = "1.3.3"
crc = "2.1.0"
url = "2.2.2"
http = "0.2.6"
http-body = "0.4.4"
tower = { version = "0.4.12", features = ["util"] }
oauth2-types = { path = "../oauth2-types" }
mas-config = { path = "../config" }
mas-templates = { path = "../templates" }
mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" }
mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" }
mas-http = { path = "../http" }

View File

@ -1,42 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Helper to deal with various unstructured errors in application code
use warp::{reject::Reject, Rejection};
#[derive(Debug)]
pub(crate) struct WrappedError(anyhow::Error);
impl warp::reject::Reject for WrappedError {}
/// Wrap any error in a [`Rejection`]
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
WrappedError(e.into())
}
/// Extension trait that wraps errors in [`Rejection`]s
pub trait WrapError<T> {
/// Wrap transform the [`Result`] error type to a [`Rejection`]
fn wrap_error(self) -> Result<T, Rejection>;
}
impl<T, E> WrapError<T> for Result<T, E>
where
E: Into<anyhow::Error>,
{
fn wrap_error(self) -> Result<T, Rejection> {
self.map_err(|e| warp::reject::custom(WrappedError(e.into())))
}
}

View File

@ -1,156 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Authenticate an endpoint with an access token as bearer authorization token
use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::{AccessToken, Session, TokenFormatError, TokenType};
use mas_storage::{
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
PostgresqlBackend,
};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use thiserror::Error;
use warp::{
reject::{MissingHeader, Reject},
reply::{with_header, with_status},
Filter, Rejection, Reply,
};
use super::{
database::connection,
headers::{typed_header, InvalidTypedHeader},
};
use crate::errors::wrapped_error;
/// Bearer token authentication failed
///
/// This is recoverable with [`recover_unauthorized`]
#[derive(Debug, Error)]
pub enum AuthenticationError {
/// The bearer token has an invalid format
#[error("invalid token format")]
TokenFormat(#[from] TokenFormatError),
/// The bearer token is not an access token
#[error("invalid token type {0:?}, expected an access token")]
WrongTokenType(TokenType),
/// The access token was not found in the database
#[error("unknown token")]
TokenNotFound(#[source] AccessTokenLookupError),
/// The `Authorization` header is missing
#[error("missing authorization header")]
MissingAuthorizationHeader,
/// The `Authorization` header is invalid
#[error("invalid authorization header")]
InvalidAuthorizationHeader,
}
impl Reject for AuthenticationError {}
/// Authenticate a request using an access token as a bearer authorization
///
/// # Rejections
///
/// This can reject with either a [`AuthenticationError`] or with a generic
/// wrapped sqlx error.
#[must_use]
pub fn authentication(
pool: &PgPool,
) -> impl Filter<
Extract = (AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>),
Error = Rejection,
> + Clone
+ Send
+ Sync
+ 'static {
connection(pool)
.and(typed_header())
.and_then(authenticate)
.recover(recover)
.unify()
.untuple_one()
}
fn ensure<T: Clone + Send + Sync + 'static>(t: T) -> T {
t
}
async fn authenticate(
mut conn: PoolConnection<Postgres>,
auth: Authorization<Bearer>,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
let token = auth.0.token();
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
if token_type != TokenType::AccessToken {
return Err(AuthenticationError::WrongTokenType(token_type).into());
}
let (token, session) = lookup_active_access_token(&mut conn, token)
.await
.map_err(|e| {
if e.not_found() {
// This error happens if the token was not found and should be recovered
warp::reject::custom(AuthenticationError::TokenNotFound(e))
} else {
// This is a generic database error that we want to propagate
warp::reject::custom(wrapped_error(e))
}
})?;
let session = ensure(session);
let token = ensure(token);
Ok((token, session))
}
/// Transform the rejections from the [`with_typed_header`] filter
async fn recover(
rejection: Rejection,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
if rejection.find::<MissingHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::MissingAuthorizationHeader,
));
}
if rejection.find::<InvalidTypedHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::InvalidAuthorizationHeader,
));
}
Err(rejection)
}
/// Recover from an [`AuthenticationError`] with a `WWW-Authenticate` header, as
/// per [RFC6750]. This is not intended for user-facing endpoints.
///
/// [RFC6750]: https://www.rfc-editor.org/rfc/rfc6750.html
pub async fn recover_unauthorized(rejection: Rejection) -> Result<Box<dyn Reply>, Rejection> {
if rejection.find::<AuthenticationError>().is_some() {
// TODO: have the issuer/realm here
let reply = "invalid token";
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
return Ok(Box::new(reply));
}
Err(rejection)
}

View File

@ -1,772 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Handle client authentication
use std::collections::HashMap;
use data_encoding::BASE64;
use headers::{authorization::Basic, Authorization};
use mas_config::Encrypter;
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebKeySet, JsonWebTokenParts, SharedSecret,
StaticJwksStore, VerifyingKeystore,
};
use mas_storage::{
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
PostgresqlBackend,
};
use serde::{de::DeserializeOwned, Deserialize};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use thiserror::Error;
use tower::{BoxError, ServiceExt};
use warp::{reject::Reject, Filter, Rejection};
use super::{database::connection, headers::typed_header};
use crate::errors::WrapError;
/// Protect an enpoint with client authentication
#[must_use]
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
pool: &PgPool,
encrypter: &Encrypter,
audience: String,
) -> impl Filter<
Extract = (
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Error = Rejection,
> + Clone
+ Send
+ Sync
+ 'static {
let encrypter = encrypter.clone();
// First, extract the client credentials
let credentials = typed_header()
.and(warp::body::form())
// Either from the "Authorization" header
.map(|auth: Authorization<Basic>, body: T| {
let client_id = auth.0.username().to_string();
let client_secret = Some(auth.0.password().to_string());
(
ClientCredentials::Pair {
via: CredentialsVia::AuthorizationHeader,
client_id,
client_secret,
},
body,
)
})
// Or from the form body
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
let ClientAuthForm { credentials, body } = form;
(credentials, body)
}))
.unify()
.untuple_one();
warp::any()
.and(connection(pool))
.and(warp::any().map(move || encrypter.clone()))
.and(warp::any().map(move || audience.clone()))
.and(credentials)
.and_then(authenticate_client)
.untuple_one()
}
#[derive(Error, Debug)]
enum ClientAuthenticationError {
#[error("wrong client secret for client {client_id:?}")]
ClientSecretMismatch { client_id: String },
#[error("could not fetch client {client_id:?}")]
ClientFetch {
client_id: String,
source: ClientFetchError,
},
#[error("client {client_id:?} has an invalid client secret")]
InvalidClientSecret {
client_id: String,
source: anyhow::Error,
},
#[error("client {client_id:?} has an invalid JWKS")]
InvalidJwks { client_id: String },
#[error("wrong client authentication method for client {client_id:?}")]
WrongAuthenticationMethod { client_id: String },
#[error("wrong audience in client assertion: expected {expected:?}")]
MissingAudience { expected: String },
#[error("invalid client assertion")]
InvalidAssertion,
}
impl Reject for ClientAuthenticationError {}
fn decrypt_client_secret<T: StorageBackend>(
client: &Client<T>,
encrypter: &Encrypter,
) -> anyhow::Result<Vec<u8>> {
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or_else(|| anyhow::anyhow!("missing encrypted_client_secret field"))?;
let encrypted_client_secret = BASE64.decode(encrypted_client_secret.as_bytes())?;
let nonce: &[u8; 12] = encrypted_client_secret
.get(0..12)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
.try_into()?;
let payload = encrypted_client_secret
.get(12..)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
let decrypted_client_secret = encrypter.decrypt(nonce, payload)?;
Ok(decrypted_client_secret)
}
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
// Assert that the output is both a VerifyingKeystore and Send
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
t
}
let inner = match jwks {
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
JwksOrJwksUri::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.json::<JsonWebKeySet>()
.map_request(move |_: ()| {
http::Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
};
assert(inner)
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
async fn authenticate_client<T>(
mut conn: PoolConnection<Postgres>,
encrypter: Encrypter,
audience: String,
credentials: ClientCredentials,
body: T,
) -> Result<
(
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Rejection,
> {
let (auth_method, client) = match credentials {
ClientCredentials::Pair {
client_id,
client_secret,
via,
} => {
let client = lookup_client_by_client_id(&mut *conn, &client_id)
.await
.map_err(|source| ClientAuthenticationError::ClientFetch {
client_id: client_id.clone(),
source,
})?;
let auth_method = client.token_endpoint_auth_method.ok_or(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id.clone(),
},
)?;
// Let's match the authentication method
match (auth_method, client_secret, via) {
(OAuthClientAuthenticationMethod::None, None, _) => {}
(
OAuthClientAuthenticationMethod::ClientSecretBasic,
Some(client_secret),
CredentialsVia::AuthorizationHeader,
)
| (
OAuthClientAuthenticationMethod::ClientSecretPost,
Some(client_secret),
CredentialsVia::FormBody,
) => {
let decrypted =
decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::InvalidClientSecret {
client_id: client.client_id.clone(),
source,
}
})?;
if client_secret.as_bytes() != decrypted {
return Err(warp::reject::custom(
ClientAuthenticationError::ClientSecretMismatch {
client_id: client.client_id,
},
));
}
}
_ => {
return Err(warp::reject::custom(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id,
},
));
}
}
(auth_method, client)
}
ClientCredentials::Assertion {
client_id,
client_assertion_type: ClientAssertionType::JwtBearer,
client_assertion,
} => {
let token: JsonWebTokenParts = client_assertion.parse().wrap_error()?;
let decoded: DecodedJsonWebToken<HashMap<String, serde_json::Value>> =
token.decode().wrap_error()?;
let time_options = TimeOptions::default()
.freeze()
.leeway(chrono::Duration::minutes(1));
let mut claims = decoded.claims().clone();
let iss = ISS.extract_required(&mut claims).wrap_error()?;
let sub = SUB.extract_required(&mut claims).wrap_error()?;
let aud = AUD.extract_required(&mut claims).wrap_error()?;
// Validate the times
let _exp = EXP
.extract_required_with_options(&mut claims, &time_options)
.wrap_error()?;
let _nbf = NBF
.extract_optional_with_options(&mut claims, &time_options)
.wrap_error()?;
let _iat = IAT
.extract_optional_with_options(&mut claims, &time_options)
.wrap_error()?;
// TODO: validate the JTI
let _jti = JTI.extract_optional(&mut claims).wrap_error()?;
// client_id might have been passed as parameter. If not, it should be inferred
// from the token, as per rfc7521 sec. 4.2
let client_id = client_id.as_ref().unwrap_or(&sub);
let client = lookup_client_by_client_id(&mut *conn, client_id)
.await
.map_err(|source| ClientAuthenticationError::ClientFetch {
client_id: client_id.to_string(),
source,
})?;
let auth_method = client.token_endpoint_auth_method.ok_or(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id.clone(),
},
)?;
match auth_method {
OAuthClientAuthenticationMethod::ClientSecretJwt => {
let client_secret =
decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::InvalidClientSecret {
client_id: client.client_id.clone(),
source,
}
})?;
let store = SharedSecret::new(&client_secret);
let fut = token.verify(decoded.header(), &store);
fut.await.wrap_error()?;
}
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
let jwks = client.jwks.as_ref().ok_or_else(|| {
ClientAuthenticationError::InvalidJwks {
client_id: client.client_id.clone(),
}
})?;
let store = jwks_key_store(jwks);
let fut = token.verify(decoded.header(), &store);
fut.await.wrap_error()?;
}
_ => {
return Err(warp::reject::custom(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id,
},
));
}
}
// rfc7523 sec. 3.3: the audience is the URL being called
if !aud.contains(&audience) {
return Err(
ClientAuthenticationError::MissingAudience { expected: audience }.into(),
);
}
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
// match the client_id
if iss != sub || &iss != client_id {
return Err(ClientAuthenticationError::InvalidAssertion.into());
}
(auth_method, client)
}
};
tracing::Span::current().record("enduser.id", &client.client_id.as_str());
Ok((auth_method, client, body))
}
#[derive(Deserialize)]
enum ClientAssertionType {
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
JwtBearer,
}
enum CredentialsVia {
FormBody,
AuthorizationHeader,
}
impl Default for CredentialsVia {
fn default() -> Self {
Self::FormBody
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ClientCredentials {
// Order here is important: serde tries to deserialize enum variants in order, so if "Pair"
// was before "Assertion", a client_assertion with a client_id would match the "Pair"
// variant first
Assertion {
client_id: Option<String>,
client_assertion_type: ClientAssertionType,
client_assertion: String,
},
Pair {
#[serde(skip)]
via: CredentialsVia,
client_id: String,
client_secret: Option<String>,
},
}
#[derive(Deserialize)]
struct ClientAuthForm<T> {
#[serde(flatten)]
credentials: ClientCredentials,
#[serde(flatten)]
body: T,
}
/* TODO: all secrets are broken because there is no way to mock the DB yet
#[cfg(test)]
mod tests {
use headers::authorization::Credentials;
use mas_config::{ClientAuthMethodConfig, ConfigurationSection};
use mas_jose::{SigningKeystore, StaticKeystore};
use serde_json::json;
use tower::{Service, ServiceExt};
use super::*;
// Long client_secret to support it as a HS512 key
const CLIENT_SECRET: &str = "leek2zaeyeb8thai7piehea3vah6ool9oanin9aeraThuci9EeghaekaiD1upe4Quoh7xeMae2meitohj0Waaveiwaorah1yazohr6Vae7iebeiRaWene5IeWeeciezu";
fn client_private_keystore() -> StaticKeystore {
let mut store = StaticKeystore::new();
store.add_test_rsa_key().unwrap();
store.add_test_ecdsa_key().unwrap();
store
}
async fn oauth2_config() -> ClientsConfig {
let mut config = ClientsConfig::test();
config.push(ClientConfig {
client_id: "public".to_string(),
client_auth_method: ClientAuthMethodConfig::None,
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-basic".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretBasic {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-post".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretPost {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-jwt".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-jwt-2".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
let store = client_private_keystore();
let jwks = (&store).ready().await.unwrap().call(()).await.unwrap();
//let jwks = store.export_jwks().await.unwrap();
config.push(ClientConfig {
client_id: "private-key-jwt".to_string(),
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.clone().into()),
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "private-key-jwt-2".to_string(),
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.into()),
redirect_uris: Vec::new(),
});
config
}
#[derive(Deserialize)]
struct Form {
foo: String,
bar: String,
}
#[tokio::test]
async fn client_secret_jwt_hs256() {
client_secret_jwt("HS256").await;
}
#[tokio::test]
async fn client_secret_jwt_hs384() {
client_secret_jwt("HS384").await;
}
#[tokio::test]
async fn client_secret_jwt_hs512() {
client_secret_jwt("HS512").await;
}
fn client_claims(
client_id: &str,
audience: &str,
iat: chrono::DateTime<chrono::Utc>,
) -> HashMap<String, serde_json::Value> {
let mut claims = HashMap::new();
let exp = iat + chrono::Duration::minutes(1);
ISS.insert(&mut claims, client_id).unwrap();
SUB.insert(&mut claims, client_id).unwrap();
AUD.insert(&mut claims, vec![audience.to_string()]).unwrap();
IAT.insert(&mut claims, iat).unwrap();
NBF.insert(&mut claims, iat).unwrap();
EXP.insert(&mut claims, exp).unwrap();
claims
}
async fn client_secret_jwt(alg: &str) {
let alg = alg.parse().unwrap();
let audience = "https://example.com/token";
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
let store = SharedSecret::new(&CLIENT_SECRET);
let claims = client_claims("secret-jwt", audience, chrono::Utc::now());
let header = store.prepare_header(alg).await.expect("JWT header");
let jwt = DecodedJsonWebToken::new(header, claims);
let jwt = jwt.sign(&store).await.expect("signed token");
let jwt = jwt.serialize();
// TODO: test failing cases
// - expired token
// - "not before" in the future
// - subject/issuer mismatch
// - audience mismatch
// - wrong secret/signature
let (auth, client, body) = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_id": "secret-jwt",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretJwt);
assert_eq!(client.client_id, "secret-jwt");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
// Without client_id
let res = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_ok());
// client_id mismatch
let res = warp::test::request()
.method("POST")
.body(serde_urlencoded::to_string(json!({
"client_id": "secret-jwt-2",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn client_secret_jwt_rs256() {
private_key_jwt("RS256").await;
}
#[tokio::test]
async fn client_secret_jwt_rs384() {
private_key_jwt("RS384").await;
}
#[tokio::test]
async fn client_secret_jwt_rs512() {
private_key_jwt("RS512").await;
}
#[tokio::test]
async fn client_secret_jwt_es256() {
private_key_jwt("ES256").await;
}
async fn private_key_jwt(alg: &str) {
let alg = alg.parse().unwrap();
let audience = "https://example.com/token";
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
let store = client_private_keystore();
let claims = client_claims("private-key-jwt", audience, chrono::Utc::now());
let header = store.prepare_header(alg).await.expect("JWT header");
let jwt = DecodedJsonWebToken::new(header, claims);
let jwt = jwt.sign(&store).await.expect("signed token");
let jwt = jwt.serialize();
// TODO: test failing cases
// - expired token
// - "not before" in the future
// - subject/issuer mismatch
// - audience mismatch
// - wrong secret/signature
let (auth, client, body) = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_id": "private-key-jwt",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::PrivateKeyJwt);
assert_eq!(client.client_id, "private-key-jwt");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
// Without client_id
let res = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_ok());
// client_id mismatch
let res = warp::test::request()
.method("POST")
.body(serde_urlencoded::to_string(json!({
"client_id": "private-key-jwt-2",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn client_secret_post() {
let filter = client_authentication::<Form>(
&oauth2_config().await,
"https://example.com/token".to_string(),
);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.body(
serde_urlencoded::to_string(json!({
"client_id": "secret-post",
"client_secret": CLIENT_SECRET,
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretPost);
assert_eq!(client.client_id, "secret-post");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
#[tokio::test]
async fn client_secret_basic() {
let filter = client_authentication::<Form>(
&oauth2_config().await,
"https://example.com/token".to_string(),
);
let auth = Authorization::basic("secret-basic", CLIENT_SECRET);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.header("Authorization", auth.0.encode())
.body(
serde_urlencoded::to_string(json!({
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretBasic);
assert_eq!(client.client_id, "secret-basic");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
#[tokio::test]
async fn none() {
let filter = client_authentication::<Form>(
&oauth2_config().await,
"https://example.com/token".to_string(),
);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.body(
serde_urlencoded::to_string(json!({
"client_id": "public",
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::None);
assert_eq!(client.client_id, "public");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
}
*/

View File

@ -1,193 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Deal with encrypted cookies
use std::{convert::Infallible, marker::PhantomData};
use cookie::{Cookie, SameSite};
use data_encoding::BASE64URL_NOPAD;
use headers::{Header, HeaderValue, SetCookie};
use mas_config::Encrypter;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use thiserror::Error;
use warp::{
reject::{InvalidHeader, MissingCookie, Reject},
Filter, Rejection, Reply,
};
use super::none_on_error;
use crate::{
errors::WrapError,
reply::{with_typed_header, WithTypedHeader},
};
/// Unable to decrypt the cookie
#[derive(Debug, Error)]
pub struct CookieDecryptionError<T: EncryptableCookieValue>(
#[source] anyhow::Error,
// This [`std::marker::PhantomData`] records what kind of cookie it was trying to save.
// This then use when displaying the error.
PhantomData<T>,
);
impl<T> Reject for CookieDecryptionError<T> where T: EncryptableCookieValue + 'static {}
impl<T: EncryptableCookieValue> From<anyhow::Error> for CookieDecryptionError<T> {
fn from(e: anyhow::Error) -> Self {
Self(e, PhantomData)
}
}
impl<T: EncryptableCookieValue> std::fmt::Display for CookieDecryptionError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to decrypt cookie {}", T::cookie_key())
}
}
fn decryption_error<T>(e: anyhow::Error) -> Rejection
where
T: EncryptableCookieValue + 'static,
{
let e: CookieDecryptionError<T> = e.into();
warp::reject::custom(e)
}
#[derive(Serialize, Deserialize)]
struct EncryptedCookie {
nonce: [u8; 12],
ciphertext: Vec<u8>,
}
impl EncryptedCookie {
/// Encrypt from a given key
fn encrypt<T: Serialize>(payload: T, encrypter: &Encrypter) -> anyhow::Result<Self> {
let message = bincode::serialize(&payload)?;
let nonce: [u8; 12] = rand::random();
let ciphertext = encrypter.encrypt(&nonce, &message)?;
Ok(Self { nonce, ciphertext })
}
/// Decrypt the content of the cookie from a given key
fn decrypt<T: DeserializeOwned>(&self, encrypter: &Encrypter) -> anyhow::Result<T> {
let message = encrypter.decrypt(&self.nonce, &self.ciphertext)?;
let token = bincode::deserialize(&message)?;
Ok(token)
}
/// Encode the encrypted cookie to be then saved as a cookie
fn to_cookie_value(&self) -> anyhow::Result<String> {
let raw = bincode::serialize(self)?;
Ok(BASE64URL_NOPAD.encode(&raw))
}
fn from_cookie_value(value: &str) -> anyhow::Result<Self> {
let raw = BASE64URL_NOPAD.decode(value.as_bytes())?;
let content = bincode::deserialize(&raw)?;
Ok(content)
}
}
/// Extract an optional encrypted cookie
#[must_use]
pub fn maybe_encrypted<T>(
encrypter: &Encrypter,
) -> impl Filter<Extract = (Option<T>,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + EncryptableCookieValue + 'static,
{
encrypted(encrypter)
.map(Some)
.recover(none_on_error::<T, InvalidHeader>)
.unify()
.recover(none_on_error::<T, MissingCookie>)
.unify()
.recover(none_on_error::<T, CookieDecryptionError<T>>)
.unify()
}
/// Extract an encrypted cookie
///
/// # Rejections
///
/// This can reject with either a [`warp::reject::MissingCookie`] or a
/// [`CookieDecryptionError`]
#[must_use]
pub fn encrypted<T>(
encrypter: &Encrypter,
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + EncryptableCookieValue + 'static,
{
let encrypter = encrypter.clone();
warp::cookie::cookie(T::cookie_key()).and_then(move |value: String| {
let encrypter = encrypter.clone();
async move {
let encrypted_payload =
EncryptedCookie::from_cookie_value(&value).map_err(decryption_error::<T>)?;
let decrypted_payload = encrypted_payload
.decrypt(&encrypter)
.map_err(decryption_error::<T>)?;
Ok::<_, Rejection>(decrypted_payload)
}
})
}
/// Get an [`EncryptedCookieSaver`] to help saving an [`EncryptableCookieValue`]
#[must_use]
pub fn encrypted_cookie_saver(
encrypter: &Encrypter,
) -> impl Filter<Extract = (EncryptedCookieSaver,), Error = Infallible> + Clone + Send + Sync + 'static
{
let encrypter = encrypter.clone();
warp::any().map(move || EncryptedCookieSaver {
encrypter: encrypter.clone(),
})
}
/// A cookie that can be encrypted with a well-known cookie key
pub trait EncryptableCookieValue: Serialize + Send + Sync + std::fmt::Debug {
/// What key should be used for this cookie
fn cookie_key() -> &'static str;
}
/// An opaque structure which helps encrypting a cookie and attach it to a reply
pub struct EncryptedCookieSaver {
encrypter: Encrypter,
}
impl EncryptedCookieSaver {
/// Save an [`EncryptableCookieValue`]
pub fn save_encrypted<T: EncryptableCookieValue, R: Reply>(
&self,
cookie: &T,
reply: R,
) -> Result<WithTypedHeader<R, SetCookie>, Rejection> {
let encrypted = EncryptedCookie::encrypt(cookie, &self.encrypter)
.wrap_error()?
.to_cookie_value()
.wrap_error()?;
// TODO: make those options customizable
let value = Cookie::build(T::cookie_key(), encrypted)
.http_only(true)
.same_site(SameSite::Lax)
.finish()
.to_string();
let header = SetCookie::decode(&mut [HeaderValue::from_str(&value).wrap_error()?].iter())
.wrap_error()?;
Ok(with_typed_header(header, reply))
}
}

View File

@ -1,42 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Wrapper around [`warp::filters::cors`]
use std::string::ToString;
use once_cell::sync::OnceCell;
static PROPAGATOR_HEADERS: OnceCell<Vec<String>> = OnceCell::new();
/// Notify the CORS filter what opentelemetry propagators are being used. This
/// helps whitelisting headers in CORS requests.
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
let headers = propagator.fields().map(ToString::to_string).collect();
tracing::debug!(
?headers,
"Headers allowed in CORS requests for trace propagators set"
);
PROPAGATOR_HEADERS
.set(headers)
.expect(concat!(module_path!(), "::set_propagator was called twice"));
}
/// Create a wrapping filter that exposes CORS behavior for a wrapped filter.
#[must_use]
pub fn cors() -> warp::filters::cors::Builder {
warp::filters::cors::cors()
.allow_any_origin()
.allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new()))
}

View File

@ -1,185 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Stateless CSRF protection middleware based on a chacha20-poly1305 encrypted
//! and signed token
use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use mas_config::{CsrfConfig, Encrypter};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection};
use super::cookies::EncryptableCookieValue;
/// Failed to validate CSRF token
#[derive(Debug, Error)]
pub enum CsrfError {
/// The token in the form did not match the token in the cookie
#[error("CSRF token mismatch")]
Mismatch,
/// The token expired
#[error("CSRF token expired")]
Expired,
/// Failed to decode the token
#[error("could not decode CSRF token")]
Decode(#[from] DecodeError),
}
impl Reject for CsrfError {}
/// A CSRF token
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct CsrfToken {
#[serde_as(as = "TimestampSeconds<i64>")]
expiration: DateTime<Utc>,
token: [u8; 32],
}
impl CsrfToken {
/// Create a new token from a defined value valid for a specified duration
fn new(token: [u8; 32], ttl: Duration) -> Self {
let expiration = Utc::now() + ttl;
Self { expiration, token }
}
/// Generate a new random token valid for a specified duration
fn generate(ttl: Duration) -> Self {
let token = rand::random();
Self::new(token, ttl)
}
/// Generate a new token with the same value but an up to date expiration
fn refresh(self, ttl: Duration) -> Self {
Self::new(self.token, ttl)
}
/// Get the value to include in HTML forms
#[must_use]
pub fn form_value(&self) -> String {
BASE64URL_NOPAD.encode(&self.token[..])
}
/// Verifies that the value got from an HTML form matches this token
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
if self.token[..] == form_value {
Ok(())
} else {
Err(CsrfError::Mismatch)
}
}
fn verify_expiration(self) -> Result<Self, CsrfError> {
if Utc::now() < self.expiration {
Ok(self)
} else {
Err(CsrfError::Expired)
}
}
}
impl EncryptableCookieValue for CsrfToken {
fn cookie_key() -> &'static str {
"csrf"
}
}
/// A CSRF-protected form
#[derive(Deserialize)]
struct CsrfForm<T> {
csrf: String,
#[serde(flatten)]
inner: T,
}
impl<T> CsrfForm<T> {
fn verify_csrf(self, token: &CsrfToken) -> Result<T, CsrfError> {
// Verify CSRF from request
token.verify_form_value(&self.csrf)?;
Ok(self.inner)
}
}
fn csrf_token(
encrypter: &Encrypter,
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
super::cookies::encrypted(encrypter).and_then(move |token: CsrfToken| async move {
let verified = token.verify_expiration()?;
Ok::<_, Rejection>(verified)
})
}
/// Extract an up-to-date CSRF token to include in forms
///
/// Routes using this should not forget to reply the updated CSRF cookie using
/// an [`EncryptedCookieSaver`][`super::cookies::EncryptedCookieSaver`] obtained
/// with [`encrypted_cookie_saver`][`super::cookies::encrypted_cookie_saver`]
#[must_use]
pub fn updated_csrf_token(
encrypter: &Encrypter,
csrf_config: &CsrfConfig,
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
let ttl = csrf_config.ttl;
super::cookies::maybe_encrypted(encrypter).and_then(
move |maybe_token: Option<CsrfToken>| async move {
// Explicitely specify the "Error" type here to have the `?` operation working
Ok::<_, Rejection>(
maybe_token
// Verify its TTL (but do not hard-error if it expired)
.and_then(|token| token.verify_expiration().ok())
.map_or_else(
// Generate a new token if no valid one were found
|| CsrfToken::generate(ttl),
// Else, refresh the expiration of the token
|token| token.refresh(ttl),
),
)
},
)
}
/// Extract values from a CSRF-protected form
///
/// # Rejections
///
/// This can reject with:
///
/// - [`warp::filters::body::BodyDeserializeError`] if the overall form failed
/// to decode
/// - [`CsrfError`] if the CSRF token was invalid or expired
/// - [`warp::reject::MissingCookie`] if the CSRF cookie was missing
/// - [`super::cookies::CookieDecryptionError`] if the cookie failed to decrypt
///
/// TODO: we might want to unify the last three rejections in one
#[must_use]
pub fn protected_form<T>(
encrypter: &Encrypter,
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + Send + 'static,
{
csrf_token(encrypter).and(warp::body::form()).and_then(
|csrf_token: CsrfToken, protected_form: CsrfForm<T>| async move {
let form = protected_form.verify_csrf(&csrf_token)?;
Ok::<_, Rejection>(form)
},
)
}

View File

@ -1,61 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Database-related filters to grab connections and start transactions from the
//! connection pool
use std::convert::Infallible;
use sqlx::{
pool::{Pool, PoolConnection},
Database, Transaction,
};
use warp::{Filter, Rejection};
use crate::errors::WrapError;
fn with_pool<T: Database>(
pool: &Pool<T>,
) -> impl Filter<Extract = (Pool<T>,), Error = Infallible> + Clone + Send + Sync + 'static {
let pool = pool.clone();
warp::any().map(move || pool.clone())
}
/// Acquire a connection to the database
pub fn connection<T: Database>(
pool: &Pool<T>,
) -> impl Filter<Extract = (PoolConnection<T>,), Error = Rejection> + Clone + Send + Sync + 'static
{
with_pool(pool).and_then(acquire_connection)
}
async fn acquire_connection<T: Database>(pool: Pool<T>) -> Result<PoolConnection<T>, Rejection> {
let conn = pool.acquire().await.wrap_error()?;
Ok(conn)
}
/// Start a database transaction
pub fn transaction<T: Database>(
pool: &Pool<T>,
) -> impl Filter<Extract = (Transaction<'static, T>,), Error = Rejection> + Clone + Send + Sync + 'static
{
with_pool(pool).and_then(acquire_transaction)
}
async fn acquire_transaction<T: Database>(
pool: Pool<T>,
) -> Result<Transaction<'static, T>, Rejection> {
let txn = pool.begin().await.wrap_error()?;
Ok(txn)
}

View File

@ -1,43 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Deal with typed headers from the [`headers`] crate
use headers::{Header, HeaderValue};
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection};
/// Failed to decode typed header
#[derive(Debug, Error)]
#[error("could not decode header {1}")]
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
impl Reject for InvalidTypedHeader {}
/// Extract a typed header from the request
///
/// # Rejections
///
/// This can reject with either a [`warp::reject::MissingHeader`] or a
/// [`InvalidTypedHeader`].
pub fn typed_header<T: Header + Send + 'static>(
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
}
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
let mut it = std::iter::once(&header);
let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
Ok(decoded)
}

View File

@ -1,72 +0,0 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Set of [`warp`] filters
#![allow(clippy::unused_async)] // Some warp filters need that
#![deny(missing_docs)]
pub mod authenticate;
pub mod client;
pub mod cookies;
pub mod cors;
pub mod csrf;
pub mod database;
pub mod headers;
pub mod session;
pub mod trace;
pub mod url_builder;
use std::convert::Infallible;
use mas_templates::Templates;
use warp::{Filter, Rejection};
pub use self::csrf::CsrfToken;
/// Get the [`Templates`]
#[must_use]
pub fn with_templates(
templates: &Templates,
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
let templates = templates.clone();
warp::any().map(move || templates.clone())
}
/// Recover a particular rejection type with a `None` option variant
///
/// # Example
///
/// ```rust
/// extern crate warp;
///
/// use warp::{filters::header::header, reject::MissingHeader, Filter};
///
/// use mas_warp_utils::filters::none_on_error;
///
/// header("Content-Length")
/// .map(Some)
/// .recover(none_on_error::<_, MissingHeader>)
/// .unify()
/// .map(|length: Option<u64>| {
/// format!("header: {:?}", length)
/// });
/// ```
pub async fn none_on_error<T, E: 'static>(rejection: Rejection) -> Result<Option<T>, Rejection> {
if rejection.find::<E>().is_some() {
Ok(None)
} else {
Err(rejection)
}
}

View File

@ -1,162 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Load user sessions from the database
use mas_config::Encrypter;
use mas_data_model::BrowserSession;
use mas_storage::{
user::{lookup_active_session, ActiveSessionLookupError},
PostgresqlBackend,
};
use serde::{Deserialize, Serialize};
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
use thiserror::Error;
use tracing::warn;
use warp::{
reject::{InvalidHeader, MissingCookie, Reject},
Filter, Rejection,
};
use super::{
cookies::{encrypted, CookieDecryptionError, EncryptableCookieValue},
database::connection,
none_on_error,
};
/// The session is missing or failed to load
#[derive(Error, Debug)]
pub enum SessionLoadError {
/// No session cookie was found
#[error("missing session cookie")]
MissingCookie,
/// The session cookie is invalid
#[error("unable to parse or decrypt session cookie")]
InvalidCookie,
/// The session is unknown or inactive
#[error("unknown or inactive session")]
UnknownSession,
}
impl Reject for SessionLoadError {}
/// An encrypted cookie to save the session ID
#[derive(Serialize, Deserialize, Debug)]
pub struct SessionCookie {
current: i64,
}
impl SessionCookie {
/// Forge the cookie from a [`BrowserSession`]
#[must_use]
pub fn from_session(session: &BrowserSession<PostgresqlBackend>) -> Self {
Self {
current: session.data,
}
}
/// Load the [`BrowserSession`] from database
pub async fn load_session(
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> {
let res = lookup_active_session(executor, self.current).await?;
Ok(res)
}
}
impl EncryptableCookieValue for SessionCookie {
fn cookie_key() -> &'static str {
"session"
}
}
/// Extract a user session information if logged in
#[must_use]
pub fn optional_session(
pool: &PgPool,
encrypter: &Encrypter,
) -> impl Filter<Extract = (Option<BrowserSession<PostgresqlBackend>>,), Error = Rejection>
+ Clone
+ Send
+ Sync
+ 'static {
session(pool, encrypter)
.map(Some)
.recover(none_on_error::<_, SessionLoadError>)
.unify()
}
/// Extract a user session information, rejecting if not logged in
///
/// # Rejections
///
/// This filter will reject with a [`SessionLoadError`] when the session is
/// inactive or missing. It will reject with a wrapped error on other database
/// failures.
#[must_use]
pub fn session(
pool: &PgPool,
encrypter: &Encrypter,
) -> impl Filter<Extract = (BrowserSession<PostgresqlBackend>,), Error = Rejection>
+ Clone
+ Send
+ Sync
+ 'static {
encrypted(encrypter)
.and(connection(pool))
.and_then(load_session)
.recover(recover)
.unify()
}
async fn load_session(
session: SessionCookie,
mut conn: PoolConnection<Postgres>,
) -> Result<BrowserSession<PostgresqlBackend>, Rejection> {
let session_info = session.load_session(&mut conn).await?;
Ok(session_info)
}
/// Recover from expected rejections, to transform them into a
/// [`SessionLoadError`]
async fn recover<T>(rejection: Rejection) -> Result<T, Rejection> {
if let Some(e) = rejection.find::<ActiveSessionLookupError>() {
if e.not_found() {
return Err(warp::reject::custom(SessionLoadError::UnknownSession));
}
// If we're here, there is a real database error that should be
// propagated
}
if let Some(e) = rejection.find::<InvalidHeader>() {
if e.name() == "cookie" {
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
}
}
if let Some(_e) = rejection.find::<MissingCookie>() {
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
}
if let Some(error) = rejection.find::<CookieDecryptionError<SessionCookie>>() {
warn!(?error, "could not decrypt session cookie");
return Err(warp::reject::custom(SessionLoadError::InvalidCookie));
}
Err(rejection)
}

View File

@ -1,35 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Route tracing utility
use std::convert::Infallible;
use warp::Filter;
/// Set the name of that route
#[must_use]
pub fn name(
name: &'static str,
) -> impl Filter<Extract = (), Error = Infallible> + Clone + Send + Sync + 'static {
warp::any()
.map(move || {
// TODO: update_name has a weird signature, which is already fixed in
// opentelemetry-rust, just not released yet
// TODO: we should find another way to classify requests. Span::update_name has
// impacts on sampling and should not be used
opentelemetry::trace::get_active_span(|s| s.update_name::<String>(name.to_string()));
})
.untuple_one()
}

View File

@ -1,121 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Utility to build URLs
// TODO: move this somewhere else
use std::convert::Infallible;
use mas_config::HttpConfig;
use url::Url;
use warp::Filter;
impl From<&HttpConfig> for UrlBuilder {
fn from(config: &HttpConfig) -> Self {
Self::new(config.public_base.clone())
}
}
/// Helps building absolute URLs
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct UrlBuilder {
base: Url,
}
impl UrlBuilder {
/// Create a new [`UrlBuilder`] from a base URL
#[must_use]
pub fn new(base: Url) -> Self {
Self { base }
}
/// OIDC issuer
#[must_use]
pub fn oidc_issuer(&self) -> Url {
self.base.clone()
}
/// OIDC dicovery document URL
#[must_use]
pub fn oidc_discovery(&self) -> Url {
self.base
.join(".well-known/openid-configuration")
.expect("build URL")
}
/// OAuth 2.0 authorization endpoint
#[must_use]
pub fn oauth_authorization_endpoint(&self) -> Url {
self.base.join("oauth2/authorize").expect("build URL")
}
/// OAuth 2.0 token endpoint
#[must_use]
pub fn oauth_token_endpoint(&self) -> Url {
self.base.join("oauth2/token").expect("build URL")
}
/// OAuth 2.0 introspection endpoint
#[must_use]
pub fn oauth_introspection_endpoint(&self) -> Url {
self.base.join("oauth2/introspect").expect("build URL")
}
/// OAuth 2.0 introspection endpoint
#[must_use]
pub fn oidc_userinfo_endpoint(&self) -> Url {
self.base.join("oauth2/userinfo").expect("build URL")
}
/// JWKS URI
#[must_use]
pub fn jwks_uri(&self) -> Url {
self.base.join("oauth2/keys.json").expect("build URL")
}
/// Email verification URL
#[must_use]
pub fn email_verification(&self, code: &str) -> Url {
self.base
.join("verify/")
.expect("build URL")
.join(code)
.expect("build URL")
}
}
/// Injects an [`UrlBuilder`] to help building absolute URLs
#[must_use]
pub fn url_builder(
config: &HttpConfig,
) -> impl Filter<Extract = (UrlBuilder,), Error = Infallible> + Clone + Send + Sync + 'static {
let builder: UrlBuilder = config.into();
warp::any().map(move || builder.clone())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_email_verification_url() {
let base = Url::parse("https://example.com/").unwrap();
let builder = UrlBuilder::new(base);
assert_eq!(
builder.email_verification("123456abcdef").as_str(),
"https://example.com/verify/123456abcdef"
);
}
}

View File

@ -1,24 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Various warp filters and replies
#![forbid(unsafe_code)]
#![deny(clippy::all, missing_docs, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)]
pub mod errors;
pub mod filters;
pub mod reply;

View File

@ -1,54 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Reply with a typed header from the [`headers`] crate.
//!
//! ```rust
//! extern crate headers;
//! extern crate warp;
//!
//! use warp::Reply;
//! use mas_warp_utils::reply::with_typed_header;
//!
//! let reply = r#"{"hello": "world"}"#;
//! let reply = with_typed_header(headers::ContentType::json(), reply);;
//! let response = reply.into_response();
//! assert_eq!(response.headers().get("Content-Type").unwrap().to_str().unwrap(), "application/json");
//! ```
use headers::{Header, HeaderMapExt};
use warp::Reply;
/// Add a typed header to a reply
pub fn with_typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
WithTypedHeader { reply, header }
}
/// A reply with a typed header set
pub struct WithTypedHeader<R, H> {
reply: R,
header: H,
}
impl<R, H> Reply for WithTypedHeader<R, H>
where
R: Reply,
H: Header + Send,
{
fn into_response(self) -> warp::reply::Response {
let mut res = self.reply.into_response();
res.headers_mut().typed_insert(self.header);
res
}
}

View File

@ -1,21 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Set of wrappers for [`warp::Reply`]
#![deny(missing_docs)]
pub mod headers;
pub use self::headers::{with_typed_header, WithTypedHeader};