You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Axum migration: signed cookies, errors, CSRF tokens, sessions
This commit is contained in:
89
Cargo.lock
generated
89
Cargo.lock
generated
@ -27,6 +27,32 @@ dependencies = [
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
"opaque-debug 0.3.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes-gcm"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df5f85a83a7d8b0442b6aa7b504b8212c1733da07b98aae43d4bc21b2cb3cdf6"
|
||||
dependencies = [
|
||||
"aead",
|
||||
"aes",
|
||||
"cipher",
|
||||
"ctr",
|
||||
"ghash",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.6"
|
||||
@ -820,6 +846,14 @@ version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94d4706de1b0fa5b132270cddffa8585166037822e260a944fe161acd137ca05"
|
||||
dependencies = [
|
||||
"aes-gcm",
|
||||
"base64",
|
||||
"hkdf",
|
||||
"hmac 0.12.1",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"sha2 0.10.2",
|
||||
"subtle",
|
||||
"time 0.3.7",
|
||||
"version_check",
|
||||
]
|
||||
@ -966,6 +1000,15 @@ dependencies = [
|
||||
"sct 0.6.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctr"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "049bb91fb4aaf0e3c7efa6cd5ef877dbbbd15b39dad06d9948de4ec8a75761ea"
|
||||
dependencies = [
|
||||
"cipher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.13.1"
|
||||
@ -1356,6 +1399,16 @@ dependencies = [
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ghash"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99"
|
||||
dependencies = [
|
||||
"opaque-debug 0.3.0",
|
||||
"polyval",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.26.1"
|
||||
@ -1503,6 +1556,15 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "hkdf"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "791a029f6b9fc27657f6f188ec6e5e43f6911f6f878e0dc5501396e09809d437"
|
||||
dependencies = [
|
||||
"hmac 0.12.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.11.0"
|
||||
@ -1870,10 +1932,23 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum",
|
||||
"bincode",
|
||||
"chrono",
|
||||
"cookie",
|
||||
"data-encoding",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http",
|
||||
"mas-data-model",
|
||||
"mas-storage",
|
||||
"mas-templates",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_with",
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1889,6 +1964,7 @@ dependencies = [
|
||||
"futures 0.3.21",
|
||||
"hyper",
|
||||
"indoc",
|
||||
"mas-axum-utils",
|
||||
"mas-config",
|
||||
"mas-email",
|
||||
"mas-handlers",
|
||||
@ -1926,6 +2002,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"chacha20poly1305",
|
||||
"chrono",
|
||||
"cookie",
|
||||
"elliptic-curve",
|
||||
"figment",
|
||||
"indoc",
|
||||
@ -2902,6 +2979,18 @@ dependencies = [
|
||||
"universal-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polyval"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"opaque-debug 0.3.0",
|
||||
"universal-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.16"
|
||||
|
28
crates/axum-utils/Cargo.toml
Normal file
28
crates/axum-utils/Cargo.toml
Normal file
@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "mas-axum-utils"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.52"
|
||||
axum = "0.4.8"
|
||||
bincode = "1.3.3"
|
||||
chrono = "0.4.19"
|
||||
cookie = { version = "0.16.0", features = ["signed", "private", "percent-encode"] }
|
||||
data-encoding = "2.3.2"
|
||||
futures-util = "0.3.21"
|
||||
headers = "0.3.7"
|
||||
http = "0.2.6"
|
||||
rand = "0.8.5"
|
||||
serde = "1.0.136"
|
||||
serde_with = "1.12.0"
|
||||
sqlx = "0.5.11"
|
||||
thiserror = "1.0.30"
|
||||
tracing = "0.1.32"
|
||||
url = "2.2.2"
|
||||
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-data-model = { path = "../data-model" }
|
159
crates/axum-utils/src/cookies.rs
Normal file
159
crates/axum-utils/src/cookies.rs
Normal file
@ -0,0 +1,159 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::extract::{Extension, FromRequest, RequestParts};
|
||||
pub use cookie::Cookie;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::HeaderMap;
|
||||
use http::header::{COOKIE, SET_COOKIE};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
pub struct PrivateCookieJar<K = cookie::Key> {
|
||||
jar: cookie::CookieJar,
|
||||
key: cookie::Key,
|
||||
_marker: PhantomData<K>,
|
||||
}
|
||||
|
||||
impl<K> PrivateCookieJar<K> {
|
||||
pub fn get(&self, name: &str) -> Option<Cookie<'static>> {
|
||||
self.private_jar().get(name)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn remove(mut self, cookie: Cookie<'static>) -> Self {
|
||||
self.private_jar_mut().remove(cookie);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn add(mut self, cookie: Cookie<'static>) -> Self {
|
||||
self.private_jar_mut().add(cookie);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn decrypt(&self, cookie: Cookie<'static>) -> Option<Cookie<'static>> {
|
||||
self.private_jar().decrypt(cookie)
|
||||
}
|
||||
|
||||
fn private_jar(&self) -> cookie::PrivateJar<&'_ cookie::CookieJar> {
|
||||
self.jar.private(&self.key)
|
||||
}
|
||||
|
||||
fn private_jar_mut(&mut self) -> cookie::PrivateJar<&'_ mut cookie::CookieJar> {
|
||||
self.jar.private_mut(&self.key)
|
||||
}
|
||||
|
||||
pub fn set_cookies(self, headers: &mut HeaderMap) {
|
||||
for cookie in self.jar.delta() {
|
||||
if let Ok(header_value) = cookie.encoded().to_string().parse() {
|
||||
headers.append(SET_COOKIE, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn headers(self) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
self.set_cookies(&mut headers);
|
||||
headers
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<B, K> FromRequest<B> for PrivateCookieJar<K>
|
||||
where
|
||||
B: Send,
|
||||
K: Into<cookie::Key> + Clone + Send + Sync + 'static,
|
||||
{
|
||||
type Rejection = <Extension<K> as FromRequest<B>>::Rejection;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let Extension(key): Extension<K> = Extension::from_request(req).await?;
|
||||
let key = key.into();
|
||||
|
||||
let mut jar = cookie::CookieJar::new();
|
||||
let mut private_jar = jar.private_mut(&key);
|
||||
|
||||
// TODO: remove this when axum 0.5 gets released
|
||||
// https://github.com/tokio-rs/axum/pull/698
|
||||
let empty_headers = HeaderMap::new();
|
||||
|
||||
let cookies = req
|
||||
.headers()
|
||||
.unwrap_or(&empty_headers)
|
||||
.get_all(COOKIE)
|
||||
.into_iter()
|
||||
.filter_map(|value| value.to_str().ok())
|
||||
.flat_map(|value| value.split(';'))
|
||||
.filter_map(|cookie| Cookie::parse_encoded(cookie.to_owned()).ok());
|
||||
|
||||
for cookie in cookies {
|
||||
if let Some(cookie) = private_jar.decrypt(cookie) {
|
||||
private_jar.add_original(cookie);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
jar,
|
||||
key,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("could not decode cookie")]
|
||||
pub enum CookieDecodeError {
|
||||
Deserialize(#[from] bincode::Error),
|
||||
Decode(#[from] data_encoding::DecodeError),
|
||||
}
|
||||
|
||||
pub trait CookieExt {
|
||||
fn decode<T>(&self) -> Result<T, CookieDecodeError>
|
||||
where
|
||||
T: DeserializeOwned;
|
||||
|
||||
fn encode<T>(self, t: &T) -> Self
|
||||
where
|
||||
T: Serialize;
|
||||
}
|
||||
|
||||
impl<'a> CookieExt for Cookie<'a> {
|
||||
fn decode<T>(&self) -> Result<T, CookieDecodeError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let bytes = BASE64URL_NOPAD.decode(self.value().as_bytes())?;
|
||||
|
||||
let decoded = bincode::deserialize(&bytes)?;
|
||||
|
||||
Ok(decoded)
|
||||
}
|
||||
|
||||
fn encode<T>(mut self, t: &T) -> Self
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let bytes = bincode::serialize(t).unwrap();
|
||||
let encoded = BASE64URL_NOPAD.encode(&bytes);
|
||||
self.set_value(encoded);
|
||||
self
|
||||
}
|
||||
}
|
111
crates/axum-utils/src/csrf.rs
Normal file
111
crates/axum-utils/src/csrf.rs
Normal file
@ -0,0 +1,111 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use cookie::Cookie;
|
||||
use data_encoding::{DecodeError, BASE64URL_NOPAD};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{CookieExt, PrivateCookieJar};
|
||||
|
||||
/// Failed to validate CSRF token
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CsrfError {
|
||||
/// The token in the form did not match the token in the cookie
|
||||
#[error("CSRF token mismatch")]
|
||||
Mismatch,
|
||||
|
||||
/// The token expired
|
||||
#[error("CSRF token expired")]
|
||||
Expired,
|
||||
|
||||
/// Failed to decode the token
|
||||
#[error("could not decode CSRF token")]
|
||||
Decode(#[from] DecodeError),
|
||||
}
|
||||
|
||||
/// A CSRF token
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CsrfToken {
|
||||
#[serde_as(as = "TimestampSeconds<i64>")]
|
||||
expiration: DateTime<Utc>,
|
||||
token: [u8; 32],
|
||||
}
|
||||
|
||||
impl CsrfToken {
|
||||
/// Create a new token from a defined value valid for a specified duration
|
||||
fn new(token: [u8; 32], ttl: Duration) -> Self {
|
||||
let expiration = Utc::now() + ttl;
|
||||
Self { expiration, token }
|
||||
}
|
||||
|
||||
/// Generate a new random token valid for a specified duration
|
||||
fn generate(ttl: Duration) -> Self {
|
||||
let token = rand::random();
|
||||
Self::new(token, ttl)
|
||||
}
|
||||
|
||||
/// Generate a new token with the same value but an up to date expiration
|
||||
fn refresh(self, ttl: Duration) -> Self {
|
||||
Self::new(self.token, ttl)
|
||||
}
|
||||
|
||||
/// Get the value to include in HTML forms
|
||||
#[must_use]
|
||||
pub fn form_value(&self) -> String {
|
||||
BASE64URL_NOPAD.encode(&self.token[..])
|
||||
}
|
||||
|
||||
/// Verifies that the value got from an HTML form matches this token
|
||||
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
|
||||
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
|
||||
if self.token[..] == form_value {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CsrfError::Mismatch)
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_expiration(self) -> Result<Self, CsrfError> {
|
||||
if Utc::now() < self.expiration {
|
||||
Ok(self)
|
||||
} else {
|
||||
Err(CsrfError::Expired)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CsrfExt {
|
||||
fn csrf_token(self) -> (CsrfToken, Self);
|
||||
}
|
||||
|
||||
impl<K> CsrfExt for PrivateCookieJar<K> {
|
||||
fn csrf_token(self) -> (CsrfToken, Self) {
|
||||
let jar = self;
|
||||
let cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", ""));
|
||||
let new_token = cookie
|
||||
.decode()
|
||||
.ok()
|
||||
.and_then(|token: CsrfToken| token.verify_expiration().ok())
|
||||
.unwrap_or_else(|| CsrfToken::generate(Duration::hours(1)))
|
||||
.refresh(Duration::hours(1));
|
||||
|
||||
let cookie = cookie.encode(&new_token);
|
||||
let jar = jar.add(cookie);
|
||||
(new_token, jar)
|
||||
}
|
||||
}
|
102
crates/axum-utils/src/fancy_error.rs
Normal file
102
crates/axum-utils/src/fancy_error.rs
Normal file
@ -0,0 +1,102 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{convert::Infallible, error::Error};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::{HttpBody, StreamBody},
|
||||
extract::{Extension, FromRequest, RequestParts},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures_util::FutureExt;
|
||||
use headers::{ContentType, HeaderMapExt};
|
||||
use mas_templates::{ErrorContext, Templates};
|
||||
use sqlx::PgPool;
|
||||
|
||||
struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>);
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for DatabaseConnection
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = FancyError;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let Extension(templates) = Extension::<Templates>::from_request(req)
|
||||
.await
|
||||
.map_err(internal_error)?;
|
||||
|
||||
let Extension(pool) = Extension::<PgPool>::from_request(req)
|
||||
.await
|
||||
.map_err(fancy_error(templates))?;
|
||||
|
||||
let conn = pool.acquire().await.map_err(internal_error)?;
|
||||
|
||||
Ok(Self(conn))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fancy_error<E: Error + 'static>(templates: Templates) -> impl Fn(E) -> FancyError {
|
||||
move |error: E| FancyError {
|
||||
templates: Some(templates.clone()),
|
||||
error: Box::new(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn internal_error<E: Error + 'static>(error: E) -> FancyError
|
||||
where
|
||||
E: Error,
|
||||
{
|
||||
FancyError {
|
||||
templates: None,
|
||||
error: Box::new(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FancyError {
|
||||
templates: Option<Templates>,
|
||||
error: Box<dyn Error>,
|
||||
}
|
||||
|
||||
impl IntoResponse for FancyError {
|
||||
fn into_response(self) -> Response {
|
||||
let error = format!("{}", self.error);
|
||||
let context = ErrorContext::new().with_description(error.clone());
|
||||
let body = match self.templates {
|
||||
Some(templates) => {
|
||||
let stream = (async move {
|
||||
Ok::<_, Infallible>(match templates.render_error(&context).await {
|
||||
Ok(s) => s,
|
||||
Err(_e) => "failed to render error template".to_string(),
|
||||
})
|
||||
})
|
||||
.into_stream();
|
||||
|
||||
StreamBody::new(stream).boxed_unsync()
|
||||
}
|
||||
None => axum::body::Full::from(error)
|
||||
.map_err(|_e| unreachable!())
|
||||
.boxed_unsync(),
|
||||
};
|
||||
|
||||
let mut res = Response::new(body);
|
||||
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
res.headers_mut().typed_insert(ContentType::html());
|
||||
res
|
||||
}
|
||||
}
|
||||
|
26
crates/axum-utils/src/lib.rs
Normal file
26
crates/axum-utils/src/lib.rs
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod cookies;
|
||||
pub mod csrf;
|
||||
pub mod fancy_error;
|
||||
pub mod session;
|
||||
pub mod url_builder;
|
||||
|
||||
pub use self::{
|
||||
cookies::{Cookie, CookieExt, PrivateCookieJar},
|
||||
fancy_error::{fancy_error, internal_error, FancyError},
|
||||
session::{SessionInfo, SessionInfoExt},
|
||||
url_builder::UrlBuilder,
|
||||
};
|
87
crates/axum-utils/src/session.rs
Normal file
87
crates/axum-utils/src/session.rs
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use cookie::Cookie;
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_storage::{
|
||||
user::{lookup_active_session, ActiveSessionLookupError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Executor, Postgres};
|
||||
|
||||
use crate::{CookieExt, PrivateCookieJar};
|
||||
|
||||
/// An encrypted cookie to save the session ID
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
pub struct SessionInfo {
|
||||
current: Option<i64>,
|
||||
}
|
||||
|
||||
impl SessionInfo {
|
||||
/// Forge the cookie from a [`BrowserSession`]
|
||||
#[must_use]
|
||||
pub fn from_session(session: &BrowserSession<PostgresqlBackend>) -> Self {
|
||||
Self {
|
||||
current: Some(session.data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the [`BrowserSession`] from database
|
||||
pub async fn load_session(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> Result<Option<BrowserSession<PostgresqlBackend>>, ActiveSessionLookupError> {
|
||||
let session_id = if let Some(id) = self.current {
|
||||
id
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let res = lookup_active_session(executor, session_id).await?;
|
||||
Ok(Some(res))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SessionInfoExt {
|
||||
fn session_info(self) -> (SessionInfo, Self);
|
||||
fn update_session_info(self, info: &SessionInfo) -> Self;
|
||||
fn set_session(self, session: &BrowserSession<PostgresqlBackend>) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let session_info = SessionInfo::from_session(session);
|
||||
self.update_session_info(&session_info)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K> SessionInfoExt for PrivateCookieJar<K> {
|
||||
fn session_info(self) -> (SessionInfo, Self) {
|
||||
let jar = self;
|
||||
let cookie = jar
|
||||
.get("session")
|
||||
.unwrap_or_else(|| Cookie::new("session", ""));
|
||||
let session_info = cookie.decode().unwrap_or_default();
|
||||
|
||||
let cookie = cookie.encode(&session_info);
|
||||
let jar = jar.add(cookie);
|
||||
(session_info, jar)
|
||||
}
|
||||
|
||||
fn update_session_info(self, info: &SessionInfo) -> Self {
|
||||
let cookie = Cookie::new("session", "");
|
||||
let cookie = cookie.encode(&info);
|
||||
self.add(cookie)
|
||||
}
|
||||
}
|
100
crates/axum-utils/src/url_builder.rs
Normal file
100
crates/axum-utils/src/url_builder.rs
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Utility to build URLs
|
||||
|
||||
use url::Url;
|
||||
|
||||
/// Helps building absolute URLs
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct UrlBuilder {
|
||||
base: Url,
|
||||
}
|
||||
|
||||
impl UrlBuilder {
|
||||
/// Create a new [`UrlBuilder`] from a base URL
|
||||
#[must_use]
|
||||
pub fn new(base: Url) -> Self {
|
||||
Self { base }
|
||||
}
|
||||
|
||||
/// OIDC issuer
|
||||
#[must_use]
|
||||
pub fn oidc_issuer(&self) -> Url {
|
||||
self.base.clone()
|
||||
}
|
||||
|
||||
/// OIDC dicovery document URL
|
||||
#[must_use]
|
||||
pub fn oidc_discovery(&self) -> Url {
|
||||
self.base
|
||||
.join(".well-known/openid-configuration")
|
||||
.expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 authorization endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_authorization_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/authorize").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 token endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_token_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/token").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 introspection endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_introspection_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/introspect").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 introspection endpoint
|
||||
#[must_use]
|
||||
pub fn oidc_userinfo_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/userinfo").expect("build URL")
|
||||
}
|
||||
|
||||
/// JWKS URI
|
||||
#[must_use]
|
||||
pub fn jwks_uri(&self) -> Url {
|
||||
self.base.join("oauth2/keys.json").expect("build URL")
|
||||
}
|
||||
|
||||
/// Email verification URL
|
||||
#[must_use]
|
||||
pub fn email_verification(&self, code: &str) -> Url {
|
||||
self.base
|
||||
.join("verify/")
|
||||
.expect("build URL")
|
||||
.join(code)
|
||||
.expect("build URL")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn build_email_verification_url() {
|
||||
let base = Url::parse("https://example.com/").unwrap();
|
||||
let builder = UrlBuilder::new(base);
|
||||
assert_eq!(
|
||||
builder.email_verification("123456abcdef").as_str(),
|
||||
"https://example.com/verify/123456abcdef"
|
||||
);
|
||||
}
|
||||
}
|
@ -43,6 +43,7 @@ mas-storage = { path = "../storage" }
|
||||
mas-tasks = { path = "../tasks" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-warp-utils = { path = "../warp-utils" }
|
||||
mas-axum-utils = { path = "../axum-utils" }
|
||||
|
||||
[dev-dependencies]
|
||||
indoc = "1.0.4"
|
||||
|
@ -22,12 +22,12 @@ use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use futures::{future::TryFutureExt, stream::TryStreamExt};
|
||||
use hyper::Server;
|
||||
use mas_axum_utils::UrlBuilder;
|
||||
use mas_config::RootConfig;
|
||||
use mas_email::{MailTransport, Mailer};
|
||||
use mas_storage::MIGRATOR;
|
||||
use mas_tasks::TaskQueue;
|
||||
use mas_templates::Templates;
|
||||
use tower::{make::Shared, Layer};
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
@ -191,6 +191,11 @@ impl Options {
|
||||
&config.email.reply_to,
|
||||
);
|
||||
|
||||
let url_builder = UrlBuilder::new(config.http.public_base.clone());
|
||||
|
||||
// Explicitely the config to properly zeroize secret keys
|
||||
drop(config);
|
||||
|
||||
// Watch for changes in templates if the --watch flag is present
|
||||
if self.watch {
|
||||
let client = watchman_client::Connector::new()
|
||||
@ -203,11 +208,14 @@ impl Options {
|
||||
.context("could not watch for templates changes")?;
|
||||
}
|
||||
|
||||
let router =
|
||||
mas_handlers::router(&pool, &templates, &key_store, &encrypter, &mailer, &config);
|
||||
|
||||
// Explicitely the config to properly zeroize secret keys
|
||||
drop(config);
|
||||
let router = mas_handlers::router(
|
||||
&pool,
|
||||
&templates,
|
||||
&key_store,
|
||||
&encrypter,
|
||||
&mailer,
|
||||
&url_builder,
|
||||
);
|
||||
|
||||
info!("Listening on http://{}", listener.local_addr().unwrap());
|
||||
|
||||
|
@ -31,6 +31,7 @@ pkcs8 = { version = "0.8.0", features = ["pem"] }
|
||||
chacha20poly1305 = { version = "0.9.0", features = ["std"] }
|
||||
elliptic-curve = { version = "0.11.12", features = ["pem", "pkcs8"] }
|
||||
pem-rfc7468 = "0.3.1"
|
||||
cookie = { version = "0.16.0", features = ["private", "key-expansion"] }
|
||||
|
||||
indoc = "1.0.4"
|
||||
|
||||
|
@ -20,6 +20,7 @@ use chacha20poly1305::{
|
||||
aead::{generic_array::GenericArray, Aead, NewAead},
|
||||
ChaCha20Poly1305,
|
||||
};
|
||||
use cookie::Key;
|
||||
use mas_jose::StaticKeystore;
|
||||
use pkcs8::DecodePrivateKey;
|
||||
use rsa::{
|
||||
@ -37,6 +38,7 @@ use super::ConfigurationSection;
|
||||
/// Helps encrypting and decrypting data
|
||||
#[derive(Clone)]
|
||||
pub struct Encrypter {
|
||||
cookie_key: Arc<Key>,
|
||||
aead: Arc<ChaCha20Poly1305>,
|
||||
}
|
||||
|
||||
@ -44,10 +46,12 @@ impl Encrypter {
|
||||
/// Creates an [`Encrypter`] out of an encryption key
|
||||
#[must_use]
|
||||
pub fn new(key: &[u8; 32]) -> Self {
|
||||
let cookie_key = Key::derive_from(&key[..]);
|
||||
let cookie_key = Arc::new(cookie_key);
|
||||
let key = GenericArray::from_slice(key);
|
||||
let aead = ChaCha20Poly1305::new(key);
|
||||
let aead = Arc::new(aead);
|
||||
Self { aead }
|
||||
Self { cookie_key, aead }
|
||||
}
|
||||
|
||||
/// Encrypt a payload
|
||||
@ -73,6 +77,12 @@ impl Encrypter {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Encrypter> for Key {
|
||||
fn from(e: Encrypter) -> Self {
|
||||
e.cookie_key.as_ref().clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn example_secret() -> &'static str {
|
||||
"0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff"
|
||||
}
|
||||
|
@ -22,6 +22,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{extract::Extension, routing::get, Router};
|
||||
use mas_axum_utils::UrlBuilder;
|
||||
use mas_config::{Encrypter, RootConfig};
|
||||
use mas_email::Mailer;
|
||||
use mas_jose::StaticKeystore;
|
||||
@ -60,13 +61,14 @@ pub fn root(
|
||||
filter.with(warp::log(module_path!())).boxed()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn router<B: Send + 'static>(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
key_store: &Arc<StaticKeystore>,
|
||||
encrypter: &Encrypter,
|
||||
mailer: &Mailer,
|
||||
config: &RootConfig,
|
||||
url_builder: &UrlBuilder,
|
||||
) -> Router<B> {
|
||||
Router::new()
|
||||
.route("/", get(self::views::index::get))
|
||||
@ -75,5 +77,6 @@ pub fn router<B: Send + 'static>(
|
||||
.layer(Extension(templates.clone()))
|
||||
.layer(Extension(key_store.clone()))
|
||||
.layer(Extension(encrypter.clone()))
|
||||
.layer(Extension(url_builder.clone()))
|
||||
.layer(Extension(mailer.clone()))
|
||||
}
|
||||
|
@ -12,79 +12,43 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use axum::{
|
||||
extract::Extension,
|
||||
response::{Html, IntoResponse},
|
||||
};
|
||||
use mas_axum_utils::{fancy_error, FancyError};
|
||||
use mas_config::{CsrfConfig, Encrypter, HttpConfig};
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_storage::PostgresqlBackend;
|
||||
use mas_templates::{IndexContext, TemplateContext, Templates};
|
||||
use mas_warp_utils::filters::{
|
||||
self,
|
||||
cookies::{encrypted_cookie_saver, EncryptedCookieSaver},
|
||||
csrf::updated_csrf_token,
|
||||
session::optional_session,
|
||||
url_builder::{url_builder, UrlBuilder},
|
||||
with_templates, CsrfToken,
|
||||
use mas_axum_utils::{
|
||||
csrf::CsrfExt, fancy_error, FancyError, PrivateCookieJar, SessionInfoExt, UrlBuilder,
|
||||
};
|
||||
use mas_config::Encrypter;
|
||||
use mas_templates::{IndexContext, TemplateContext, Templates};
|
||||
use sqlx::PgPool;
|
||||
use url::Url;
|
||||
use warp::{filters::BoxedFilter, reply::html, Filter, Rejection, Reply};
|
||||
|
||||
/*
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
encrypter: &Encrypter,
|
||||
http_config: &HttpConfig,
|
||||
csrf_config: &CsrfConfig,
|
||||
) -> BoxedFilter<(Box<dyn Reply>,)> {
|
||||
warp::path::end()
|
||||
.and(filters::trace::name("GET /"))
|
||||
.and(warp::get())
|
||||
.and(url_builder(http_config))
|
||||
.and(with_templates(templates))
|
||||
.and(encrypted_cookie_saver(encrypter))
|
||||
.and(updated_csrf_token(encrypter, csrf_config))
|
||||
.and(optional_session(pool, encrypter))
|
||||
.and_then(get)
|
||||
.boxed()
|
||||
}
|
||||
|
||||
async fn get(
|
||||
url_builder: UrlBuilder,
|
||||
templates: Templates,
|
||||
cookie_saver: EncryptedCookieSaver,
|
||||
csrf_token: CsrfToken,
|
||||
maybe_session: Option<BrowserSession<PostgresqlBackend>>,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let ctx = IndexContext::new(url_builder.oidc_discovery())
|
||||
.maybe_with_session(maybe_session)
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates.render_index(&ctx).await?;
|
||||
let reply = html(content);
|
||||
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
|
||||
Ok(Box::new(reply))
|
||||
}
|
||||
*/
|
||||
|
||||
pub async fn get(
|
||||
Extension(templates): Extension<Templates>,
|
||||
Extension(url_builder): Extension<UrlBuilder>,
|
||||
Extension(pool): Extension<PgPool>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let ctx = IndexContext::new(
|
||||
Url::from_str("https://example.com/.well-known/openid-discovery").unwrap(),
|
||||
)
|
||||
.maybe_with_session::<PostgresqlBackend>(None)
|
||||
.with_csrf("csrf_token".to_string());
|
||||
let mut conn = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(fancy_error(templates.clone()))?;
|
||||
|
||||
let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
let session = session_info
|
||||
.load_session(&mut conn)
|
||||
.await
|
||||
.map_err(fancy_error(templates.clone()))?;
|
||||
|
||||
let ctx = IndexContext::new(url_builder.oidc_discovery())
|
||||
.maybe_with_session(session)
|
||||
.with_csrf(csrf_token.form_value());
|
||||
|
||||
let content = templates
|
||||
.render_index(&ctx)
|
||||
.await
|
||||
.map_err(fancy_error(templates))?;
|
||||
Ok(Html(content))
|
||||
|
||||
Ok((cookie_jar.headers(), Html(content)))
|
||||
}
|
||||
|
@ -56,13 +56,14 @@ pub trait TemplateContext: Serialize {
|
||||
}
|
||||
|
||||
/// Attach a CSRF token to the template context
|
||||
fn with_csrf(self, csrf_token: String) -> WithCsrf<Self>
|
||||
fn with_csrf<C>(self, csrf_token: C) -> WithCsrf<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
C: ToString,
|
||||
{
|
||||
// TODO: make this method use a CsrfToken again
|
||||
WithCsrf {
|
||||
csrf_token,
|
||||
csrf_token: csrf_token.to_string(),
|
||||
inner: self,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user