You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Rewrite the listeners crate
Now with a way better graceful shutdown! With proper handshakes!
This commit is contained in:
@ -22,6 +22,7 @@ watchman_client = "0.8.0"
|
||||
atty = "0.2.14"
|
||||
listenfd = "1.0.0"
|
||||
rustls = "0.20.6"
|
||||
itertools = "0.10.5"
|
||||
|
||||
tracing = "0.1.36"
|
||||
tracing-appender = "0.2.2"
|
||||
|
@ -16,26 +16,19 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use futures_util::{
|
||||
future::FutureExt,
|
||||
stream::{StreamExt, TryStreamExt},
|
||||
};
|
||||
use hyper::Server;
|
||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||
use itertools::Itertools;
|
||||
use mas_config::RootConfig;
|
||||
use mas_email::Mailer;
|
||||
use mas_handlers::{AppState, MatrixHomeserver};
|
||||
use mas_http::ServerLayer;
|
||||
use mas_listener::{
|
||||
info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection},
|
||||
maybe_tls::MaybeTlsAcceptor,
|
||||
proxy_protocol::MaybeProxyAcceptor,
|
||||
};
|
||||
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::MIGRATOR;
|
||||
use mas_tasks::TaskQueue;
|
||||
use mas_templates::Templates;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::{io::AsyncRead, signal::unix::SignalKind};
|
||||
use tracing::{error, info, log::warn};
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
@ -49,32 +42,6 @@ pub(super) struct Options {
|
||||
watch: bool,
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
async fn shutdown_signal() {
|
||||
// Wait for the CTRL+C signal
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C signal handler");
|
||||
|
||||
tracing::info!("Got Ctrl+C, shutting down");
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn shutdown_signal() {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
// Wait for SIGTERM and SIGINT signals
|
||||
// This might panic but should be fine
|
||||
let mut term =
|
||||
signal(SignalKind::terminate()).expect("failed to install SIGTERM signal handler");
|
||||
let mut int = signal(SignalKind::interrupt()).expect("failed to install SIGINT signal handler");
|
||||
|
||||
tokio::select! {
|
||||
_ = term.recv() => tracing::info!("Got SIGTERM, shutting down"),
|
||||
_ = int.recv() => tracing::info!("Got SIGINT, shutting down"),
|
||||
};
|
||||
}
|
||||
|
||||
/// Watch for changes in the templates folders
|
||||
async fn watch_templates(
|
||||
client: &watchman_client::Client,
|
||||
@ -247,68 +214,75 @@ impl Options {
|
||||
policy_factory,
|
||||
});
|
||||
|
||||
let signal = shutdown_signal().shared();
|
||||
let shutdown_signal = signal.clone();
|
||||
|
||||
let mut fd_manager = listenfd::ListenFd::from_env();
|
||||
let listeners = listeners_config.into_iter().map(|listener_config| {
|
||||
// Let's first grab all the listeners in a synchronous manner
|
||||
let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds);
|
||||
|
||||
Ok((listener_config, listeners?))
|
||||
});
|
||||
let servers: Vec<Server<_>> = listeners_config
|
||||
.into_iter()
|
||||
.map(|config| {
|
||||
// Let's first grab all the listeners
|
||||
let listeners = crate::server::build_listeners(&mut fd_manager, &config.binds)?;
|
||||
|
||||
// Now that we have the listeners ready, we can do the rest concurrently
|
||||
futures_util::stream::iter(listeners)
|
||||
.try_for_each_concurrent(None, move |(config, listeners)| {
|
||||
let signal = signal.clone();
|
||||
// Load the TLS config
|
||||
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
||||
let tls_config = crate::server::build_tls_server_config(tls_config)?;
|
||||
Some(Arc::new(tls_config))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// and build the router
|
||||
let router = crate::server::build_router(&state, &config.resources)
|
||||
.layer(ServerLayer::new(config.name.clone()));
|
||||
|
||||
// Display some informations about where we'll be serving connections
|
||||
let is_tls = config.tls.is_some();
|
||||
let addresses: Vec<String> = listeners.iter().map(|listener| {
|
||||
let addr = listener.local_addr();
|
||||
let proto = if is_tls { "https" } else { "http" };
|
||||
if let Ok(addr) = addr {
|
||||
format!("{proto}://{addr:?}")
|
||||
} else {
|
||||
warn!("Could not get local address for listener, something might be wrong!");
|
||||
format!("{proto}://???")
|
||||
let addresses: Vec<String> = listeners
|
||||
.iter()
|
||||
.map(|listener| {
|
||||
let addr = listener.local_addr();
|
||||
let proto = if is_tls { "https" } else { "http" };
|
||||
if let Ok(addr) = addr {
|
||||
format!("{proto}://{addr:?}")
|
||||
} else {
|
||||
warn!(
|
||||
"Could not get local address for listener, something might be wrong!"
|
||||
);
|
||||
format!("{proto}://???")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let additional = if config.proxy_protocol {
|
||||
"(with Proxy Protocol)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
info!(
|
||||
"Listening on {addresses:?} with resources {resources:?} {additional}",
|
||||
resources = &config.resources
|
||||
);
|
||||
|
||||
anyhow::Ok(listeners.into_iter().map(move |listener| {
|
||||
let mut server = Server::new(listener, router.clone());
|
||||
if let Some(tls_config) = &tls_config {
|
||||
server = server.with_tls(tls_config.clone());
|
||||
}
|
||||
}).collect();
|
||||
|
||||
let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" };
|
||||
|
||||
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
|
||||
|
||||
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
|
||||
let make_service = IntoMakeServiceWithConnection::new(router);
|
||||
|
||||
async move {
|
||||
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
||||
let tls_config = crate::server::build_tls_server_config(tls_config).await?;
|
||||
Some(Arc::new(tls_config))
|
||||
} else { None };
|
||||
|
||||
futures_util::stream::iter(listeners)
|
||||
.map(Ok)
|
||||
.try_for_each_concurrent(None, move |listener| {
|
||||
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
|
||||
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
|
||||
let listener = ConnectionInfoAcceptor::new(listener);
|
||||
|
||||
Server::builder(listener)
|
||||
.serve(make_service.clone())
|
||||
.with_graceful_shutdown(signal.clone())
|
||||
})
|
||||
.await?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
if config.proxy_protocol {
|
||||
server = server.with_proxy();
|
||||
}
|
||||
server
|
||||
}))
|
||||
})
|
||||
.await?;
|
||||
.flatten_ok()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// This ensures we're running, even if no listener are setup
|
||||
// This is useful for only running the task runner
|
||||
shutdown_signal.await;
|
||||
let shutdown = ShutdownStream::default()
|
||||
.with_timeout(Duration::from_secs(60))
|
||||
.with_signal(SignalKind::terminate())?
|
||||
.with_signal(SignalKind::interrupt())?;
|
||||
|
||||
mas_listener::server::run_servers(servers, shutdown).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -23,10 +23,9 @@ use axum::{body::HttpBody, Extension, Router};
|
||||
use listenfd::ListenFd;
|
||||
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
|
||||
use mas_handlers::AppState;
|
||||
use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener};
|
||||
use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo};
|
||||
use mas_router::Route;
|
||||
use rustls::ServerConfig;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
|
||||
@ -64,12 +63,9 @@ where
|
||||
// TODO: do a better handler here
|
||||
mas_config::HttpResource::ConnectionInfo => router.route(
|
||||
"/connection-info",
|
||||
axum::routing::get(
|
||||
|connection: Extension<Arc<OnceCell<Connection>>>| async move {
|
||||
let connection = connection.get().unwrap();
|
||||
format!("{connection:?}")
|
||||
},
|
||||
),
|
||||
axum::routing::get(|connection: Extension<ConnectionInfo>| async move {
|
||||
format!("{connection:?}")
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
@ -77,10 +73,8 @@ where
|
||||
router
|
||||
}
|
||||
|
||||
pub async fn build_tls_server_config(
|
||||
config: &HttpTlsConfig,
|
||||
) -> Result<ServerConfig, anyhow::Error> {
|
||||
let (key, chain) = config.load().await?;
|
||||
pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result<ServerConfig, anyhow::Error> {
|
||||
let (key, chain) = config.load()?;
|
||||
let key = rustls::PrivateKey(key);
|
||||
let chain = chain.into_iter().map(rustls::Certificate).collect();
|
||||
|
||||
|
Reference in New Issue
Block a user