diff --git a/Cargo.lock b/Cargo.lock index a1df67e1..85ab3859 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2269,6 +2269,17 @@ version = "0.0.46" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4d2456c373231a208ad294c33dc5bff30051eafd954cd4caae83a712b12854d" +[[package]] +name = "listenfd" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14e4fcc00ff6731d94b70e16e71f43bda62883461f31230742e3bc6dddf12988" +dependencies = [ + "libc", + "uuid", + "winapi", +] + [[package]] name = "lock_api" version = "0.4.9" @@ -2350,6 +2361,7 @@ dependencies = [ "futures-util", "hyper", "indoc", + "listenfd", "mas-config", "mas-email", "mas-handlers", @@ -2390,6 +2402,7 @@ dependencies = [ "figment", "indoc", "lettre", + "listenfd", "mas-email", "mas-iana", "mas-jose", @@ -5182,6 +5195,12 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" +[[package]] +name = "uuid" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f" + [[package]] name = "valuable" version = "0.1.0" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 500c9a47..6d24c971 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -19,6 +19,7 @@ url = "2.3.1" argon2 = { version = "0.4.1", features = ["password-hash"] } watchman_client = "0.8.0" atty = "0.2.14" +listenfd = "1.0.0" tracing = "0.1.36" tracing-appender = "0.2.2" diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 65ce9ecc..c259d45f 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - net::{SocketAddr, TcpListener}, - sync::Arc, - time::Duration, -}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use anyhow::Context; use clap::Parser; -use futures_util::stream::{StreamExt, TryStreamExt}; +use futures_util::{ + future::FutureExt, + stream::{StreamExt, TryStreamExt}, +}; use hyper::Server; use mas_config::RootConfig; use mas_email::Mailer; @@ -138,16 +137,10 @@ async fn watch_templates( } impl Options { + #[allow(clippy::too_many_lines)] pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { let config: RootConfig = root.load_config()?; - let addr: SocketAddr = config - .http - .address - .parse() - .context("could not parse listener address")?; - let listener = TcpListener::bind(addr).context("could not bind address")?; - // Connect to the mail server let mail_transport = config.email.transport.to_transport().await?; mail_transport.test_connection().await?; @@ -223,6 +216,8 @@ impl Options { let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone()); + let listeners_config = config.http.listeners.clone(); + // Explicitely the config to properly zeroize secret keys drop(config); @@ -238,7 +233,7 @@ impl Options { .context("could not watch for templates changes")?; } - let state = AppState { + let state = Arc::new(AppState { pool, templates, key_store, @@ -247,18 +242,42 @@ impl Options { mailer, homeserver, policy_factory, - }; + }); - let router = mas_handlers::router(state) - .nest(mas_router::StaticAsset::route(), static_files) - .layer(ServerLayer::default()); + let signal = shutdown_signal().shared(); + let mut fd_manager = listenfd::ListenFd::from_env(); - info!("Listening on http://{}", listener.local_addr().unwrap()); + let futs = listeners_config + .into_iter() + .map(|listener_config| { + let signal = signal.clone(); + let router = mas_handlers::router(state.clone()) + .nest(mas_router::StaticAsset::route(), static_files.clone()) + .layer(ServerLayer::default()); - Server::from_tcp(listener)? - .serve(router.into_make_service_with_connect_info::()) - .with_graceful_shutdown(shutdown_signal()) - .await?; + let mut futs: Vec<_> = Vec::with_capacity(listener_config.binds.len()); + for bind in listener_config.binds { + let listener = bind.listener(&mut fd_manager)?; + let router = router.clone(); + + let addr = listener.local_addr()?; + info!("Listening on http://{addr}"); + + let fut = Server::from_tcp(listener)? + .serve(router.into_make_service_with_connect_info::()) + .with_graceful_shutdown(signal.clone()); + futs.push(fut); + } + + anyhow::Ok(futures_util::future::try_join_all(futs)) + }) + .collect::, _>>()?; + + futures_util::future::try_join_all(futs).await?; + + // This ensures we're running, even if no listener are setup + // This is useful for only running the task runner + signal.await; Ok(()) } diff --git a/crates/config/Cargo.toml b/crates/config/Cargo.toml index 65859cad..a9bad83b 100644 --- a/crates/config/Cargo.toml +++ b/crates/config/Cargo.toml @@ -24,6 +24,8 @@ serde_json = "1.0.85" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } lettre = { version = "0.10.1", default-features = false, features = ["serde", "builder"] } +listenfd = "1.0.0" + pem-rfc7468 = "0.6.0" rand = "0.8.5" diff --git a/crates/config/src/sections/http.rs b/crates/config/src/sections/http.rs index 36f9155f..da3b486c 100644 --- a/crates/config/src/sections/http.rs +++ b/crates/config/src/sections/http.rs @@ -12,9 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::path::PathBuf; +use std::{ + net::{SocketAddr, TcpListener}, + path::PathBuf, +}; +use anyhow::Context; use async_trait::async_trait; +use listenfd::ListenFd; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use url::Url; @@ -42,18 +47,59 @@ fn http_address_example_4() -> &'static str { "0.0.0.0:8080" } +#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] +#[serde(untagged)] +pub enum BindConfig { + Address { + #[schemars( + example = "http_address_example_1", + example = "http_address_example_2", + example = "http_address_example_3", + example = "http_address_example_4" + )] + address: String, + }, + FileDescriptor { + fd: usize, + }, +} + +impl BindConfig { + pub fn listener(self, fd_manager: &mut ListenFd) -> Result { + match self { + BindConfig::Address { address } => { + let addr: SocketAddr = address + .parse() + .context("could not parse listener address")?; + let listener = TcpListener::bind(addr).context("could not bind address")?; + Ok(listener) + } + + BindConfig::FileDescriptor { fd } => { + let listener = fd_manager + .take_tcp_listener(fd)? + .context("no listener found on file descriptor")?; + // XXX: Do I need that? + listener.set_nonblocking(true)?; + Ok(listener) + } + } + } +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] +pub struct ListenerConfig { + pub name: Option, + + pub binds: Vec, +} + /// Configuration related to the web server #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct HttpConfig { - /// IP and port the server should listen to - #[schemars( - example = "http_address_example_1", - example = "http_address_example_2", - example = "http_address_example_3", - example = "http_address_example_4" - )] - #[serde(default = "default_http_address")] - pub address: String, + /// List of listeners to run + #[serde(default)] + pub listeners: Vec, /// Path from which to serve static files. If not specified, it will serve /// the static files embedded in the server binary @@ -67,8 +113,13 @@ pub struct HttpConfig { impl Default for HttpConfig { fn default() -> Self { Self { - address: default_http_address(), web_root: None, + listeners: vec![ListenerConfig { + name: None, + binds: vec![BindConfig::Address { + address: default_http_address(), + }], + }], public_base: default_public_base(), } } diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 0a6967b6..577124ef 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -252,7 +252,7 @@ where #[must_use] #[allow(clippy::trait_duplication_in_bounds)] -pub fn router(state: S) -> Router +pub fn router(state: Arc) -> Router where B: HttpBody + Send + 'static, ::Data: Send, @@ -267,8 +267,6 @@ where Mailer: FromRef, MatrixHomeserver: FromRef, { - let state = Arc::new(state); - let api_router = api_router(state.clone()); let compat_router = compat_router(state.clone()); let human_router = human_router(state);