diff --git a/Cargo.lock b/Cargo.lock index 828fa9fe..c630c2d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2419,6 +2419,7 @@ dependencies = [ "anyhow", "argon2", "atty", + "axum 0.6.0-rc.2", "clap", "dotenv", "futures-util", @@ -2467,7 +2468,6 @@ dependencies = [ "figment", "indoc", "lettre", - "listenfd", "mas-email", "mas-iana", "mas-jose", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 5e798f5f..a488c9a6 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] +axum = "0.6.0-rc.2" tokio = { version = "1.21.2", features = ["full"] } futures-util = "0.3.24" anyhow = "1.0.65" diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 8d62dc7c..df820063 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -17,7 +17,7 @@ use std::{sync::Arc, time::Duration}; use anyhow::Context; use clap::Parser; use futures_util::{ - future::{FutureExt, OptionFuture}, + future::FutureExt, stream::{StreamExt, TryStreamExt}, }; use hyper::Server; @@ -25,9 +25,9 @@ use mas_config::RootConfig; use mas_email::Mailer; use mas_handlers::{AppState, MatrixHomeserver}; use mas_http::ServerLayer; -use mas_listener::{maybe_tls::MaybeTlsAcceptor, unix_or_tcp::UnixOrTcpListener}; +use mas_listener::maybe_tls::MaybeTlsAcceptor; use mas_policy::PolicyFactory; -use mas_router::{Route, UrlBuilder}; +use mas_router::UrlBuilder; use mas_storage::MIGRATOR; use mas_tasks::TaskQueue; use mas_templates::Templates; @@ -213,8 +213,6 @@ impl Options { &config.email.reply_to, ); - let static_files = mas_static_files::service(&config.http.web_root); - let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone()); let listeners_config = config.http.listeners.clone(); @@ -247,19 +245,11 @@ impl Options { let signal = shutdown_signal().shared(); let shutdown_signal = signal.clone(); + let mut fd_manager = listenfd::ListenFd::from_env(); - let listeners = listeners_config.into_iter().map(|listener_config| { - // We have to borrow it here, not in the nested closure - let fd_manager = &mut fd_manager; - // Let's first grab all the listeners in a synchronous manner - // This helps with the fd_manager mutable borrow - let listeners: Result, _> = listener_config - .binds - .iter() - .map(move |bind_config| bind_config.listener(fd_manager)) - .collect(); + let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds); Ok((listener_config, listeners?)) }); @@ -269,10 +259,8 @@ impl Options { .try_for_each_concurrent(None, move |(config, listeners)| { let signal = signal.clone(); - let mut router = mas_handlers::empty_router(state.clone()); - let is_tls = config.tls.is_some(); - let adresses: Vec = listeners.iter().map(|listener| { + let addresses: Vec = listeners.iter().map(|listener| { let addr = listener.local_addr(); let proto = if is_tls { "https" } else { "http" }; if let Ok(addr) = addr { @@ -283,53 +271,15 @@ impl Options { } }).collect(); - info!("Listening on {adresses:?} with resources {resources:?}", resources = &config.resources); + info!("Listening on {addresses:?} with resources {resources:?}", resources = &config.resources); - for resource in &config.resources { - router = match resource { - mas_config::HttpResource::Health => { - router.merge(mas_handlers::healthcheck_router(state.clone())) - } - mas_config::HttpResource::Prometheus => { - router.route_service("/metrics", crate::telemetry::prometheus_service()) - } - mas_config::HttpResource::Discovery => { - router.merge(mas_handlers::discovery_router(state.clone())) - } - mas_config::HttpResource::Human => { - router.merge(mas_handlers::human_router(state.clone())) - } - mas_config::HttpResource::Static => { - router.nest(mas_router::StaticAsset::route(), static_files.clone()) - } - mas_config::HttpResource::OAuth => { - router.merge(mas_handlers::api_router(state.clone())) - } - mas_config::HttpResource::Compat => { - router.merge(mas_handlers::compat_router(state.clone())) - } - } - } - - let router = router.layer(ServerLayer::default()); + let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone())); async move { - let tls_config: OptionFuture<_> = config - .tls - .map(|tls_config| async move { - let (key, chain) = tls_config.load().await?; - let key = rustls::PrivateKey(key); - let chain = chain.into_iter().map(rustls::Certificate).collect(); - let mut config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(chain, key) - .context("failed to build TLS server config")?; - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - anyhow::Ok(Arc::new(config)) - }) - .into(); - let tls_config = tls_config.await.transpose()?; + let tls_config = if let Some(tls_config) = config.tls.as_ref() { + let tls_config = crate::server::build_tls_server_config(tls_config).await?; + Some(Arc::new(tls_config)) + } else { None }; futures_util::stream::iter(listeners) .map(Ok) diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 44f4340c..926f2de8 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -28,6 +28,7 @@ use tracing_subscriber::{ }; mod commands; +mod server; mod telemetry; #[tokio::main] diff --git a/crates/cli/src/server.rs b/crates/cli/src/server.rs new file mode 100644 index 00000000..f9f6f402 --- /dev/null +++ b/crates/cli/src/server.rs @@ -0,0 +1,153 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs}, + os::unix::net::UnixListener, + sync::Arc, +}; + +use anyhow::Context; +use axum::{body::HttpBody, Router}; +use listenfd::ListenFd; +use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp}; +use mas_handlers::AppState; +use mas_listener::unix_or_tcp::UnixOrTcpListener; +use mas_router::Route; +use rustls::ServerConfig; + +#[allow(clippy::trait_duplication_in_bounds)] +pub fn build_router(state: &Arc, resources: &[HttpResource]) -> Router +where + B: HttpBody + Send + 'static, + ::Data: Send, + ::Error: std::error::Error + Send + Sync, +{ + let mut router = Router::with_state_arc(state.clone()); + + for resource in resources { + router = match resource { + mas_config::HttpResource::Health => { + router.merge(mas_handlers::healthcheck_router(state.clone())) + } + mas_config::HttpResource::Prometheus => { + router.route_service("/metrics", crate::telemetry::prometheus_service()) + } + mas_config::HttpResource::Discovery => { + router.merge(mas_handlers::discovery_router(state.clone())) + } + mas_config::HttpResource::Human => { + router.merge(mas_handlers::human_router(state.clone())) + } + mas_config::HttpResource::Static { web_root } => { + let handler = mas_static_files::service(web_root); + router.nest(mas_router::StaticAsset::route(), handler) + } + mas_config::HttpResource::OAuth => { + router.merge(mas_handlers::api_router(state.clone())) + } + mas_config::HttpResource::Compat => { + router.merge(mas_handlers::compat_router(state.clone())) + } + } + } + + router +} + +pub async fn build_tls_server_config( + config: &HttpTlsConfig, +) -> Result { + let (key, chain) = config.load().await?; + let key = rustls::PrivateKey(key); + let chain = chain.into_iter().map(rustls::Certificate).collect(); + + let mut config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(chain, key) + .context("failed to build TLS server config")?; + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + Ok(config) +} + +pub fn build_listeners( + fd_manager: &mut ListenFd, + configs: &[HttpBindConfig], +) -> Result, anyhow::Error> { + let mut listeners = Vec::with_capacity(configs.len()); + + for bind in configs { + let listener = match bind { + HttpBindConfig::Listen { host, port } => { + let addrs = match host.as_deref() { + Some(host) => (host, *port) + .to_socket_addrs() + .context("could not parse listener host")? + .collect(), + + None => vec![ + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port), + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port), + ], + }; + + let listener = TcpListener::bind(&addrs[..]).context("could not bind address")?; + listener.set_nonblocking(true)?; + listener.try_into()? + } + + HttpBindConfig::Address { address } => { + let addr: SocketAddr = address + .parse() + .context("could not parse listener address")?; + let listener = TcpListener::bind(addr).context("could not bind address")?; + listener.set_nonblocking(true)?; + listener.try_into()? + } + + HttpBindConfig::Unix { socket } => { + let listener = UnixListener::bind(socket).context("could not bind socket")?; + listener.try_into()? + } + + HttpBindConfig::FileDescriptor { + fd, + kind: UnixOrTcp::Tcp, + } => { + let listener = fd_manager + .take_tcp_listener(*fd)? + .context("no listener found on file descriptor")?; + listener.set_nonblocking(true)?; + listener.try_into()? + } + + HttpBindConfig::FileDescriptor { + fd, + kind: UnixOrTcp::Unix, + } => { + let listener = fd_manager + .take_unix_listener(*fd)? + .context("no unix socket found on file descriptor")?; + listener.set_nonblocking(true)?; + listener.try_into()? + } + }; + + listeners.push(listener); + } + + Ok(listeners) +} diff --git a/crates/config/Cargo.toml b/crates/config/Cargo.toml index 496631d6..d37eee6e 100644 --- a/crates/config/Cargo.toml +++ b/crates/config/Cargo.toml @@ -24,8 +24,6 @@ serde_json = "1.0.85" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } lettre = { version = "0.10.1", default-features = false, features = ["serde", "builder"] } -listenfd = "1.0.0" - pem-rfc7468 = "0.6.0" rustls-pemfile = "1.0.1" rand = "0.8.5" diff --git a/crates/config/src/sections/http.rs b/crates/config/src/sections/http.rs index d3a2e2e9..80a8d772 100644 --- a/crates/config/src/sections/http.rs +++ b/crates/config/src/sections/http.rs @@ -12,18 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - borrow::Cow, - io::Cursor, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs}, - ops::Deref, - os::unix::net::UnixListener, - path::PathBuf, -}; +use std::{borrow::Cow, io::Cursor, ops::Deref, path::PathBuf}; -use anyhow::{bail, Context}; +use anyhow::bail; use async_trait::async_trait; -use listenfd::ListenFd; use mas_keystore::PrivateKey; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -48,32 +40,49 @@ fn http_address_example_4() -> &'static str { "0.0.0.0:8080" } -#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] +/// Kind of socket +#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, Copy)] #[serde(rename_all = "lowercase")] pub enum UnixOrTcp { + /// UNIX domain socket Unix, + + /// TCP socket Tcp, } impl UnixOrTcp { + /// UNIX domain socket + #[must_use] pub const fn unix() -> Self { Self::Unix } + /// TCP socket + #[must_use] pub const fn tcp() -> Self { Self::Tcp } } +/// Configuration of a single listener #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] #[serde(untagged)] pub enum BindConfig { + /// Listen on the specified host and port Listen { + /// Host on which to listen. + /// + /// Defaults to listening on all addresses host: Option, + + /// Port on which to listen. port: u16, }, + /// Listen on the specified address Address { + /// Host and port on which to listen #[schemars( example = "http_address_example_1", example = "http_address_example_2", @@ -83,85 +92,30 @@ pub enum BindConfig { address: String, }, + /// Listen on a UNIX domain socket Unix { + /// Path to the socket socket: PathBuf, }, + /// Accept connections on file descriptors passed by the parent process. + /// + /// This is useful for grabbing sockets passed by systemd. + /// + /// See FileDescriptor { + /// Index of the file descriptor. Note that this is offseted by 3 + /// because of the standard input/output sockets, so setting + /// here a value of `0` will grab the file descriptor `3` fd: usize, + /// Whether the socket is a TCP socket or a UNIX domain socket. Defaults + /// to TCP. #[serde(default = "UnixOrTcp::tcp")] kind: UnixOrTcp, }, } -impl BindConfig { - // TODO: move this somewhere else - pub fn listener(&self, fd_manager: &mut ListenFd) -> Result - where - T: TryFrom + TryFrom, - >::Error: std::error::Error + Sync + Send + 'static, - >::Error: std::error::Error + Sync + Send + 'static, - { - match self { - BindConfig::Listen { host, port } => { - let addrs = match host.as_deref() { - Some(host) => (host, *port) - .to_socket_addrs() - .context("could not parse listener host")? - .collect(), - - None => vec![ - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port), - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port), - ], - }; - - let listener = TcpListener::bind(&addrs[..]).context("could not bind address")?; - listener.set_nonblocking(true)?; - Ok(listener.try_into()?) - } - - BindConfig::Address { address } => { - let addr: SocketAddr = address - .parse() - .context("could not parse listener address")?; - let listener = TcpListener::bind(addr).context("could not bind address")?; - listener.set_nonblocking(true)?; - Ok(listener.try_into()?) - } - - BindConfig::Unix { socket } => { - let listener = UnixListener::bind(socket).context("could not bind socket")?; - listener.set_nonblocking(true)?; - Ok(listener.try_into()?) - } - - BindConfig::FileDescriptor { - fd, - kind: UnixOrTcp::Tcp, - } => { - let listener = fd_manager - .take_tcp_listener(*fd)? - .context("no listener found on file descriptor")?; - listener.set_nonblocking(true)?; - Ok(listener.try_into()?) - } - - BindConfig::FileDescriptor { - fd, - kind: UnixOrTcp::Unix, - } => { - let listener = fd_manager - .take_unix_listener(*fd)? - .context("no unix socket found on file descriptor")?; - listener.set_nonblocking(true)?; - Ok(listener.try_into()?) - } - } - } -} - #[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "snake_case")] pub enum KeyOrFile { @@ -176,19 +130,34 @@ pub enum CertificateOrFile { CertificateFile(PathBuf), } +/// Configuration related to TLS on a listener #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] pub struct TlsConfig { + /// PEM-encoded X509 certificate chain #[serde(flatten)] pub certificate: CertificateOrFile, + /// Private key #[serde(flatten)] pub key: KeyOrFile, + /// Password used to decode the private key #[serde(flatten)] pub password: Option, } impl TlsConfig { + /// Load the TLS certificate chain and key file from disk + /// + /// # Errors + /// + /// Returns an error if an error was encountered either while: + /// - reading the certificate, key or password files + /// - decoding the key as PEM or DER + /// - decrypting the key if encrypted + /// - a password was provided but the key was not encrypted + /// - decoding the certificate chain as PEM + /// - the certificate chain is empty pub async fn load(&self) -> Result<(Vec, Vec>), anyhow::Error> { let password = match &self.password { Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())), @@ -267,12 +236,21 @@ pub enum Resource { Compat, /// Static files - Static, + Static { + /// Path from which to serve static files. If not specified, it will + /// serve the static files embedded in the server binary + #[serde(default)] + web_root: Option, + }, } /// Configuration of a listener #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] pub struct ListenerConfig { + /// A unique name for this listener which will be shown in traces and in + /// metrics labels + pub name: Option, + /// List of resources to mount pub resources: Vec, @@ -290,11 +268,6 @@ pub struct HttpConfig { #[serde(default)] pub listeners: Vec, - /// Path from which to serve static files. If not specified, it will serve - /// the static files embedded in the server binary - #[serde(default)] - pub web_root: Option, - /// Public URL base from where the authentication service is reachable pub public_base: Url, } @@ -302,15 +275,15 @@ pub struct HttpConfig { impl Default for HttpConfig { fn default() -> Self { Self { - web_root: None, listeners: vec![ ListenerConfig { + name: Some("web".to_owned()), resources: vec![ Resource::Discovery, Resource::Human, Resource::OAuth, Resource::Compat, - Resource::Static, + Resource::Static { web_root: None }, ], tls: None, binds: vec![BindConfig::Address { @@ -318,6 +291,7 @@ impl Default for HttpConfig { }], }, ListenerConfig { + name: Some("internal".to_owned()), resources: vec![Resource::Health], tls: None, binds: vec![BindConfig::Address { diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index 41bd8193..44a13483 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -32,7 +32,10 @@ pub use self::{ csrf::CsrfConfig, database::DatabaseConfig, email::{EmailConfig, EmailSmtpMode, EmailTransportConfig}, - http::{HttpConfig, Resource as HttpResource}, + http::{ + BindConfig as HttpBindConfig, HttpConfig, ListenerConfig as HttpListenerConfig, + Resource as HttpResource, TlsConfig as HttpTlsConfig, UnixOrTcp, + }, matrix::MatrixConfig, policy::PolicyConfig, secrets::SecretsConfig, diff --git a/crates/http/src/layers/server.rs b/crates/http/src/layers/server.rs index 636cbb1e..8dc45171 100644 --- a/crates/http/src/layers/server.rs +++ b/crates/http/src/layers/server.rs @@ -22,9 +22,20 @@ use super::otel::TraceLayer; #[derive(Debug, Default)] pub struct ServerLayer { + listener_name: Option, _t: PhantomData, } +impl ServerLayer { + #[must_use] + pub fn new(listener_name: Option) -> Self { + Self { + listener_name, + _t: PhantomData, + } + } +} + impl Layer for ServerLayer where S: Service, Response = Response> + Clone + Send + 'static, diff --git a/crates/static-files/src/lib.rs b/crates/static-files/src/lib.rs index b6e60eee..4f798da0 100644 --- a/crates/static-files/src/lib.rs +++ b/crates/static-files/src/lib.rs @@ -155,14 +155,11 @@ use tower_http::services::ServeDir; pub fn service( path: &Option, ) -> BoxCloneService, Response, Infallible> { - let builtin = self::builtin::service(); - let svc = if let Some(path) = path { - let handler = ServeDir::new(path) - .append_index_html_on_directories(false) - .fallback(builtin); + let handler = ServeDir::new(path).append_index_html_on_directories(false); on_service(MethodFilter::HEAD | MethodFilter::GET, handler) } else { + let builtin = self::builtin::service(); on_service(MethodFilter::HEAD | MethodFilter::GET, builtin) };