1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Allow disabling registrations (#2553)

This commit is contained in:
Quentin Gliech
2024-04-03 09:27:14 +02:00
committed by GitHub
parent e3944d1f34
commit 58fd6ab4c1
21 changed files with 308 additions and 164 deletions

View File

@@ -19,7 +19,7 @@ use clap::Parser;
use figment::Figment; use figment::Figment;
use itertools::Itertools; use itertools::Itertools;
use mas_config::{AppConfig, ClientsConfig, ConfigurationSection, UpstreamOAuth2Config}; use mas_config::{AppConfig, ClientsConfig, ConfigurationSection, UpstreamOAuth2Config};
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache, SiteConfig}; use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache};
use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_matrix_synapse::SynapseConnection; use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
@@ -37,7 +37,8 @@ use crate::{
app_state::AppState, app_state::AppState,
util::{ util::{
database_pool_from_config, mailer_from_config, password_manager_from_config, database_pool_from_config, mailer_from_config, password_manager_from_config,
policy_factory_from_config, register_sighup, templates_from_config, policy_factory_from_config, register_sighup, site_config_from_config,
templates_from_config,
}, },
}; };
@@ -138,14 +139,17 @@ impl Options {
None, None,
); );
// Load and compile the templates // Load the site configuration
let templates = templates_from_config( let site_config = site_config_from_config(
&config.templates,
&config.branding, &config.branding,
&url_builder, &config.matrix,
&config.matrix.homeserver, &config.experimental,
) &config.passwords,
.await?; );
// Load and compile the templates
let templates =
templates_from_config(&config.templates, &site_config, &url_builder).await?;
let http_client_factory = HttpClientFactory::new(); let http_client_factory = HttpClientFactory::new();
@@ -179,12 +183,6 @@ impl Options {
// The upstream OIDC metadata cache // The upstream OIDC metadata cache
let metadata_cache = MetadataCache::new(); let metadata_cache = MetadataCache::new();
let site_config = SiteConfig {
tos_uri: config.branding.tos_uri.clone(),
access_token_ttl: config.experimental.access_token_ttl,
compat_token_ttl: config.experimental.compat_token_ttl,
};
// Initialize the activity tracker // Initialize the activity tracker
// Activity is flushed every minute // Activity is flushed every minute
let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60)); let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60));

View File

@@ -14,12 +14,15 @@
use clap::Parser; use clap::Parser;
use figment::Figment; use figment::Figment;
use mas_config::{BrandingConfig, ConfigurationSection, MatrixConfig, TemplatesConfig}; use mas_config::{
BrandingConfig, ConfigurationSection, ExperimentalConfig, MatrixConfig, PasswordsConfig,
TemplatesConfig,
};
use mas_storage::{Clock, SystemClock}; use mas_storage::{Clock, SystemClock};
use rand::SeedableRng; use rand::SeedableRng;
use tracing::info_span; use tracing::info_span;
use crate::util::templates_from_config; use crate::util::{site_config_from_config, templates_from_config};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
pub(super) struct Options { pub(super) struct Options {
@@ -43,19 +46,22 @@ impl Options {
let template_config = TemplatesConfig::extract(figment)?; let template_config = TemplatesConfig::extract(figment)?;
let branding_config = BrandingConfig::extract(figment)?; let branding_config = BrandingConfig::extract(figment)?;
let matrix_config = MatrixConfig::extract(figment)?; let matrix_config = MatrixConfig::extract(figment)?;
let experimental_config = ExperimentalConfig::extract(figment)?;
let password_config = PasswordsConfig::extract(figment)?;
let clock = SystemClock::default(); let clock = SystemClock::default();
// XXX: we should disallow SeedableRng::from_entropy // XXX: we should disallow SeedableRng::from_entropy
let mut rng = rand_chacha::ChaChaRng::from_entropy(); let mut rng = rand_chacha::ChaChaRng::from_entropy();
let url_builder = let url_builder =
mas_router::UrlBuilder::new("https://example.com/".parse()?, None, None); mas_router::UrlBuilder::new("https://example.com/".parse()?, None, None);
let templates = templates_from_config( let site_config = site_config_from_config(
&template_config,
&branding_config, &branding_config,
&url_builder, &matrix_config,
&matrix_config.homeserver, &experimental_config,
) &password_config,
.await?; );
let templates =
templates_from_config(&template_config, &site_config, &url_builder).await?;
templates.check_render(clock.now(), &mut rng)?; templates.check_render(clock.now(), &mut rng)?;
Ok(()) Ok(())

View File

@@ -24,7 +24,9 @@ use rand::{
}; };
use tracing::{info, info_span}; use tracing::{info, info_span};
use crate::util::{database_pool_from_config, mailer_from_config, templates_from_config}; use crate::util::{
database_pool_from_config, mailer_from_config, site_config_from_config, templates_from_config,
};
#[derive(Parser, Debug, Default)] #[derive(Parser, Debug, Default)]
pub(super) struct Options {} pub(super) struct Options {}
@@ -44,14 +46,17 @@ impl Options {
None, None,
); );
// Load and compile the templates // Load the site configuration
let templates = templates_from_config( let site_config = site_config_from_config(
&config.templates,
&config.branding, &config.branding,
&url_builder, &config.matrix,
&config.matrix.homeserver, &config.experimental,
) &config.passwords,
.await?; );
// Load and compile the templates
let templates =
templates_from_config(&config.templates, &site_config, &url_builder).await?;
let mailer = mailer_from_config(&config.email, &templates)?; let mailer = mailer_from_config(&config.email, &templates)?;
mailer.test_connection().await?; mailer.test_connection().await?;

View File

@@ -17,13 +17,13 @@ use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use mas_config::{ use mas_config::{
BrandingConfig, DatabaseConfig, EmailConfig, EmailSmtpMode, EmailTransportKind, BrandingConfig, DatabaseConfig, EmailConfig, EmailSmtpMode, EmailTransportKind,
PasswordsConfig, PolicyConfig, TemplatesConfig, ExperimentalConfig, MatrixConfig, PasswordsConfig, PolicyConfig, TemplatesConfig,
}; };
use mas_email::{MailTransport, Mailer}; use mas_email::{MailTransport, Mailer};
use mas_handlers::{passwords::PasswordManager, ActivityTracker}; use mas_handlers::{passwords::PasswordManager, ActivityTracker, SiteConfig};
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_templates::{SiteBranding, TemplateLoadingError, Templates}; use mas_templates::{TemplateLoadingError, Templates};
use sqlx::{ use sqlx::{
postgres::{PgConnectOptions, PgPoolOptions}, postgres::{PgConnectOptions, PgPoolOptions},
ConnectOptions, PgConnection, PgPool, ConnectOptions, PgConnection, PgPool,
@@ -119,36 +119,37 @@ pub async fn policy_factory_from_config(
.context("failed to load the policy") .context("failed to load the policy")
} }
pub fn site_config_from_config(
branding_config: &BrandingConfig,
matrix_config: &MatrixConfig,
experimental_config: &ExperimentalConfig,
password_config: &PasswordsConfig,
) -> SiteConfig {
SiteConfig {
access_token_ttl: experimental_config.access_token_ttl,
compat_token_ttl: experimental_config.compat_token_ttl,
server_name: matrix_config.homeserver.clone(),
policy_uri: branding_config.policy_uri.clone(),
tos_uri: branding_config.tos_uri.clone(),
imprint: branding_config.imprint.clone(),
password_login_enabled: password_config.enabled(),
password_registration_enabled: password_config.enabled()
&& experimental_config.password_registration_enabled,
}
}
pub async fn templates_from_config( pub async fn templates_from_config(
config: &TemplatesConfig, config: &TemplatesConfig,
branding: &BrandingConfig, site_config: &SiteConfig,
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
server_name: &str,
) -> Result<Templates, TemplateLoadingError> { ) -> Result<Templates, TemplateLoadingError> {
let mut site_branding = SiteBranding::new(server_name);
if let Some(service_name) = branding.service_name.as_deref() {
site_branding = site_branding.with_service_name(service_name);
}
if let Some(policy_uri) = &branding.policy_uri {
site_branding = site_branding.with_policy_uri(policy_uri.as_str());
}
if let Some(tos_uri) = &branding.tos_uri {
site_branding = site_branding.with_tos_uri(tos_uri.as_str());
}
if let Some(imprint) = branding.imprint.as_deref() {
site_branding = site_branding.with_imprint(imprint);
}
Templates::load( Templates::load(
config.path.clone(), config.path.clone(),
url_builder.clone(), url_builder.clone(),
config.assets_manifest.clone(), config.assets_manifest.clone(),
config.translations_path.clone(), config.translations_path.clone(),
site_branding, site_config.templates_branding(),
site_config.templates_features(),
) )
.await .await
} }

View File

@@ -27,6 +27,15 @@ fn is_default_token_ttl(value: &Duration) -> bool {
*value == default_token_ttl() *value == default_token_ttl()
} }
const fn default_true() -> bool {
true
}
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default_true(value: &bool) -> bool {
*value == default_true()
}
/// Configuration sections for experimental options /// Configuration sections for experimental options
/// ///
/// Do not change these options unless you know what you are doing. /// Do not change these options unless you know what you are doing.
@@ -51,6 +60,11 @@ pub struct ExperimentalConfig {
)] )]
#[serde_as(as = "serde_with::DurationSeconds<i64>")] #[serde_as(as = "serde_with::DurationSeconds<i64>")]
pub compat_token_ttl: Duration, pub compat_token_ttl: Duration,
/// Whether to enable self-service password registration. Defaults to `true`
/// if password authentication is enabled.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub password_registration_enabled: bool,
} }
impl Default for ExperimentalConfig { impl Default for ExperimentalConfig {
@@ -58,13 +72,16 @@ impl Default for ExperimentalConfig {
Self { Self {
access_token_ttl: default_token_ttl(), access_token_ttl: default_token_ttl(),
compat_token_ttl: default_token_ttl(), compat_token_ttl: default_token_ttl(),
password_registration_enabled: default_true(),
} }
} }
} }
impl ExperimentalConfig { impl ExperimentalConfig {
pub(crate) fn is_default(&self) -> bool { pub(crate) fn is_default(&self) -> bool {
is_default_token_ttl(&self.access_token_ttl) && is_default_token_ttl(&self.compat_token_ttl) is_default_token_ttl(&self.access_token_ttl)
&& is_default_token_ttl(&self.compat_token_ttl)
&& is_default_true(&self.password_registration_enabled)
} }
} }

View File

@@ -150,6 +150,7 @@ where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>, Keystore: FromRef<S>,
SiteConfig: FromRef<S>,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
BoxClock: FromRequestParts<S>, BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>, BoxRng: FromRequestParts<S>,

View File

@@ -27,6 +27,8 @@ use oauth2_types::{
}; };
use serde::Serialize; use serde::Serialize;
use crate::SiteConfig;
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct DiscoveryResponse { struct DiscoveryResponse {
#[serde(flatten)] #[serde(flatten)]
@@ -45,6 +47,7 @@ struct DiscoveryResponse {
pub(crate) async fn get( pub(crate) async fn get(
State(key_store): State<Keystore>, State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(site_config): State<SiteConfig>,
) -> impl IntoResponse { ) -> impl IntoResponse {
// This is how clients can authenticate // This is how clients can authenticate
let client_auth_methods_supported = Some(vec![ let client_auth_methods_supported = Some(vec![
@@ -136,7 +139,16 @@ pub(crate) async fn get(
let request_parameter_supported = Some(false); let request_parameter_supported = Some(false);
let request_uri_parameter_supported = Some(false); let request_uri_parameter_supported = Some(false);
let prompt_values_supported = Some(vec![Prompt::None, Prompt::Login, Prompt::Create]); let prompt_values_supported = Some({
let mut v = vec![Prompt::None, Prompt::Login];
// Advertise for prompt=create if password registration is enabled
// TODO: we may want to be able to forward that to upstream providers if they
// support it
if site_config.password_registration_enabled {
v.push(Prompt::Create);
}
v
});
let standard = ProviderMetadata { let standard = ProviderMetadata {
issuer, issuer,

View File

@@ -1,4 +1,4 @@
// Copyright 2023 The Matrix.org Foundation C.I.C. // Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use chrono::Duration; use chrono::Duration;
use mas_templates::{SiteBranding, SiteFeatures};
use url::Url; use url::Url;
/// Random site configuration we don't now where to put yet. /// Random site configuration we don't now where to put yet.
@@ -20,15 +21,39 @@ use url::Url;
pub struct SiteConfig { pub struct SiteConfig {
pub access_token_ttl: Duration, pub access_token_ttl: Duration,
pub compat_token_ttl: Duration, pub compat_token_ttl: Duration,
pub server_name: String,
pub policy_uri: Option<Url>,
pub tos_uri: Option<Url>, pub tos_uri: Option<Url>,
pub imprint: Option<String>,
pub password_login_enabled: bool,
pub password_registration_enabled: bool,
} }
impl Default for SiteConfig { impl SiteConfig {
fn default() -> Self { #[must_use]
Self { pub fn templates_branding(&self) -> SiteBranding {
access_token_ttl: Duration::microseconds(5 * 60 * 1000 * 1000), let mut branding = SiteBranding::new(self.server_name.clone());
compat_token_ttl: Duration::microseconds(5 * 60 * 1000 * 1000),
tos_uri: None, if let Some(policy_uri) = &self.policy_uri {
branding = branding.with_policy_uri(policy_uri.as_str());
}
if let Some(tos_uri) = &self.tos_uri {
branding = branding.with_tos_uri(tos_uri.as_str());
}
if let Some(imprint) = &self.imprint {
branding = branding.with_imprint(imprint.as_str());
}
branding
}
#[must_use]
pub fn templates_features(&self) -> SiteFeatures {
SiteFeatures {
password_registration: self.password_registration_enabled,
password_login: self.password_login_enabled,
} }
} }
} }

View File

@@ -24,6 +24,7 @@ use axum::{
extract::{FromRef, FromRequestParts}, extract::{FromRef, FromRequestParts},
response::{IntoResponse, IntoResponseParts}, response::{IntoResponse, IntoResponseParts},
}; };
use chrono::Duration;
use cookie_store::{CookieStore, RawCookie}; use cookie_store::{CookieStore, RawCookie};
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue}; use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue};
@@ -43,7 +44,7 @@ use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::{SimpleRoute, UrlBuilder}; use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
use mas_storage_pg::{DatabaseError, PgRepository}; use mas_storage_pg::{DatabaseError, PgRepository};
use mas_templates::{SiteBranding, Templates}; use mas_templates::Templates;
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
@@ -110,25 +111,49 @@ pub(crate) struct TestState {
pub rng: Arc<Mutex<ChaChaRng>>, pub rng: Arc<Mutex<ChaChaRng>>,
} }
fn workspace_root() -> camino::Utf8PathBuf {
camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.canonicalize_utf8()
.unwrap()
}
pub fn test_site_config() -> SiteConfig {
SiteConfig {
access_token_ttl: Duration::try_minutes(5).unwrap(),
compat_token_ttl: Duration::try_minutes(5).unwrap(),
server_name: "example.com".to_owned(),
policy_uri: Some("https://example.com/policy".parse().unwrap()),
tos_uri: Some("https://example.com/tos".parse().unwrap()),
imprint: None,
password_login_enabled: true,
password_registration_enabled: true,
}
}
impl TestState { impl TestState {
/// Create a new test state from the given database pool /// Create a new test state from the given database pool
pub async fn from_pool(pool: PgPool) -> Result<Self, anyhow::Error> { pub async fn from_pool(pool: PgPool) -> Result<Self, anyhow::Error> {
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR")) Self::from_pool_with_site_config(pool, test_site_config()).await
.join("..") }
.join("..");
/// Create a new test state from the given database pool and site config
pub async fn from_pool_with_site_config(
pool: PgPool,
site_config: SiteConfig,
) -> Result<Self, anyhow::Error> {
let workspace_root = workspace_root();
let url_builder = UrlBuilder::new("https://example.com/".parse()?, None, None); let url_builder = UrlBuilder::new("https://example.com/".parse()?, None, None);
let site_branding = SiteBranding::new("example.com")
.with_service_name("Example")
.with_tos_uri("https://example.com/tos");
let templates = Templates::load( let templates = Templates::load(
workspace_root.join("templates"), workspace_root.join("templates"),
url_builder.clone(), url_builder.clone(),
workspace_root.join("frontend/dist/manifest.json"), workspace_root.join("frontend/dist/manifest.json"),
workspace_root.join("translations"), workspace_root.join("translations"),
site_branding, site_config.templates_branding(),
site_config.templates_features(),
) )
.await?; .await?;
@@ -141,24 +166,23 @@ impl TestState {
let key_store = Keystore::new(jwks); let key_store = Keystore::new(jwks);
let encrypter = Encrypter::new(&[0x42; 32]); let encrypter = Encrypter::new(&[0x42; 32]);
let cookie_manager = let cookie_manager = CookieManager::derive_from(url_builder.http_base(), &[0x42; 32]);
CookieManager::derive_from("https://example.com".parse()?, &[0x42; 32]);
let metadata_cache = MetadataCache::new(); let metadata_cache = MetadataCache::new();
let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; let password_manager = if site_config.password_login_enabled {
PasswordManager::new([(1, Hasher::argon2id(None))])?
} else {
PasswordManager::disabled()
};
let policy_factory = policy_factory(serde_json::json!({})).await?; let policy_factory = policy_factory(serde_json::json!({})).await?;
let homeserver_connection = Arc::new(MockHomeserverConnection::new("example.com")); let homeserver_connection =
Arc::new(MockHomeserverConnection::new(&site_config.server_name));
let http_client_factory = HttpClientFactory::new(); let http_client_factory = HttpClientFactory::new();
let site_config = SiteConfig {
tos_uri: Some("https://example.com/tos".parse().unwrap()),
..SiteConfig::default()
};
let clock = Arc::new(MockClock::default()); let clock = Arc::new(MockClock::default());
let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42))); let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42)));

View File

@@ -39,7 +39,7 @@ use serde::{Deserialize, Serialize};
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction; use super::shared::OptionalPostAuthAction;
use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage}; use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage, SiteConfig};
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
pub(crate) struct LoginForm { pub(crate) struct LoginForm {
@@ -56,9 +56,9 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
PreferredLanguage(locale): PreferredLanguage, PreferredLanguage(locale): PreferredLanguage,
State(password_manager): State<PasswordManager>,
State(templates): State<Templates>, State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(site_config): State<SiteConfig>,
mut repo: BoxRepository, mut repo: BoxRepository,
activity_tracker: BoundActivityTracker, activity_tracker: BoundActivityTracker,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
@@ -82,7 +82,7 @@ pub(crate) async fn get(
// If password-based login is disabled, and there is only one upstream provider, // If password-based login is disabled, and there is only one upstream provider,
// we can directly start an authorization flow // we can directly start an authorization flow
if !password_manager.is_enabled() && providers.len() == 1 { if !site_config.password_login_enabled && providers.len() == 1 {
let provider = providers.into_iter().next().unwrap(); let provider = providers.into_iter().next().unwrap();
let mut destination = UpstreamOAuth2Authorize::new(provider.id); let mut destination = UpstreamOAuth2Authorize::new(provider.id);
@@ -96,10 +96,7 @@ pub(crate) async fn get(
let content = render( let content = render(
locale, locale,
LoginContext::default() LoginContext::default().with_upstream_providers(providers),
// XXX: we might want to have a site-wide config in the templates context instead?
.with_password_login(password_manager.is_enabled())
.with_upstream_providers(providers),
query, query,
csrf_token, csrf_token,
&mut repo, &mut repo,
@@ -116,6 +113,7 @@ pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
PreferredLanguage(locale): PreferredLanguage, PreferredLanguage(locale): PreferredLanguage,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
State(site_config): State<SiteConfig>,
State(templates): State<Templates>, State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
mut repo: BoxRepository, mut repo: BoxRepository,
@@ -126,7 +124,7 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<LoginForm>>, Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned())); let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
if !password_manager.is_enabled() { if !site_config.password_login_enabled {
// XXX: is it necessary to have better errors here? // XXX: is it necessary to have better errors here?
return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response());
} }
@@ -320,18 +318,25 @@ mod test {
use zeroize::Zeroizing; use zeroize::Zeroizing;
use crate::{ use crate::{
passwords::PasswordManager, test_utils::{
test_utils::{init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState}, init_tracing, test_site_config, CookieHelper, RequestBuilderExt, ResponseExt, TestState,
},
SiteConfig,
}; };
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_disabled(pool: PgPool) { async fn test_password_disabled(pool: PgPool) {
init_tracing(); init_tracing();
let state = { let state = TestState::from_pool_with_site_config(
let mut state = TestState::from_pool(pool).await.unwrap(); pool,
state.password_manager = PasswordManager::disabled(); SiteConfig {
state password_login_enabled: false,
}; ..test_site_config()
},
)
.await
.unwrap();
let mut rng = state.rng(); let mut rng = state.rng();
// Without password login and no upstream providers, we should get an error // Without password login and no upstream providers, we should get an error
@@ -339,7 +344,11 @@ mod test {
let response = state.request(Request::get("/login").empty()).await; let response = state.request(Request::get("/login").empty()).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(response.body().contains("No login methods available")); assert!(
response.body().contains("No login methods available"),
"Response body: {}",
response.body()
);
// Adding an upstream provider should redirect to it // Adding an upstream provider should redirect to it
let mut repo = state.repository().await.unwrap(); let mut repo = state.repository().await.unwrap();

View File

@@ -33,7 +33,7 @@ use serde::Deserialize;
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction; use super::shared::OptionalPostAuthAction;
use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage}; use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage, SiteConfig};
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub(crate) struct ReauthForm { pub(crate) struct ReauthForm {
@@ -45,15 +45,15 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
PreferredLanguage(locale): PreferredLanguage, PreferredLanguage(locale): PreferredLanguage,
State(password_manager): State<PasswordManager>,
State(templates): State<Templates>, State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(site_config): State<SiteConfig>,
activity_tracker: BoundActivityTracker, activity_tracker: BoundActivityTracker,
mut repo: BoxRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar, cookie_jar: CookieJar,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
if !password_manager.is_enabled() { if !site_config.password_login_enabled {
// XXX: do something better here // XXX: do something better here
return Ok(url_builder return Ok(url_builder
.redirect(&mas_router::Account::default()) .redirect(&mas_router::Account::default())
@@ -99,12 +99,13 @@ pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(site_config): State<SiteConfig>,
mut repo: BoxRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar, cookie_jar: CookieJar,
Form(form): Form<ProtectedForm<ReauthForm>>, Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
if !password_manager.is_enabled() { if !site_config.password_login_enabled {
// XXX: do something better here // XXX: do something better here
return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response());
} }

View File

@@ -66,8 +66,8 @@ pub(crate) async fn get(
clock: BoxClock, clock: BoxClock,
PreferredLanguage(locale): PreferredLanguage, PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>, State(templates): State<Templates>,
State(password_manager): State<PasswordManager>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(site_config): State<SiteConfig>,
mut repo: BoxRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar, cookie_jar: CookieJar,
@@ -82,8 +82,8 @@ pub(crate) async fn get(
return Ok((cookie_jar, reply).into_response()); return Ok((cookie_jar, reply).into_response());
} }
if !password_manager.is_enabled() { if !site_config.password_registration_enabled {
// If password-based login is disabled, redirect to the login page here // If password-based registration is disabled, redirect to the login page here
return Ok(url_builder return Ok(url_builder
.redirect(&mas_router::Login::from(query.post_auth_action)) .redirect(&mas_router::Login::from(query.post_auth_action))
.into_response()); .into_response());
@@ -122,7 +122,7 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<RegisterForm>>, Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned())); let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
if !password_manager.is_enabled() { if !site_config.password_registration_enabled {
return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()); return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response());
} }
@@ -301,18 +301,25 @@ mod tests {
use sqlx::PgPool; use sqlx::PgPool;
use crate::{ use crate::{
passwords::PasswordManager, test_utils::{
test_utils::{init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState}, init_tracing, test_site_config, CookieHelper, RequestBuilderExt, ResponseExt, TestState,
},
SiteConfig,
}; };
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_disabled(pool: PgPool) { async fn test_password_disabled(pool: PgPool) {
init_tracing(); init_tracing();
let state = { let state = TestState::from_pool_with_site_config(
let mut state = TestState::from_pool(pool).await.unwrap(); pool,
state.password_manager = PasswordManager::disabled(); SiteConfig {
state password_login_enabled: false,
}; password_registration_enabled: false,
..test_site_config()
},
)
.await
.unwrap();
let request = Request::get(&*mas_router::Register::default().path_and_query()).empty(); let request = Request::get(&*mas_router::Register::default().path_and_query()).empty();
let response = state.request(request).await; let response = state.request(request).await;

View File

@@ -112,6 +112,12 @@ impl UrlBuilder {
} }
} }
/// HTTP base
#[must_use]
pub fn http_base(&self) -> Url {
self.http_base.clone()
}
/// OIDC issuer /// OIDC issuer
#[must_use] #[must_use]
pub fn oidc_issuer(&self) -> Url { pub fn oidc_issuer(&self) -> Url {

View File

@@ -15,6 +15,7 @@
//! Contexts used in templates //! Contexts used in templates
mod branding; mod branding;
mod features;
use std::{ use std::{
fmt::Formatter, fmt::Formatter,
@@ -39,7 +40,7 @@ use serde::{ser::SerializeStruct, Deserialize, Serialize};
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
pub use self::branding::SiteBranding; pub use self::{branding::SiteBranding, features::SiteFeatures};
use crate::{FieldError, FormField, FormState}; use crate::{FieldError, FormField, FormState};
/// Helper trait to construct context wrappers /// Helper trait to construct context wrappers
@@ -399,7 +400,6 @@ pub struct PostAuthContext {
pub struct LoginContext { pub struct LoginContext {
form: FormState<LoginFormField>, form: FormState<LoginFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
password_disabled: bool,
providers: Vec<UpstreamOAuthProvider>, providers: Vec<UpstreamOAuthProvider>,
} }
@@ -413,13 +413,11 @@ impl TemplateContext for LoginContext {
LoginContext { LoginContext {
form: FormState::default(), form: FormState::default(),
next: None, next: None,
password_disabled: true,
providers: Vec::new(), providers: Vec::new(),
}, },
LoginContext { LoginContext {
form: FormState::default(), form: FormState::default(),
next: None, next: None,
password_disabled: false,
providers: Vec::new(), providers: Vec::new(),
}, },
LoginContext { LoginContext {
@@ -432,14 +430,12 @@ impl TemplateContext for LoginContext {
}, },
), ),
next: None, next: None,
password_disabled: false,
providers: Vec::new(), providers: Vec::new(),
}, },
LoginContext { LoginContext {
form: FormState::default() form: FormState::default()
.with_error_on_field(LoginFormField::Username, FieldError::Exists), .with_error_on_field(LoginFormField::Username, FieldError::Exists),
next: None, next: None,
password_disabled: false,
providers: Vec::new(), providers: Vec::new(),
}, },
] ]
@@ -447,15 +443,6 @@ impl TemplateContext for LoginContext {
} }
impl LoginContext { impl LoginContext {
/// Set whether password login is enabled or not
#[must_use]
pub fn with_password_login(self, enabled: bool) -> Self {
Self {
password_disabled: !enabled,
..self
}
}
/// Set the form state /// Set the form state
#[must_use] #[must_use]
pub fn with_form_state(self, form: FormState<LoginFormField>) -> Self { pub fn with_form_state(self, form: FormState<LoginFormField>) -> Self {

View File

@@ -1,3 +1,17 @@
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc; use std::sync::Arc;
use minijinja::{value::StructObject, Value}; use minijinja::{value::StructObject, Value};
@@ -6,11 +20,9 @@ use minijinja::{value::StructObject, Value};
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct SiteBranding { pub struct SiteBranding {
server_name: Arc<str>, server_name: Arc<str>,
service_name: Option<Arc<str>>,
policy_uri: Option<Arc<str>>, policy_uri: Option<Arc<str>>,
tos_uri: Option<Arc<str>>, tos_uri: Option<Arc<str>>,
imprint: Option<Arc<str>>, imprint: Option<Arc<str>>,
logo_uri: Option<Arc<str>>,
} }
impl SiteBranding { impl SiteBranding {
@@ -19,21 +31,12 @@ impl SiteBranding {
pub fn new(server_name: impl Into<Arc<str>>) -> Self { pub fn new(server_name: impl Into<Arc<str>>) -> Self {
Self { Self {
server_name: server_name.into(), server_name: server_name.into(),
service_name: None,
policy_uri: None, policy_uri: None,
tos_uri: None, tos_uri: None,
imprint: None, imprint: None,
logo_uri: None,
} }
} }
/// Set the service name.
#[must_use]
pub fn with_service_name(mut self, service_name: impl Into<Arc<str>>) -> Self {
self.service_name = Some(service_name.into());
self
}
/// Set the policy URI. /// Set the policy URI.
#[must_use] #[must_use]
pub fn with_policy_uri(mut self, policy_uri: impl Into<Arc<str>>) -> Self { pub fn with_policy_uri(mut self, policy_uri: impl Into<Arc<str>>) -> Self {
@@ -54,36 +57,20 @@ impl SiteBranding {
self.imprint = Some(imprint.into()); self.imprint = Some(imprint.into());
self self
} }
/// Set the logo URI.
#[must_use]
pub fn with_logo_uri(mut self, logo_uri: impl Into<Arc<str>>) -> Self {
self.logo_uri = Some(logo_uri.into());
self
}
} }
impl StructObject for SiteBranding { impl StructObject for SiteBranding {
fn get_field(&self, name: &str) -> Option<Value> { fn get_field(&self, name: &str) -> Option<Value> {
match name { match name {
"server_name" => Some(self.server_name.clone().into()), "server_name" => Some(self.server_name.clone().into()),
"service_name" => self.service_name.clone().map(Value::from),
"policy_uri" => self.policy_uri.clone().map(Value::from), "policy_uri" => self.policy_uri.clone().map(Value::from),
"tos_uri" => self.tos_uri.clone().map(Value::from), "tos_uri" => self.tos_uri.clone().map(Value::from),
"imprint" => self.imprint.clone().map(Value::from), "imprint" => self.imprint.clone().map(Value::from),
"logo_uri" => self.logo_uri.clone().map(Value::from),
_ => None, _ => None,
} }
} }
fn static_fields(&self) -> Option<&'static [&'static str]> { fn static_fields(&self) -> Option<&'static [&'static str]> {
Some(&[ Some(&["server_name", "policy_uri", "tos_uri", "imprint"])
"server_name",
"service_name",
"policy_uri",
"tos_uri",
"imprint",
"logo_uri",
])
} }
} }

View File

@@ -0,0 +1,39 @@
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use minijinja::{value::StructObject, Value};
/// Site features information.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SiteFeatures {
/// Whether local password-based registration is enabled.
pub password_registration: bool,
/// Whether local password-based login is enabled.
pub password_login: bool,
}
impl StructObject for SiteFeatures {
fn get_field(&self, field: &str) -> Option<Value> {
match field {
"password_registration" => Some(Value::from(self.password_registration)),
"password_login" => Some(Value::from(self.password_login)),
_ => None,
}
}
fn static_fields(&self) -> Option<&'static [&'static str]> {
Some(&["password_registration", "password_login"])
}
}

View File

@@ -47,7 +47,7 @@ pub use self::{
EmailVerificationPageContext, EmptyContext, ErrorContext, FormPostContext, IndexContext, EmailVerificationPageContext, EmptyContext, ErrorContext, FormPostContext, IndexContext,
LoginContext, LoginFormField, NotFoundContext, PolicyViolationContext, PostAuthContext, LoginContext, LoginFormField, NotFoundContext, PolicyViolationContext, PostAuthContext,
PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
SiteBranding, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, SiteBranding, SiteFeatures, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamRegisterFormField, UpstreamSuggestLink, WithCsrf, WithLanguage, UpstreamRegisterFormField, UpstreamSuggestLink, WithCsrf, WithLanguage,
WithOptionalSession, WithSession, WithOptionalSession, WithSession,
}, },
@@ -70,6 +70,7 @@ pub struct Templates {
translator: Arc<ArcSwap<Translator>>, translator: Arc<ArcSwap<Translator>>,
url_builder: UrlBuilder, url_builder: UrlBuilder,
branding: SiteBranding, branding: SiteBranding,
features: SiteFeatures,
vite_manifest_path: Utf8PathBuf, vite_manifest_path: Utf8PathBuf,
translations_path: Utf8PathBuf, translations_path: Utf8PathBuf,
path: Utf8PathBuf, path: Utf8PathBuf,
@@ -149,6 +150,7 @@ impl Templates {
vite_manifest_path: Utf8PathBuf, vite_manifest_path: Utf8PathBuf,
translations_path: Utf8PathBuf, translations_path: Utf8PathBuf,
branding: SiteBranding, branding: SiteBranding,
features: SiteFeatures,
) -> Result<Self, TemplateLoadingError> { ) -> Result<Self, TemplateLoadingError> {
let (translator, environment) = Self::load_( let (translator, environment) = Self::load_(
&path, &path,
@@ -156,6 +158,7 @@ impl Templates {
&vite_manifest_path, &vite_manifest_path,
&translations_path, &translations_path,
branding.clone(), branding.clone(),
features,
) )
.await?; .await?;
Ok(Self { Ok(Self {
@@ -166,6 +169,7 @@ impl Templates {
vite_manifest_path, vite_manifest_path,
translations_path, translations_path,
branding, branding,
features,
}) })
} }
@@ -175,6 +179,7 @@ impl Templates {
vite_manifest_path: &Utf8Path, vite_manifest_path: &Utf8Path,
translations_path: &Utf8Path, translations_path: &Utf8Path,
branding: SiteBranding, branding: SiteBranding,
features: SiteFeatures,
) -> Result<(Arc<Translator>, Arc<minijinja::Environment<'static>>), TemplateLoadingError> { ) -> Result<(Arc<Translator>, Arc<minijinja::Environment<'static>>), TemplateLoadingError> {
let path = path.to_owned(); let path = path.to_owned();
let span = tracing::Span::current(); let span = tracing::Span::current();
@@ -230,6 +235,7 @@ impl Templates {
.await??; .await??;
env.add_global("branding", Value::from_struct_object(branding)); env.add_global("branding", Value::from_struct_object(branding));
env.add_global("features", Value::from_struct_object(features));
self::functions::register( self::functions::register(
&mut env, &mut env,
@@ -265,6 +271,7 @@ impl Templates {
&self.vite_manifest_path, &self.vite_manifest_path,
&self.translations_path, &self.translations_path,
self.branding.clone(), self.branding.clone(),
self.features,
) )
.await?; .await?;
@@ -425,7 +432,11 @@ mod tests {
let path = Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../templates/"); let path = Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../templates/");
let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None); let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None);
let branding = SiteBranding::new("example.com").with_service_name("Example"); let branding = SiteBranding::new("example.com");
let features = SiteFeatures {
password_login: true,
password_registration: true,
};
let vite_manifest_path = let vite_manifest_path =
Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../frontend/dist/manifest.json"); Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../frontend/dist/manifest.json");
let translations_path = let translations_path =
@@ -436,6 +447,7 @@ mod tests {
vite_manifest_path, vite_manifest_path,
translations_path, translations_path,
branding, branding,
features,
) )
.await .await
.unwrap(); .unwrap();

View File

@@ -1962,6 +1962,10 @@
"format": "uint64", "format": "uint64",
"maximum": 86400.0, "maximum": 86400.0,
"minimum": 60.0 "minimum": 60.0
},
"password_registration_enabled": {
"description": "Whether to enable self-service password registration. Defaults to `true` if password authentication is enabled.",
"type": "boolean"
} }
} }
} }

View File

@@ -36,7 +36,10 @@ limitations under the License.
{{ logout.button(text=_("action.sign_out"), csrf_token=csrf_token) }} {{ logout.button(text=_("action.sign_out"), csrf_token=csrf_token) }}
{% else %} {% else %}
{{ button.link(text=_("action.sign_in"), href="/login") }} {{ button.link(text=_("action.sign_in"), href="/login") }}
{% if features.password_registration %}
{{ button.link_outline(text=_("mas.navbar.register"), href="/register") }} {{ button.link_outline(text=_("mas.navbar.register"), href="/register") }}
{% endif %} {% endif %}
{% endif %}
</main> </main>
{% endblock content %} {% endblock content %}

View File

@@ -20,7 +20,7 @@ limitations under the License.
{% block content %} {% block content %}
<main class="flex flex-col gap-6"> <main class="flex flex-col gap-6">
{% if not password_disabled %} {% if features.password_login %}
<header class="page-heading"> <header class="page-heading">
<div class="icon"> <div class="icon">
{{ icon.user_profile_solid() }} {{ icon.user_profile_solid() }}
@@ -62,7 +62,7 @@ limitations under the License.
{{ button.button(text=_("action.continue")) }} {{ button.button(text=_("action.continue")) }}
</form> </form>
{% if not next or next.kind != "link_upstream" %} {% if (not next or next.kind != "link_upstream") and features.password_registration %}
<div class="flex gap-1 justify-center items-center cpd-text-body-md-regular"> <div class="flex gap-1 justify-center items-center cpd-text-body-md-regular">
<p class="cpd-text-secondary"> <p class="cpd-text-secondary">
{{ _("mas.login.call_to_register") }} {{ _("mas.login.call_to_register") }}
@@ -75,7 +75,7 @@ limitations under the License.
{% endif %} {% endif %}
{% if providers %} {% if providers %}
{% if not password_disabled %} {% if features.password_login %}
{{ field.separator() }} {{ field.separator() }}
{% endif %} {% endif %}
@@ -89,7 +89,7 @@ limitations under the License.
{% endfor %} {% endfor %}
{% endif %} {% endif %}
{% if not providers and password_disabled %} {% if not providers and not features.password_login %}
<div class="text-center"> <div class="text-center">
{{ _("mas.login.no_login_methods") }} {{ _("mas.login.no_login_methods") }}
</div> </div>

View File

@@ -226,7 +226,7 @@
}, },
"register": "Create an account", "register": "Create an account",
"@register": { "@register": {
"context": "pages/index.html:39:34-58" "context": "pages/index.html:41:36-60"
}, },
"signed_in_as": "Signed in as <span class=\"font-semibold\">%(username)s</span>.", "signed_in_as": "Signed in as <span class=\"font-semibold\">%(username)s</span>.",
"@signed_in_as": { "@signed_in_as": {