You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Get rid of warp
This commit is contained in:
168
Cargo.lock
generated
168
Cargo.lock
generated
@ -672,16 +672,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "buf_redux"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b953a6887648bb07a535631f2bc00fbdb2a2216f135552cb3f534ed136b9c07f"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"safemem",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.9.1"
|
||||
@ -1992,7 +1982,6 @@ dependencies = [
|
||||
"mas-storage",
|
||||
"mas-tasks",
|
||||
"mas-templates",
|
||||
"mas-warp-utils",
|
||||
"opentelemetry",
|
||||
"opentelemetry-jaeger",
|
||||
"opentelemetry-otlp",
|
||||
@ -2010,7 +1999,6 @@ dependencies = [
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"warp",
|
||||
"watchman_client",
|
||||
]
|
||||
|
||||
@ -2102,7 +2090,6 @@ dependencies = [
|
||||
"mas-static-files",
|
||||
"mas-storage",
|
||||
"mas-templates",
|
||||
"mas-warp-utils",
|
||||
"mime",
|
||||
"oauth2-types",
|
||||
"pkcs8",
|
||||
@ -2119,7 +2106,6 @@ dependencies = [
|
||||
"tower",
|
||||
"tracing",
|
||||
"url",
|
||||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2240,7 +2226,6 @@ dependencies = [
|
||||
"tokio",
|
||||
"tracing",
|
||||
"url",
|
||||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2272,46 +2257,6 @@ dependencies = [
|
||||
"tokio",
|
||||
"tracing",
|
||||
"url",
|
||||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mas-warp-utils"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
"chrono",
|
||||
"cookie",
|
||||
"crc",
|
||||
"data-encoding",
|
||||
"headers",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"mas-config",
|
||||
"mas-data-model",
|
||||
"mas-http",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"mas-storage",
|
||||
"mas-templates",
|
||||
"mime",
|
||||
"oauth2-types",
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"serde_with",
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tracing",
|
||||
"url",
|
||||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2418,24 +2363,6 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
||||
|
||||
[[package]]
|
||||
name = "multipart"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182"
|
||||
dependencies = [
|
||||
"buf_redux",
|
||||
"httparse",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"quick-error",
|
||||
"rand",
|
||||
"safemem",
|
||||
"tempfile",
|
||||
"twoway",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.0"
|
||||
@ -3120,12 +3047,6 @@ dependencies = [
|
||||
"prost",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.15"
|
||||
@ -3436,12 +3357,6 @@ version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
|
||||
|
||||
[[package]]
|
||||
name = "safemem"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
@ -3487,12 +3402,6 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scoped-tls"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
@ -4206,19 +4115,6 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "511de3f85caf1c98983545490c3d09685fa8eb634e57eec22bb4db271f46cbd8"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.6.9"
|
||||
@ -4451,34 +4347,6 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0b2d8558abd2e276b0a8df5c05a2ec762609344191e5fd23e292c910e9165b5"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"byteorder",
|
||||
"bytes 1.1.0",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha-1 0.9.8",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twoway"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typed-builder"
|
||||
version = "0.9.1"
|
||||
@ -4644,12 +4512,6 @@ version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a1f0175e03a0973cf4afd476bef05c26e228520400eb1fd473ad417b1c00ffb"
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.0"
|
||||
@ -4683,36 +4545,6 @@ dependencies = [
|
||||
"try-lock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "warp"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3cef4e1e9114a4b7f1ac799f16ce71c14de5778500c5450ec6b7b920c55b587e"
|
||||
dependencies = [
|
||||
"bytes 1.1.0",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http",
|
||||
"hyper",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"multipart",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"scoped-tls",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util 0.6.9",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.10.0+wasi-snapshot-preview1"
|
||||
|
@ -16,7 +16,6 @@ tower = { version = "0.4.12", features = ["full"] }
|
||||
hyper = { version = "0.14.17", features = ["full"] }
|
||||
serde_yaml = "0.8.23"
|
||||
serde_json = "1.0.79"
|
||||
warp = "0.3.2"
|
||||
url = "2.2.2"
|
||||
argon2 = { version = "0.3.4", features = ["password-hash"] }
|
||||
reqwest = { version = "0.11.10", features = ["rustls-tls"], default-features = false, optional = true }
|
||||
@ -42,7 +41,6 @@ mas-http = { path = "../http" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-tasks = { path = "../tasks" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-warp-utils = { path = "../warp-utils" }
|
||||
mas-axum-utils = { path = "../axum-utils" }
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -40,7 +40,7 @@ pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
|
||||
|
||||
// The CORS filter needs to know what headers it should whitelist for
|
||||
// CORS-protected requests.
|
||||
mas_warp_utils::filters::cors::set_propagator(&propagator);
|
||||
// TODO mas_warp_utils::filters::cors::set_propagator(&propagator);
|
||||
global::set_text_map_propagator(propagator);
|
||||
|
||||
let tracer = tracer(&config.tracing.exporter)?;
|
||||
|
@ -20,7 +20,6 @@ thiserror = "1.0.30"
|
||||
anyhow = "1.0.56"
|
||||
|
||||
# Web server
|
||||
warp = "0.3.2"
|
||||
hyper = { version = "0.14.17", features = ["full"] }
|
||||
tower = "0.4.12"
|
||||
axum = "0.4.8"
|
||||
@ -67,7 +66,6 @@ mas-jose = { path = "../jose" }
|
||||
mas-static-files = { path = "../static-files" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-warp-utils = { path = "../warp-utils" }
|
||||
|
||||
[dev-dependencies]
|
||||
indoc = "1.0.4"
|
||||
|
@ -16,7 +16,7 @@
|
||||
#![deny(clippy::all, rustdoc::broken_intra_doc_links)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::unused_async // Some warp filters need that
|
||||
clippy::unused_async // Some axum handlers need that
|
||||
)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
@ -14,7 +14,6 @@ serde_json = "1.0.79"
|
||||
thiserror = "1.0.30"
|
||||
anyhow = "1.0.56"
|
||||
tracing = "0.1.32"
|
||||
warp = "0.3.2"
|
||||
|
||||
# Password hashing
|
||||
argon2 = { version = "0.3.4", features = ["password-hash"] }
|
||||
|
@ -21,7 +21,6 @@ use oauth2_types::requests::GrantType;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
use warp::reject::Reject;
|
||||
|
||||
use crate::PostgresqlBackend;
|
||||
|
||||
@ -79,8 +78,6 @@ impl ClientFetchError {
|
||||
}
|
||||
}
|
||||
|
||||
impl Reject for ClientFetchError {}
|
||||
|
||||
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
|
||||
type Error = ClientFetchError;
|
||||
|
||||
|
@ -19,7 +19,6 @@ use mas_data_model::{
|
||||
};
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use warp::reject::Reject;
|
||||
|
||||
use super::client::{lookup_client_by_client_id, ClientFetchError};
|
||||
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
|
||||
@ -87,8 +86,6 @@ pub enum RefreshTokenLookupError {
|
||||
Conversion(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl Reject for RefreshTokenLookupError {}
|
||||
|
||||
impl RefreshTokenLookupError {
|
||||
#[must_use]
|
||||
pub fn not_found(&self) -> bool {
|
||||
|
@ -27,7 +27,6 @@ use sqlx::{postgres::types::PgInterval, Acquire, PgExecutor, Postgres, Transacti
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
use tracing::{info_span, Instrument};
|
||||
use warp::reject::Reject;
|
||||
|
||||
use super::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
use crate::IdAndCreationTime;
|
||||
@ -117,8 +116,6 @@ pub enum ActiveSessionLookupError {
|
||||
Conversion(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl Reject for ActiveSessionLookupError {}
|
||||
|
||||
impl ActiveSessionLookupError {
|
||||
#[must_use]
|
||||
pub fn not_found(&self) -> bool {
|
||||
|
@ -21,7 +21,6 @@ serde_json = "1.0.79"
|
||||
serde_urlencoded = "0.7.1"
|
||||
|
||||
url = "2.2.2"
|
||||
warp = "0.3.2"
|
||||
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
|
@ -271,8 +271,6 @@ pub enum TemplateError {
|
||||
},
|
||||
}
|
||||
|
||||
impl warp::reject::Reject for TemplateError {}
|
||||
|
||||
register_templates! {
|
||||
extra = {
|
||||
"components/button.html",
|
||||
|
@ -1,42 +0,0 @@
|
||||
[package]
|
||||
name = "mas-warp-utils"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.17.0", features = ["macros"] }
|
||||
headers = "0.3.7"
|
||||
cookie = "0.16.0"
|
||||
warp = "0.3.2"
|
||||
hyper = { version = "0.14.17", features = ["full"] }
|
||||
thiserror = "1.0.30"
|
||||
anyhow = "1.0.56"
|
||||
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres"] }
|
||||
chrono = { version = "0.4.19", features = ["serde"] }
|
||||
serde = { version = "1.0.136", features = ["derive"] }
|
||||
serde_with = { version = "1.12.0", features = ["hex", "chrono"] }
|
||||
serde_json = "1.0.79"
|
||||
serde_urlencoded = "0.7.1"
|
||||
data-encoding = "2.3.2"
|
||||
once_cell = "1.10.0"
|
||||
tracing = "0.1.32"
|
||||
opentelemetry = "0.17.0"
|
||||
rand = "0.8.5"
|
||||
mime = "0.3.16"
|
||||
bincode = "1.3.3"
|
||||
crc = "2.1.0"
|
||||
url = "2.2.2"
|
||||
http = "0.2.6"
|
||||
http-body = "0.4.4"
|
||||
tower = { version = "0.4.12", features = ["util"] }
|
||||
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
mas-config = { path = "../config" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-jose = { path = "../jose" }
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-http = { path = "../http" }
|
@ -1,42 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Helper to deal with various unstructured errors in application code
|
||||
|
||||
use warp::{reject::Reject, Rejection};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct WrappedError(anyhow::Error);
|
||||
|
||||
impl warp::reject::Reject for WrappedError {}
|
||||
|
||||
/// Wrap any error in a [`Rejection`]
|
||||
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
|
||||
WrappedError(e.into())
|
||||
}
|
||||
|
||||
/// Extension trait that wraps errors in [`Rejection`]s
|
||||
pub trait WrapError<T> {
|
||||
/// Wrap transform the [`Result`] error type to a [`Rejection`]
|
||||
fn wrap_error(self) -> Result<T, Rejection>;
|
||||
}
|
||||
|
||||
impl<T, E> WrapError<T> for Result<T, E>
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn wrap_error(self) -> Result<T, Rejection> {
|
||||
self.map_err(|e| warp::reject::custom(WrappedError(e.into())))
|
||||
}
|
||||
}
|
@ -1,156 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Authenticate an endpoint with an access token as bearer authorization token
|
||||
|
||||
use headers::{authorization::Bearer, Authorization};
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::{AccessToken, Session, TokenFormatError, TokenType};
|
||||
use mas_storage::{
|
||||
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use thiserror::Error;
|
||||
use warp::{
|
||||
reject::{MissingHeader, Reject},
|
||||
reply::{with_header, with_status},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use super::{
|
||||
database::connection,
|
||||
headers::{typed_header, InvalidTypedHeader},
|
||||
};
|
||||
use crate::errors::wrapped_error;
|
||||
|
||||
/// Bearer token authentication failed
|
||||
///
|
||||
/// This is recoverable with [`recover_unauthorized`]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthenticationError {
|
||||
/// The bearer token has an invalid format
|
||||
#[error("invalid token format")]
|
||||
TokenFormat(#[from] TokenFormatError),
|
||||
|
||||
/// The bearer token is not an access token
|
||||
#[error("invalid token type {0:?}, expected an access token")]
|
||||
WrongTokenType(TokenType),
|
||||
|
||||
/// The access token was not found in the database
|
||||
#[error("unknown token")]
|
||||
TokenNotFound(#[source] AccessTokenLookupError),
|
||||
|
||||
/// The `Authorization` header is missing
|
||||
#[error("missing authorization header")]
|
||||
MissingAuthorizationHeader,
|
||||
|
||||
/// The `Authorization` header is invalid
|
||||
#[error("invalid authorization header")]
|
||||
InvalidAuthorizationHeader,
|
||||
}
|
||||
|
||||
impl Reject for AuthenticationError {}
|
||||
|
||||
/// Authenticate a request using an access token as a bearer authorization
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`AuthenticationError`] or with a generic
|
||||
/// wrapped sqlx error.
|
||||
#[must_use]
|
||||
pub fn authentication(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<
|
||||
Extract = (AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>),
|
||||
Error = Rejection,
|
||||
> + Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
connection(pool)
|
||||
.and(typed_header())
|
||||
.and_then(authenticate)
|
||||
.recover(recover)
|
||||
.unify()
|
||||
.untuple_one()
|
||||
}
|
||||
|
||||
fn ensure<T: Clone + Send + Sync + 'static>(t: T) -> T {
|
||||
t
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
auth: Authorization<Bearer>,
|
||||
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
|
||||
let token = auth.0.token();
|
||||
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
|
||||
|
||||
if token_type != TokenType::AccessToken {
|
||||
return Err(AuthenticationError::WrongTokenType(token_type).into());
|
||||
}
|
||||
|
||||
let (token, session) = lookup_active_access_token(&mut conn, token)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.not_found() {
|
||||
// This error happens if the token was not found and should be recovered
|
||||
warp::reject::custom(AuthenticationError::TokenNotFound(e))
|
||||
} else {
|
||||
// This is a generic database error that we want to propagate
|
||||
warp::reject::custom(wrapped_error(e))
|
||||
}
|
||||
})?;
|
||||
|
||||
let session = ensure(session);
|
||||
let token = ensure(token);
|
||||
|
||||
Ok((token, session))
|
||||
}
|
||||
|
||||
/// Transform the rejections from the [`with_typed_header`] filter
|
||||
async fn recover(
|
||||
rejection: Rejection,
|
||||
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
|
||||
if rejection.find::<MissingHeader>().is_some() {
|
||||
return Err(warp::reject::custom(
|
||||
AuthenticationError::MissingAuthorizationHeader,
|
||||
));
|
||||
}
|
||||
|
||||
if rejection.find::<InvalidTypedHeader>().is_some() {
|
||||
return Err(warp::reject::custom(
|
||||
AuthenticationError::InvalidAuthorizationHeader,
|
||||
));
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
||||
|
||||
/// Recover from an [`AuthenticationError`] with a `WWW-Authenticate` header, as
|
||||
/// per [RFC6750]. This is not intended for user-facing endpoints.
|
||||
///
|
||||
/// [RFC6750]: https://www.rfc-editor.org/rfc/rfc6750.html
|
||||
pub async fn recover_unauthorized(rejection: Rejection) -> Result<Box<dyn Reply>, Rejection> {
|
||||
if rejection.find::<AuthenticationError>().is_some() {
|
||||
// TODO: have the issuer/realm here
|
||||
let reply = "invalid token";
|
||||
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
|
||||
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
|
||||
return Ok(Box::new(reply));
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
@ -1,772 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Handle client authentication
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use data_encoding::BASE64;
|
||||
use headers::{authorization::Basic, Authorization};
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
|
||||
use mas_http::HttpServiceExt;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::{
|
||||
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
|
||||
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebKeySet, JsonWebTokenParts, SharedSecret,
|
||||
StaticJwksStore, VerifyingKeystore,
|
||||
};
|
||||
use mas_storage::{
|
||||
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use thiserror::Error;
|
||||
use tower::{BoxError, ServiceExt};
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
use super::{database::connection, headers::typed_header};
|
||||
use crate::errors::WrapError;
|
||||
|
||||
/// Protect an enpoint with client authentication
|
||||
#[must_use]
|
||||
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
audience: String,
|
||||
) -> impl Filter<
|
||||
Extract = (
|
||||
OAuthClientAuthenticationMethod,
|
||||
Client<PostgresqlBackend>,
|
||||
T,
|
||||
),
|
||||
Error = Rejection,
|
||||
> + Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
let encrypter = encrypter.clone();
|
||||
|
||||
// First, extract the client credentials
|
||||
let credentials = typed_header()
|
||||
.and(warp::body::form())
|
||||
// Either from the "Authorization" header
|
||||
.map(|auth: Authorization<Basic>, body: T| {
|
||||
let client_id = auth.0.username().to_string();
|
||||
let client_secret = Some(auth.0.password().to_string());
|
||||
(
|
||||
ClientCredentials::Pair {
|
||||
via: CredentialsVia::AuthorizationHeader,
|
||||
client_id,
|
||||
client_secret,
|
||||
},
|
||||
body,
|
||||
)
|
||||
})
|
||||
// Or from the form body
|
||||
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
|
||||
let ClientAuthForm { credentials, body } = form;
|
||||
|
||||
(credentials, body)
|
||||
}))
|
||||
.unify()
|
||||
.untuple_one();
|
||||
|
||||
warp::any()
|
||||
.and(connection(pool))
|
||||
.and(warp::any().map(move || encrypter.clone()))
|
||||
.and(warp::any().map(move || audience.clone()))
|
||||
.and(credentials)
|
||||
.and_then(authenticate_client)
|
||||
.untuple_one()
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
enum ClientAuthenticationError {
|
||||
#[error("wrong client secret for client {client_id:?}")]
|
||||
ClientSecretMismatch { client_id: String },
|
||||
|
||||
#[error("could not fetch client {client_id:?}")]
|
||||
ClientFetch {
|
||||
client_id: String,
|
||||
source: ClientFetchError,
|
||||
},
|
||||
|
||||
#[error("client {client_id:?} has an invalid client secret")]
|
||||
InvalidClientSecret {
|
||||
client_id: String,
|
||||
source: anyhow::Error,
|
||||
},
|
||||
|
||||
#[error("client {client_id:?} has an invalid JWKS")]
|
||||
InvalidJwks { client_id: String },
|
||||
|
||||
#[error("wrong client authentication method for client {client_id:?}")]
|
||||
WrongAuthenticationMethod { client_id: String },
|
||||
|
||||
#[error("wrong audience in client assertion: expected {expected:?}")]
|
||||
MissingAudience { expected: String },
|
||||
|
||||
#[error("invalid client assertion")]
|
||||
InvalidAssertion,
|
||||
}
|
||||
|
||||
impl Reject for ClientAuthenticationError {}
|
||||
|
||||
fn decrypt_client_secret<T: StorageBackend>(
|
||||
client: &Client<T>,
|
||||
encrypter: &Encrypter,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let encrypted_client_secret = client
|
||||
.encrypted_client_secret
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing encrypted_client_secret field"))?;
|
||||
|
||||
let encrypted_client_secret = BASE64.decode(encrypted_client_secret.as_bytes())?;
|
||||
|
||||
let nonce: &[u8; 12] = encrypted_client_secret
|
||||
.get(0..12)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
|
||||
.try_into()?;
|
||||
|
||||
let payload = encrypted_client_secret
|
||||
.get(12..)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
|
||||
|
||||
let decrypted_client_secret = encrypter.decrypt(nonce, payload)?;
|
||||
|
||||
Ok(decrypted_client_secret)
|
||||
}
|
||||
|
||||
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||
// Assert that the output is both a VerifyingKeystore and Send
|
||||
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
|
||||
t
|
||||
}
|
||||
|
||||
let inner = match jwks {
|
||||
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
|
||||
JwksOrJwksUri::JwksUri(uri) => {
|
||||
let uri = uri.clone();
|
||||
|
||||
// TODO: get the client from somewhere else?
|
||||
let exporter = mas_http::client("fetch-jwks")
|
||||
.json::<JsonWebKeySet>()
|
||||
.map_request(move |_: ()| {
|
||||
http::Request::builder()
|
||||
.method("GET")
|
||||
// TODO: change the Uri type in config to avoid reparsing here
|
||||
.uri(uri.to_string())
|
||||
.body(http_body::Empty::new())
|
||||
.unwrap()
|
||||
})
|
||||
.map_response(http::Response::into_body)
|
||||
.map_err(BoxError::from)
|
||||
.boxed_clone();
|
||||
|
||||
Either::Right(DynamicJwksStore::new(exporter))
|
||||
}
|
||||
};
|
||||
|
||||
assert(inner)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
|
||||
async fn authenticate_client<T>(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
encrypter: Encrypter,
|
||||
audience: String,
|
||||
credentials: ClientCredentials,
|
||||
body: T,
|
||||
) -> Result<
|
||||
(
|
||||
OAuthClientAuthenticationMethod,
|
||||
Client<PostgresqlBackend>,
|
||||
T,
|
||||
),
|
||||
Rejection,
|
||||
> {
|
||||
let (auth_method, client) = match credentials {
|
||||
ClientCredentials::Pair {
|
||||
client_id,
|
||||
client_secret,
|
||||
via,
|
||||
} => {
|
||||
let client = lookup_client_by_client_id(&mut *conn, &client_id)
|
||||
.await
|
||||
.map_err(|source| ClientAuthenticationError::ClientFetch {
|
||||
client_id: client_id.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
let auth_method = client.token_endpoint_auth_method.ok_or(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id.clone(),
|
||||
},
|
||||
)?;
|
||||
|
||||
// Let's match the authentication method
|
||||
match (auth_method, client_secret, via) {
|
||||
(OAuthClientAuthenticationMethod::None, None, _) => {}
|
||||
(
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic,
|
||||
Some(client_secret),
|
||||
CredentialsVia::AuthorizationHeader,
|
||||
)
|
||||
| (
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost,
|
||||
Some(client_secret),
|
||||
CredentialsVia::FormBody,
|
||||
) => {
|
||||
let decrypted =
|
||||
decrypt_client_secret(&client, &encrypter).map_err(|source| {
|
||||
ClientAuthenticationError::InvalidClientSecret {
|
||||
client_id: client.client_id.clone(),
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
|
||||
if client_secret.as_bytes() != decrypted {
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::ClientSecretMismatch {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
(auth_method, client)
|
||||
}
|
||||
ClientCredentials::Assertion {
|
||||
client_id,
|
||||
client_assertion_type: ClientAssertionType::JwtBearer,
|
||||
client_assertion,
|
||||
} => {
|
||||
let token: JsonWebTokenParts = client_assertion.parse().wrap_error()?;
|
||||
let decoded: DecodedJsonWebToken<HashMap<String, serde_json::Value>> =
|
||||
token.decode().wrap_error()?;
|
||||
|
||||
let time_options = TimeOptions::default()
|
||||
.freeze()
|
||||
.leeway(chrono::Duration::minutes(1));
|
||||
|
||||
let mut claims = decoded.claims().clone();
|
||||
let iss = ISS.extract_required(&mut claims).wrap_error()?;
|
||||
let sub = SUB.extract_required(&mut claims).wrap_error()?;
|
||||
let aud = AUD.extract_required(&mut claims).wrap_error()?;
|
||||
|
||||
// Validate the times
|
||||
let _exp = EXP
|
||||
.extract_required_with_options(&mut claims, &time_options)
|
||||
.wrap_error()?;
|
||||
let _nbf = NBF
|
||||
.extract_optional_with_options(&mut claims, &time_options)
|
||||
.wrap_error()?;
|
||||
let _iat = IAT
|
||||
.extract_optional_with_options(&mut claims, &time_options)
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: validate the JTI
|
||||
let _jti = JTI.extract_optional(&mut claims).wrap_error()?;
|
||||
|
||||
// client_id might have been passed as parameter. If not, it should be inferred
|
||||
// from the token, as per rfc7521 sec. 4.2
|
||||
let client_id = client_id.as_ref().unwrap_or(&sub);
|
||||
|
||||
let client = lookup_client_by_client_id(&mut *conn, client_id)
|
||||
.await
|
||||
.map_err(|source| ClientAuthenticationError::ClientFetch {
|
||||
client_id: client_id.to_string(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
let auth_method = client.token_endpoint_auth_method.ok_or(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id.clone(),
|
||||
},
|
||||
)?;
|
||||
|
||||
match auth_method {
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt => {
|
||||
let client_secret =
|
||||
decrypt_client_secret(&client, &encrypter).map_err(|source| {
|
||||
ClientAuthenticationError::InvalidClientSecret {
|
||||
client_id: client.client_id.clone(),
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
|
||||
let store = SharedSecret::new(&client_secret);
|
||||
let fut = token.verify(decoded.header(), &store);
|
||||
fut.await.wrap_error()?;
|
||||
}
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
|
||||
let jwks = client.jwks.as_ref().ok_or_else(|| {
|
||||
ClientAuthenticationError::InvalidJwks {
|
||||
client_id: client.client_id.clone(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let store = jwks_key_store(jwks);
|
||||
let fut = token.verify(decoded.header(), &store);
|
||||
fut.await.wrap_error()?;
|
||||
}
|
||||
_ => {
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// rfc7523 sec. 3.3: the audience is the URL being called
|
||||
if !aud.contains(&audience) {
|
||||
return Err(
|
||||
ClientAuthenticationError::MissingAudience { expected: audience }.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
|
||||
// match the client_id
|
||||
if iss != sub || &iss != client_id {
|
||||
return Err(ClientAuthenticationError::InvalidAssertion.into());
|
||||
}
|
||||
|
||||
(auth_method, client)
|
||||
}
|
||||
};
|
||||
|
||||
tracing::Span::current().record("enduser.id", &client.client_id.as_str());
|
||||
|
||||
Ok((auth_method, client, body))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum ClientAssertionType {
|
||||
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
|
||||
JwtBearer,
|
||||
}
|
||||
|
||||
enum CredentialsVia {
|
||||
FormBody,
|
||||
AuthorizationHeader,
|
||||
}
|
||||
|
||||
impl Default for CredentialsVia {
|
||||
fn default() -> Self {
|
||||
Self::FormBody
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ClientCredentials {
|
||||
// Order here is important: serde tries to deserialize enum variants in order, so if "Pair"
|
||||
// was before "Assertion", a client_assertion with a client_id would match the "Pair"
|
||||
// variant first
|
||||
Assertion {
|
||||
client_id: Option<String>,
|
||||
client_assertion_type: ClientAssertionType,
|
||||
client_assertion: String,
|
||||
},
|
||||
Pair {
|
||||
#[serde(skip)]
|
||||
via: CredentialsVia,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClientAuthForm<T> {
|
||||
#[serde(flatten)]
|
||||
credentials: ClientCredentials,
|
||||
|
||||
#[serde(flatten)]
|
||||
body: T,
|
||||
}
|
||||
|
||||
/* TODO: all secrets are broken because there is no way to mock the DB yet
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use headers::authorization::Credentials;
|
||||
use mas_config::{ClientAuthMethodConfig, ConfigurationSection};
|
||||
use mas_jose::{SigningKeystore, StaticKeystore};
|
||||
use serde_json::json;
|
||||
use tower::{Service, ServiceExt};
|
||||
|
||||
use super::*;
|
||||
|
||||
// Long client_secret to support it as a HS512 key
|
||||
const CLIENT_SECRET: &str = "leek2zaeyeb8thai7piehea3vah6ool9oanin9aeraThuci9EeghaekaiD1upe4Quoh7xeMae2meitohj0Waaveiwaorah1yazohr6Vae7iebeiRaWene5IeWeeciezu";
|
||||
|
||||
fn client_private_keystore() -> StaticKeystore {
|
||||
let mut store = StaticKeystore::new();
|
||||
store.add_test_rsa_key().unwrap();
|
||||
store.add_test_ecdsa_key().unwrap();
|
||||
store
|
||||
}
|
||||
|
||||
async fn oauth2_config() -> ClientsConfig {
|
||||
let mut config = ClientsConfig::test();
|
||||
config.push(ClientConfig {
|
||||
client_id: "public".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::None,
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-basic".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretBasic {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-post".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretPost {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-jwt".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "secret-jwt-2".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::ClientSecretJwt {
|
||||
client_secret: CLIENT_SECRET.to_string(),
|
||||
},
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
|
||||
let store = client_private_keystore();
|
||||
let jwks = (&store).ready().await.unwrap().call(()).await.unwrap();
|
||||
//let jwks = store.export_jwks().await.unwrap();
|
||||
config.push(ClientConfig {
|
||||
client_id: "private-key-jwt".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.clone().into()),
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.push(ClientConfig {
|
||||
client_id: "private-key-jwt-2".to_string(),
|
||||
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.into()),
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Form {
|
||||
foo: String,
|
||||
bar: String,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_hs256() {
|
||||
client_secret_jwt("HS256").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_hs384() {
|
||||
client_secret_jwt("HS384").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_hs512() {
|
||||
client_secret_jwt("HS512").await;
|
||||
}
|
||||
|
||||
fn client_claims(
|
||||
client_id: &str,
|
||||
audience: &str,
|
||||
iat: chrono::DateTime<chrono::Utc>,
|
||||
) -> HashMap<String, serde_json::Value> {
|
||||
let mut claims = HashMap::new();
|
||||
let exp = iat + chrono::Duration::minutes(1);
|
||||
|
||||
ISS.insert(&mut claims, client_id).unwrap();
|
||||
SUB.insert(&mut claims, client_id).unwrap();
|
||||
AUD.insert(&mut claims, vec![audience.to_string()]).unwrap();
|
||||
IAT.insert(&mut claims, iat).unwrap();
|
||||
NBF.insert(&mut claims, iat).unwrap();
|
||||
EXP.insert(&mut claims, exp).unwrap();
|
||||
|
||||
claims
|
||||
}
|
||||
|
||||
async fn client_secret_jwt(alg: &str) {
|
||||
let alg = alg.parse().unwrap();
|
||||
let audience = "https://example.com/token";
|
||||
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
|
||||
|
||||
let store = SharedSecret::new(&CLIENT_SECRET);
|
||||
let claims = client_claims("secret-jwt", audience, chrono::Utc::now());
|
||||
let header = store.prepare_header(alg).await.expect("JWT header");
|
||||
let jwt = DecodedJsonWebToken::new(header, claims);
|
||||
let jwt = jwt.sign(&store).await.expect("signed token");
|
||||
let jwt = jwt.serialize();
|
||||
|
||||
// TODO: test failing cases
|
||||
// - expired token
|
||||
// - "not before" in the future
|
||||
// - subject/issuer mismatch
|
||||
// - audience mismatch
|
||||
// - wrong secret/signature
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "secret-jwt",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretJwt);
|
||||
assert_eq!(client.client_id, "secret-jwt");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
|
||||
// Without client_id
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_ok());
|
||||
|
||||
// client_id mismatch
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "secret-jwt-2",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_rs256() {
|
||||
private_key_jwt("RS256").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_rs384() {
|
||||
private_key_jwt("RS384").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_rs512() {
|
||||
private_key_jwt("RS512").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_jwt_es256() {
|
||||
private_key_jwt("ES256").await;
|
||||
}
|
||||
|
||||
async fn private_key_jwt(alg: &str) {
|
||||
let alg = alg.parse().unwrap();
|
||||
let audience = "https://example.com/token";
|
||||
let filter = client_authentication::<Form>(&oauth2_config().await, audience.to_string());
|
||||
|
||||
let store = client_private_keystore();
|
||||
let claims = client_claims("private-key-jwt", audience, chrono::Utc::now());
|
||||
let header = store.prepare_header(alg).await.expect("JWT header");
|
||||
let jwt = DecodedJsonWebToken::new(header, claims);
|
||||
let jwt = jwt.sign(&store).await.expect("signed token");
|
||||
let jwt = jwt.serialize();
|
||||
|
||||
// TODO: test failing cases
|
||||
// - expired token
|
||||
// - "not before" in the future
|
||||
// - subject/issuer mismatch
|
||||
// - audience mismatch
|
||||
// - wrong secret/signature
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "private-key-jwt",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::PrivateKeyJwt);
|
||||
assert_eq!(client.client_id, "private-key-jwt");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
|
||||
// Without client_id
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_ok());
|
||||
|
||||
// client_id mismatch
|
||||
let res = warp::test::request()
|
||||
.method("POST")
|
||||
.body(serde_urlencoded::to_string(json!({
|
||||
"client_id": "private-key-jwt-2",
|
||||
"client_assertion": jwt,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
})).unwrap())
|
||||
.filter(&filter)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_post() {
|
||||
let filter = client_authentication::<Form>(
|
||||
&oauth2_config().await,
|
||||
"https://example.com/token".to_string(),
|
||||
);
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header(
|
||||
"Content-Type",
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||
)
|
||||
.body(
|
||||
serde_urlencoded::to_string(json!({
|
||||
"client_id": "secret-post",
|
||||
"client_secret": CLIENT_SECRET,
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretPost);
|
||||
assert_eq!(client.client_id, "secret-post");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_basic() {
|
||||
let filter = client_authentication::<Form>(
|
||||
&oauth2_config().await,
|
||||
"https://example.com/token".to_string(),
|
||||
);
|
||||
|
||||
let auth = Authorization::basic("secret-basic", CLIENT_SECRET);
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header(
|
||||
"Content-Type",
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||
)
|
||||
.header("Authorization", auth.0.encode())
|
||||
.body(
|
||||
serde_urlencoded::to_string(json!({
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretBasic);
|
||||
assert_eq!(client.client_id, "secret-basic");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn none() {
|
||||
let filter = client_authentication::<Form>(
|
||||
&oauth2_config().await,
|
||||
"https://example.com/token".to_string(),
|
||||
);
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header(
|
||||
"Content-Type",
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||
)
|
||||
.body(
|
||||
serde_urlencoded::to_string(json!({
|
||||
"client_id": "public",
|
||||
"foo": "baz",
|
||||
"bar": "foobar",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, OAuthClientAuthenticationMethod::None);
|
||||
assert_eq!(client.client_id, "public");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
}
|
||||
*/
|
@ -1,193 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Deal with encrypted cookies
|
||||
|
||||
use std::{convert::Infallible, marker::PhantomData};
|
||||
|
||||
use cookie::{Cookie, SameSite};
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{Header, HeaderValue, SetCookie};
|
||||
use mas_config::Encrypter;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use warp::{
|
||||
reject::{InvalidHeader, MissingCookie, Reject},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use super::none_on_error;
|
||||
use crate::{
|
||||
errors::WrapError,
|
||||
reply::{with_typed_header, WithTypedHeader},
|
||||
};
|
||||
|
||||
/// Unable to decrypt the cookie
|
||||
#[derive(Debug, Error)]
|
||||
pub struct CookieDecryptionError<T: EncryptableCookieValue>(
|
||||
#[source] anyhow::Error,
|
||||
// This [`std::marker::PhantomData`] records what kind of cookie it was trying to save.
|
||||
// This then use when displaying the error.
|
||||
PhantomData<T>,
|
||||
);
|
||||
|
||||
impl<T> Reject for CookieDecryptionError<T> where T: EncryptableCookieValue + 'static {}
|
||||
|
||||
impl<T: EncryptableCookieValue> From<anyhow::Error> for CookieDecryptionError<T> {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self(e, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: EncryptableCookieValue> std::fmt::Display for CookieDecryptionError<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to decrypt cookie {}", T::cookie_key())
|
||||
}
|
||||
}
|
||||
|
||||
fn decryption_error<T>(e: anyhow::Error) -> Rejection
|
||||
where
|
||||
T: EncryptableCookieValue + 'static,
|
||||
{
|
||||
let e: CookieDecryptionError<T> = e.into();
|
||||
warp::reject::custom(e)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct EncryptedCookie {
|
||||
nonce: [u8; 12],
|
||||
ciphertext: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncryptedCookie {
|
||||
/// Encrypt from a given key
|
||||
fn encrypt<T: Serialize>(payload: T, encrypter: &Encrypter) -> anyhow::Result<Self> {
|
||||
let message = bincode::serialize(&payload)?;
|
||||
let nonce: [u8; 12] = rand::random();
|
||||
let ciphertext = encrypter.encrypt(&nonce, &message)?;
|
||||
Ok(Self { nonce, ciphertext })
|
||||
}
|
||||
|
||||
/// Decrypt the content of the cookie from a given key
|
||||
fn decrypt<T: DeserializeOwned>(&self, encrypter: &Encrypter) -> anyhow::Result<T> {
|
||||
let message = encrypter.decrypt(&self.nonce, &self.ciphertext)?;
|
||||
let token = bincode::deserialize(&message)?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Encode the encrypted cookie to be then saved as a cookie
|
||||
fn to_cookie_value(&self) -> anyhow::Result<String> {
|
||||
let raw = bincode::serialize(self)?;
|
||||
Ok(BASE64URL_NOPAD.encode(&raw))
|
||||
}
|
||||
|
||||
fn from_cookie_value(value: &str) -> anyhow::Result<Self> {
|
||||
let raw = BASE64URL_NOPAD.decode(value.as_bytes())?;
|
||||
let content = bincode::deserialize(&raw)?;
|
||||
Ok(content)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract an optional encrypted cookie
|
||||
#[must_use]
|
||||
pub fn maybe_encrypted<T>(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (Option<T>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + EncryptableCookieValue + 'static,
|
||||
{
|
||||
encrypted(encrypter)
|
||||
.map(Some)
|
||||
.recover(none_on_error::<T, InvalidHeader>)
|
||||
.unify()
|
||||
.recover(none_on_error::<T, MissingCookie>)
|
||||
.unify()
|
||||
.recover(none_on_error::<T, CookieDecryptionError<T>>)
|
||||
.unify()
|
||||
}
|
||||
|
||||
/// Extract an encrypted cookie
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`warp::reject::MissingCookie`] or a
|
||||
/// [`CookieDecryptionError`]
|
||||
#[must_use]
|
||||
pub fn encrypted<T>(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + EncryptableCookieValue + 'static,
|
||||
{
|
||||
let encrypter = encrypter.clone();
|
||||
warp::cookie::cookie(T::cookie_key()).and_then(move |value: String| {
|
||||
let encrypter = encrypter.clone();
|
||||
async move {
|
||||
let encrypted_payload =
|
||||
EncryptedCookie::from_cookie_value(&value).map_err(decryption_error::<T>)?;
|
||||
let decrypted_payload = encrypted_payload
|
||||
.decrypt(&encrypter)
|
||||
.map_err(decryption_error::<T>)?;
|
||||
Ok::<_, Rejection>(decrypted_payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Get an [`EncryptedCookieSaver`] to help saving an [`EncryptableCookieValue`]
|
||||
#[must_use]
|
||||
pub fn encrypted_cookie_saver(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (EncryptedCookieSaver,), Error = Infallible> + Clone + Send + Sync + 'static
|
||||
{
|
||||
let encrypter = encrypter.clone();
|
||||
warp::any().map(move || EncryptedCookieSaver {
|
||||
encrypter: encrypter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// A cookie that can be encrypted with a well-known cookie key
|
||||
pub trait EncryptableCookieValue: Serialize + Send + Sync + std::fmt::Debug {
|
||||
/// What key should be used for this cookie
|
||||
fn cookie_key() -> &'static str;
|
||||
}
|
||||
|
||||
/// An opaque structure which helps encrypting a cookie and attach it to a reply
|
||||
pub struct EncryptedCookieSaver {
|
||||
encrypter: Encrypter,
|
||||
}
|
||||
|
||||
impl EncryptedCookieSaver {
|
||||
/// Save an [`EncryptableCookieValue`]
|
||||
pub fn save_encrypted<T: EncryptableCookieValue, R: Reply>(
|
||||
&self,
|
||||
cookie: &T,
|
||||
reply: R,
|
||||
) -> Result<WithTypedHeader<R, SetCookie>, Rejection> {
|
||||
let encrypted = EncryptedCookie::encrypt(cookie, &self.encrypter)
|
||||
.wrap_error()?
|
||||
.to_cookie_value()
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: make those options customizable
|
||||
let value = Cookie::build(T::cookie_key(), encrypted)
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Lax)
|
||||
.finish()
|
||||
.to_string();
|
||||
|
||||
let header = SetCookie::decode(&mut [HeaderValue::from_str(&value).wrap_error()?].iter())
|
||||
.wrap_error()?;
|
||||
Ok(with_typed_header(header, reply))
|
||||
}
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Wrapper around [`warp::filters::cors`]
|
||||
|
||||
use std::string::ToString;
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
static PROPAGATOR_HEADERS: OnceCell<Vec<String>> = OnceCell::new();
|
||||
|
||||
/// Notify the CORS filter what opentelemetry propagators are being used. This
|
||||
/// helps whitelisting headers in CORS requests.
|
||||
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
|
||||
let headers = propagator.fields().map(ToString::to_string).collect();
|
||||
tracing::debug!(
|
||||
?headers,
|
||||
"Headers allowed in CORS requests for trace propagators set"
|
||||
);
|
||||
PROPAGATOR_HEADERS
|
||||
.set(headers)
|
||||
.expect(concat!(module_path!(), "::set_propagator was called twice"));
|
||||
}
|
||||
|
||||
/// Create a wrapping filter that exposes CORS behavior for a wrapped filter.
|
||||
#[must_use]
|
||||
pub fn cors() -> warp::filters::cors::Builder {
|
||||
warp::filters::cors::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new()))
|
||||
}
|
@ -1,185 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Stateless CSRF protection middleware based on a chacha20-poly1305 encrypted
|
||||
//! and signed token
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use data_encoding::{DecodeError, BASE64URL_NOPAD};
|
||||
use mas_config::{CsrfConfig, Encrypter};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
use super::cookies::EncryptableCookieValue;
|
||||
|
||||
/// Failed to validate CSRF token
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CsrfError {
|
||||
/// The token in the form did not match the token in the cookie
|
||||
#[error("CSRF token mismatch")]
|
||||
Mismatch,
|
||||
|
||||
/// The token expired
|
||||
#[error("CSRF token expired")]
|
||||
Expired,
|
||||
|
||||
/// Failed to decode the token
|
||||
#[error("could not decode CSRF token")]
|
||||
Decode(#[from] DecodeError),
|
||||
}
|
||||
|
||||
impl Reject for CsrfError {}
|
||||
|
||||
/// A CSRF token
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CsrfToken {
|
||||
#[serde_as(as = "TimestampSeconds<i64>")]
|
||||
expiration: DateTime<Utc>,
|
||||
token: [u8; 32],
|
||||
}
|
||||
|
||||
impl CsrfToken {
|
||||
/// Create a new token from a defined value valid for a specified duration
|
||||
fn new(token: [u8; 32], ttl: Duration) -> Self {
|
||||
let expiration = Utc::now() + ttl;
|
||||
Self { expiration, token }
|
||||
}
|
||||
|
||||
/// Generate a new random token valid for a specified duration
|
||||
fn generate(ttl: Duration) -> Self {
|
||||
let token = rand::random();
|
||||
Self::new(token, ttl)
|
||||
}
|
||||
|
||||
/// Generate a new token with the same value but an up to date expiration
|
||||
fn refresh(self, ttl: Duration) -> Self {
|
||||
Self::new(self.token, ttl)
|
||||
}
|
||||
|
||||
/// Get the value to include in HTML forms
|
||||
#[must_use]
|
||||
pub fn form_value(&self) -> String {
|
||||
BASE64URL_NOPAD.encode(&self.token[..])
|
||||
}
|
||||
|
||||
/// Verifies that the value got from an HTML form matches this token
|
||||
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
|
||||
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
|
||||
if self.token[..] == form_value {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CsrfError::Mismatch)
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_expiration(self) -> Result<Self, CsrfError> {
|
||||
if Utc::now() < self.expiration {
|
||||
Ok(self)
|
||||
} else {
|
||||
Err(CsrfError::Expired)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptableCookieValue for CsrfToken {
|
||||
fn cookie_key() -> &'static str {
|
||||
"csrf"
|
||||
}
|
||||
}
|
||||
|
||||
/// A CSRF-protected form
|
||||
#[derive(Deserialize)]
|
||||
struct CsrfForm<T> {
|
||||
csrf: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T> CsrfForm<T> {
|
||||
fn verify_csrf(self, token: &CsrfToken) -> Result<T, CsrfError> {
|
||||
// Verify CSRF from request
|
||||
token.verify_form_value(&self.csrf)?;
|
||||
Ok(self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
fn csrf_token(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
super::cookies::encrypted(encrypter).and_then(move |token: CsrfToken| async move {
|
||||
let verified = token.verify_expiration()?;
|
||||
Ok::<_, Rejection>(verified)
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract an up-to-date CSRF token to include in forms
|
||||
///
|
||||
/// Routes using this should not forget to reply the updated CSRF cookie using
|
||||
/// an [`EncryptedCookieSaver`][`super::cookies::EncryptedCookieSaver`] obtained
|
||||
/// with [`encrypted_cookie_saver`][`super::cookies::encrypted_cookie_saver`]
|
||||
#[must_use]
|
||||
pub fn updated_csrf_token(
|
||||
encrypter: &Encrypter,
|
||||
csrf_config: &CsrfConfig,
|
||||
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let ttl = csrf_config.ttl;
|
||||
super::cookies::maybe_encrypted(encrypter).and_then(
|
||||
move |maybe_token: Option<CsrfToken>| async move {
|
||||
// Explicitely specify the "Error" type here to have the `?` operation working
|
||||
Ok::<_, Rejection>(
|
||||
maybe_token
|
||||
// Verify its TTL (but do not hard-error if it expired)
|
||||
.and_then(|token| token.verify_expiration().ok())
|
||||
.map_or_else(
|
||||
// Generate a new token if no valid one were found
|
||||
|| CsrfToken::generate(ttl),
|
||||
// Else, refresh the expiration of the token
|
||||
|token| token.refresh(ttl),
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Extract values from a CSRF-protected form
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with:
|
||||
///
|
||||
/// - [`warp::filters::body::BodyDeserializeError`] if the overall form failed
|
||||
/// to decode
|
||||
/// - [`CsrfError`] if the CSRF token was invalid or expired
|
||||
/// - [`warp::reject::MissingCookie`] if the CSRF cookie was missing
|
||||
/// - [`super::cookies::CookieDecryptionError`] if the cookie failed to decrypt
|
||||
///
|
||||
/// TODO: we might want to unify the last three rejections in one
|
||||
#[must_use]
|
||||
pub fn protected_form<T>(
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + Send + 'static,
|
||||
{
|
||||
csrf_token(encrypter).and(warp::body::form()).and_then(
|
||||
|csrf_token: CsrfToken, protected_form: CsrfForm<T>| async move {
|
||||
let form = protected_form.verify_csrf(&csrf_token)?;
|
||||
Ok::<_, Rejection>(form)
|
||||
},
|
||||
)
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Database-related filters to grab connections and start transactions from the
|
||||
//! connection pool
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use sqlx::{
|
||||
pool::{Pool, PoolConnection},
|
||||
Database, Transaction,
|
||||
};
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
use crate::errors::WrapError;
|
||||
|
||||
fn with_pool<T: Database>(
|
||||
pool: &Pool<T>,
|
||||
) -> impl Filter<Extract = (Pool<T>,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let pool = pool.clone();
|
||||
warp::any().map(move || pool.clone())
|
||||
}
|
||||
|
||||
/// Acquire a connection to the database
|
||||
pub fn connection<T: Database>(
|
||||
pool: &Pool<T>,
|
||||
) -> impl Filter<Extract = (PoolConnection<T>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
with_pool(pool).and_then(acquire_connection)
|
||||
}
|
||||
|
||||
async fn acquire_connection<T: Database>(pool: Pool<T>) -> Result<PoolConnection<T>, Rejection> {
|
||||
let conn = pool.acquire().await.wrap_error()?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
/// Start a database transaction
|
||||
pub fn transaction<T: Database>(
|
||||
pool: &Pool<T>,
|
||||
) -> impl Filter<Extract = (Transaction<'static, T>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
with_pool(pool).and_then(acquire_transaction)
|
||||
}
|
||||
|
||||
async fn acquire_transaction<T: Database>(
|
||||
pool: Pool<T>,
|
||||
) -> Result<Transaction<'static, T>, Rejection> {
|
||||
let txn = pool.begin().await.wrap_error()?;
|
||||
Ok(txn)
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Deal with typed headers from the [`headers`] crate
|
||||
|
||||
use headers::{Header, HeaderValue};
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
/// Failed to decode typed header
|
||||
#[derive(Debug, Error)]
|
||||
#[error("could not decode header {1}")]
|
||||
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
|
||||
|
||||
impl Reject for InvalidTypedHeader {}
|
||||
|
||||
/// Extract a typed header from the request
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`warp::reject::MissingHeader`] or a
|
||||
/// [`InvalidTypedHeader`].
|
||||
pub fn typed_header<T: Header + Send + 'static>(
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
|
||||
}
|
||||
|
||||
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
|
||||
let mut it = std::iter::once(&header);
|
||||
let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
|
||||
Ok(decoded)
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Set of [`warp`] filters
|
||||
|
||||
#![allow(clippy::unused_async)] // Some warp filters need that
|
||||
#![deny(missing_docs)]
|
||||
|
||||
pub mod authenticate;
|
||||
pub mod client;
|
||||
pub mod cookies;
|
||||
pub mod cors;
|
||||
pub mod csrf;
|
||||
pub mod database;
|
||||
pub mod headers;
|
||||
pub mod session;
|
||||
pub mod trace;
|
||||
pub mod url_builder;
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use mas_templates::Templates;
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
pub use self::csrf::CsrfToken;
|
||||
|
||||
/// Get the [`Templates`]
|
||||
#[must_use]
|
||||
pub fn with_templates(
|
||||
templates: &Templates,
|
||||
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let templates = templates.clone();
|
||||
warp::any().map(move || templates.clone())
|
||||
}
|
||||
|
||||
/// Recover a particular rejection type with a `None` option variant
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// extern crate warp;
|
||||
///
|
||||
/// use warp::{filters::header::header, reject::MissingHeader, Filter};
|
||||
///
|
||||
/// use mas_warp_utils::filters::none_on_error;
|
||||
///
|
||||
/// header("Content-Length")
|
||||
/// .map(Some)
|
||||
/// .recover(none_on_error::<_, MissingHeader>)
|
||||
/// .unify()
|
||||
/// .map(|length: Option<u64>| {
|
||||
/// format!("header: {:?}", length)
|
||||
/// });
|
||||
/// ```
|
||||
pub async fn none_on_error<T, E: 'static>(rejection: Rejection) -> Result<Option<T>, Rejection> {
|
||||
if rejection.find::<E>().is_some() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(rejection)
|
||||
}
|
||||
}
|
@ -1,162 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Load user sessions from the database
|
||||
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_storage::{
|
||||
user::{lookup_active_session, ActiveSessionLookupError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
use warp::{
|
||||
reject::{InvalidHeader, MissingCookie, Reject},
|
||||
Filter, Rejection,
|
||||
};
|
||||
|
||||
use super::{
|
||||
cookies::{encrypted, CookieDecryptionError, EncryptableCookieValue},
|
||||
database::connection,
|
||||
none_on_error,
|
||||
};
|
||||
|
||||
/// The session is missing or failed to load
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SessionLoadError {
|
||||
/// No session cookie was found
|
||||
#[error("missing session cookie")]
|
||||
MissingCookie,
|
||||
|
||||
/// The session cookie is invalid
|
||||
#[error("unable to parse or decrypt session cookie")]
|
||||
InvalidCookie,
|
||||
|
||||
/// The session is unknown or inactive
|
||||
#[error("unknown or inactive session")]
|
||||
UnknownSession,
|
||||
}
|
||||
|
||||
impl Reject for SessionLoadError {}
|
||||
|
||||
/// An encrypted cookie to save the session ID
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SessionCookie {
|
||||
current: i64,
|
||||
}
|
||||
|
||||
impl SessionCookie {
|
||||
/// Forge the cookie from a [`BrowserSession`]
|
||||
#[must_use]
|
||||
pub fn from_session(session: &BrowserSession<PostgresqlBackend>) -> Self {
|
||||
Self {
|
||||
current: session.data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the [`BrowserSession`] from database
|
||||
pub async fn load_session(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> {
|
||||
let res = lookup_active_session(executor, self.current).await?;
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptableCookieValue for SessionCookie {
|
||||
fn cookie_key() -> &'static str {
|
||||
"session"
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract a user session information if logged in
|
||||
#[must_use]
|
||||
pub fn optional_session(
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (Option<BrowserSession<PostgresqlBackend>>,), Error = Rejection>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
session(pool, encrypter)
|
||||
.map(Some)
|
||||
.recover(none_on_error::<_, SessionLoadError>)
|
||||
.unify()
|
||||
}
|
||||
|
||||
/// Extract a user session information, rejecting if not logged in
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This filter will reject with a [`SessionLoadError`] when the session is
|
||||
/// inactive or missing. It will reject with a wrapped error on other database
|
||||
/// failures.
|
||||
#[must_use]
|
||||
pub fn session(
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
) -> impl Filter<Extract = (BrowserSession<PostgresqlBackend>,), Error = Rejection>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
encrypted(encrypter)
|
||||
.and(connection(pool))
|
||||
.and_then(load_session)
|
||||
.recover(recover)
|
||||
.unify()
|
||||
}
|
||||
|
||||
async fn load_session(
|
||||
session: SessionCookie,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
) -> Result<BrowserSession<PostgresqlBackend>, Rejection> {
|
||||
let session_info = session.load_session(&mut conn).await?;
|
||||
Ok(session_info)
|
||||
}
|
||||
|
||||
/// Recover from expected rejections, to transform them into a
|
||||
/// [`SessionLoadError`]
|
||||
async fn recover<T>(rejection: Rejection) -> Result<T, Rejection> {
|
||||
if let Some(e) = rejection.find::<ActiveSessionLookupError>() {
|
||||
if e.not_found() {
|
||||
return Err(warp::reject::custom(SessionLoadError::UnknownSession));
|
||||
}
|
||||
|
||||
// If we're here, there is a real database error that should be
|
||||
// propagated
|
||||
}
|
||||
|
||||
if let Some(e) = rejection.find::<InvalidHeader>() {
|
||||
if e.name() == "cookie" {
|
||||
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(_e) = rejection.find::<MissingCookie>() {
|
||||
return Err(warp::reject::custom(SessionLoadError::MissingCookie));
|
||||
}
|
||||
|
||||
if let Some(error) = rejection.find::<CookieDecryptionError<SessionCookie>>() {
|
||||
warn!(?error, "could not decrypt session cookie");
|
||||
return Err(warp::reject::custom(SessionLoadError::InvalidCookie));
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Route tracing utility
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use warp::Filter;
|
||||
|
||||
/// Set the name of that route
|
||||
#[must_use]
|
||||
pub fn name(
|
||||
name: &'static str,
|
||||
) -> impl Filter<Extract = (), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
warp::any()
|
||||
.map(move || {
|
||||
// TODO: update_name has a weird signature, which is already fixed in
|
||||
// opentelemetry-rust, just not released yet
|
||||
// TODO: we should find another way to classify requests. Span::update_name has
|
||||
// impacts on sampling and should not be used
|
||||
opentelemetry::trace::get_active_span(|s| s.update_name::<String>(name.to_string()));
|
||||
})
|
||||
.untuple_one()
|
||||
}
|
@ -1,121 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Utility to build URLs
|
||||
|
||||
// TODO: move this somewhere else
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use mas_config::HttpConfig;
|
||||
use url::Url;
|
||||
use warp::Filter;
|
||||
|
||||
impl From<&HttpConfig> for UrlBuilder {
|
||||
fn from(config: &HttpConfig) -> Self {
|
||||
Self::new(config.public_base.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helps building absolute URLs
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct UrlBuilder {
|
||||
base: Url,
|
||||
}
|
||||
|
||||
impl UrlBuilder {
|
||||
/// Create a new [`UrlBuilder`] from a base URL
|
||||
#[must_use]
|
||||
pub fn new(base: Url) -> Self {
|
||||
Self { base }
|
||||
}
|
||||
|
||||
/// OIDC issuer
|
||||
#[must_use]
|
||||
pub fn oidc_issuer(&self) -> Url {
|
||||
self.base.clone()
|
||||
}
|
||||
|
||||
/// OIDC dicovery document URL
|
||||
#[must_use]
|
||||
pub fn oidc_discovery(&self) -> Url {
|
||||
self.base
|
||||
.join(".well-known/openid-configuration")
|
||||
.expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 authorization endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_authorization_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/authorize").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 token endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_token_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/token").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 introspection endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_introspection_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/introspect").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 introspection endpoint
|
||||
#[must_use]
|
||||
pub fn oidc_userinfo_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/userinfo").expect("build URL")
|
||||
}
|
||||
|
||||
/// JWKS URI
|
||||
#[must_use]
|
||||
pub fn jwks_uri(&self) -> Url {
|
||||
self.base.join("oauth2/keys.json").expect("build URL")
|
||||
}
|
||||
|
||||
/// Email verification URL
|
||||
#[must_use]
|
||||
pub fn email_verification(&self, code: &str) -> Url {
|
||||
self.base
|
||||
.join("verify/")
|
||||
.expect("build URL")
|
||||
.join(code)
|
||||
.expect("build URL")
|
||||
}
|
||||
}
|
||||
|
||||
/// Injects an [`UrlBuilder`] to help building absolute URLs
|
||||
#[must_use]
|
||||
pub fn url_builder(
|
||||
config: &HttpConfig,
|
||||
) -> impl Filter<Extract = (UrlBuilder,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let builder: UrlBuilder = config.into();
|
||||
warp::any().map(move || builder.clone())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn build_email_verification_url() {
|
||||
let base = Url::parse("https://example.com/").unwrap();
|
||||
let builder = UrlBuilder::new(base);
|
||||
assert_eq!(
|
||||
builder.email_verification("123456abcdef").as_str(),
|
||||
"https://example.com/verify/123456abcdef"
|
||||
);
|
||||
}
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Various warp filters and replies
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(clippy::all, missing_docs, rustdoc::broken_intra_doc_links)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)]
|
||||
|
||||
pub mod errors;
|
||||
pub mod filters;
|
||||
pub mod reply;
|
@ -1,54 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Reply with a typed header from the [`headers`] crate.
|
||||
//!
|
||||
//! ```rust
|
||||
//! extern crate headers;
|
||||
//! extern crate warp;
|
||||
//!
|
||||
//! use warp::Reply;
|
||||
//! use mas_warp_utils::reply::with_typed_header;
|
||||
//!
|
||||
//! let reply = r#"{"hello": "world"}"#;
|
||||
//! let reply = with_typed_header(headers::ContentType::json(), reply);;
|
||||
//! let response = reply.into_response();
|
||||
//! assert_eq!(response.headers().get("Content-Type").unwrap().to_str().unwrap(), "application/json");
|
||||
//! ```
|
||||
|
||||
use headers::{Header, HeaderMapExt};
|
||||
use warp::Reply;
|
||||
|
||||
/// Add a typed header to a reply
|
||||
pub fn with_typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
|
||||
WithTypedHeader { reply, header }
|
||||
}
|
||||
|
||||
/// A reply with a typed header set
|
||||
pub struct WithTypedHeader<R, H> {
|
||||
reply: R,
|
||||
header: H,
|
||||
}
|
||||
|
||||
impl<R, H> Reply for WithTypedHeader<R, H>
|
||||
where
|
||||
R: Reply,
|
||||
H: Header + Send,
|
||||
{
|
||||
fn into_response(self) -> warp::reply::Response {
|
||||
let mut res = self.reply.into_response();
|
||||
res.headers_mut().typed_insert(self.header);
|
||||
res
|
||||
}
|
||||
}
|
@ -1,21 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Set of wrappers for [`warp::Reply`]
|
||||
|
||||
#![deny(missing_docs)]
|
||||
|
||||
pub mod headers;
|
||||
|
||||
pub use self::headers::{with_typed_header, WithTypedHeader};
|
@ -20,5 +20,5 @@
|
||||
|
||||
- [Architecture](./development/architecture.md)
|
||||
- [Database](./development/database.md)
|
||||
- [Routing with `warp`](./development/warp.md)
|
||||
- [Routing with `axum`]()
|
||||
- [Templates]()
|
||||
|
@ -20,7 +20,6 @@ This includes:
|
||||
- `mas-static-files`: Frontend static files (CSS/JS). Includes some frontend tooling
|
||||
- `mas-storage`: Interactions with the database
|
||||
- `mas-tasks`: Asynchronous task runner and scheduler
|
||||
- `mas-warp-utils`: Various filters and utilities for the `warp` web framework
|
||||
- `oauth2-types`: Useful structures and types to deal with OAuth 2.0/OpenID Connect endpoints. This might end up published as a standalone library as it can be useful in other contexts.
|
||||
|
||||
## Important crates
|
||||
@ -74,11 +73,6 @@ Both crates work well together and complement each other.
|
||||
Interactions with the database are done through [`sqlx`](https://github.com/launchbadge/sqlx), an async, pure-Rust SQL library with compile-time check of queries.
|
||||
It also handles schema migrations.
|
||||
|
||||
### Web framework: `warp`
|
||||
|
||||
[`warp`](https://docs.rs/warp/*/warp/) is an easy, macro-free web framework.
|
||||
Its composability makes a lot of sense when implementing OAuth 2.0 endpoints, because of the need to deal with a lot of different scenarios.
|
||||
|
||||
### Templates: `tera`
|
||||
|
||||
[Tera](https://tera.netlify.app/) was chosen as template engine for its simplicity as well as its ability to load templates at runtime.
|
||||
|
@ -1,93 +0,0 @@
|
||||
# `warp`
|
||||
|
||||
**Warning: this document is not up to date**
|
||||
|
||||
Warp has a pretty unique approach in terms of routing.
|
||||
It does not have a central router, rather a chain of filters composed together.
|
||||
|
||||
It encourages writing reusable filters to handle stuff like authentication, extracting user sessions, starting database transactions, etc.
|
||||
|
||||
Everything related to `warp` currently lives in the `mas-core` crate:
|
||||
|
||||
- `crates/core/src/`
|
||||
- `handlers/`: The actual handlers for each route
|
||||
- `oauth2/`: Everything related to OAuth 2.0/OIDC endpoints
|
||||
- `views/`: HTML views (login, registration, account management, etc.)
|
||||
- `filters/`: Reusable, composable filters
|
||||
- `reply/`: Composable replies
|
||||
|
||||
## Defining a new endpoint
|
||||
|
||||
We usually keep one endpoint per file and use module roots to combine the filters of endpoints.
|
||||
|
||||
This is how it looks like in the current hierarchy at time of writing:
|
||||
- `mod.rs`: combines the filters from `oauth2`, `views` and `health`
|
||||
- `oauth2/`
|
||||
- `mod.rs`: combines filters from `authorization`, `discovery`, etc.
|
||||
- `authorization.rs`: handles `GET /oauth2/authorize` and `GET /oauth2/authorize/step`
|
||||
- `discovery.rs`: handles `GET /.well-known/openid-configuration`
|
||||
- ...
|
||||
- `views/`
|
||||
- `mod.rs`: combines the filters from `index`, `login`, `logout`, etc.
|
||||
- `index.rs`: handles `GET /`
|
||||
- `login.rs`: handles `GET /login` and `POST /login`
|
||||
- `logout.rs`: handles `POST /logout`
|
||||
- ...
|
||||
- `health.rs`: handles `GET /health`
|
||||
|
||||
All filters are functions that take their dependencies (the database connection pool, the template engine, etc.) as parameters and return an `impl warp::Filter<Extract = (impl warp::Reply,)>`.
|
||||
|
||||
```rust
|
||||
// crates/core/src/handlers/hello.rs
|
||||
|
||||
// Don't be scared by the type at the end, just copy-paste it
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
// Handles `GET /hello/:param`
|
||||
warp::path!("hello" / String)
|
||||
.and(warp::get())
|
||||
// Pass the template engine
|
||||
.and(with_templates(templates))
|
||||
// Extract the current user session
|
||||
.and(optional_session(pool, cookies_config))
|
||||
.and_then(get)
|
||||
}
|
||||
|
||||
async fn get(
|
||||
// Parameter from the route
|
||||
parameter: String,
|
||||
// Template engine
|
||||
templates: Templates,
|
||||
// The current user session
|
||||
session: Option<SessionInfo>,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let ctx = SomeTemplateContext::new(parameter)
|
||||
.maybe_with_session(session);
|
||||
|
||||
let content = templates.render_something(&ctx)?;
|
||||
let reply = html(content);
|
||||
Ok(reply)
|
||||
}
|
||||
```
|
||||
|
||||
And then, it can be attached to the root handler:
|
||||
|
||||
```rust
|
||||
// crates/core/src/handlers/mod.rs
|
||||
|
||||
use self::{health::filter as health, oauth2::filter as oauth2, hello::filter as hello};
|
||||
|
||||
pub fn root(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
config: &RootConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
health(pool)
|
||||
.or(oauth2(pool, templates, &config.oauth2, &config.cookies))
|
||||
// Attach it here, passing the right dependencies
|
||||
.or(hello(pool, templates, &config.cookies))
|
||||
}
|
||||
```
|
Reference in New Issue
Block a user