1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Rewrite the listeners crate

Now with a way better graceful shutdown! With proper handshakes!
This commit is contained in:
Quentin Gliech
2022-10-12 14:36:19 +02:00
parent 485778beb3
commit ee43f08cf7
19 changed files with 1092 additions and 1016 deletions

View File

@ -22,6 +22,7 @@ watchman_client = "0.8.0"
atty = "0.2.14"
listenfd = "1.0.0"
rustls = "0.20.6"
itertools = "0.10.5"
tracing = "0.1.36"
tracing-appender = "0.2.2"

View File

@ -16,26 +16,19 @@ use std::{sync::Arc, time::Duration};
use anyhow::Context;
use clap::Parser;
use futures_util::{
future::FutureExt,
stream::{StreamExt, TryStreamExt},
};
use hyper::Server;
use futures_util::stream::{StreamExt, TryStreamExt};
use itertools::Itertools;
use mas_config::RootConfig;
use mas_email::Mailer;
use mas_handlers::{AppState, MatrixHomeserver};
use mas_http::ServerLayer;
use mas_listener::{
info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection},
maybe_tls::MaybeTlsAcceptor,
proxy_protocol::MaybeProxyAcceptor,
};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR;
use mas_tasks::TaskQueue;
use mas_templates::Templates;
use tokio::io::AsyncRead;
use tokio::{io::AsyncRead, signal::unix::SignalKind};
use tracing::{error, info, log::warn};
#[derive(Parser, Debug, Default)]
@ -49,32 +42,6 @@ pub(super) struct Options {
watch: bool,
}
#[cfg(not(unix))]
async fn shutdown_signal() {
// Wait for the CTRL+C signal
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C signal handler");
tracing::info!("Got Ctrl+C, shutting down");
}
#[cfg(unix)]
async fn shutdown_signal() {
use tokio::signal::unix::{signal, SignalKind};
// Wait for SIGTERM and SIGINT signals
// This might panic but should be fine
let mut term =
signal(SignalKind::terminate()).expect("failed to install SIGTERM signal handler");
let mut int = signal(SignalKind::interrupt()).expect("failed to install SIGINT signal handler");
tokio::select! {
_ = term.recv() => tracing::info!("Got SIGTERM, shutting down"),
_ = int.recv() => tracing::info!("Got SIGINT, shutting down"),
};
}
/// Watch for changes in the templates folders
async fn watch_templates(
client: &watchman_client::Client,
@ -247,68 +214,75 @@ impl Options {
policy_factory,
});
let signal = shutdown_signal().shared();
let shutdown_signal = signal.clone();
let mut fd_manager = listenfd::ListenFd::from_env();
let listeners = listeners_config.into_iter().map(|listener_config| {
// Let's first grab all the listeners in a synchronous manner
let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds);
Ok((listener_config, listeners?))
});
let servers: Vec<Server<_>> = listeners_config
.into_iter()
.map(|config| {
// Let's first grab all the listeners
let listeners = crate::server::build_listeners(&mut fd_manager, &config.binds)?;
// Now that we have the listeners ready, we can do the rest concurrently
futures_util::stream::iter(listeners)
.try_for_each_concurrent(None, move |(config, listeners)| {
let signal = signal.clone();
// Load the TLS config
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
let tls_config = crate::server::build_tls_server_config(tls_config)?;
Some(Arc::new(tls_config))
} else {
None
};
// and build the router
let router = crate::server::build_router(&state, &config.resources)
.layer(ServerLayer::new(config.name.clone()));
// Display some informations about where we'll be serving connections
let is_tls = config.tls.is_some();
let addresses: Vec<String> = listeners.iter().map(|listener| {
let addr = listener.local_addr();
let proto = if is_tls { "https" } else { "http" };
if let Ok(addr) = addr {
format!("{proto}://{addr:?}")
} else {
warn!("Could not get local address for listener, something might be wrong!");
format!("{proto}://???")
let addresses: Vec<String> = listeners
.iter()
.map(|listener| {
let addr = listener.local_addr();
let proto = if is_tls { "https" } else { "http" };
if let Ok(addr) = addr {
format!("{proto}://{addr:?}")
} else {
warn!(
"Could not get local address for listener, something might be wrong!"
);
format!("{proto}://???")
}
})
.collect();
let additional = if config.proxy_protocol {
"(with Proxy Protocol)"
} else {
""
};
info!(
"Listening on {addresses:?} with resources {resources:?} {additional}",
resources = &config.resources
);
anyhow::Ok(listeners.into_iter().map(move |listener| {
let mut server = Server::new(listener, router.clone());
if let Some(tls_config) = &tls_config {
server = server.with_tls(tls_config.clone());
}
}).collect();
let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" };
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
let make_service = IntoMakeServiceWithConnection::new(router);
async move {
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
let tls_config = crate::server::build_tls_server_config(tls_config).await?;
Some(Arc::new(tls_config))
} else { None };
futures_util::stream::iter(listeners)
.map(Ok)
.try_for_each_concurrent(None, move |listener| {
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
let listener = ConnectionInfoAcceptor::new(listener);
Server::builder(listener)
.serve(make_service.clone())
.with_graceful_shutdown(signal.clone())
})
.await?;
anyhow::Ok(())
}
if config.proxy_protocol {
server = server.with_proxy();
}
server
}))
})
.await?;
.flatten_ok()
.collect::<Result<Vec<_>, _>>()?;
// This ensures we're running, even if no listener are setup
// This is useful for only running the task runner
shutdown_signal.await;
let shutdown = ShutdownStream::default()
.with_timeout(Duration::from_secs(60))
.with_signal(SignalKind::terminate())?
.with_signal(SignalKind::interrupt())?;
mas_listener::server::run_servers(servers, shutdown).await;
Ok(())
}

View File

@ -23,10 +23,9 @@ use axum::{body::HttpBody, Extension, Router};
use listenfd::ListenFd;
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
use mas_handlers::AppState;
use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener};
use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo};
use mas_router::Route;
use rustls::ServerConfig;
use tokio::sync::OnceCell;
#[allow(clippy::trait_duplication_in_bounds)]
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
@ -64,12 +63,9 @@ where
// TODO: do a better handler here
mas_config::HttpResource::ConnectionInfo => router.route(
"/connection-info",
axum::routing::get(
|connection: Extension<Arc<OnceCell<Connection>>>| async move {
let connection = connection.get().unwrap();
format!("{connection:?}")
},
),
axum::routing::get(|connection: Extension<ConnectionInfo>| async move {
format!("{connection:?}")
}),
),
}
}
@ -77,10 +73,8 @@ where
router
}
pub async fn build_tls_server_config(
config: &HttpTlsConfig,
) -> Result<ServerConfig, anyhow::Error> {
let (key, chain) = config.load().await?;
pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result<ServerConfig, anyhow::Error> {
let (key, chain) = config.load()?;
let key = rustls::PrivateKey(key);
let chain = chain.into_iter().map(rustls::Certificate).collect();