diff --git a/Cargo.lock b/Cargo.lock index c658fbc5..ac9f70f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -956,6 +956,23 @@ dependencies = [ "version_check", ] +[[package]] +name = "cookie_store" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387461abbc748185c3a6e1673d826918b450b87ff22639429c694619a83b6cf6" +dependencies = [ + "cookie", + "idna 0.3.0", + "log", + "publicsuffix", + "serde", + "serde_derive", + "serde_json", + "time", + "url", +] + [[package]] name = "core-foundation" version = "0.9.3" @@ -2770,6 +2787,7 @@ dependencies = [ "bcrypt", "camino", "chrono", + "cookie_store", "futures-util", "headers", "hyper", @@ -4106,6 +4124,12 @@ version = "2.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" +[[package]] +name = "psl-types" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" + [[package]] name = "psm" version = "0.1.21" @@ -4115,6 +4139,16 @@ dependencies = [ "cc", ] +[[package]] +name = "publicsuffix" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96a8c1bda5ae1af7f99a2962e49df150414a43d62404644d98dd5c3a93d07457" +dependencies = [ + "idna 0.3.0", + "psl-types", +] + [[package]] name = "pulldown-cmark" version = "0.9.3" diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index c5f2f561..c5c41f9f 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -79,6 +79,7 @@ oauth2-types = { path = "../oauth2-types" } [dev-dependencies] insta = "1.31.0" tracing-subscriber = "0.3.17" +cookie_store = "0.20.0" [features] default = ["webpki-roots"] diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 3575d81f..7ae54316 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -14,7 +14,8 @@ use std::{ convert::Infallible, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, RwLock}, + task::{Context, Poll}, }; use axum::{ @@ -22,8 +23,13 @@ use axum::{ body::{Bytes, HttpBody}, extract::{FromRef, FromRequestParts}, }; -use headers::{Authorization, ContentType, HeaderMapExt, HeaderName}; -use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; +use cookie_store::{CookieStore, RawCookie}; +use futures_util::future::BoxFuture; +use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue}; +use hyper::{ + header::{CONTENT_TYPE, COOKIE, SET_COOKIE}, + Request, Response, StatusCode, +}; use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; @@ -36,7 +42,8 @@ use rand::SeedableRng; use rand_chacha::ChaChaRng; use serde::{de::DeserializeOwned, Serialize}; use sqlx::PgPool; -use tower::{Service, ServiceExt}; +use tower::{Layer, Service, ServiceExt}; +use url::Url; use crate::{ app_state::RepositoryError, @@ -483,3 +490,100 @@ impl ResponseExt for Response { serde_json::from_str(self.body()).expect("JSON deserialization failed") } } + +/// A helper for storing and retrieving cookies in tests. +#[derive(Clone, Debug, Default)] +pub struct CookieHelper { + store: Arc>, +} + +impl CookieHelper { + pub fn new() -> Self { + Self::default() + } + + /// Inject the cookies from the store into the request. + pub fn with_cookies(&self, mut request: Request) -> Request { + let url = Url::options() + .base_url(Some(&"https://example.com/".parse().unwrap())) + .parse(&request.uri().to_string()) + .expect("Failed to parse URL"); + + let store = self.store.read().unwrap(); + let value = store + .get_request_values(&url) + .map(|(name, value)| format!("{name}={value}")) + .collect::>() + .join("; "); + + request.headers_mut().insert( + COOKIE, + HeaderValue::from_str(&value).expect("Invalid cookie value"), + ); + request + } + + /// Save the cookies from the response into the store. + pub fn save_cookies(&self, response: &Response) { + let url = "https://example.com/".parse().unwrap(); + let mut store = self.store.write().unwrap(); + store.store_response_cookies( + response + .headers() + .get_all(SET_COOKIE) + .iter() + .map(|set_cookie| { + RawCookie::parse( + set_cookie + .to_str() + .expect("Invalid set-cookie header") + .to_owned(), + ) + .expect("Invalid set-cookie header") + }), + &url, + ); + } +} + +impl Layer for CookieHelper { + type Service = CookieStoreService; + + fn layer(&self, inner: S) -> Self::Service { + CookieStoreService { + helper: self.clone(), + inner, + } + } +} + +/// A middleware that stores and retrieves cookies. +pub struct CookieStoreService { + helper: CookieHelper, + inner: S, +} + +impl Service> for CookieStoreService +where + S: Service, Response = Response> + Send, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let req = self.helper.with_cookies(request); + let inner = self.inner.call(req); + let helper = self.helper.clone(); + Box::pin(async move { + let response: Response<_> = inner.await?; + helper.save_cookies(&response); + Ok(response) + }) + } +} diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 57dc46da..44552955 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -290,10 +290,11 @@ mod test { use mas_templates::escape_html; use oauth2_types::scope::OPENID; use sqlx::PgPool; + use zeroize::Zeroizing; use crate::{ passwords::PasswordManager, - test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}, + test_utils::{init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState}, }; #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] @@ -375,4 +376,67 @@ mod test { .body() .contains(&escape_html(&second_provider_login.relative_url()))); } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + let mut rng = state.rng(); + let cookies = CookieHelper::new(); + + // Provision a user with a password + let mut repo = state.repository().await.unwrap(); + let user = repo + .user() + .add(&mut rng, &state.clock, "john".to_owned()) + .await + .unwrap(); + let (version, hash) = state + .password_manager + .hash(&mut rng, Zeroizing::new("hunter2".as_bytes().to_vec())) + .await + .unwrap(); + repo.user_password() + .add(&mut rng, &state.clock, &user, version, hash, None) + .await + .unwrap(); + repo.save().await.unwrap(); + + // Render the login page to get a CSRF token + let request = Request::get("/login").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + // Submit the login form + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "john", + "password": "hunter2", + })); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::SEE_OTHER); + + // Now if we get to the home page, we should see the user's username + let request = Request::get("/").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(response.body().contains("john")); + } }