diff --git a/Cargo.lock b/Cargo.lock index cc73ddc4..121a294f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,6 +27,32 @@ dependencies = [ "rand_core", ] +[[package]] +name = "aes" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", + "opaque-debug 0.3.0", +] + +[[package]] +name = "aes-gcm" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df5f85a83a7d8b0442b6aa7b504b8212c1733da07b98aae43d4bc21b2cb3cdf6" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.7.6" @@ -820,6 +846,14 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94d4706de1b0fa5b132270cddffa8585166037822e260a944fe161acd137ca05" dependencies = [ + "aes-gcm", + "base64", + "hkdf", + "hmac 0.12.1", + "percent-encoding", + "rand", + "sha2 0.10.2", + "subtle", "time 0.3.7", "version_check", ] @@ -966,6 +1000,15 @@ dependencies = [ "sct 0.6.1", ] +[[package]] +name = "ctr" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "049bb91fb4aaf0e3c7efa6cd5ef877dbbbd15b39dad06d9948de4ec8a75761ea" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.13.1" @@ -1356,6 +1399,16 @@ dependencies = [ "wasi", ] +[[package]] +name = "ghash" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" +dependencies = [ + "opaque-debug 0.3.0", + "polyval", +] + [[package]] name = "gimli" version = "0.26.1" @@ -1503,6 +1556,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "791a029f6b9fc27657f6f188ec6e5e43f6911f6f878e0dc5501396e09809d437" +dependencies = [ + "hmac 0.12.1", +] + [[package]] name = "hmac" version = "0.11.0" @@ -1870,10 +1932,23 @@ version = "0.1.0" dependencies = [ "async-trait", "axum", + "bincode", + "chrono", + "cookie", + "data-encoding", "futures-util", "headers", + "http", + "mas-data-model", + "mas-storage", "mas-templates", + "rand", + "serde", + "serde_with", "sqlx", + "thiserror", + "tracing", + "url", ] [[package]] @@ -1889,6 +1964,7 @@ dependencies = [ "futures 0.3.21", "hyper", "indoc", + "mas-axum-utils", "mas-config", "mas-email", "mas-handlers", @@ -1926,6 +2002,7 @@ dependencies = [ "async-trait", "chacha20poly1305", "chrono", + "cookie", "elliptic-curve", "figment", "indoc", @@ -2902,6 +2979,18 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "polyval" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug 0.3.0", + "universal-hash", +] + [[package]] name = "ppv-lite86" version = "0.2.16" diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml new file mode 100644 index 00000000..3054b720 --- /dev/null +++ b/crates/axum-utils/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "mas-axum-utils" +version = "0.1.0" +authors = ["Quentin Gliech "] +edition = "2021" +license = "Apache-2.0" + +[dependencies] +async-trait = "0.1.52" +axum = "0.4.8" +bincode = "1.3.3" +chrono = "0.4.19" +cookie = { version = "0.16.0", features = ["signed", "private", "percent-encode"] } +data-encoding = "2.3.2" +futures-util = "0.3.21" +headers = "0.3.7" +http = "0.2.6" +rand = "0.8.5" +serde = "1.0.136" +serde_with = "1.12.0" +sqlx = "0.5.11" +thiserror = "1.0.30" +tracing = "0.1.32" +url = "2.2.2" + +mas-templates = { path = "../templates" } +mas-storage = { path = "../storage" } +mas-data-model = { path = "../data-model" } diff --git a/crates/axum-utils/src/cookies.rs b/crates/axum-utils/src/cookies.rs new file mode 100644 index 00000000..a6b867c6 --- /dev/null +++ b/crates/axum-utils/src/cookies.rs @@ -0,0 +1,159 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Private (encrypted) cookie jar, based on axum-extra's cookie jar + +use std::marker::PhantomData; + +use async_trait::async_trait; +use axum::extract::{Extension, FromRequest, RequestParts}; +pub use cookie::Cookie; +use data_encoding::BASE64URL_NOPAD; +use headers::HeaderMap; +use http::header::{COOKIE, SET_COOKIE}; +use serde::{de::DeserializeOwned, Serialize}; +use thiserror::Error; + +pub struct PrivateCookieJar { + jar: cookie::CookieJar, + key: cookie::Key, + _marker: PhantomData, +} + +impl PrivateCookieJar { + pub fn get(&self, name: &str) -> Option> { + self.private_jar().get(name) + } + + #[must_use] + pub fn remove(mut self, cookie: Cookie<'static>) -> Self { + self.private_jar_mut().remove(cookie); + self + } + + #[must_use] + #[allow(clippy::should_implement_trait)] + pub fn add(mut self, cookie: Cookie<'static>) -> Self { + self.private_jar_mut().add(cookie); + self + } + + pub fn decrypt(&self, cookie: Cookie<'static>) -> Option> { + self.private_jar().decrypt(cookie) + } + + fn private_jar(&self) -> cookie::PrivateJar<&'_ cookie::CookieJar> { + self.jar.private(&self.key) + } + + fn private_jar_mut(&mut self) -> cookie::PrivateJar<&'_ mut cookie::CookieJar> { + self.jar.private_mut(&self.key) + } + + pub fn set_cookies(self, headers: &mut HeaderMap) { + for cookie in self.jar.delta() { + if let Ok(header_value) = cookie.encoded().to_string().parse() { + headers.append(SET_COOKIE, header_value); + } + } + } + + pub fn headers(self) -> HeaderMap { + let mut headers = HeaderMap::new(); + self.set_cookies(&mut headers); + headers + } +} + +#[async_trait] +impl FromRequest for PrivateCookieJar +where + B: Send, + K: Into + Clone + Send + Sync + 'static, +{ + type Rejection = as FromRequest>::Rejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let Extension(key): Extension = Extension::from_request(req).await?; + let key = key.into(); + + let mut jar = cookie::CookieJar::new(); + let mut private_jar = jar.private_mut(&key); + + // TODO: remove this when axum 0.5 gets released + // https://github.com/tokio-rs/axum/pull/698 + let empty_headers = HeaderMap::new(); + + let cookies = req + .headers() + .unwrap_or(&empty_headers) + .get_all(COOKIE) + .into_iter() + .filter_map(|value| value.to_str().ok()) + .flat_map(|value| value.split(';')) + .filter_map(|cookie| Cookie::parse_encoded(cookie.to_owned()).ok()); + + for cookie in cookies { + if let Some(cookie) = private_jar.decrypt(cookie) { + private_jar.add_original(cookie); + } + } + + Ok(Self { + jar, + key, + _marker: PhantomData, + }) + } +} + +#[derive(Debug, Error)] +#[error("could not decode cookie")] +pub enum CookieDecodeError { + Deserialize(#[from] bincode::Error), + Decode(#[from] data_encoding::DecodeError), +} + +pub trait CookieExt { + fn decode(&self) -> Result + where + T: DeserializeOwned; + + fn encode(self, t: &T) -> Self + where + T: Serialize; +} + +impl<'a> CookieExt for Cookie<'a> { + fn decode(&self) -> Result + where + T: DeserializeOwned, + { + let bytes = BASE64URL_NOPAD.decode(self.value().as_bytes())?; + + let decoded = bincode::deserialize(&bytes)?; + + Ok(decoded) + } + + fn encode(mut self, t: &T) -> Self + where + T: Serialize, + { + let bytes = bincode::serialize(t).unwrap(); + let encoded = BASE64URL_NOPAD.encode(&bytes); + self.set_value(encoded); + self + } +} diff --git a/crates/axum-utils/src/csrf.rs b/crates/axum-utils/src/csrf.rs new file mode 100644 index 00000000..3c843bfe --- /dev/null +++ b/crates/axum-utils/src/csrf.rs @@ -0,0 +1,111 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Duration, Utc}; +use cookie::Cookie; +use data_encoding::{DecodeError, BASE64URL_NOPAD}; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, TimestampSeconds}; +use thiserror::Error; + +use crate::{CookieExt, PrivateCookieJar}; + +/// Failed to validate CSRF token +#[derive(Debug, Error)] +pub enum CsrfError { + /// The token in the form did not match the token in the cookie + #[error("CSRF token mismatch")] + Mismatch, + + /// The token expired + #[error("CSRF token expired")] + Expired, + + /// Failed to decode the token + #[error("could not decode CSRF token")] + Decode(#[from] DecodeError), +} + +/// A CSRF token +#[serde_as] +#[derive(Serialize, Deserialize, Debug)] +pub struct CsrfToken { + #[serde_as(as = "TimestampSeconds")] + expiration: DateTime, + token: [u8; 32], +} + +impl CsrfToken { + /// Create a new token from a defined value valid for a specified duration + fn new(token: [u8; 32], ttl: Duration) -> Self { + let expiration = Utc::now() + ttl; + Self { expiration, token } + } + + /// Generate a new random token valid for a specified duration + fn generate(ttl: Duration) -> Self { + let token = rand::random(); + Self::new(token, ttl) + } + + /// Generate a new token with the same value but an up to date expiration + fn refresh(self, ttl: Duration) -> Self { + Self::new(self.token, ttl) + } + + /// Get the value to include in HTML forms + #[must_use] + pub fn form_value(&self) -> String { + BASE64URL_NOPAD.encode(&self.token[..]) + } + + /// Verifies that the value got from an HTML form matches this token + pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> { + let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?; + if self.token[..] == form_value { + Ok(()) + } else { + Err(CsrfError::Mismatch) + } + } + + fn verify_expiration(self) -> Result { + if Utc::now() < self.expiration { + Ok(self) + } else { + Err(CsrfError::Expired) + } + } +} + +pub trait CsrfExt { + fn csrf_token(self) -> (CsrfToken, Self); +} + +impl CsrfExt for PrivateCookieJar { + fn csrf_token(self) -> (CsrfToken, Self) { + let jar = self; + let cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", "")); + let new_token = cookie + .decode() + .ok() + .and_then(|token: CsrfToken| token.verify_expiration().ok()) + .unwrap_or_else(|| CsrfToken::generate(Duration::hours(1))) + .refresh(Duration::hours(1)); + + let cookie = cookie.encode(&new_token); + let jar = jar.add(cookie); + (new_token, jar) + } +} diff --git a/crates/axum-utils/src/fancy_error.rs b/crates/axum-utils/src/fancy_error.rs new file mode 100644 index 00000000..991da794 --- /dev/null +++ b/crates/axum-utils/src/fancy_error.rs @@ -0,0 +1,102 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{convert::Infallible, error::Error}; + +use async_trait::async_trait; +use axum::{ + body::{HttpBody, StreamBody}, + extract::{Extension, FromRequest, RequestParts}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use futures_util::FutureExt; +use headers::{ContentType, HeaderMapExt}; +use mas_templates::{ErrorContext, Templates}; +use sqlx::PgPool; + +struct DatabaseConnection(sqlx::pool::PoolConnection); + +#[async_trait] +impl FromRequest for DatabaseConnection +where + B: Send, +{ + type Rejection = FancyError; + + async fn from_request(req: &mut RequestParts) -> Result { + let Extension(templates) = Extension::::from_request(req) + .await + .map_err(internal_error)?; + + let Extension(pool) = Extension::::from_request(req) + .await + .map_err(fancy_error(templates))?; + + let conn = pool.acquire().await.map_err(internal_error)?; + + Ok(Self(conn)) + } +} + +pub fn fancy_error(templates: Templates) -> impl Fn(E) -> FancyError { + move |error: E| FancyError { + templates: Some(templates.clone()), + error: Box::new(error), + } +} + +pub fn internal_error(error: E) -> FancyError +where + E: Error, +{ + FancyError { + templates: None, + error: Box::new(error), + } +} + +pub struct FancyError { + templates: Option, + error: Box, +} + +impl IntoResponse for FancyError { + fn into_response(self) -> Response { + let error = format!("{}", self.error); + let context = ErrorContext::new().with_description(error.clone()); + let body = match self.templates { + Some(templates) => { + let stream = (async move { + Ok::<_, Infallible>(match templates.render_error(&context).await { + Ok(s) => s, + Err(_e) => "failed to render error template".to_string(), + }) + }) + .into_stream(); + + StreamBody::new(stream).boxed_unsync() + } + None => axum::body::Full::from(error) + .map_err(|_e| unreachable!()) + .boxed_unsync(), + }; + + let mut res = Response::new(body); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + res.headers_mut().typed_insert(ContentType::html()); + res + } +} + diff --git a/crates/axum-utils/src/lib.rs b/crates/axum-utils/src/lib.rs new file mode 100644 index 00000000..6b4d16a6 --- /dev/null +++ b/crates/axum-utils/src/lib.rs @@ -0,0 +1,26 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod cookies; +pub mod csrf; +pub mod fancy_error; +pub mod session; +pub mod url_builder; + +pub use self::{ + cookies::{Cookie, CookieExt, PrivateCookieJar}, + fancy_error::{fancy_error, internal_error, FancyError}, + session::{SessionInfo, SessionInfoExt}, + url_builder::UrlBuilder, +}; diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs new file mode 100644 index 00000000..bac41b78 --- /dev/null +++ b/crates/axum-utils/src/session.rs @@ -0,0 +1,87 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use cookie::Cookie; +use mas_data_model::BrowserSession; +use mas_storage::{ + user::{lookup_active_session, ActiveSessionLookupError}, + PostgresqlBackend, +}; +use serde::{Deserialize, Serialize}; +use sqlx::{Executor, Postgres}; + +use crate::{CookieExt, PrivateCookieJar}; + +/// An encrypted cookie to save the session ID +#[derive(Serialize, Deserialize, Debug, Default)] +pub struct SessionInfo { + current: Option, +} + +impl SessionInfo { + /// Forge the cookie from a [`BrowserSession`] + #[must_use] + pub fn from_session(session: &BrowserSession) -> Self { + Self { + current: Some(session.data), + } + } + + /// Load the [`BrowserSession`] from database + pub async fn load_session( + &self, + executor: impl Executor<'_, Database = Postgres>, + ) -> Result>, ActiveSessionLookupError> { + let session_id = if let Some(id) = self.current { + id + } else { + return Ok(None); + }; + + let res = lookup_active_session(executor, session_id).await?; + Ok(Some(res)) + } +} + +pub trait SessionInfoExt { + fn session_info(self) -> (SessionInfo, Self); + fn update_session_info(self, info: &SessionInfo) -> Self; + fn set_session(self, session: &BrowserSession) -> Self + where + Self: Sized, + { + let session_info = SessionInfo::from_session(session); + self.update_session_info(&session_info) + } +} + +impl SessionInfoExt for PrivateCookieJar { + fn session_info(self) -> (SessionInfo, Self) { + let jar = self; + let cookie = jar + .get("session") + .unwrap_or_else(|| Cookie::new("session", "")); + let session_info = cookie.decode().unwrap_or_default(); + + let cookie = cookie.encode(&session_info); + let jar = jar.add(cookie); + (session_info, jar) + } + + fn update_session_info(self, info: &SessionInfo) -> Self { + let cookie = Cookie::new("session", ""); + let cookie = cookie.encode(&info); + self.add(cookie) + } +} diff --git a/crates/axum-utils/src/url_builder.rs b/crates/axum-utils/src/url_builder.rs new file mode 100644 index 00000000..56bd061e --- /dev/null +++ b/crates/axum-utils/src/url_builder.rs @@ -0,0 +1,100 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utility to build URLs + +use url::Url; + +/// Helps building absolute URLs +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct UrlBuilder { + base: Url, +} + +impl UrlBuilder { + /// Create a new [`UrlBuilder`] from a base URL + #[must_use] + pub fn new(base: Url) -> Self { + Self { base } + } + + /// OIDC issuer + #[must_use] + pub fn oidc_issuer(&self) -> Url { + self.base.clone() + } + + /// OIDC dicovery document URL + #[must_use] + pub fn oidc_discovery(&self) -> Url { + self.base + .join(".well-known/openid-configuration") + .expect("build URL") + } + + /// OAuth 2.0 authorization endpoint + #[must_use] + pub fn oauth_authorization_endpoint(&self) -> Url { + self.base.join("oauth2/authorize").expect("build URL") + } + + /// OAuth 2.0 token endpoint + #[must_use] + pub fn oauth_token_endpoint(&self) -> Url { + self.base.join("oauth2/token").expect("build URL") + } + + /// OAuth 2.0 introspection endpoint + #[must_use] + pub fn oauth_introspection_endpoint(&self) -> Url { + self.base.join("oauth2/introspect").expect("build URL") + } + + /// OAuth 2.0 introspection endpoint + #[must_use] + pub fn oidc_userinfo_endpoint(&self) -> Url { + self.base.join("oauth2/userinfo").expect("build URL") + } + + /// JWKS URI + #[must_use] + pub fn jwks_uri(&self) -> Url { + self.base.join("oauth2/keys.json").expect("build URL") + } + + /// Email verification URL + #[must_use] + pub fn email_verification(&self, code: &str) -> Url { + self.base + .join("verify/") + .expect("build URL") + .join(code) + .expect("build URL") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_email_verification_url() { + let base = Url::parse("https://example.com/").unwrap(); + let builder = UrlBuilder::new(base); + assert_eq!( + builder.email_verification("123456abcdef").as_str(), + "https://example.com/verify/123456abcdef" + ); + } +} diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 920abb10..c9eb6c4d 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -43,6 +43,7 @@ mas-storage = { path = "../storage" } mas-tasks = { path = "../tasks" } mas-templates = { path = "../templates" } mas-warp-utils = { path = "../warp-utils" } +mas-axum-utils = { path = "../axum-utils" } [dev-dependencies] indoc = "1.0.4" diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 27555c7d..df9ddf21 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -22,12 +22,12 @@ use anyhow::Context; use clap::Parser; use futures::{future::TryFutureExt, stream::TryStreamExt}; use hyper::Server; +use mas_axum_utils::UrlBuilder; use mas_config::RootConfig; use mas_email::{MailTransport, Mailer}; use mas_storage::MIGRATOR; use mas_tasks::TaskQueue; use mas_templates::Templates; -use tower::{make::Shared, Layer}; use tracing::{error, info}; #[derive(Parser, Debug, Default)] @@ -191,6 +191,11 @@ impl Options { &config.email.reply_to, ); + let url_builder = UrlBuilder::new(config.http.public_base.clone()); + + // Explicitely the config to properly zeroize secret keys + drop(config); + // Watch for changes in templates if the --watch flag is present if self.watch { let client = watchman_client::Connector::new() @@ -203,11 +208,14 @@ impl Options { .context("could not watch for templates changes")?; } - let router = - mas_handlers::router(&pool, &templates, &key_store, &encrypter, &mailer, &config); - - // Explicitely the config to properly zeroize secret keys - drop(config); + let router = mas_handlers::router( + &pool, + &templates, + &key_store, + &encrypter, + &mailer, + &url_builder, + ); info!("Listening on http://{}", listener.local_addr().unwrap()); diff --git a/crates/config/Cargo.toml b/crates/config/Cargo.toml index f5b9f787..a4eb0b94 100644 --- a/crates/config/Cargo.toml +++ b/crates/config/Cargo.toml @@ -31,6 +31,7 @@ pkcs8 = { version = "0.8.0", features = ["pem"] } chacha20poly1305 = { version = "0.9.0", features = ["std"] } elliptic-curve = { version = "0.11.12", features = ["pem", "pkcs8"] } pem-rfc7468 = "0.3.1" +cookie = { version = "0.16.0", features = ["private", "key-expansion"] } indoc = "1.0.4" diff --git a/crates/config/src/sections/secrets.rs b/crates/config/src/sections/secrets.rs index 706b6e5b..19d31992 100644 --- a/crates/config/src/sections/secrets.rs +++ b/crates/config/src/sections/secrets.rs @@ -20,6 +20,7 @@ use chacha20poly1305::{ aead::{generic_array::GenericArray, Aead, NewAead}, ChaCha20Poly1305, }; +use cookie::Key; use mas_jose::StaticKeystore; use pkcs8::DecodePrivateKey; use rsa::{ @@ -37,6 +38,7 @@ use super::ConfigurationSection; /// Helps encrypting and decrypting data #[derive(Clone)] pub struct Encrypter { + cookie_key: Arc, aead: Arc, } @@ -44,10 +46,12 @@ impl Encrypter { /// Creates an [`Encrypter`] out of an encryption key #[must_use] pub fn new(key: &[u8; 32]) -> Self { + let cookie_key = Key::derive_from(&key[..]); + let cookie_key = Arc::new(cookie_key); let key = GenericArray::from_slice(key); let aead = ChaCha20Poly1305::new(key); let aead = Arc::new(aead); - Self { aead } + Self { cookie_key, aead } } /// Encrypt a payload @@ -73,6 +77,12 @@ impl Encrypter { } } +impl From for Key { + fn from(e: Encrypter) -> Self { + e.cookie_key.as_ref().clone() + } +} + fn example_secret() -> &'static str { "0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff" } diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 404f583f..1a3f3677 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use axum::{extract::Extension, routing::get, Router}; +use mas_axum_utils::UrlBuilder; use mas_config::{Encrypter, RootConfig}; use mas_email::Mailer; use mas_jose::StaticKeystore; @@ -60,13 +61,14 @@ pub fn root( filter.with(warp::log(module_path!())).boxed() } +#[must_use] pub fn router( pool: &PgPool, templates: &Templates, key_store: &Arc, encrypter: &Encrypter, mailer: &Mailer, - config: &RootConfig, + url_builder: &UrlBuilder, ) -> Router { Router::new() .route("/", get(self::views::index::get)) @@ -75,5 +77,6 @@ pub fn router( .layer(Extension(templates.clone())) .layer(Extension(key_store.clone())) .layer(Extension(encrypter.clone())) + .layer(Extension(url_builder.clone())) .layer(Extension(mailer.clone())) } diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index d584d003..a616d6ec 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -12,79 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::str::FromStr; - use axum::{ extract::Extension, response::{Html, IntoResponse}, }; -use mas_axum_utils::{fancy_error, FancyError}; -use mas_config::{CsrfConfig, Encrypter, HttpConfig}; -use mas_data_model::BrowserSession; -use mas_storage::PostgresqlBackend; -use mas_templates::{IndexContext, TemplateContext, Templates}; -use mas_warp_utils::filters::{ - self, - cookies::{encrypted_cookie_saver, EncryptedCookieSaver}, - csrf::updated_csrf_token, - session::optional_session, - url_builder::{url_builder, UrlBuilder}, - with_templates, CsrfToken, +use mas_axum_utils::{ + csrf::CsrfExt, fancy_error, FancyError, PrivateCookieJar, SessionInfoExt, UrlBuilder, }; +use mas_config::Encrypter; +use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; -use url::Url; -use warp::{filters::BoxedFilter, reply::html, Filter, Rejection, Reply}; - -/* -pub(super) fn filter( - pool: &PgPool, - templates: &Templates, - encrypter: &Encrypter, - http_config: &HttpConfig, - csrf_config: &CsrfConfig, -) -> BoxedFilter<(Box,)> { - warp::path::end() - .and(filters::trace::name("GET /")) - .and(warp::get()) - .and(url_builder(http_config)) - .and(with_templates(templates)) - .and(encrypted_cookie_saver(encrypter)) - .and(updated_csrf_token(encrypter, csrf_config)) - .and(optional_session(pool, encrypter)) - .and_then(get) - .boxed() -} - -async fn get( - url_builder: UrlBuilder, - templates: Templates, - cookie_saver: EncryptedCookieSaver, - csrf_token: CsrfToken, - maybe_session: Option>, -) -> Result, Rejection> { - let ctx = IndexContext::new(url_builder.oidc_discovery()) - .maybe_with_session(maybe_session) - .with_csrf(csrf_token.form_value()); - - let content = templates.render_index(&ctx).await?; - let reply = html(content); - let reply = cookie_saver.save_encrypted(&csrf_token, reply)?; - Ok(Box::new(reply)) -} -*/ pub async fn get( Extension(templates): Extension, + Extension(url_builder): Extension, + Extension(pool): Extension, + cookie_jar: PrivateCookieJar, ) -> Result { - let ctx = IndexContext::new( - Url::from_str("https://example.com/.well-known/openid-discovery").unwrap(), - ) - .maybe_with_session::(None) - .with_csrf("csrf_token".to_string()); + let mut conn = pool + .acquire() + .await + .map_err(fancy_error(templates.clone()))?; + + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); + let (session_info, cookie_jar) = cookie_jar.session_info(); + let session = session_info + .load_session(&mut conn) + .await + .map_err(fancy_error(templates.clone()))?; + + let ctx = IndexContext::new(url_builder.oidc_discovery()) + .maybe_with_session(session) + .with_csrf(csrf_token.form_value()); let content = templates .render_index(&ctx) .await .map_err(fancy_error(templates))?; - Ok(Html(content)) + + Ok((cookie_jar.headers(), Html(content))) } diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 45e1494c..2b7cf71f 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -56,13 +56,14 @@ pub trait TemplateContext: Serialize { } /// Attach a CSRF token to the template context - fn with_csrf(self, csrf_token: String) -> WithCsrf + fn with_csrf(self, csrf_token: C) -> WithCsrf where Self: Sized, + C: ToString, { // TODO: make this method use a CsrfToken again WithCsrf { - csrf_token, + csrf_token: csrf_token.to_string(), inner: self, } }