1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Get rid of warp

This commit is contained in:
Quentin Gliech
2022-04-06 15:40:16 +02:00
parent 9cd63f6cf1
commit 4e31fc6c84
30 changed files with 3 additions and 2312 deletions

168
Cargo.lock generated
View File

@ -672,16 +672,6 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "buf_redux"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b953a6887648bb07a535631f2bc00fbdb2a2216f135552cb3f534ed136b9c07f"
dependencies = [
"memchr",
"safemem",
]
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
version = "3.9.1" version = "3.9.1"
@ -1992,7 +1982,6 @@ dependencies = [
"mas-storage", "mas-storage",
"mas-tasks", "mas-tasks",
"mas-templates", "mas-templates",
"mas-warp-utils",
"opentelemetry", "opentelemetry",
"opentelemetry-jaeger", "opentelemetry-jaeger",
"opentelemetry-otlp", "opentelemetry-otlp",
@ -2010,7 +1999,6 @@ dependencies = [
"tracing-opentelemetry", "tracing-opentelemetry",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
"warp",
"watchman_client", "watchman_client",
] ]
@ -2102,7 +2090,6 @@ dependencies = [
"mas-static-files", "mas-static-files",
"mas-storage", "mas-storage",
"mas-templates", "mas-templates",
"mas-warp-utils",
"mime", "mime",
"oauth2-types", "oauth2-types",
"pkcs8", "pkcs8",
@ -2119,7 +2106,6 @@ dependencies = [
"tower", "tower",
"tracing", "tracing",
"url", "url",
"warp",
] ]
[[package]] [[package]]
@ -2240,7 +2226,6 @@ dependencies = [
"tokio", "tokio",
"tracing", "tracing",
"url", "url",
"warp",
] ]
[[package]] [[package]]
@ -2272,46 +2257,6 @@ dependencies = [
"tokio", "tokio",
"tracing", "tracing",
"url", "url",
"warp",
]
[[package]]
name = "mas-warp-utils"
version = "0.1.0"
dependencies = [
"anyhow",
"bincode",
"chrono",
"cookie",
"crc",
"data-encoding",
"headers",
"http",
"http-body",
"hyper",
"mas-config",
"mas-data-model",
"mas-http",
"mas-iana",
"mas-jose",
"mas-storage",
"mas-templates",
"mime",
"oauth2-types",
"once_cell",
"opentelemetry",
"rand",
"serde",
"serde_json",
"serde_urlencoded",
"serde_with",
"sqlx",
"thiserror",
"tokio",
"tower",
"tracing",
"url",
"warp",
] ]
[[package]] [[package]]
@ -2418,24 +2363,6 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "multipart"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182"
dependencies = [
"buf_redux",
"httparse",
"log",
"mime",
"mime_guess",
"quick-error",
"rand",
"safemem",
"tempfile",
"twoway",
]
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.0" version = "7.1.0"
@ -3120,12 +3047,6 @@ dependencies = [
"prost", "prost",
] ]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.15" version = "1.0.15"
@ -3436,12 +3357,6 @@ version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
[[package]]
name = "safemem"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072"
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@ -3487,12 +3402,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "scoped-tls"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2"
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.1.0" version = "1.1.0"
@ -4206,19 +4115,6 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "511de3f85caf1c98983545490c3d09685fa8eb634e57eec22bb4db271f46cbd8"
dependencies = [
"futures-util",
"log",
"pin-project",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.6.9" version = "0.6.9"
@ -4451,34 +4347,6 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0b2d8558abd2e276b0a8df5c05a2ec762609344191e5fd23e292c910e9165b5"
dependencies = [
"base64",
"byteorder",
"bytes 1.1.0",
"http",
"httparse",
"log",
"rand",
"sha-1 0.9.8",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "twoway"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "typed-builder" name = "typed-builder"
version = "0.9.1" version = "0.9.1"
@ -4644,12 +4512,6 @@ version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a1f0175e03a0973cf4afd476bef05c26e228520400eb1fd473ad417b1c00ffb" checksum = "5a1f0175e03a0973cf4afd476bef05c26e228520400eb1fd473ad417b1c00ffb"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "valuable" name = "valuable"
version = "0.1.0" version = "0.1.0"
@ -4683,36 +4545,6 @@ dependencies = [
"try-lock", "try-lock",
] ]
[[package]]
name = "warp"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cef4e1e9114a4b7f1ac799f16ce71c14de5778500c5450ec6b7b920c55b587e"
dependencies = [
"bytes 1.1.0",
"futures-channel",
"futures-util",
"headers",
"http",
"hyper",
"log",
"mime",
"mime_guess",
"multipart",
"percent-encoding",
"pin-project",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-stream",
"tokio-tungstenite",
"tokio-util 0.6.9",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.10.0+wasi-snapshot-preview1" version = "0.10.0+wasi-snapshot-preview1"

View File

@ -16,7 +16,6 @@ tower = { version = "0.4.12", features = ["full"] }
hyper = { version = "0.14.17", features = ["full"] } hyper = { version = "0.14.17", features = ["full"] }
serde_yaml = "0.8.23" serde_yaml = "0.8.23"
serde_json = "1.0.79" serde_json = "1.0.79"
warp = "0.3.2"
url = "2.2.2" url = "2.2.2"
argon2 = { version = "0.3.4", features = ["password-hash"] } argon2 = { version = "0.3.4", features = ["password-hash"] }
reqwest = { version = "0.11.10", features = ["rustls-tls"], default-features = false, optional = true } reqwest = { version = "0.11.10", features = ["rustls-tls"], default-features = false, optional = true }
@ -42,7 +41,6 @@ mas-http = { path = "../http" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-tasks = { path = "../tasks" } mas-tasks = { path = "../tasks" }
mas-templates = { path = "../templates" } mas-templates = { path = "../templates" }
mas-warp-utils = { path = "../warp-utils" }
mas-axum-utils = { path = "../axum-utils" } mas-axum-utils = { path = "../axum-utils" }
[dev-dependencies] [dev-dependencies]

View File

@ -40,7 +40,7 @@ pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
// The CORS filter needs to know what headers it should whitelist for // The CORS filter needs to know what headers it should whitelist for
// CORS-protected requests. // CORS-protected requests.
mas_warp_utils::filters::cors::set_propagator(&propagator); // TODO mas_warp_utils::filters::cors::set_propagator(&propagator);
global::set_text_map_propagator(propagator); global::set_text_map_propagator(propagator);
let tracer = tracer(&config.tracing.exporter)?; let tracer = tracer(&config.tracing.exporter)?;

View File

@ -20,7 +20,6 @@ thiserror = "1.0.30"
anyhow = "1.0.56" anyhow = "1.0.56"
# Web server # Web server
warp = "0.3.2"
hyper = { version = "0.14.17", features = ["full"] } hyper = { version = "0.14.17", features = ["full"] }
tower = "0.4.12" tower = "0.4.12"
axum = "0.4.8" axum = "0.4.8"
@ -67,7 +66,6 @@ mas-jose = { path = "../jose" }
mas-static-files = { path = "../static-files" } mas-static-files = { path = "../static-files" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-templates = { path = "../templates" } mas-templates = { path = "../templates" }
mas-warp-utils = { path = "../warp-utils" }
[dev-dependencies] [dev-dependencies]
indoc = "1.0.4" indoc = "1.0.4"

View File

@ -16,7 +16,7 @@
#![deny(clippy::all, rustdoc::broken_intra_doc_links)] #![deny(clippy::all, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow( #![allow(
clippy::unused_async // Some warp filters need that clippy::unused_async // Some axum handlers need that
)] )]
use std::sync::Arc; use std::sync::Arc;

View File

@ -14,7 +14,6 @@ serde_json = "1.0.79"
thiserror = "1.0.30" thiserror = "1.0.30"
anyhow = "1.0.56" anyhow = "1.0.56"
tracing = "0.1.32" tracing = "0.1.32"
warp = "0.3.2"
# Password hashing # Password hashing
argon2 = { version = "0.3.4", features = ["password-hash"] } argon2 = { version = "0.3.4", features = ["password-hash"] }

View File

@ -21,7 +21,6 @@ use oauth2_types::requests::GrantType;
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use thiserror::Error; use thiserror::Error;
use url::Url; use url::Url;
use warp::reject::Reject;
use crate::PostgresqlBackend; use crate::PostgresqlBackend;
@ -79,8 +78,6 @@ impl ClientFetchError {
} }
} }
impl Reject for ClientFetchError {}
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup { impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
type Error = ClientFetchError; type Error = ClientFetchError;

View File

@ -19,7 +19,6 @@ use mas_data_model::{
}; };
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use thiserror::Error; use thiserror::Error;
use warp::reject::Reject;
use super::client::{lookup_client_by_client_id, ClientFetchError}; use super::client::{lookup_client_by_client_id, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend}; use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
@ -87,8 +86,6 @@ pub enum RefreshTokenLookupError {
Conversion(#[from] DatabaseInconsistencyError), Conversion(#[from] DatabaseInconsistencyError),
} }
impl Reject for RefreshTokenLookupError {}
impl RefreshTokenLookupError { impl RefreshTokenLookupError {
#[must_use] #[must_use]
pub fn not_found(&self) -> bool { pub fn not_found(&self) -> bool {

View File

@ -27,7 +27,6 @@ use sqlx::{postgres::types::PgInterval, Acquire, PgExecutor, Postgres, Transacti
use thiserror::Error; use thiserror::Error;
use tokio::task; use tokio::task;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
use warp::reject::Reject;
use super::{DatabaseInconsistencyError, PostgresqlBackend}; use super::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::IdAndCreationTime; use crate::IdAndCreationTime;
@ -117,8 +116,6 @@ pub enum ActiveSessionLookupError {
Conversion(#[from] DatabaseInconsistencyError), Conversion(#[from] DatabaseInconsistencyError),
} }
impl Reject for ActiveSessionLookupError {}
impl ActiveSessionLookupError { impl ActiveSessionLookupError {
#[must_use] #[must_use]
pub fn not_found(&self) -> bool { pub fn not_found(&self) -> bool {

View File

@ -21,7 +21,6 @@ serde_json = "1.0.79"
serde_urlencoded = "0.7.1" serde_urlencoded = "0.7.1"
url = "2.2.2" url = "2.2.2"
warp = "0.3.2"
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }

View File

@ -271,8 +271,6 @@ pub enum TemplateError {
}, },
} }
impl warp::reject::Reject for TemplateError {}
register_templates! { register_templates! {
extra = { extra = {
"components/button.html", "components/button.html",

View File

@ -1,42 +0,0 @@
[package]
name = "mas-warp-utils"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
tokio = { version = "1.17.0", features = ["macros"] }
headers = "0.3.7"
cookie = "0.16.0"
warp = "0.3.2"
hyper = { version = "0.14.17", features = ["full"] }
thiserror = "1.0.30"
anyhow = "1.0.56"
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres"] }
chrono = { version = "0.4.19", features = ["serde"] }
serde = { version = "1.0.136", features = ["derive"] }
serde_with = { version = "1.12.0", features = ["hex", "chrono"] }
serde_json = "1.0.79"
serde_urlencoded = "0.7.1"
data-encoding = "2.3.2"
once_cell = "1.10.0"
tracing = "0.1.32"
opentelemetry = "0.17.0"
rand = "0.8.5"
mime = "0.3.16"
bincode = "1.3.3"
crc = "2.1.0"
url = "2.2.2"
http = "0.2.6"
http-body = "0.4.4"
tower = { version = "0.4.12", features = ["util"] }
oauth2-types = { path = "../oauth2-types" }
mas-config = { path = "../config" }
mas-templates = { path = "../templates" }
mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" }
mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" }
mas-http = { path = "../http" }

View File

@ -1,42 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Helper to deal with various unstructured errors in application code
use warp::{reject::Reject, Rejection};
#[derive(Debug)]
pub(crate) struct WrappedError(anyhow::Error);
impl warp::reject::Reject for WrappedError {}
/// Wrap any error in a [`Rejection`]
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
WrappedError(e.into())
}
/// Extension trait that wraps errors in [`Rejection`]s
pub trait WrapError<T> {
/// Wrap transform the [`Result`] error type to a [`Rejection`]
fn wrap_error(self) -> Result<T, Rejection>;
}
impl<T, E> WrapError<T> for Result<T, E>
where
E: Into<anyhow::Error>,
{
fn wrap_error(self) -> Result<T, Rejection> {
self.map_err(|e| warp::reject::custom(WrappedError(e.into())))
}
}

View File

@ -1,156 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Authenticate an endpoint with an access token as bearer authorization token
use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::{AccessToken, Session, TokenFormatError, TokenType};
use mas_storage::{
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
PostgresqlBackend,
};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use thiserror::Error;
use warp::{
reject::{MissingHeader, Reject},
reply::{with_header, with_status},
Filter, Rejection, Reply,
};
use super::{
database::connection,
headers::{typed_header, InvalidTypedHeader},
};
use crate::errors::wrapped_error;
/// Bearer token authentication failed
///
/// This is recoverable with [`recover_unauthorized`]
#[derive(Debug, Error)]
pub enum AuthenticationError {
/// The bearer token has an invalid format
#[error("invalid token format")]
TokenFormat(#[from] TokenFormatError),
/// The bearer token is not an access token
#[error("invalid token type {0:?}, expected an access token")]
WrongTokenType(TokenType),
/// The access token was not found in the database
#[error("unknown token")]
TokenNotFound(#[source] AccessTokenLookupError),
/// The `Authorization` header is missing
#[error("missing authorization header")]
MissingAuthorizationHeader,
/// The `Authorization` header is invalid
#[error("invalid authorization header")]
InvalidAuthorizationHeader,
}
impl Reject for AuthenticationError {}
/// Authenticate a request using an access token as a bearer authorization
///
/// # Rejections
///
/// This can reject with either a [`AuthenticationError`] or with a generic
/// wrapped sqlx error.
#[must_use]
pub fn authentication(
pool: &PgPool,
) -> impl Filter<
Extract = (AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>),
Error = Rejection,
> + Clone
+ Send
+ Sync
+ 'static {
connection(pool)
.and(typed_header())
.and_then(authenticate)
.recover(recover)
.unify()
.untuple_one()
}
fn ensure<T: Clone + Send + Sync + 'static>(t: T) -> T {
t
}
async fn authenticate(
mut conn: PoolConnection<Postgres>,
auth: Authorization<Bearer>,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
let token = auth.0.token();
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
if token_type != TokenType::AccessToken {
return Err(AuthenticationError::WrongTokenType(token_type).into());
}
let (token, session) = lookup_active_access_token(&mut conn, token)
.await
.map_err(|e| {
if e.not_found() {
// This error happens if the token was not found and should be recovered
warp::reject::custom(AuthenticationError::TokenNotFound(e))
} else {
// This is a generic database error that we want to propagate
warp::reject::custom(wrapped_error(e))
}
})?;
let session = ensure(session);
let token = ensure(token);
Ok((token, session))
}
/// Transform the rejections from the [`with_typed_header`] filter
async fn recover(
rejection: Rejection,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
if rejection.find::<MissingHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::MissingAuthorizationHeader,
));
}
if rejection.find::<InvalidTypedHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::InvalidAuthorizationHeader,
));
}
Err(rejection)
}
/// Recover from an [`AuthenticationError`] with a `WWW-Authenticate` header, as
/// per [RFC6750]. This is not intended for user-facing endpoints.
///
/// [RFC6750]: https://www.rfc-editor.org/rfc/rfc6750.html
pub async fn recover_unauthorized(rejection: Rejection) -> Result<Box<dyn Reply>, Rejection> {
if rejection.find::<AuthenticationError>().is_some() {
// TODO: have the issuer/realm here
let reply = "invalid token";
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
return Ok(Box::new(reply));
}
Err(rejection)
}

View File

@ -1,772 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Handle client authentication
use std::collections::HashMap;
use data_encoding::BASE64;
use headers::{authorization::Basic, Authorization};
use mas_config::Encrypter;
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebKeySet, JsonWebTokenParts, SharedSecret,
StaticJwksStore, VerifyingKeystore,
};
use mas_storage::{
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
PostgresqlBackend,
};
use serde::{de::DeserializeOwned, Deserialize};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use thiserror::Error;
use tower::{BoxError, ServiceExt};
use warp::{reject::Reject, Filter, Rejection};
use super::{database::connection, headers::typed_header};
use crate::errors::WrapError;
/// Protect an enpoint with client authentication
#[must_use]
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
pool: &PgPool,
encrypter: &Encrypter,
audience: String,
) -> impl Filter<
Extract = (
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Error = Rejection,
> + Clone
+ Send
+ Sync
+ 'static {
let encrypter = encrypter.clone();
// First, extract the client credentials
let credentials = typed_header()
.and(warp::body::form())
// Either from the "Authorization" header
.map(|auth: Authorization<Basic>, body: T| {
let client_id = auth.0.username().to_string();
let client_secret = Some(auth.0.password().to_string());
(
ClientCredentials::Pair {
via: CredentialsVia::AuthorizationHeader,
client_id,
client_secret,
},
body,
)
})
// Or from the form body
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
let ClientAuthForm { credentials, body } = form;
(credentials, body)
}))
.unify()
.untuple_one();
warp::any()
.and(connection(pool))
.and(warp::any().map(move || encrypter.clone()))
.and(warp::any().map(move || audience.clone()))
.and(credentials)
.and_then(authenticate_client)
.untuple_one()
}
#[derive(Error, Debug)]
enum ClientAuthenticationError {
#[error("wrong client secret for client {client_id:?}")]
ClientSecretMismatch { client_id: String },
#[error("could not fetch client {client_id:?}")]
ClientFetch {
client_id: String,
source: ClientFetchError,
},
#[error("client {client_id:?} has an invalid client secret")]
InvalidClientSecret {
client_id: String,
source: anyhow::Error,
},
#[error("client {client_id:?} has an invalid JWKS")]
InvalidJwks { client_id: String },
#[error("wrong client authentication method for client {client_id:?}")]
WrongAuthenticationMethod { client_id: String },
#[error("wrong audience in client assertion: expected {expected:?}")]
MissingAudience { expected: String },
#[error("invalid client assertion")]
InvalidAssertion,
}
impl Reject for ClientAuthenticationError {}
fn decrypt_client_secret<T: StorageBackend>(
client: &Client<T>,
encrypter: &Encrypter,
) -> anyhow::Result<Vec<u8>> {
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or_else(|| anyhow::anyhow!("missing encrypted_client_secret field"))?;
let encrypted_client_secret = BASE64.decode(encrypted_client_secret.as_bytes())?;
let nonce: &[u8; 12] = encrypted_client_secret
.get(0..12)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
.try_into()?;
let payload = encrypted_client_secret
.get(12..)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
let decrypted_client_secret = encrypter.decrypt(nonce, payload)?;
Ok(decrypted_client_secret)
}
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
// Assert that the output is both a VerifyingKeystore and Send
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
t
}
let inner = match jwks {
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
JwksOrJwksUri::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.json::<JsonWebKeySet>()
.map_request(move |_: ()| {
http::Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
};
assert(inner)
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
async fn authenticate_client<T>(
mut conn: PoolConnection<Postgres>,
encrypter: Encrypter,
audience: String,
credentials: ClientCredentials,
body: T,
) -> Result<
(
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Rejection,
> {
let (auth_method, client) = match credentials {
ClientCredentials::Pair {
client_id,
client_secret,
via,
} => {
let client = lookup_client_by_client_id(&mut *conn, &client_id)
.await
.map_err(|source| ClientAuthenticationError::ClientFetch {
client_id: client_id.clone(),
source,
})?;
let auth_method = client.token_endpoint_auth_method.ok_or(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id.clone(),
},
)?;
// Let's match the authentication method
match (auth_method, client_secret, via) {
(OAuthClientAuthenticationMethod::None, None, _) => {}
(
OAuthClientAuthenticationMethod::ClientSecretBasic,
Some(client_secret),
CredentialsVia::AuthorizationHeader,
)
| (
OAuthClientAuthenticationMethod::ClientSecretPost,
Some(client_secret),
CredentialsVia::FormBody,
) => {
let decrypted =
decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::InvalidClientSecret {
client_id: client.client_id.clone(),
source,
}
})?;
if client_secret.as_bytes() != decrypted {
return Err(warp::reject::custom(
ClientAuthenticationError::ClientSecretMismatch {
client_id: client.client_id,
},
));
}
}
_ => {
return Err(warp::reject::custom(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id,
},
));
}
}
(auth_method, client)
}
ClientCredentials::Assertion {
client_id,
client_assertion_type: ClientAssertionType::JwtBearer,
client_assertion,
} => {
let token: JsonWebTokenParts = client_assertion.parse().wrap_error()?;
let decoded: DecodedJsonWebToken<HashMap<String, serde_json::Value>> =
token.decode().wrap_error()?;
let time_options = TimeOptions::default()
.freeze()
.leeway(chrono::Duration::minutes(1));
let mut claims = decoded.claims().clone();
let iss = ISS.extract_required(&mut claims).wrap_error()?;
let sub = SUB.extract_required(&mut claims).wrap_error()?;
let aud = AUD.extract_required(&mut claims).wrap_error()?;
// Validate the times
let _exp = EXP
.extract_required_with_options(&mut claims, &time_options)
.wrap_error()?;
let _nbf = NBF
.extract_optional_with_options(&mut claims, &time_options)
.wrap_error()?;
let _iat = IAT
.extract_optional_with_options(&mut claims, &time_options)
.wrap_error()?;
// TODO: validate the JTI
let _jti = JTI.extract_optional(&mut claims).wrap_error()?;
// client_id might have been passed as parameter. If not, it should be inferred
// from the token, as per rfc7521 sec. 4.2
let client_id = client_id.as_ref().unwrap_or(&sub);
let client = lookup_client_by_client_id(&mut *conn, client_id)
.await
.map_err(|source| ClientAuthenticationError::ClientFetch {
client_id: client_id.to_string(),
source,
})?;
let auth_method = client.token_endpoint_auth_method.ok_or(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id.clone(),
},
)?;
match auth_method {
OAuthClientAuthenticationMethod::ClientSecretJwt => {
let client_secret =
decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::InvalidClientSecret {
client_id: client.client_id.clone(),
source,
}
})?;
let store = SharedSecret::new(&client_secret);
let fut = token.verify(decoded.header(), &store);
fut.await.wrap_error()?;
}
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
let jwks = client.jwks.as_ref().ok_or_else(|| {
ClientAuthenticationError::InvalidJwks {
client_id: client.client_id.clone(),
}
})?;
let store = jwks_key_store(jwks);
let fut = token.verify(decoded.header(), &store);
fut.await.wrap_error()?;
}
_ => {
return Err(warp::reject::custom(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id,
},
));
}
}
// rfc7523 sec. 3.3: the audience is the URL being called
if !aud.contains(&audience) {
return Err(
ClientAuthenticationError::MissingAudience { expected: audience }.into(),
);
}
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
// match the client_id
if iss != sub || &iss != client_id {
return Err(ClientAuthenticationError::InvalidAssertion.into());
}
(auth_method, client)
}
};
tracing::Span::current().record("enduser.id", &client.client_id.as_str());
Ok((auth_method, client, body))
}
#[derive(Deserialize)]
enum ClientAssertionType {
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
JwtBearer,
}
enum CredentialsVia {
FormBody,
AuthorizationHeader,
}
impl Default for CredentialsVia {
fn default() -> Self {
Self::FormBody
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ClientCredentials {
// Order here is important: serde tries to deserialize enum variants in order, so if "Pair"
// was before "Assertion", a client_assertion with a client_id would match the "Pair"
// variant first
Assertion {
client_id: Option<String>,
client_assertion_type: ClientAssertionType,
client_assertion: String,
},
Pair {
#[serde(skip)]
via: CredentialsVia,
client_id: String,
client_secret: Option<String>,
},
}
#[derive(Deserialize)]
struct ClientAuthForm<T> {
#[serde(flatten)]
credentials: ClientCredentials,
#[serde(flatten)]
body: T,
}
/* TODO: all secrets are broken because there is no way to mock the DB yet
#[cfg(test)]
mod tests {
use headers::authorization::Credentials;
use mas_config::{ClientAuthMethodConfig, ConfigurationSection};
use mas_jose::{SigningKeystore, StaticKeystore};
use serde_json::json;
use tower::{Service, ServiceExt};
use super::*;
// Long client_secret to support it as a HS512 key
const CLIENT_SECRET: &str = "leek2zaeyeb8thai7piehea3vah6ool9oanin9aeraThuci9EeghaekaiD1upe4Quoh7xeMae2meitohj0Waaveiwaorah1yazohr6Vae7iebeiRaWene5IeWeeciezu";
fn client_private_keystore() -> StaticKeystore {
let mut store = StaticKeystore::new();
store.add_test_rsa_key().unwrap();
store.add_test_ecdsa_key().unwrap();
store
}
async fn oauth2_config() -> ClientsConfig {
let mut config = ClientsConfig::test();
config.push(ClientConfig {
client_id: "public".to_string(),
client_auth_method: ClientAuthMethodConfig::None,
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-basic".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretBasic {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-post".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretPost {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-jwt".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "secret-jwt-2".to_string(),
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
let store = client_private_keystore();
let jwks = (&store).ready().await.unwrap().call(()).await.unwrap();
//let jwks = store.export_jwks().await.unwrap();
config.push(ClientConfig {
client_id: "private-key-jwt".to_string(),
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.clone().into()),
redirect_uris: Vec::new(),
});
config.push(ClientConfig {
client_id: "private-key-jwt-2".to_string(),
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.into()),
redirect_uris: Vec::new(),
});
config
}
#[derive(Deserialize)]
struct Form {
foo: String,
bar: String,
}
#[tokio::test]
async fn client_secret_jwt_hs256() {
client_secret_jwt("HS256").await;
}
#[tokio::test]
async fn client_secret_jwt_hs384() {
client_secret_jwt("HS384").await;
}
#[tokio::test]
async fn client_secret_jwt_hs512() {
client_secret_jwt("HS512").await;
}
fn client_claims(
client_id: &str,
audience: &str,
iat: chrono::DateTime<chrono::Utc>,
) -> HashMap<String, serde_json::Value> {
let mut claims = HashMap::new();
let exp = iat + chrono::Duration::minutes(1);
ISS.insert(&mut claims, client_id).unwrap();
SUB.insert(&mut claims, client_id).unwrap();
AUD.insert(&mut claims, vec![audience.to_string()]).unwrap();
IAT.insert(&mut claims, iat).unwrap();
NBF.insert(&mut claims, iat).unwrap();
EXP.insert(&mut claims, exp).unwrap();
claims
}
async fn client_secret_jwt(alg: &str) {
let alg = alg.parse().unwrap();
let audience = "https://example.com/token";
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
let store = SharedSecret::new(&CLIENT_SECRET);
let claims = client_claims("secret-jwt", audience, chrono::Utc::now());
let header = store.prepare_header(alg).await.expect("JWT header");
let jwt = DecodedJsonWebToken::new(header, claims);
let jwt = jwt.sign(&store).await.expect("signed token");
let jwt = jwt.serialize();
// TODO: test failing cases
// - expired token
// - "not before" in the future
// - subject/issuer mismatch
// - audience mismatch
// - wrong secret/signature
let (auth, client, body) = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_id": "secret-jwt",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretJwt);
assert_eq!(client.client_id, "secret-jwt");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
// Without client_id
let res = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_ok());
// client_id mismatch
let res = warp::test::request()
.method("POST")
.body(serde_urlencoded::to_string(json!({
"client_id": "secret-jwt-2",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn client_secret_jwt_rs256() {
private_key_jwt("RS256").await;
}
#[tokio::test]
async fn client_secret_jwt_rs384() {
private_key_jwt("RS384").await;
}
#[tokio::test]
async fn client_secret_jwt_rs512() {
private_key_jwt("RS512").await;
}
#[tokio::test]
async fn client_secret_jwt_es256() {
private_key_jwt("ES256").await;
}
async fn private_key_jwt(alg: &str) {
let alg = alg.parse().unwrap();
let audience = "https://example.com/token";
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
let store = client_private_keystore();
let claims = client_claims("private-key-jwt", audience, chrono::Utc::now());
let header = store.prepare_header(alg).await.expect("JWT header");
let jwt = DecodedJsonWebToken::new(header, claims);
let jwt = jwt.sign(&store).await.expect("signed token");
let jwt = jwt.serialize();
// TODO: test failing cases
// - expired token
// - "not before" in the future
// - subject/issuer mismatch
// - audience mismatch
// - wrong secret/signature
let (auth, client, body) = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_id": "private-key-jwt",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::PrivateKeyJwt);
assert_eq!(client.client_id, "private-key-jwt");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
// Without client_id
let res = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_ok());
// client_id mismatch
let res = warp::test::request()
.method("POST")
.body(serde_urlencoded::to_string(json!({
"client_id": "private-key-jwt-2",
"client_assertion": jwt,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn client_secret_post() {
let filter = client_authentication::<Form>(
&oauth2_config().await,
"https://example.com/token".to_string(),
);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.body(
serde_urlencoded::to_string(json!({
"client_id": "secret-post",
"client_secret": CLIENT_SECRET,
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretPost);
assert_eq!(client.client_id, "secret-post");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
#[tokio::test]
async fn client_secret_basic() {
let filter = client_authentication::<Form>(
&oauth2_config().await,
"https://example.com/token".to_string(),
);
let auth = Authorization::basic("secret-basic", CLIENT_SECRET);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.header("Authorization", auth.0.encode())
.body(
serde_urlencoded::to_string(json!({
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretBasic);
assert_eq!(client.client_id, "secret-basic");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
#[tokio::test]
async fn none() {
let filter = client_authentication::<Form>(
&oauth2_config().await,
"https://example.com/token".to_string(),
);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.body(
serde_urlencoded::to_string(json!({
"client_id": "public",
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, OAuthClientAuthenticationMethod::None);
assert_eq!(client.client_id, "public");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
}
*/

View File

@ -1,193 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Deal with encrypted cookies
use std::{convert::Infallible, marker::PhantomData};
use cookie::{Cookie, SameSite};
use data_encoding::BASE64URL_NOPAD;
use headers::{Header, HeaderValue, SetCookie};
use mas_config::Encrypter;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use thiserror::Error;
use warp::{
reject::{InvalidHeader, MissingCookie, Reject},
Filter, Rejection, Reply,
};
use super::none_on_error;
use crate::{
errors::WrapError,
reply::{with_typed_header, WithTypedHeader},
};
/// Unable to decrypt the cookie
#[derive(Debug, Error)]
pub struct CookieDecryptionError<T: EncryptableCookieValue>(
#[source] anyhow::Error,
// This [`std::marker::PhantomData`] records what kind of cookie it was trying to save.
// This then use when displaying the error.
PhantomData<T>,
);
impl<T> Reject for CookieDecryptionError<T> where T: EncryptableCookieValue + 'static {}
impl<T: EncryptableCookieValue> From<anyhow::Error> for CookieDecryptionError<T> {
fn from(e: anyhow::Error) -> Self {
Self(e, PhantomData)
}
}
impl<T: EncryptableCookieValue> std::fmt::Display for CookieDecryptionError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to decrypt cookie {}", T::cookie_key())
}
}
fn decryption_error<T>(e: anyhow::Error) -> Rejection
where
T: EncryptableCookieValue + 'static,
{
let e: CookieDecryptionError<T> = e.into();
warp::reject::custom(e)
}
#[derive(Serialize, Deserialize)]
struct EncryptedCookie {
nonce: [u8; 12],
ciphertext: Vec<u8>,
}
impl EncryptedCookie {
/// Encrypt from a given key
fn encrypt<T: Serialize>(payload: T, encrypter: &Encrypter) -> anyhow::Result<Self> {
let message = bincode::serialize(&payload)?;
let nonce: [u8; 12] = rand::random();
let ciphertext = encrypter.encrypt(&nonce, &message)?;
Ok(Self { nonce, ciphertext })
}
/// Decrypt the content of the cookie from a given key
fn decrypt<T: DeserializeOwned>(&self, encrypter: &Encrypter) -> anyhow::Result<T> {
let message = encrypter.decrypt(&self.nonce, &self.ciphertext)?;
let token = bincode::deserialize(&message)?;
Ok(token)
}
/// Encode the encrypted cookie to be then saved as a cookie
fn to_cookie_value(&self) -> anyhow::Result<String> {
let raw = bincode::serialize(self)?;
Ok(BASE64URL_NOPAD.encode(&raw))
}
fn from_cookie_value(value: &str) -> anyhow::Result<Self> {
let raw = BASE64URL_NOPAD.decode(value.as_bytes())?;
let content = bincode::deserialize(&raw)?;
Ok(content)
}
}
/// Extract an optional encrypted cookie
#[must_use]
pub fn maybe_encrypted<T>(
encrypter: &Encrypter,
) -> impl Filter<Extract = (Option<T>,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + EncryptableCookieValue + 'static,
{
encrypted(encrypter)
.map(Some)
.recover(none_on_error::<T, InvalidHeader>)
.unify()
.recover(none_on_error::<T, MissingCookie>)
.unify()
.recover(none_on_error::<T, CookieDecryptionError<T>>)
.unify()
}
/// Extract an encrypted cookie
///
/// # Rejections
///
/// This can reject with either a [`warp::reject::MissingCookie`] or a
/// [`CookieDecryptionError`]
#[must_use]
pub fn encrypted<T>(
encrypter: &Encrypter,
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + EncryptableCookieValue + 'static,
{
let encrypter = encrypter.clone();
warp::cookie::cookie(T::cookie_key()).and_then(move |value: String| {
let encrypter = encrypter.clone();
async move {
let encrypted_payload =
EncryptedCookie::from_cookie_value(&value).map_err(decryption_error::<T>)?;
let decrypted_payload = encrypted_payload
.decrypt(&encrypter)
.map_err(decryption_error::<T>)?;
Ok::<_, Rejection>(decrypted_payload)
}
})
}
/// Get an [`EncryptedCookieSaver`] to help saving an [`EncryptableCookieValue`]
#[must_use]
pub fn encrypted_cookie_saver(
encrypter: &Encrypter,
) -> impl Filter<Extract = (EncryptedCookieSaver,), Error = Infallible> + Clone + Send + Sync + 'static
{
let encrypter = encrypter.clone();
warp::any().map(move || EncryptedCookieSaver {
encrypter: encrypter.clone(),
})
}
/// A cookie that can be encrypted with a well-known cookie key
pub trait EncryptableCookieValue: Serialize + Send + Sync + std::fmt::Debug {
/// What key should be used for this cookie
fn cookie_key() -> &'static str;
}
/// An opaque structure which helps encrypting a cookie and attach it to a reply
pub struct EncryptedCookieSaver {
encrypter: Encrypter,
}
impl EncryptedCookieSaver {
/// Save an [`EncryptableCookieValue`]
pub fn save_encrypted<T: EncryptableCookieValue, R: Reply>(
&self,
cookie: &T,
reply: R,
) -> Result<WithTypedHeader<R, SetCookie>, Rejection> {
let encrypted = EncryptedCookie::encrypt(cookie, &self.encrypter)
.wrap_error()?
.to_cookie_value()
.wrap_error()?;
// TODO: make those options customizable
let value = Cookie::build(T::cookie_key(), encrypted)
.http_only(true)
.same_site(SameSite::Lax)
.finish()
.to_string();
let header = SetCookie::decode(&mut [HeaderValue::from_str(&value).wrap_error()?].iter())
.wrap_error()?;
Ok(with_typed_header(header, reply))
}
}

View File

@ -1,42 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Wrapper around [`warp::filters::cors`]
use std::string::ToString;
use once_cell::sync::OnceCell;
static PROPAGATOR_HEADERS: OnceCell<Vec<String>> = OnceCell::new();
/// Notify the CORS filter what opentelemetry propagators are being used. This
/// helps whitelisting headers in CORS requests.
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
let headers = propagator.fields().map(ToString::to_string).collect();
tracing::debug!(
?headers,
"Headers allowed in CORS requests for trace propagators set"
);
PROPAGATOR_HEADERS
.set(headers)
.expect(concat!(module_path!(), "::set_propagator was called twice"));
}
/// Create a wrapping filter that exposes CORS behavior for a wrapped filter.
#[must_use]
pub fn cors() -> warp::filters::cors::Builder {
warp::filters::cors::cors()
.allow_any_origin()
.allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new()))
}

View File

@ -1,185 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Stateless CSRF protection middleware based on a chacha20-poly1305 encrypted
//! and signed token
use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use mas_config::{CsrfConfig, Encrypter};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection};
use super::cookies::EncryptableCookieValue;
/// Failed to validate CSRF token
#[derive(Debug, Error)]
pub enum CsrfError {
/// The token in the form did not match the token in the cookie
#[error("CSRF token mismatch")]
Mismatch,
/// The token expired
#[error("CSRF token expired")]
Expired,
/// Failed to decode the token
#[error("could not decode CSRF token")]
Decode(#[from] DecodeError),
}
impl Reject for CsrfError {}
/// A CSRF token
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct CsrfToken {
#[serde_as(as = "TimestampSeconds<i64>")]
expiration: DateTime<Utc>,
token: [u8; 32],
}
impl CsrfToken {
/// Create a new token from a defined value valid for a specified duration
fn new(token: [u8; 32], ttl: Duration) -> Self {
let expiration = Utc::now() + ttl;
Self { expiration, token }
}
/// Generate a new random token valid for a specified duration
fn generate(ttl: Duration) -> Self {
let token = rand::random();
Self::new(token, ttl)
}
/// Generate a new token with the same value but an up to date expiration
fn refresh(self, ttl: Duration) -> Self {
Self::new(self.token, ttl)
}
/// Get the value to include in HTML forms
#[must_use]
pub fn form_value(&self) -> String {
BASE64URL_NOPAD.encode(&self.token[..])
}
/// Verifies that the value got from an HTML form matches this token
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
if self.token[..] == form_value {
Ok(())
} else {
Err(CsrfError::Mismatch)
}
}
fn verify_expiration(self) -> Result<Self, CsrfError> {
if Utc::now() < self.expiration {
Ok(self)
} else {
Err(CsrfError::Expired)
}
}
}
impl EncryptableCookieValue for CsrfToken {
fn cookie_key() -> &'static str {
"csrf"
}
}
/// A CSRF-protected form
#[derive(Deserialize)]
struct CsrfForm<T> {
csrf: String,
#[serde(flatten)]
inner: T,
}
impl<T> CsrfForm<T> {
fn verify_csrf(self, token: &CsrfToken) -> Result<T, CsrfError> {
// Verify CSRF from request
token.verify_form_value(&self.csrf)?;
Ok(self.inner)
}
}
fn csrf_token(
encrypter: &Encrypter,
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
super::cookies::encrypted(encrypter).and_then(move |token: CsrfToken| async move {
let verified = token.verify_expiration()?;
Ok::<_, Rejection>(verified)
})
}
/// Extract an up-to-date CSRF token to include in forms
///
/// Routes using this should not forget to reply the updated CSRF cookie using
/// an [`EncryptedCookieSaver`][`super::cookies::EncryptedCookieSaver`] obtained
/// with [`encrypted_cookie_saver`][`super::cookies::encrypted_cookie_saver`]
#[must_use]
pub fn updated_csrf_token(
encrypter: &Encrypter,
csrf_config: &CsrfConfig,
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
let ttl = csrf_config.ttl;
super::cookies::maybe_encrypted(encrypter).and_then(
move |maybe_token: Option<CsrfToken>| async move {
// Explicitely specify the "Error" type here to have the `?` operation working
Ok::<_, Rejection>(
maybe_token
// Verify its TTL (but do not hard-error if it expired)
.and_then(|token| token.verify_expiration().ok())
.map_or_else(
// Generate a new token if no valid one were found
|| CsrfToken::generate(ttl),
// Else, refresh the expiration of the token
|token| token.refresh(ttl),
),
)
},
)
}
/// Extract values from a CSRF-protected form
///
/// # Rejections
///
/// This can reject with:
///
/// - [`warp::filters::body::BodyDeserializeError`] if the overall form failed
/// to decode
/// - [`CsrfError`] if the CSRF token was invalid or expired
/// - [`warp::reject::MissingCookie`] if the CSRF cookie was missing
/// - [`super::cookies::CookieDecryptionError`] if the cookie failed to decrypt
///
/// TODO: we might want to unify the last three rejections in one
#[must_use]
pub fn protected_form<T>(
encrypter: &Encrypter,
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + Send + 'static,
{
csrf_token(encrypter).and(warp::body::form()).and_then(
|csrf_token: CsrfToken, protected_form: CsrfForm<T>| async move {
let form = protected_form.verify_csrf(&csrf_token)?;
Ok::<_, Rejection>(form)
},
)
}

View File

@ -1,61 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Database-related filters to grab connections and start transactions from the
//! connection pool
use std::convert::Infallible;
use sqlx::{
pool::{Pool, PoolConnection},
Database, Transaction,
};
use warp::{Filter, Rejection};
use crate::errors::WrapError;
fn with_pool<T: Database>(
pool: &Pool<T>,
) -> impl Filter<Extract = (Pool<T>,), Error = Infallible> + Clone + Send + Sync + 'static {
let pool = pool.clone();
warp::any().map(move || pool.clone())
}
/// Acquire a connection to the database
pub fn connection<T: Database>(
pool: &Pool<T>,
) -> impl Filter<Extract = (PoolConnection<T>,), Error = Rejection> + Clone + Send + Sync + 'static
{
with_pool(pool).and_then(acquire_connection)
}
async fn acquire_connection<T: Database>(pool: Pool<T>) -> Result<PoolConnection<T>, Rejection> {
let conn = pool.acquire().await.wrap_error()?;
Ok(conn)
}
/// Start a database transaction
pub fn transaction<T: Database>(
pool: &Pool<T>,
) -> impl Filter<Extract = (Transaction<'static, T>,), Error = Rejection> + Clone + Send + Sync + 'static
{
with_pool(pool).and_then(acquire_transaction)
}
async fn acquire_transaction<T: Database>(
pool: Pool<T>,
) -> Result<Transaction<'static, T>, Rejection> {
let txn = pool.begin().await.wrap_error()?;
Ok(txn)
}

View File

@ -1,43 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Deal with typed headers from the [`headers`] crate
use headers::{Header, HeaderValue};
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection};
/// Failed to decode typed header
#[derive(Debug, Error)]
#[error("could not decode header {1}")]
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
impl Reject for InvalidTypedHeader {}
/// Extract a typed header from the request
///
/// # Rejections
///
/// This can reject with either a [`warp::reject::MissingHeader`] or a
/// [`InvalidTypedHeader`].
pub fn typed_header<T: Header + Send + 'static>(
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
}
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
let mut it = std::iter::once(&header);
let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
Ok(decoded)
}

View File

@ -1,72 +0,0 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Set of [`warp`] filters
#![allow(clippy::unused_async)] // Some warp filters need that
#![deny(missing_docs)]
pub mod authenticate;
pub mod client;
pub mod cookies;
pub mod cors;
pub mod csrf;
pub mod database;
pub mod headers;
pub mod session;
pub mod trace;
pub mod url_builder;
use std::convert::Infallible;
use mas_templates::Templates;
use warp::{Filter, Rejection};
pub use self::csrf::CsrfToken;
/// Get the [`Templates`]
#[must_use]
pub fn with_templates(
templates: &Templates,
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
let templates = templates.clone();
warp::any().map(move || templates.clone())
}
/// Recover a particular rejection type with a `None` option variant
///
/// # Example
///
/// ```rust
/// extern crate warp;
///
/// use warp::{filters::header::header, reject::MissingHeader, Filter};
///
/// use mas_warp_utils::filters::none_on_error;
///
/// header("Content-Length")
/// .map(Some)
/// .recover(none_on_error::<_, MissingHeader>)
/// .unify()
/// .map(|length: Option<u64>| {
/// format!("header: {:?}", length)
/// });
/// ```
pub async fn none_on_error<T, E: 'static>(rejection: Rejection) -> Result<Option<T>, Rejection> {
if rejection.find::<E>().is_some() {
Ok(None)
} else {
Err(rejection)
}
}

View File

@ -1,162 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Load user sessions from the database
use mas_config::Encrypter;
use mas_data_model::BrowserSession;
use mas_storage::{
user::{lookup_active_session, ActiveSessionLookupError},
PostgresqlBackend,
};
use serde::{Deserialize, Serialize};
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
use thiserror::Error;
use tracing::warn;
use warp::{
reject::{InvalidHeader, MissingCookie, Reject},
Filter, Rejection,
};
use super::{
cookies::{encrypted, CookieDecryptionError, EncryptableCookieValue},
database::connection,
none_on_error,
};
/// The session is missing or failed to load
#[derive(Error, Debug)]
pub enum SessionLoadError {
/// No session cookie was found
#[error("missing session cookie")]
MissingCookie,
/// The session cookie is invalid
#[error("unable to parse or decrypt session cookie")]
InvalidCookie,
/// The session is unknown or inactive
#[error("unknown or inactive session")]
UnknownSession,
}
impl Reject for SessionLoadError {}
/// An encrypted cookie to save the session ID
#[derive(Serialize, Deserialize, Debug)]
pub struct SessionCookie {
current: i64,
}
impl SessionCookie {
/// Forge the cookie from a [`BrowserSession`]
#[must_use]
pub fn from_session(session: &BrowserSession<PostgresqlBackend>) -> Self {
Self {
current: session.data,
}
}
/// Load the [`BrowserSession`] from database
pub async fn load_session(
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> {
let res = lookup_active_session(executor, self.current).await?;
Ok(res)
}
}
impl EncryptableCookieValue for SessionCookie {
fn cookie_key() -> &'static str {
"session"
}
}
/// Extract a user session information if logged in
#[must_use]
pub fn optional_session(
pool: &PgPool,
encrypter: &Encrypter,
) -> impl Filter<Extract = (Option<BrowserSession<PostgresqlBackend>>,), Error = Rejection>
+ Clone
+ Send
+ Sync
+ 'static {
session(pool, encrypter)
.map(Some)
.recover(none_on_error::<_, SessionLoadError>)
.unify()
}
/// Extract a user session information, rejecting if not logged in
///
/// # Rejections
///
/// This filter will reject with a [`SessionLoadError`] when the session is
/// inactive or missing. It will reject with a wrapped error on other database
/// failures.
#[must_use]
pub fn session(
pool: &PgPool,
encrypter: &Encrypter,
) -> impl Filter<Extract = (BrowserSession<PostgresqlBackend>,), Error = Rejection>
+ Clone
+ Send
+ Sync
+ 'static {
encrypted(encrypter)
.and(connection(pool))
.and_then(load_session)
.recover(recover)
.unify()
}
async fn load_session(
session: SessionCookie,
mut conn: PoolConnection<Postgres>,
) -> Result<BrowserSession<PostgresqlBackend>, Rejection> {
let session_info = session.load_session(&mut conn).await?;
Ok(session_info)
}
/// Recover from expected rejections, to transform them into a
/// [`SessionLoadError`]
async fn recover<T>(rejection: Rejection) -> Result<T, Rejection> {
if let Some(e) = rejection.find::<ActiveSessionLookupError>() {
if e.not_found() {
return Err(warp::reject::custom(SessionLoadError::UnknownSession));
}
// If we're here, there is a real database error that should be
// propagated
}
if let Some(e) = rejection.find::<InvalidHeader>() {
if e.name() == "cookie" {
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
}
}
if let Some(_e) = rejection.find::<MissingCookie>() {
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
}
if let Some(error) = rejection.find::<CookieDecryptionError<SessionCookie>>() {
warn!(?error, "could not decrypt session cookie");
return Err(warp::reject::custom(SessionLoadError::InvalidCookie));
}
Err(rejection)
}

View File

@ -1,35 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Route tracing utility
use std::convert::Infallible;
use warp::Filter;
/// Set the name of that route
#[must_use]
pub fn name(
name: &'static str,
) -> impl Filter<Extract = (), Error = Infallible> + Clone + Send + Sync + 'static {
warp::any()
.map(move || {
// TODO: update_name has a weird signature, which is already fixed in
// opentelemetry-rust, just not released yet
// TODO: we should find another way to classify requests. Span::update_name has
// impacts on sampling and should not be used
opentelemetry::trace::get_active_span(|s| s.update_name::<String>(name.to_string()));
})
.untuple_one()
}

View File

@ -1,121 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Utility to build URLs
// TODO: move this somewhere else
use std::convert::Infallible;
use mas_config::HttpConfig;
use url::Url;
use warp::Filter;
impl From<&HttpConfig> for UrlBuilder {
fn from(config: &HttpConfig) -> Self {
Self::new(config.public_base.clone())
}
}
/// Helps building absolute URLs
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct UrlBuilder {
base: Url,
}
impl UrlBuilder {
/// Create a new [`UrlBuilder`] from a base URL
#[must_use]
pub fn new(base: Url) -> Self {
Self { base }
}
/// OIDC issuer
#[must_use]
pub fn oidc_issuer(&self) -> Url {
self.base.clone()
}
/// OIDC dicovery document URL
#[must_use]
pub fn oidc_discovery(&self) -> Url {
self.base
.join(".well-known/openid-configuration")
.expect("build URL")
}
/// OAuth 2.0 authorization endpoint
#[must_use]
pub fn oauth_authorization_endpoint(&self) -> Url {
self.base.join("oauth2/authorize").expect("build URL")
}
/// OAuth 2.0 token endpoint
#[must_use]
pub fn oauth_token_endpoint(&self) -> Url {
self.base.join("oauth2/token").expect("build URL")
}
/// OAuth 2.0 introspection endpoint
#[must_use]
pub fn oauth_introspection_endpoint(&self) -> Url {
self.base.join("oauth2/introspect").expect("build URL")
}
/// OAuth 2.0 introspection endpoint
#[must_use]
pub fn oidc_userinfo_endpoint(&self) -> Url {
self.base.join("oauth2/userinfo").expect("build URL")
}
/// JWKS URI
#[must_use]
pub fn jwks_uri(&self) -> Url {
self.base.join("oauth2/keys.json").expect("build URL")
}
/// Email verification URL
#[must_use]
pub fn email_verification(&self, code: &str) -> Url {
self.base
.join("verify/")
.expect("build URL")
.join(code)
.expect("build URL")
}
}
/// Injects an [`UrlBuilder`] to help building absolute URLs
#[must_use]
pub fn url_builder(
config: &HttpConfig,
) -> impl Filter<Extract = (UrlBuilder,), Error = Infallible> + Clone + Send + Sync + 'static {
let builder: UrlBuilder = config.into();
warp::any().map(move || builder.clone())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_email_verification_url() {
let base = Url::parse("https://example.com/").unwrap();
let builder = UrlBuilder::new(base);
assert_eq!(
builder.email_verification("123456abcdef").as_str(),
"https://example.com/verify/123456abcdef"
);
}
}

View File

@ -1,24 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Various warp filters and replies
#![forbid(unsafe_code)]
#![deny(clippy::all, missing_docs, rustdoc::broken_intra_doc_links)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)]
pub mod errors;
pub mod filters;
pub mod reply;

View File

@ -1,54 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Reply with a typed header from the [`headers`] crate.
//!
//! ```rust
//! extern crate headers;
//! extern crate warp;
//!
//! use warp::Reply;
//! use mas_warp_utils::reply::with_typed_header;
//!
//! let reply = r#"{"hello": "world"}"#;
//! let reply = with_typed_header(headers::ContentType::json(), reply);;
//! let response = reply.into_response();
//! assert_eq!(response.headers().get("Content-Type").unwrap().to_str().unwrap(), "application/json");
//! ```
use headers::{Header, HeaderMapExt};
use warp::Reply;
/// Add a typed header to a reply
pub fn with_typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
WithTypedHeader { reply, header }
}
/// A reply with a typed header set
pub struct WithTypedHeader<R, H> {
reply: R,
header: H,
}
impl<R, H> Reply for WithTypedHeader<R, H>
where
R: Reply,
H: Header + Send,
{
fn into_response(self) -> warp::reply::Response {
let mut res = self.reply.into_response();
res.headers_mut().typed_insert(self.header);
res
}
}

View File

@ -1,21 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Set of wrappers for [`warp::Reply`]
#![deny(missing_docs)]
pub mod headers;
pub use self::headers::{with_typed_header, WithTypedHeader};

View File

@ -20,5 +20,5 @@
- [Architecture](./development/architecture.md) - [Architecture](./development/architecture.md)
- [Database](./development/database.md) - [Database](./development/database.md)
- [Routing with `warp`](./development/warp.md) - [Routing with `axum`]()
- [Templates]() - [Templates]()

View File

@ -20,7 +20,6 @@ This includes:
- `mas-static-files`: Frontend static files (CSS/JS). Includes some frontend tooling - `mas-static-files`: Frontend static files (CSS/JS). Includes some frontend tooling
- `mas-storage`: Interactions with the database - `mas-storage`: Interactions with the database
- `mas-tasks`: Asynchronous task runner and scheduler - `mas-tasks`: Asynchronous task runner and scheduler
- `mas-warp-utils`: Various filters and utilities for the `warp` web framework
- `oauth2-types`: Useful structures and types to deal with OAuth 2.0/OpenID Connect endpoints. This might end up published as a standalone library as it can be useful in other contexts. - `oauth2-types`: Useful structures and types to deal with OAuth 2.0/OpenID Connect endpoints. This might end up published as a standalone library as it can be useful in other contexts.
## Important crates ## Important crates
@ -74,11 +73,6 @@ Both crates work well together and complement each other.
Interactions with the database are done through [`sqlx`](https://github.com/launchbadge/sqlx), an async, pure-Rust SQL library with compile-time check of queries. Interactions with the database are done through [`sqlx`](https://github.com/launchbadge/sqlx), an async, pure-Rust SQL library with compile-time check of queries.
It also handles schema migrations. It also handles schema migrations.
### Web framework: `warp`
[`warp`](https://docs.rs/warp/*/warp/) is an easy, macro-free web framework.
Its composability makes a lot of sense when implementing OAuth 2.0 endpoints, because of the need to deal with a lot of different scenarios.
### Templates: `tera` ### Templates: `tera`
[Tera](https://tera.netlify.app/) was chosen as template engine for its simplicity as well as its ability to load templates at runtime. [Tera](https://tera.netlify.app/) was chosen as template engine for its simplicity as well as its ability to load templates at runtime.

View File

@ -1,93 +0,0 @@
# `warp`
**Warning: this document is not up to date**
Warp has a pretty unique approach in terms of routing.
It does not have a central router, rather a chain of filters composed together.
It encourages writing reusable filters to handle stuff like authentication, extracting user sessions, starting database transactions, etc.
Everything related to `warp` currently lives in the `mas-core` crate:
- `crates/core/src/`
- `handlers/`: The actual handlers for each route
- `oauth2/`: Everything related to OAuth 2.0/OIDC endpoints
- `views/`: HTML views (login, registration, account management, etc.)
- `filters/`: Reusable, composable filters
- `reply/`: Composable replies
## Defining a new endpoint
We usually keep one endpoint per file and use module roots to combine the filters of endpoints.
This is how it looks like in the current hierarchy at time of writing:
- `mod.rs`: combines the filters from `oauth2`, `views` and `health`
- `oauth2/`
- `mod.rs`: combines filters from `authorization`, `discovery`, etc.
- `authorization.rs`: handles `GET /oauth2/authorize` and `GET /oauth2/authorize/step`
- `discovery.rs`: handles `GET /.well-known/openid-configuration`
- ...
- `views/`
- `mod.rs`: combines the filters from `index`, `login`, `logout`, etc.
- `index.rs`: handles `GET /`
- `login.rs`: handles `GET /login` and `POST /login`
- `logout.rs`: handles `POST /logout`
- ...
- `health.rs`: handles `GET /health`
All filters are functions that take their dependencies (the database connection pool, the template engine, etc.) as parameters and return an `impl warp::Filter<Extract = (impl warp::Reply,)>`.
```rust
// crates/core/src/handlers/hello.rs
// Don't be scared by the type at the end, just copy-paste it
pub(super) fn filter(
pool: &PgPool,
templates: &Templates,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
// Handles `GET /hello/:param`
warp::path!("hello" / String)
.and(warp::get())
// Pass the template engine
.and(with_templates(templates))
// Extract the current user session
.and(optional_session(pool, cookies_config))
.and_then(get)
}
async fn get(
// Parameter from the route
parameter: String,
// Template engine
templates: Templates,
// The current user session
session: Option<SessionInfo>,
) -> Result<impl Reply, Rejection> {
let ctx = SomeTemplateContext::new(parameter)
.maybe_with_session(session);
let content = templates.render_something(&ctx)?;
let reply = html(content);
Ok(reply)
}
```
And then, it can be attached to the root handler:
```rust
// crates/core/src/handlers/mod.rs
use self::{health::filter as health, oauth2::filter as oauth2, hello::filter as hello};
pub fn root(
pool: &PgPool,
templates: &Templates,
config: &RootConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
health(pool)
.or(oauth2(pool, templates, &config.oauth2, &config.cookies))
// Attach it here, passing the right dependencies
.or(hello(pool, templates, &config.cookies))
}
```