From d2d68e9a27526b34d77bac7e7c837a7e80c633b5 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 23 May 2023 14:20:27 +0200 Subject: [PATCH] Make password-based login optional --- Cargo.lock | 1 + Cargo.toml | 10 ++ crates/cli/Cargo.toml | 1 + crates/cli/src/commands/manage.rs | 4 +- crates/cli/src/util.rs | 77 +++++++++ crates/config/src/sections/passwords.rs | 15 ++ crates/handlers/src/compat/login.rs | 83 ++++++++-- crates/handlers/src/passwords.rs | 104 +++++++++--- crates/handlers/src/views/account/password.rs | 12 ++ crates/handlers/src/views/login.rs | 151 ++++++++++++++++-- crates/handlers/src/views/reauth.rs | 12 ++ crates/handlers/src/views/register.rs | 76 +++++++-- crates/templates/src/context.rs | 31 +++- crates/templates/src/lib.rs | 1 + docs/config.schema.json | 6 + templates/pages/login.html | 106 ++++++------ 16 files changed, 572 insertions(+), 118 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3012fbd7..535793d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3178,6 +3178,7 @@ dependencies = [ "tracing-subscriber", "url", "watchman_client", + "zeroize", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 0492c568..0c17cf1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,16 @@ opt-level = 3 [profile.dev.package.sqlx-macros] opt-level = 3 +[profile.dev.package.cranelift-codegen] +opt-level = 3 + +[profile.dev.package.regalloc2] +opt-level = 3 + +[profile.dev.package.argon2] +opt-level = 3 + + # Until https://github.com/dylanhart/ulid-rs/pull/56 gets released [patch.crates-io.ulid] git = "https://github.com/dylanhart/ulid-rs.git" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index bd23364d..0c091c0b 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -27,6 +27,7 @@ tower = { version = "0.4.13", features = ["full"] } tower-http = { version = "0.4.0", features = ["fs", "compression-full"] } url = "2.3.1" watchman_client = "0.8.0" +zeroize = "1.6.0" tracing = "0.1.37" tracing-appender = "0.2.2" diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index b685a167..f78c9987 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -327,7 +327,7 @@ impl Options { let encrypter = config.secrets.encrypter(); let pool = database_from_config(&config.database).await?; let url_builder = UrlBuilder::new(config.http.public_base); - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); let requires_client_secret = token_endpoint_auth_method.requires_client_secret(); @@ -362,6 +362,8 @@ impl Options { ) .await?; + repo.save().await?; + let redirect_uri = url_builder.upstream_oauth_callback(provider.id); let auth_uri = url_builder.upstream_oauth_authorize(provider.id); tracing::info!( diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index b6485e95..703038c8 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -33,6 +33,10 @@ use tracing::{error, info, log::LevelFilter}; pub async fn password_manager_from_config( config: &PasswordsConfig, ) -> Result { + if !config.enabled() { + return Ok(PasswordManager::disabled()); + } + let schemes = config .load() .await? @@ -227,3 +231,76 @@ pub async fn watch_templates(templates: &Templates) -> anyhow::Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use rand::SeedableRng; + use zeroize::Zeroizing; + + use super::*; + + #[tokio::test] + async fn test_password_manager_from_config() { + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let password = Zeroizing::new(b"hunter2".to_vec()); + + // Test a valid, enabled config + let config = serde_json::from_value(serde_json::json!({ + "schemes": [{ + "version": 42, + "algorithm": "argon2id" + }, { + "version": 10, + "algorithm": "bcrypt" + }] + })) + .unwrap(); + + let manager = password_manager_from_config(&config).await; + assert!(manager.is_ok()); + let manager = manager.unwrap(); + assert!(manager.is_enabled()); + let hashed = manager.hash(&mut rng, password.clone()).await; + assert!(hashed.is_ok()); + let (version, hashed) = hashed.unwrap(); + assert_eq!(version, 42); + assert!(hashed.starts_with("$argon2id$")); + + // Test a valid, disabled config + let config = serde_json::from_value(serde_json::json!({ + "enabled": false, + "schemes": [] + })) + .unwrap(); + + let manager = password_manager_from_config(&config).await; + assert!(manager.is_ok()); + let manager = manager.unwrap(); + assert!(!manager.is_enabled()); + let res = manager.hash(&mut rng, password.clone()).await; + assert!(res.is_err()); + + // Test an invalid config + // Repeat the same version twice + let config = serde_json::from_value(serde_json::json!({ + "schemes": [{ + "version": 42, + "algorithm": "argon2id" + }, { + "version": 42, + "algorithm": "bcrypt" + }] + })) + .unwrap(); + let manager = password_manager_from_config(&config).await; + assert!(manager.is_err()); + + // Empty schemes + let config = serde_json::from_value(serde_json::json!({ + "schemes": [] + })) + .unwrap(); + let manager = password_manager_from_config(&config).await; + assert!(manager.is_err()); + } +} diff --git a/crates/config/src/sections/passwords.rs b/crates/config/src/sections/passwords.rs index 19365534..f92db589 100644 --- a/crates/config/src/sections/passwords.rs +++ b/crates/config/src/sections/passwords.rs @@ -31,9 +31,17 @@ fn default_schemes() -> Vec { }] } +fn default_enabled() -> bool { + true +} + /// User password hashing config #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct PasswordsConfig { + /// Whether password-based authentication is enabled + #[serde(default = "default_enabled")] + enabled: bool, + #[serde(default = "default_schemes")] schemes: Vec, } @@ -41,6 +49,7 @@ pub struct PasswordsConfig { impl Default for PasswordsConfig { fn default() -> Self { Self { + enabled: default_enabled(), schemes: default_schemes(), } } @@ -65,6 +74,12 @@ impl ConfigurationSection<'_> for PasswordsConfig { } impl PasswordsConfig { + /// Whether password-based authentication is enabled + #[must_use] + pub fn enabled(&self) -> bool { + self.enabled + } + /// Load the password hashing schemes defined by the config /// /// # Errors diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 8dd19307..96bfde15 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -66,18 +66,28 @@ struct LoginTypes { } #[tracing::instrument(name = "handlers.compat.login.get", skip_all)] -pub(crate) async fn get() -> impl IntoResponse { - let res = LoginTypes { - flows: vec![ +pub(crate) async fn get(State(password_manager): State) -> impl IntoResponse { + let flows = if password_manager.is_enabled() { + vec![ LoginType::Password, LoginType::Sso { identity_providers: vec![], delegated_oidc_compatibility: true, }, LoginType::Token, - ], + ] + } else { + vec![ + LoginType::Sso { + identity_providers: vec![], + delegated_oidc_compatibility: true, + }, + LoginType::Token, + ] }; + let res = LoginTypes { flows }; + Json(res) } @@ -202,11 +212,14 @@ pub(crate) async fn post( State(homeserver): State, Json(input): Json, ) -> Result { - let (session, user) = match input.credentials { - Credentials::Password { - identifier: Identifier::User { user }, - password, - } => { + let (session, user) = match (password_manager.is_enabled(), input.credentials) { + ( + true, + Credentials::Password { + identifier: Identifier::User { user }, + password, + }, + ) => { user_password_login( &mut rng, &clock, @@ -218,7 +231,7 @@ pub(crate) async fn post( .await? } - Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?, + (_, Credentials::Token { token }) => token_login(&mut repo, &clock, &token).await?, _ => { return Err(RouteError::Unsupported); @@ -407,7 +420,7 @@ mod tests { init_tracing(); let state = TestState::from_pool(pool).await.unwrap(); - // Now let's try to login with the password, without asking for a refresh token. + // Now let's get the login flows let request = Request::get("/_matrix/client/v3/login").empty(); let response = state.request(request).await; response.assert_status(StatusCode::OK); @@ -432,6 +445,54 @@ mod tests { ); } + /// Test that the server doesn't allow login with a password if the password + /// manager is disabled + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_disabled(pool: PgPool) { + init_tracing(); + let state = { + let mut state = TestState::from_pool(pool).await.unwrap(); + state.password_manager = PasswordManager::disabled(); + state + }; + + // Now let's get the login flows + let request = Request::get("/_matrix/client/v3/login").empty(); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let body: serde_json::Value = response.json(); + + assert_eq!( + body, + serde_json::json!({ + "flows": [ + { + "type": "m.login.sso", + "org.matrix.msc3824.delegated_oidc_compatibility": true, + }, + { + "type": "m.login.token", + } + ], + }) + ); + + // Try to login with a password, it should be rejected + let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "alice", + }, + "password": "password", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let body: serde_json::Value = response.json(); + assert_eq!(body["errcode"], "M_UNRECOGNIZED"); + } + /// Test that a user can login with a password using the Matrix /// compatibility API. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] diff --git a/crates/handlers/src/passwords.rs b/crates/handlers/src/passwords.rs index 4be71a45..9e29d9bb 100644 --- a/crates/handlers/src/passwords.rs +++ b/crates/handlers/src/passwords.rs @@ -19,14 +19,26 @@ use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, Pa use futures_util::future::OptionFuture; use pbkdf2::Pbkdf2; use rand::{CryptoRng, Rng, RngCore, SeedableRng}; +use thiserror::Error; use zeroize::Zeroizing; pub type SchemeVersion = u16; +#[derive(Debug, Error)] +#[error("Password manager is disabled")] +pub struct PasswordManagerDisabledError; + #[derive(Clone)] pub struct PasswordManager { - hashers: Arc>, - default_hasher: SchemeVersion, + inner: Option>, +} + +struct InnerPasswordManager { + current_hasher: Hasher, + current_version: SchemeVersion, + + /// A map of "old" hashers used only for verification + other_hashers: HashMap, } impl PasswordManager { @@ -51,58 +63,87 @@ impl PasswordManager { pub fn new>( iter: I, ) -> Result { - let mut iter = iter.into_iter().peekable(); - let (default_hasher, _) = iter - .peek() - .context("Iterator must have at least one item")?; - let default_hasher = *default_hasher; + let mut iter = iter.into_iter(); - let hashers = iter.collect(); + // Take the first hasher as the current hasher + let (current_version, current_hasher) = iter + .next() + .context("Iterator must have at least one item")?; + + // Collect the other hashers in a map used only in verification + let other_hashers = iter.collect(); Ok(Self { - hashers: Arc::new(hashers), - default_hasher, + inner: Some(Arc::new(InnerPasswordManager { + current_hasher, + current_version, + other_hashers, + })), }) } + /// Creates a new disabled password manager + #[must_use] + pub const fn disabled() -> Self { + Self { inner: None } + } + + /// Checks if the password manager is enabled or not + #[must_use] + pub const fn is_enabled(&self) -> bool { + self.inner.is_some() + } + + /// Get the inner password manager + /// + /// # Errors + /// + /// Returns an error if the password manager is disabled + fn get_inner(&self) -> Result, PasswordManagerDisabledError> { + self.inner + .as_ref() + .map(Arc::clone) + .ok_or(PasswordManagerDisabledError) + } + /// Hash a password with the default hashing scheme. /// Returns the version of the hashing scheme used and the hashed password. /// /// # Errors /// - /// Returns an error if the hashing failed + /// Returns an error if the hashing failed or if the password manager is + /// disabled #[tracing::instrument(name = "passwords.hash", skip_all)] pub async fn hash( &self, rng: R, password: Zeroizing>, ) -> Result<(SchemeVersion, String), anyhow::Error> { + let inner = self.get_inner()?; + // Seed a future-local RNG so the RNG passed in parameters doesn't have to be // 'static let rng = rand_chacha::ChaChaRng::from_rng(rng)?; - let hashers = self.hashers.clone(); - let default_hasher_version = self.default_hasher; let span = tracing::Span::current(); - let hashed = tokio::task::spawn_blocking(move || { - span.in_scope(move || { - let default_hasher = hashers - .get(&default_hasher_version) - .context("Default hasher not found")?; + // `inner` is being moved in the blocking task, so we need to copy the version + // first + let version = inner.current_version; - default_hasher.hash_blocking(rng, &password) - }) + let hashed = tokio::task::spawn_blocking(move || { + span.in_scope(move || inner.current_hasher.hash_blocking(rng, &password)) }) .await??; - Ok((default_hasher_version, hashed)) + Ok((version, hashed)) } /// Verify a password hash for the given hashing scheme. /// /// # Errors /// - /// Returns an error if the password hash verification failed + /// Returns an error if the password hash verification failed or if the + /// password manager is disabled #[tracing::instrument(name = "passwords.verify", skip_all, fields(%scheme))] pub async fn verify( &self, @@ -110,12 +151,20 @@ impl PasswordManager { password: Zeroizing>, hashed_password: String, ) -> Result<(), anyhow::Error> { - let hashers = self.hashers.clone(); + let inner = self.get_inner()?; let span = tracing::Span::current(); tokio::task::spawn_blocking(move || { span.in_scope(move || { - let hasher = hashers.get(&scheme).context("Hashing scheme not found")?; + let hasher = if scheme == inner.current_version { + &inner.current_hasher + } else { + inner + .other_hashers + .get(&scheme) + .context("Hashing scheme not found")? + }; + hasher.verify_blocking(&hashed_password, &password) }) }) @@ -129,7 +178,8 @@ impl PasswordManager { /// /// # Errors /// - /// Returns an error if the password hash verification failed + /// Returns an error if the password hash verification failed or if the + /// password manager is disabled #[tracing::instrument(name = "passwords.verify_and_upgrade", skip_all, fields(%scheme))] pub async fn verify_and_upgrade( &self, @@ -138,9 +188,11 @@ impl PasswordManager { password: Zeroizing>, hashed_password: String, ) -> Result, anyhow::Error> { + let inner = self.get_inner()?; + // If the current scheme isn't the default one, we also hash with the default // one so that - let new_hash_fut: OptionFuture<_> = (scheme != self.default_hasher) + let new_hash_fut: OptionFuture<_> = (scheme != inner.current_version) .then(|| self.hash(rng, password.clone())) .into(); diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 674e0f94..8de6d586 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -15,6 +15,7 @@ use anyhow::Context; use axum::{ extract::{Form, State}, + http::StatusCode, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -48,9 +49,15 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, + State(password_manager): State, mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { + // If the password manager is disabled, we can go back to the account page. + if !password_manager.is_enabled() { + return Ok(mas_router::Account.go().into_response()); + } + let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -91,6 +98,11 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + if !password_manager.is_enabled() { + // XXX: do something better here + return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); + } + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 853e5e05..4b2d8a44 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -17,12 +17,14 @@ use axum::{ response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; +use hyper::StatusCode; use mas_axum_utils::{ csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; use mas_data_model::BrowserSession; use mas_keystore::Encrypter; +use mas_router::{Route, UpstreamOAuth2Authorize}; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, @@ -52,6 +54,7 @@ impl ToFormState for LoginForm { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, + State(password_manager): State, State(templates): State, mut repo: BoxRepository, Query(query): Query, @@ -64,20 +67,38 @@ pub(crate) async fn get( if maybe_session.is_some() { let reply = query.go_next(); - Ok((cookie_jar, reply).into_response()) - } else { - let providers = repo.upstream_oauth_provider().all().await?; - let content = render( - LoginContext::default().with_upstrem_providers(providers), - query, - csrf_token, - &mut repo, - &templates, - ) - .await?; + return Ok((cookie_jar, reply).into_response()); + }; - Ok((cookie_jar, Html(content)).into_response()) - } + let providers = repo.upstream_oauth_provider().all().await?; + + // If password-based login is disabled, and there is only one upstream provider, + // we can directly start an authorization flow + if !password_manager.is_enabled() && providers.len() == 1 { + let provider = providers.into_iter().next().unwrap(); + + let mut destination = UpstreamOAuth2Authorize::new(provider.id); + + if let Some(action) = query.post_auth_action { + destination = destination.and_then(action); + }; + + return Ok((cookie_jar, destination.go()).into_response()); + }; + + let content = render( + LoginContext::default() + // XXX: we might want to have a site-wide config in the templates context instead? + .with_password_login(password_manager.is_enabled()) + .with_upstream_providers(providers), + query, + csrf_token, + &mut repo, + &templates, + ) + .await?; + + Ok((cookie_jar, Html(content)).into_response()) } #[tracing::instrument(name = "handlers.views.login.post", skip_all, err)] @@ -91,6 +112,11 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + if !password_manager.is_enabled() { + // XXX: is it necessary to have better errors here? + return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); + } + let form = cookie_jar.verify_form(&clock, form)?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); @@ -115,7 +141,7 @@ pub(crate) async fn post( let content = render( LoginContext::default() .with_form_state(state) - .with_upstrem_providers(providers), + .with_upstream_providers(providers), query, csrf_token, &mut repo, @@ -251,3 +277,100 @@ async fn render( let content = templates.render_login(&ctx).await?; Ok(content) } + +#[cfg(test)] +mod test { + use hyper::{ + header::{CONTENT_TYPE, LOCATION}, + Request, StatusCode, + }; + use mas_iana::oauth::OAuthClientAuthenticationMethod; + use mas_router::Route; + use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess}; + use mas_templates::escape_html; + use oauth2_types::scope::OPENID; + use sqlx::PgPool; + + use crate::{ + passwords::PasswordManager, + test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}, + }; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_disabled(pool: PgPool) { + init_tracing(); + let state = { + let mut state = TestState::from_pool(pool).await.unwrap(); + state.password_manager = PasswordManager::disabled(); + state + }; + let mut rng = state.rng(); + + // Without password login and no upstream providers, we should get an error + // message + let response = state.request(Request::get("/login").empty()).await; + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(response.body().contains("No login method available")); + + // Adding an upstream provider should redirect to it + let mut repo = state.repository().await.unwrap(); + let first_provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &state.clock, + "https://first.com/".into(), + [OPENID].into_iter().collect(), + OAuthClientAuthenticationMethod::None, + None, + "first_client".into(), + None, + ) + .await + .unwrap(); + repo.save().await.unwrap(); + + let first_provider_login = mas_router::UpstreamOAuth2Authorize::new(first_provider.id); + + let response = state.request(Request::get("/login").empty()).await; + response.assert_status(StatusCode::SEE_OTHER); + response.assert_header_value(LOCATION, &first_provider_login.relative_url()); + + // Adding a second provider should show a login page with both providers + let mut repo = state.repository().await.unwrap(); + let second_provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &state.clock, + "https://second.com/".into(), + [OPENID].into_iter().collect(), + OAuthClientAuthenticationMethod::None, + None, + "second_client".into(), + None, + ) + .await + .unwrap(); + repo.save().await.unwrap(); + + let second_provider_login = mas_router::UpstreamOAuth2Authorize::new(second_provider.id); + + let response = state.request(Request::get("/login").empty()).await; + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(response + .body() + .contains(&escape_html(&first_provider.issuer))); + assert!(response + .body() + .contains(&escape_html(&first_provider_login.relative_url()))); + assert!(response + .body() + .contains(&escape_html(&second_provider.issuer))); + assert!(response + .body() + .contains(&escape_html(&second_provider_login.relative_url()))); + } +} diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 549326bf..1c6243d5 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -18,6 +18,7 @@ use axum::{ response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; +use hyper::StatusCode; use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, FancyError, SessionInfoExt, @@ -44,11 +45,17 @@ pub(crate) struct ReauthForm { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, + State(password_manager): State, State(templates): State, mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { + if !password_manager.is_enabled() { + // XXX: do something better here + return Ok(mas_router::Account.go().into_response()); + } + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -85,6 +92,11 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + if !password_manager.is_enabled() { + // XXX: do something better here + return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); + } + let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 61b4c8bb..8d336af9 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -19,6 +19,7 @@ use axum::{ response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; +use hyper::StatusCode; use lettre::Address; use mas_axum_utils::{ csrf::{CsrfExt, CsrfToken, ProtectedForm}, @@ -59,6 +60,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, + State(password_manager): State, mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, @@ -70,19 +72,26 @@ pub(crate) async fn get( if maybe_session.is_some() { let reply = query.go_next(); - Ok((cookie_jar, reply).into_response()) - } else { - let content = render( - RegisterContext::default(), - query, - csrf_token, - &mut repo, - &templates, - ) - .await?; - - Ok((cookie_jar, Html(content)).into_response()) + return Ok((cookie_jar, reply).into_response()); } + + if !password_manager.is_enabled() { + // If password-based login is disabled, redirect to the login page here + return Ok(mas_router::Login::from(query.post_auth_action) + .go() + .into_response()); + } + + let content = render( + RegisterContext::default(), + query, + csrf_token, + &mut repo, + &templates, + ) + .await?; + + Ok((cookie_jar, Html(content)).into_response()) } #[tracing::instrument(name = "handlers.views.register.post", skip_all, err)] @@ -98,6 +107,10 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + if !password_manager.is_enabled() { + return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); + } + let form = cookie_jar.verify_form(&clock, form)?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); @@ -233,3 +246,42 @@ async fn render( let content = templates.render_register(&ctx).await?; Ok(content) } + +#[cfg(test)] +mod tests { + use hyper::{header::LOCATION, Request, StatusCode}; + use mas_router::Route; + use sqlx::PgPool; + + use crate::{ + passwords::PasswordManager, + test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}, + }; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_disabled(pool: PgPool) { + init_tracing(); + let state = { + let mut state = TestState::from_pool(pool).await.unwrap(); + state.password_manager = PasswordManager::disabled(); + state + }; + + let request = Request::get(&*mas_router::Register::default().relative_url()).empty(); + let response = state.request(request).await; + response.assert_status(StatusCode::SEE_OTHER); + response.assert_header_value(LOCATION, "/login"); + + let request = Request::post(&*mas_router::Register::default().relative_url()).form( + serde_json::json!({ + "csrf": "abc", + "username": "john", + "email": "john@example.com", + "password": "hunter2", + "password_confirm": "hunter2", + }), + ); + let response = state.request(request).await; + response.assert_status(StatusCode::METHOD_NOT_ALLOWED); + } +} diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index f5edef39..8b453ec9 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -286,6 +286,7 @@ pub struct PostAuthContext { pub struct LoginContext { form: FormState, next: Option, + password_disabled: bool, providers: Vec, } @@ -295,15 +296,33 @@ impl TemplateContext for LoginContext { Self: Sized, { // TODO: samples with errors - vec![LoginContext { - form: FormState::default(), - next: None, - providers: Vec::new(), - }] + vec![ + LoginContext { + form: FormState::default(), + next: None, + password_disabled: true, + providers: Vec::new(), + }, + LoginContext { + form: FormState::default(), + next: None, + password_disabled: false, + providers: Vec::new(), + }, + ] } } impl LoginContext { + /// Set whether password login is enabled or not + #[must_use] + pub fn with_password_login(self, enabled: bool) -> Self { + Self { + password_disabled: !enabled, + ..self + } + } + /// Set the form state #[must_use] pub fn with_form_state(self, form: FormState) -> Self { @@ -312,7 +331,7 @@ impl LoginContext { /// Set the upstream OAuth 2.0 providers #[must_use] - pub fn with_upstrem_providers(self, providers: Vec) -> Self { + pub fn with_upstream_providers(self, providers: Vec) -> Self { Self { providers, ..self } } diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 938a701a..95de35e2 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -31,6 +31,7 @@ use camino::{Utf8Path, Utf8PathBuf}; use mas_router::UrlBuilder; use rand::Rng; use serde::Serialize; +pub use tera::escape_html; use tera::{Context, Error as TeraError, Tera}; use thiserror::Error; use tokio::{sync::RwLock, task::JoinError}; diff --git a/docs/config.schema.json b/docs/config.schema.json index e8e7fb0f..d1d914a0 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -130,6 +130,7 @@ "passwords": { "description": "Configuration related to user passwords", "default": { + "enabled": true, "schemes": [ { "algorithm": "argon2id", @@ -1215,6 +1216,11 @@ "description": "User password hashing config", "type": "object", "properties": { + "enabled": { + "description": "Whether password-based authentication is enabled", + "default": true, + "type": "boolean" + }, "schemes": { "default": [ { diff --git a/templates/pages/login.html b/templates/pages/login.html index 1fbe9992..ac5a467e 100644 --- a/templates/pages/login.html +++ b/templates/pages/login.html @@ -19,66 +19,76 @@ limitations under the License. {% block content %}
- {% if next and next.kind == "link_upstream" %} -
-

Sign in to link

-

Linking your {{ next.provider.issuer }} account

-
- {% else %} -
-

Sign in

-

Please sign in to continue:

-
- {% endif %} - - {% if form.errors is not empty %} - {% for error in form.errors %} -
- {{ errors::form_error_message(error=error) }} + {% if not password_disabled %} + {% if next and next.kind == "link_upstream" %} +
+

Sign in to link

+

Linking your {{ next.provider.issuer }} account

- {% endfor %} - {% endif %} + {% else %} +
+

Sign in

+

Please sign in to continue:

+
+ {% endif %} - - {{ field::input(label="Username", name="username", form_state=form, autocomplete="username", autocorrect="off", autocapitalize="none") }} - {{ field::input(label="Password", name="password", type="password", form_state=form, autocomplete="password") }} - {% if next and next.kind == "continue_authorization_grant" %} -
- {{ back_to_client::link( - text="Cancel", - class=button::outline_error_class(), - uri=next.grant.redirect_uri, - mode=next.grant.response_mode, - params=dict(error="access_denied", state=next.grant.state) - ) }} - {{ button::button(text="Next") }} -
- {% else %} -
- {{ button::button(text="Next") }} -
- {% endif %} + {% if form.errors is not empty %} + {% for error in form.errors %} +
+ {{ errors::form_error_message(error=error) }} +
+ {% endfor %} + {% endif %} - {% if not next or next.kind != "link_upstream" %} -
- Don't have an account yet? - {% set params = next | safe_get(key="params") | to_params(prefix="?") %} - {{ button::link_text(text="Create an account", href="/register" ~ params) }} -
+ + {{ field::input(label="Username", name="username", form_state=form, autocomplete="username", autocorrect="off", autocapitalize="none") }} + {{ field::input(label="Password", name="password", type="password", form_state=form, autocomplete="password") }} + {% if next and next.kind == "continue_authorization_grant" %} +
+ {{ back_to_client::link( + text="Cancel", + class=button::outline_error_class(), + uri=next.grant.redirect_uri, + mode=next.grant.response_mode, + params=dict(error="access_denied", state=next.grant.state) + ) }} + {{ button::button(text="Next") }} +
+ {% else %} +
+ {{ button::button(text="Next") }} +
+ {% endif %} + + {% if not next or next.kind != "link_upstream" %} +
+ Don't have an account yet? + {% set params = next | safe_get(key="params") | to_params(prefix="?") %} + {{ button::link_text(text="Create an account", href="/register" ~ params) }} +
+ {% endif %} {% endif %} {% if providers %} -
-
-
Or
-
-
+ {% if not password_disabled %} +
+
+
Or
+
+
+ {% endif %} {% for provider in providers %} {% set params = next | safe_get(key="params") | to_params(prefix="?") %} {{ button::link(text="Continue with " ~ provider.issuer, href="/upstream/authorize/" ~ provider.id ~ params) }} {% endfor %} {% endif %} + + {% if not providers and password_disabled %} +
+ No login method available. +
+ {% endif %}
{% endblock content %}