1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

WIP: better HTTP listeners

This commit is contained in:
Quentin Gliech
2022-09-30 17:49:52 +02:00
parent 93ce5c797c
commit 7fbfb74a5e
6 changed files with 127 additions and 37 deletions

19
Cargo.lock generated
View File

@ -2269,6 +2269,17 @@ version = "0.0.46"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4d2456c373231a208ad294c33dc5bff30051eafd954cd4caae83a712b12854d" checksum = "d4d2456c373231a208ad294c33dc5bff30051eafd954cd4caae83a712b12854d"
[[package]]
name = "listenfd"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14e4fcc00ff6731d94b70e16e71f43bda62883461f31230742e3bc6dddf12988"
dependencies = [
"libc",
"uuid",
"winapi",
]
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.9" version = "0.4.9"
@ -2350,6 +2361,7 @@ dependencies = [
"futures-util", "futures-util",
"hyper", "hyper",
"indoc", "indoc",
"listenfd",
"mas-config", "mas-config",
"mas-email", "mas-email",
"mas-handlers", "mas-handlers",
@ -2390,6 +2402,7 @@ dependencies = [
"figment", "figment",
"indoc", "indoc",
"lettre", "lettre",
"listenfd",
"mas-email", "mas-email",
"mas-iana", "mas-iana",
"mas-jose", "mas-jose",
@ -5182,6 +5195,12 @@ version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
[[package]]
name = "uuid"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f"
[[package]] [[package]]
name = "valuable" name = "valuable"
version = "0.1.0" version = "0.1.0"

View File

@ -19,6 +19,7 @@ url = "2.3.1"
argon2 = { version = "0.4.1", features = ["password-hash"] } argon2 = { version = "0.4.1", features = ["password-hash"] }
watchman_client = "0.8.0" watchman_client = "0.8.0"
atty = "0.2.14" atty = "0.2.14"
listenfd = "1.0.0"
tracing = "0.1.36" tracing = "0.1.36"
tracing-appender = "0.2.2" tracing-appender = "0.2.2"

View File

@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{ use std::{net::SocketAddr, sync::Arc, time::Duration};
net::{SocketAddr, TcpListener},
sync::Arc,
time::Duration,
};
use anyhow::Context; use anyhow::Context;
use clap::Parser; use clap::Parser;
use futures_util::stream::{StreamExt, TryStreamExt}; use futures_util::{
future::FutureExt,
stream::{StreamExt, TryStreamExt},
};
use hyper::Server; use hyper::Server;
use mas_config::RootConfig; use mas_config::RootConfig;
use mas_email::Mailer; use mas_email::Mailer;
@ -138,16 +137,10 @@ async fn watch_templates(
} }
impl Options { impl Options {
#[allow(clippy::too_many_lines)]
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
let config: RootConfig = root.load_config()?; let config: RootConfig = root.load_config()?;
let addr: SocketAddr = config
.http
.address
.parse()
.context("could not parse listener address")?;
let listener = TcpListener::bind(addr).context("could not bind address")?;
// Connect to the mail server // Connect to the mail server
let mail_transport = config.email.transport.to_transport().await?; let mail_transport = config.email.transport.to_transport().await?;
mail_transport.test_connection().await?; mail_transport.test_connection().await?;
@ -223,6 +216,8 @@ impl Options {
let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone()); let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone());
let listeners_config = config.http.listeners.clone();
// Explicitely the config to properly zeroize secret keys // Explicitely the config to properly zeroize secret keys
drop(config); drop(config);
@ -238,7 +233,7 @@ impl Options {
.context("could not watch for templates changes")?; .context("could not watch for templates changes")?;
} }
let state = AppState { let state = Arc::new(AppState {
pool, pool,
templates, templates,
key_store, key_store,
@ -247,18 +242,42 @@ impl Options {
mailer, mailer,
homeserver, homeserver,
policy_factory, policy_factory,
}; });
let router = mas_handlers::router(state) let signal = shutdown_signal().shared();
.nest(mas_router::StaticAsset::route(), static_files) let mut fd_manager = listenfd::ListenFd::from_env();
let futs = listeners_config
.into_iter()
.map(|listener_config| {
let signal = signal.clone();
let router = mas_handlers::router(state.clone())
.nest(mas_router::StaticAsset::route(), static_files.clone())
.layer(ServerLayer::default()); .layer(ServerLayer::default());
info!("Listening on http://{}", listener.local_addr().unwrap()); let mut futs: Vec<_> = Vec::with_capacity(listener_config.binds.len());
for bind in listener_config.binds {
let listener = bind.listener(&mut fd_manager)?;
let router = router.clone();
Server::from_tcp(listener)? let addr = listener.local_addr()?;
info!("Listening on http://{addr}");
let fut = Server::from_tcp(listener)?
.serve(router.into_make_service_with_connect_info::<SocketAddr>()) .serve(router.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(signal.clone());
.await?; futs.push(fut);
}
anyhow::Ok(futures_util::future::try_join_all(futs))
})
.collect::<Result<Vec<_>, _>>()?;
futures_util::future::try_join_all(futs).await?;
// This ensures we're running, even if no listener are setup
// This is useful for only running the task runner
signal.await;
Ok(()) Ok(())
} }

View File

@ -24,6 +24,8 @@ serde_json = "1.0.85"
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
lettre = { version = "0.10.1", default-features = false, features = ["serde", "builder"] } lettre = { version = "0.10.1", default-features = false, features = ["serde", "builder"] }
listenfd = "1.0.0"
pem-rfc7468 = "0.6.0" pem-rfc7468 = "0.6.0"
rand = "0.8.5" rand = "0.8.5"

View File

@ -12,9 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::path::PathBuf; use std::{
net::{SocketAddr, TcpListener},
path::PathBuf,
};
use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use listenfd::ListenFd;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
@ -42,18 +47,59 @@ fn http_address_example_4() -> &'static str {
"0.0.0.0:8080" "0.0.0.0:8080"
} }
/// Configuration related to the web server #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(untagged)]
pub struct HttpConfig { pub enum BindConfig {
/// IP and port the server should listen to Address {
#[schemars( #[schemars(
example = "http_address_example_1", example = "http_address_example_1",
example = "http_address_example_2", example = "http_address_example_2",
example = "http_address_example_3", example = "http_address_example_3",
example = "http_address_example_4" example = "http_address_example_4"
)] )]
#[serde(default = "default_http_address")] address: String,
pub address: String, },
FileDescriptor {
fd: usize,
},
}
impl BindConfig {
pub fn listener(self, fd_manager: &mut ListenFd) -> Result<TcpListener, anyhow::Error> {
match self {
BindConfig::Address { address } => {
let addr: SocketAddr = address
.parse()
.context("could not parse listener address")?;
let listener = TcpListener::bind(addr).context("could not bind address")?;
Ok(listener)
}
BindConfig::FileDescriptor { fd } => {
let listener = fd_manager
.take_tcp_listener(fd)?
.context("no listener found on file descriptor")?;
// XXX: Do I need that?
listener.set_nonblocking(true)?;
Ok(listener)
}
}
}
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct ListenerConfig {
pub name: Option<String>,
pub binds: Vec<BindConfig>,
}
/// Configuration related to the web server
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct HttpConfig {
/// List of listeners to run
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
/// Path from which to serve static files. If not specified, it will serve /// Path from which to serve static files. If not specified, it will serve
/// the static files embedded in the server binary /// the static files embedded in the server binary
@ -67,8 +113,13 @@ pub struct HttpConfig {
impl Default for HttpConfig { impl Default for HttpConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
address: default_http_address(),
web_root: None, web_root: None,
listeners: vec![ListenerConfig {
name: None,
binds: vec![BindConfig::Address {
address: default_http_address(),
}],
}],
public_base: default_public_base(), public_base: default_public_base(),
} }
} }

View File

@ -252,7 +252,7 @@ where
#[must_use] #[must_use]
#[allow(clippy::trait_duplication_in_bounds)] #[allow(clippy::trait_duplication_in_bounds)]
pub fn router<S, B>(state: S) -> Router<S, B> pub fn router<S, B>(state: Arc<S>) -> Router<S, B>
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send, <B as HttpBody>::Data: Send,
@ -267,8 +267,6 @@ where
Mailer: FromRef<S>, Mailer: FromRef<S>,
MatrixHomeserver: FromRef<S>, MatrixHomeserver: FromRef<S>,
{ {
let state = Arc::new(state);
let api_router = api_router(state.clone()); let api_router = api_router(state.clone());
let compat_router = compat_router(state.clone()); let compat_router = compat_router(state.clone());
let human_router = human_router(state); let human_router = human_router(state);