diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 7c5f9219..e45a5c0c 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -19,7 +19,7 @@ use clap::Parser; use figment::Figment; use itertools::Itertools; use mas_config::{AppConfig, ClientsConfig, ConfigurationSection, UpstreamOAuth2Config}; -use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache, SiteConfig}; +use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; @@ -37,7 +37,8 @@ use crate::{ app_state::AppState, util::{ database_pool_from_config, mailer_from_config, password_manager_from_config, - policy_factory_from_config, register_sighup, templates_from_config, + policy_factory_from_config, register_sighup, site_config_from_config, + templates_from_config, }, }; @@ -138,14 +139,17 @@ impl Options { None, ); - // Load and compile the templates - let templates = templates_from_config( - &config.templates, + // Load the site configuration + let site_config = site_config_from_config( &config.branding, - &url_builder, - &config.matrix.homeserver, - ) - .await?; + &config.matrix, + &config.experimental, + &config.passwords, + ); + + // Load and compile the templates + let templates = + templates_from_config(&config.templates, &site_config, &url_builder).await?; let http_client_factory = HttpClientFactory::new(); @@ -179,12 +183,6 @@ impl Options { // The upstream OIDC metadata cache let metadata_cache = MetadataCache::new(); - let site_config = SiteConfig { - tos_uri: config.branding.tos_uri.clone(), - access_token_ttl: config.experimental.access_token_ttl, - compat_token_ttl: config.experimental.compat_token_ttl, - }; - // Initialize the activity tracker // Activity is flushed every minute let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60)); diff --git a/crates/cli/src/commands/templates.rs b/crates/cli/src/commands/templates.rs index 9a0ce28d..6c905cfd 100644 --- a/crates/cli/src/commands/templates.rs +++ b/crates/cli/src/commands/templates.rs @@ -14,12 +14,15 @@ use clap::Parser; use figment::Figment; -use mas_config::{BrandingConfig, ConfigurationSection, MatrixConfig, TemplatesConfig}; +use mas_config::{ + BrandingConfig, ConfigurationSection, ExperimentalConfig, MatrixConfig, PasswordsConfig, + TemplatesConfig, +}; use mas_storage::{Clock, SystemClock}; use rand::SeedableRng; use tracing::info_span; -use crate::util::templates_from_config; +use crate::util::{site_config_from_config, templates_from_config}; #[derive(Parser, Debug)] pub(super) struct Options { @@ -43,19 +46,22 @@ impl Options { let template_config = TemplatesConfig::extract(figment)?; let branding_config = BrandingConfig::extract(figment)?; let matrix_config = MatrixConfig::extract(figment)?; + let experimental_config = ExperimentalConfig::extract(figment)?; + let password_config = PasswordsConfig::extract(figment)?; let clock = SystemClock::default(); // XXX: we should disallow SeedableRng::from_entropy let mut rng = rand_chacha::ChaChaRng::from_entropy(); let url_builder = mas_router::UrlBuilder::new("https://example.com/".parse()?, None, None); - let templates = templates_from_config( - &template_config, + let site_config = site_config_from_config( &branding_config, - &url_builder, - &matrix_config.homeserver, - ) - .await?; + &matrix_config, + &experimental_config, + &password_config, + ); + let templates = + templates_from_config(&template_config, &site_config, &url_builder).await?; templates.check_render(clock.now(), &mut rng)?; Ok(()) diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index ecb236ac..79120f20 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -24,7 +24,9 @@ use rand::{ }; use tracing::{info, info_span}; -use crate::util::{database_pool_from_config, mailer_from_config, templates_from_config}; +use crate::util::{ + database_pool_from_config, mailer_from_config, site_config_from_config, templates_from_config, +}; #[derive(Parser, Debug, Default)] pub(super) struct Options {} @@ -44,14 +46,17 @@ impl Options { None, ); - // Load and compile the templates - let templates = templates_from_config( - &config.templates, + // Load the site configuration + let site_config = site_config_from_config( &config.branding, - &url_builder, - &config.matrix.homeserver, - ) - .await?; + &config.matrix, + &config.experimental, + &config.passwords, + ); + + // Load and compile the templates + let templates = + templates_from_config(&config.templates, &site_config, &url_builder).await?; let mailer = mailer_from_config(&config.email, &templates)?; mailer.test_connection().await?; diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 001624a7..5467275e 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -17,13 +17,13 @@ use std::time::Duration; use anyhow::Context; use mas_config::{ BrandingConfig, DatabaseConfig, EmailConfig, EmailSmtpMode, EmailTransportKind, - PasswordsConfig, PolicyConfig, TemplatesConfig, + ExperimentalConfig, MatrixConfig, PasswordsConfig, PolicyConfig, TemplatesConfig, }; use mas_email::{MailTransport, Mailer}; -use mas_handlers::{passwords::PasswordManager, ActivityTracker}; +use mas_handlers::{passwords::PasswordManager, ActivityTracker, SiteConfig}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; -use mas_templates::{SiteBranding, TemplateLoadingError, Templates}; +use mas_templates::{TemplateLoadingError, Templates}; use sqlx::{ postgres::{PgConnectOptions, PgPoolOptions}, ConnectOptions, PgConnection, PgPool, @@ -119,36 +119,37 @@ pub async fn policy_factory_from_config( .context("failed to load the policy") } +pub fn site_config_from_config( + branding_config: &BrandingConfig, + matrix_config: &MatrixConfig, + experimental_config: &ExperimentalConfig, + password_config: &PasswordsConfig, +) -> SiteConfig { + SiteConfig { + access_token_ttl: experimental_config.access_token_ttl, + compat_token_ttl: experimental_config.compat_token_ttl, + server_name: matrix_config.homeserver.clone(), + policy_uri: branding_config.policy_uri.clone(), + tos_uri: branding_config.tos_uri.clone(), + imprint: branding_config.imprint.clone(), + password_login_enabled: password_config.enabled(), + password_registration_enabled: password_config.enabled() + && experimental_config.password_registration_enabled, + } +} + pub async fn templates_from_config( config: &TemplatesConfig, - branding: &BrandingConfig, + site_config: &SiteConfig, url_builder: &UrlBuilder, - server_name: &str, ) -> Result { - let mut site_branding = SiteBranding::new(server_name); - - if let Some(service_name) = branding.service_name.as_deref() { - site_branding = site_branding.with_service_name(service_name); - } - - if let Some(policy_uri) = &branding.policy_uri { - site_branding = site_branding.with_policy_uri(policy_uri.as_str()); - } - - if let Some(tos_uri) = &branding.tos_uri { - site_branding = site_branding.with_tos_uri(tos_uri.as_str()); - } - - if let Some(imprint) = branding.imprint.as_deref() { - site_branding = site_branding.with_imprint(imprint); - } - Templates::load( config.path.clone(), url_builder.clone(), config.assets_manifest.clone(), config.translations_path.clone(), - site_branding, + site_config.templates_branding(), + site_config.templates_features(), ) .await } diff --git a/crates/config/src/sections/experimental.rs b/crates/config/src/sections/experimental.rs index ad1b5eb4..97b362e6 100644 --- a/crates/config/src/sections/experimental.rs +++ b/crates/config/src/sections/experimental.rs @@ -27,6 +27,15 @@ fn is_default_token_ttl(value: &Duration) -> bool { *value == default_token_ttl() } +const fn default_true() -> bool { + true +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +const fn is_default_true(value: &bool) -> bool { + *value == default_true() +} + /// Configuration sections for experimental options /// /// Do not change these options unless you know what you are doing. @@ -51,6 +60,11 @@ pub struct ExperimentalConfig { )] #[serde_as(as = "serde_with::DurationSeconds")] pub compat_token_ttl: Duration, + + /// Whether to enable self-service password registration. Defaults to `true` + /// if password authentication is enabled. + #[serde(default = "default_true", skip_serializing_if = "is_default_true")] + pub password_registration_enabled: bool, } impl Default for ExperimentalConfig { @@ -58,13 +72,16 @@ impl Default for ExperimentalConfig { Self { access_token_ttl: default_token_ttl(), compat_token_ttl: default_token_ttl(), + password_registration_enabled: default_true(), } } } impl ExperimentalConfig { pub(crate) fn is_default(&self) -> bool { - is_default_token_ttl(&self.access_token_ttl) && is_default_token_ttl(&self.compat_token_ttl) + is_default_token_ttl(&self.access_token_ttl) + && is_default_token_ttl(&self.compat_token_ttl) + && is_default_true(&self.password_registration_enabled) } } diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 6588e5bd..2fbf39ed 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -150,6 +150,7 @@ where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, Keystore: FromRef, + SiteConfig: FromRef, UrlBuilder: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index 3c39a103..d1e09aaf 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -27,6 +27,8 @@ use oauth2_types::{ }; use serde::Serialize; +use crate::SiteConfig; + #[derive(Debug, Serialize)] struct DiscoveryResponse { #[serde(flatten)] @@ -45,6 +47,7 @@ struct DiscoveryResponse { pub(crate) async fn get( State(key_store): State, State(url_builder): State, + State(site_config): State, ) -> impl IntoResponse { // This is how clients can authenticate let client_auth_methods_supported = Some(vec![ @@ -136,7 +139,16 @@ pub(crate) async fn get( let request_parameter_supported = Some(false); let request_uri_parameter_supported = Some(false); - let prompt_values_supported = Some(vec![Prompt::None, Prompt::Login, Prompt::Create]); + let prompt_values_supported = Some({ + let mut v = vec![Prompt::None, Prompt::Login]; + // Advertise for prompt=create if password registration is enabled + // TODO: we may want to be able to forward that to upstream providers if they + // support it + if site_config.password_registration_enabled { + v.push(Prompt::Create); + } + v + }); let standard = ProviderMetadata { issuer, diff --git a/crates/handlers/src/site_config.rs b/crates/handlers/src/site_config.rs index 58c3668f..6ae76885 100644 --- a/crates/handlers/src/site_config.rs +++ b/crates/handlers/src/site_config.rs @@ -1,4 +1,4 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. +// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ // limitations under the License. use chrono::Duration; +use mas_templates::{SiteBranding, SiteFeatures}; use url::Url; /// Random site configuration we don't now where to put yet. @@ -20,15 +21,39 @@ use url::Url; pub struct SiteConfig { pub access_token_ttl: Duration, pub compat_token_ttl: Duration, + pub server_name: String, + pub policy_uri: Option, pub tos_uri: Option, + pub imprint: Option, + pub password_login_enabled: bool, + pub password_registration_enabled: bool, } -impl Default for SiteConfig { - fn default() -> Self { - Self { - access_token_ttl: Duration::microseconds(5 * 60 * 1000 * 1000), - compat_token_ttl: Duration::microseconds(5 * 60 * 1000 * 1000), - tos_uri: None, +impl SiteConfig { + #[must_use] + pub fn templates_branding(&self) -> SiteBranding { + let mut branding = SiteBranding::new(self.server_name.clone()); + + if let Some(policy_uri) = &self.policy_uri { + branding = branding.with_policy_uri(policy_uri.as_str()); + } + + if let Some(tos_uri) = &self.tos_uri { + branding = branding.with_tos_uri(tos_uri.as_str()); + } + + if let Some(imprint) = &self.imprint { + branding = branding.with_imprint(imprint.as_str()); + } + + branding + } + + #[must_use] + pub fn templates_features(&self) -> SiteFeatures { + SiteFeatures { + password_registration: self.password_registration_enabled, + password_login: self.password_login_enabled, } } } diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 93c8f757..d82158cf 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -24,6 +24,7 @@ use axum::{ extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts}, }; +use chrono::Duration; use cookie_store::{CookieStore, RawCookie}; use futures_util::future::BoxFuture; use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue}; @@ -43,7 +44,7 @@ use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; use mas_storage_pg::{DatabaseError, PgRepository}; -use mas_templates::{SiteBranding, Templates}; +use mas_templates::Templates; use rand::SeedableRng; use rand_chacha::ChaChaRng; use serde::{de::DeserializeOwned, Serialize}; @@ -110,25 +111,49 @@ pub(crate) struct TestState { pub rng: Arc>, } +fn workspace_root() -> camino::Utf8PathBuf { + camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("..") + .canonicalize_utf8() + .unwrap() +} + +pub fn test_site_config() -> SiteConfig { + SiteConfig { + access_token_ttl: Duration::try_minutes(5).unwrap(), + compat_token_ttl: Duration::try_minutes(5).unwrap(), + server_name: "example.com".to_owned(), + policy_uri: Some("https://example.com/policy".parse().unwrap()), + tos_uri: Some("https://example.com/tos".parse().unwrap()), + imprint: None, + password_login_enabled: true, + password_registration_enabled: true, + } +} + impl TestState { /// Create a new test state from the given database pool pub async fn from_pool(pool: PgPool) -> Result { - let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR")) - .join("..") - .join(".."); + Self::from_pool_with_site_config(pool, test_site_config()).await + } + + /// Create a new test state from the given database pool and site config + pub async fn from_pool_with_site_config( + pool: PgPool, + site_config: SiteConfig, + ) -> Result { + let workspace_root = workspace_root(); let url_builder = UrlBuilder::new("https://example.com/".parse()?, None, None); - let site_branding = SiteBranding::new("example.com") - .with_service_name("Example") - .with_tos_uri("https://example.com/tos"); - let templates = Templates::load( workspace_root.join("templates"), url_builder.clone(), workspace_root.join("frontend/dist/manifest.json"), workspace_root.join("translations"), - site_branding, + site_config.templates_branding(), + site_config.templates_features(), ) .await?; @@ -141,24 +166,23 @@ impl TestState { let key_store = Keystore::new(jwks); let encrypter = Encrypter::new(&[0x42; 32]); - let cookie_manager = - CookieManager::derive_from("https://example.com".parse()?, &[0x42; 32]); + let cookie_manager = CookieManager::derive_from(url_builder.http_base(), &[0x42; 32]); let metadata_cache = MetadataCache::new(); - let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; + let password_manager = if site_config.password_login_enabled { + PasswordManager::new([(1, Hasher::argon2id(None))])? + } else { + PasswordManager::disabled() + }; let policy_factory = policy_factory(serde_json::json!({})).await?; - let homeserver_connection = Arc::new(MockHomeserverConnection::new("example.com")); + let homeserver_connection = + Arc::new(MockHomeserverConnection::new(&site_config.server_name)); let http_client_factory = HttpClientFactory::new(); - let site_config = SiteConfig { - tos_uri: Some("https://example.com/tos".parse().unwrap()), - ..SiteConfig::default() - }; - let clock = Arc::new(MockClock::default()); let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42))); diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 513a9327..90d0a497 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -39,7 +39,7 @@ use serde::{Deserialize, Serialize}; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; -use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage}; +use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage, SiteConfig}; #[derive(Debug, Deserialize, Serialize)] pub(crate) struct LoginForm { @@ -56,9 +56,9 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, PreferredLanguage(locale): PreferredLanguage, - State(password_manager): State, State(templates): State, State(url_builder): State, + State(site_config): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, Query(query): Query, @@ -82,7 +82,7 @@ pub(crate) async fn get( // If password-based login is disabled, and there is only one upstream provider, // we can directly start an authorization flow - if !password_manager.is_enabled() && providers.len() == 1 { + if !site_config.password_login_enabled && providers.len() == 1 { let provider = providers.into_iter().next().unwrap(); let mut destination = UpstreamOAuth2Authorize::new(provider.id); @@ -96,10 +96,7 @@ pub(crate) async fn get( let content = render( locale, - LoginContext::default() - // XXX: we might want to have a site-wide config in the templates context instead? - .with_password_login(password_manager.is_enabled()) - .with_upstream_providers(providers), + LoginContext::default().with_upstream_providers(providers), query, csrf_token, &mut repo, @@ -116,6 +113,7 @@ pub(crate) async fn post( clock: BoxClock, PreferredLanguage(locale): PreferredLanguage, State(password_manager): State, + State(site_config): State, State(templates): State, State(url_builder): State, mut repo: BoxRepository, @@ -126,7 +124,7 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned())); - if !password_manager.is_enabled() { + if !site_config.password_login_enabled { // XXX: is it necessary to have better errors here? return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); } @@ -320,18 +318,25 @@ mod test { use zeroize::Zeroizing; use crate::{ - passwords::PasswordManager, - test_utils::{init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState}, + test_utils::{ + init_tracing, test_site_config, CookieHelper, RequestBuilderExt, ResponseExt, TestState, + }, + SiteConfig, }; #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_password_disabled(pool: PgPool) { init_tracing(); - let state = { - let mut state = TestState::from_pool(pool).await.unwrap(); - state.password_manager = PasswordManager::disabled(); - state - }; + let state = TestState::from_pool_with_site_config( + pool, + SiteConfig { + password_login_enabled: false, + ..test_site_config() + }, + ) + .await + .unwrap(); + let mut rng = state.rng(); // Without password login and no upstream providers, we should get an error @@ -339,7 +344,11 @@ mod test { let response = state.request(Request::get("/login").empty()).await; response.assert_status(StatusCode::OK); response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); - assert!(response.body().contains("No login methods available")); + assert!( + response.body().contains("No login methods available"), + "Response body: {}", + response.body() + ); // Adding an upstream provider should redirect to it let mut repo = state.repository().await.unwrap(); diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index bd1e01e4..3a07b42d 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -33,7 +33,7 @@ use serde::Deserialize; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; -use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage}; +use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage, SiteConfig}; #[derive(Deserialize, Debug)] pub(crate) struct ReauthForm { @@ -45,15 +45,15 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, PreferredLanguage(locale): PreferredLanguage, - State(password_manager): State, State(templates): State, State(url_builder): State, + State(site_config): State, activity_tracker: BoundActivityTracker, mut repo: BoxRepository, Query(query): Query, cookie_jar: CookieJar, ) -> Result { - if !password_manager.is_enabled() { + if !site_config.password_login_enabled { // XXX: do something better here return Ok(url_builder .redirect(&mas_router::Account::default()) @@ -99,12 +99,13 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(url_builder): State, + State(site_config): State, mut repo: BoxRepository, Query(query): Query, cookie_jar: CookieJar, Form(form): Form>, ) -> Result { - if !password_manager.is_enabled() { + if !site_config.password_login_enabled { // XXX: do something better here return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); } diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 437567ef..0912a2a7 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -66,8 +66,8 @@ pub(crate) async fn get( clock: BoxClock, PreferredLanguage(locale): PreferredLanguage, State(templates): State, - State(password_manager): State, State(url_builder): State, + State(site_config): State, mut repo: BoxRepository, Query(query): Query, cookie_jar: CookieJar, @@ -82,8 +82,8 @@ pub(crate) async fn get( return Ok((cookie_jar, reply).into_response()); } - if !password_manager.is_enabled() { - // If password-based login is disabled, redirect to the login page here + if !site_config.password_registration_enabled { + // If password-based registration is disabled, redirect to the login page here return Ok(url_builder .redirect(&mas_router::Login::from(query.post_auth_action)) .into_response()); @@ -122,7 +122,7 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned())); - if !password_manager.is_enabled() { + if !site_config.password_registration_enabled { return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); } @@ -301,18 +301,25 @@ mod tests { use sqlx::PgPool; use crate::{ - passwords::PasswordManager, - test_utils::{init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState}, + test_utils::{ + init_tracing, test_site_config, CookieHelper, RequestBuilderExt, ResponseExt, TestState, + }, + SiteConfig, }; #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_password_disabled(pool: PgPool) { init_tracing(); - let state = { - let mut state = TestState::from_pool(pool).await.unwrap(); - state.password_manager = PasswordManager::disabled(); - state - }; + let state = TestState::from_pool_with_site_config( + pool, + SiteConfig { + password_login_enabled: false, + password_registration_enabled: false, + ..test_site_config() + }, + ) + .await + .unwrap(); let request = Request::get(&*mas_router::Register::default().path_and_query()).empty(); let response = state.request(request).await; diff --git a/crates/router/src/url_builder.rs b/crates/router/src/url_builder.rs index b86505c5..a4c3020c 100644 --- a/crates/router/src/url_builder.rs +++ b/crates/router/src/url_builder.rs @@ -112,6 +112,12 @@ impl UrlBuilder { } } + /// HTTP base + #[must_use] + pub fn http_base(&self) -> Url { + self.http_base.clone() + } + /// OIDC issuer #[must_use] pub fn oidc_issuer(&self) -> Url { diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index d7788864..de1f7cc0 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -15,6 +15,7 @@ //! Contexts used in templates mod branding; +mod features; use std::{ fmt::Formatter, @@ -39,7 +40,7 @@ use serde::{ser::SerializeStruct, Deserialize, Serialize}; use ulid::Ulid; use url::Url; -pub use self::branding::SiteBranding; +pub use self::{branding::SiteBranding, features::SiteFeatures}; use crate::{FieldError, FormField, FormState}; /// Helper trait to construct context wrappers @@ -399,7 +400,6 @@ pub struct PostAuthContext { pub struct LoginContext { form: FormState, next: Option, - password_disabled: bool, providers: Vec, } @@ -413,13 +413,11 @@ impl TemplateContext for LoginContext { LoginContext { form: FormState::default(), next: None, - password_disabled: true, providers: Vec::new(), }, LoginContext { form: FormState::default(), next: None, - password_disabled: false, providers: Vec::new(), }, LoginContext { @@ -432,14 +430,12 @@ impl TemplateContext for LoginContext { }, ), next: None, - password_disabled: false, providers: Vec::new(), }, LoginContext { form: FormState::default() .with_error_on_field(LoginFormField::Username, FieldError::Exists), next: None, - password_disabled: false, providers: Vec::new(), }, ] @@ -447,15 +443,6 @@ impl TemplateContext for LoginContext { } impl LoginContext { - /// Set whether password login is enabled or not - #[must_use] - pub fn with_password_login(self, enabled: bool) -> Self { - Self { - password_disabled: !enabled, - ..self - } - } - /// Set the form state #[must_use] pub fn with_form_state(self, form: FormState) -> Self { diff --git a/crates/templates/src/context/branding.rs b/crates/templates/src/context/branding.rs index afa28925..5cdb7469 100644 --- a/crates/templates/src/context/branding.rs +++ b/crates/templates/src/context/branding.rs @@ -1,3 +1,17 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::sync::Arc; use minijinja::{value::StructObject, Value}; @@ -6,11 +20,9 @@ use minijinja::{value::StructObject, Value}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct SiteBranding { server_name: Arc, - service_name: Option>, policy_uri: Option>, tos_uri: Option>, imprint: Option>, - logo_uri: Option>, } impl SiteBranding { @@ -19,21 +31,12 @@ impl SiteBranding { pub fn new(server_name: impl Into>) -> Self { Self { server_name: server_name.into(), - service_name: None, policy_uri: None, tos_uri: None, imprint: None, - logo_uri: None, } } - /// Set the service name. - #[must_use] - pub fn with_service_name(mut self, service_name: impl Into>) -> Self { - self.service_name = Some(service_name.into()); - self - } - /// Set the policy URI. #[must_use] pub fn with_policy_uri(mut self, policy_uri: impl Into>) -> Self { @@ -54,36 +57,20 @@ impl SiteBranding { self.imprint = Some(imprint.into()); self } - - /// Set the logo URI. - #[must_use] - pub fn with_logo_uri(mut self, logo_uri: impl Into>) -> Self { - self.logo_uri = Some(logo_uri.into()); - self - } } impl StructObject for SiteBranding { fn get_field(&self, name: &str) -> Option { match name { "server_name" => Some(self.server_name.clone().into()), - "service_name" => self.service_name.clone().map(Value::from), "policy_uri" => self.policy_uri.clone().map(Value::from), "tos_uri" => self.tos_uri.clone().map(Value::from), "imprint" => self.imprint.clone().map(Value::from), - "logo_uri" => self.logo_uri.clone().map(Value::from), _ => None, } } fn static_fields(&self) -> Option<&'static [&'static str]> { - Some(&[ - "server_name", - "service_name", - "policy_uri", - "tos_uri", - "imprint", - "logo_uri", - ]) + Some(&["server_name", "policy_uri", "tos_uri", "imprint"]) } } diff --git a/crates/templates/src/context/features.rs b/crates/templates/src/context/features.rs new file mode 100644 index 00000000..6433f8e8 --- /dev/null +++ b/crates/templates/src/context/features.rs @@ -0,0 +1,39 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use minijinja::{value::StructObject, Value}; + +/// Site features information. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SiteFeatures { + /// Whether local password-based registration is enabled. + pub password_registration: bool, + + /// Whether local password-based login is enabled. + pub password_login: bool, +} + +impl StructObject for SiteFeatures { + fn get_field(&self, field: &str) -> Option { + match field { + "password_registration" => Some(Value::from(self.password_registration)), + "password_login" => Some(Value::from(self.password_login)), + _ => None, + } + } + + fn static_fields(&self) -> Option<&'static [&'static str]> { + Some(&["password_registration", "password_login"]) + } +} diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 38a390b5..d5dc67d8 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -47,7 +47,7 @@ pub use self::{ EmailVerificationPageContext, EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, NotFoundContext, PolicyViolationContext, PostAuthContext, PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, - SiteBranding, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, + SiteBranding, SiteFeatures, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamRegisterFormField, UpstreamSuggestLink, WithCsrf, WithLanguage, WithOptionalSession, WithSession, }, @@ -70,6 +70,7 @@ pub struct Templates { translator: Arc>, url_builder: UrlBuilder, branding: SiteBranding, + features: SiteFeatures, vite_manifest_path: Utf8PathBuf, translations_path: Utf8PathBuf, path: Utf8PathBuf, @@ -149,6 +150,7 @@ impl Templates { vite_manifest_path: Utf8PathBuf, translations_path: Utf8PathBuf, branding: SiteBranding, + features: SiteFeatures, ) -> Result { let (translator, environment) = Self::load_( &path, @@ -156,6 +158,7 @@ impl Templates { &vite_manifest_path, &translations_path, branding.clone(), + features, ) .await?; Ok(Self { @@ -166,6 +169,7 @@ impl Templates { vite_manifest_path, translations_path, branding, + features, }) } @@ -175,6 +179,7 @@ impl Templates { vite_manifest_path: &Utf8Path, translations_path: &Utf8Path, branding: SiteBranding, + features: SiteFeatures, ) -> Result<(Arc, Arc>), TemplateLoadingError> { let path = path.to_owned(); let span = tracing::Span::current(); @@ -230,6 +235,7 @@ impl Templates { .await??; env.add_global("branding", Value::from_struct_object(branding)); + env.add_global("features", Value::from_struct_object(features)); self::functions::register( &mut env, @@ -265,6 +271,7 @@ impl Templates { &self.vite_manifest_path, &self.translations_path, self.branding.clone(), + self.features, ) .await?; @@ -425,7 +432,11 @@ mod tests { let path = Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../templates/"); let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None); - let branding = SiteBranding::new("example.com").with_service_name("Example"); + let branding = SiteBranding::new("example.com"); + let features = SiteFeatures { + password_login: true, + password_registration: true, + }; let vite_manifest_path = Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../frontend/dist/manifest.json"); let translations_path = @@ -436,6 +447,7 @@ mod tests { vite_manifest_path, translations_path, branding, + features, ) .await .unwrap(); diff --git a/docs/config.schema.json b/docs/config.schema.json index 57b1c935..0b607a71 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1962,6 +1962,10 @@ "format": "uint64", "maximum": 86400.0, "minimum": 60.0 + }, + "password_registration_enabled": { + "description": "Whether to enable self-service password registration. Defaults to `true` if password authentication is enabled.", + "type": "boolean" } } } diff --git a/templates/pages/index.html b/templates/pages/index.html index 81ba6161..554efbaa 100644 --- a/templates/pages/index.html +++ b/templates/pages/index.html @@ -36,7 +36,10 @@ limitations under the License. {{ logout.button(text=_("action.sign_out"), csrf_token=csrf_token) }} {% else %} {{ button.link(text=_("action.sign_in"), href="/login") }} - {{ button.link_outline(text=_("mas.navbar.register"), href="/register") }} + + {% if features.password_registration %} + {{ button.link_outline(text=_("mas.navbar.register"), href="/register") }} + {% endif %} {% endif %} {% endblock content %} diff --git a/templates/pages/login.html b/templates/pages/login.html index 5cddbbbe..1568bab6 100644 --- a/templates/pages/login.html +++ b/templates/pages/login.html @@ -20,7 +20,7 @@ limitations under the License. {% block content %}
- {% if not password_disabled %} + {% if features.password_login %}
{{ icon.user_profile_solid() }} @@ -62,7 +62,7 @@ limitations under the License. {{ button.button(text=_("action.continue")) }} - {% if not next or next.kind != "link_upstream" %} + {% if (not next or next.kind != "link_upstream") and features.password_registration %}

{{ _("mas.login.call_to_register") }} @@ -75,7 +75,7 @@ limitations under the License. {% endif %} {% if providers %} - {% if not password_disabled %} + {% if features.password_login %} {{ field.separator() }} {% endif %} @@ -89,7 +89,7 @@ limitations under the License. {% endfor %} {% endif %} - {% if not providers and password_disabled %} + {% if not providers and not features.password_login %}

{{ _("mas.login.no_login_methods") }}
diff --git a/translations/en.json b/translations/en.json index 70d202bf..c6d5552b 100644 --- a/translations/en.json +++ b/translations/en.json @@ -226,7 +226,7 @@ }, "register": "Create an account", "@register": { - "context": "pages/index.html:39:34-58" + "context": "pages/index.html:41:36-60" }, "signed_in_as": "Signed in as %(username)s.", "@signed_in_as": {