You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Refactor listeners building
This commit is contained in:
@@ -17,7 +17,7 @@ use std::{sync::Arc, time::Duration};
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use futures_util::{
|
||||
future::{FutureExt, OptionFuture},
|
||||
future::FutureExt,
|
||||
stream::{StreamExt, TryStreamExt},
|
||||
};
|
||||
use hyper::Server;
|
||||
@@ -25,9 +25,9 @@ use mas_config::RootConfig;
|
||||
use mas_email::Mailer;
|
||||
use mas_handlers::{AppState, MatrixHomeserver};
|
||||
use mas_http::ServerLayer;
|
||||
use mas_listener::{maybe_tls::MaybeTlsAcceptor, unix_or_tcp::UnixOrTcpListener};
|
||||
use mas_listener::maybe_tls::MaybeTlsAcceptor;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{Route, UrlBuilder};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::MIGRATOR;
|
||||
use mas_tasks::TaskQueue;
|
||||
use mas_templates::Templates;
|
||||
@@ -213,8 +213,6 @@ impl Options {
|
||||
&config.email.reply_to,
|
||||
);
|
||||
|
||||
let static_files = mas_static_files::service(&config.http.web_root);
|
||||
|
||||
let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone());
|
||||
|
||||
let listeners_config = config.http.listeners.clone();
|
||||
@@ -247,19 +245,11 @@ impl Options {
|
||||
|
||||
let signal = shutdown_signal().shared();
|
||||
let shutdown_signal = signal.clone();
|
||||
|
||||
let mut fd_manager = listenfd::ListenFd::from_env();
|
||||
|
||||
let listeners = listeners_config.into_iter().map(|listener_config| {
|
||||
// We have to borrow it here, not in the nested closure
|
||||
let fd_manager = &mut fd_manager;
|
||||
|
||||
// Let's first grab all the listeners in a synchronous manner
|
||||
// This helps with the fd_manager mutable borrow
|
||||
let listeners: Result<Vec<UnixOrTcpListener>, _> = listener_config
|
||||
.binds
|
||||
.iter()
|
||||
.map(move |bind_config| bind_config.listener(fd_manager))
|
||||
.collect();
|
||||
let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds);
|
||||
|
||||
Ok((listener_config, listeners?))
|
||||
});
|
||||
@@ -269,10 +259,8 @@ impl Options {
|
||||
.try_for_each_concurrent(None, move |(config, listeners)| {
|
||||
let signal = signal.clone();
|
||||
|
||||
let mut router = mas_handlers::empty_router(state.clone());
|
||||
|
||||
let is_tls = config.tls.is_some();
|
||||
let adresses: Vec<String> = listeners.iter().map(|listener| {
|
||||
let addresses: Vec<String> = listeners.iter().map(|listener| {
|
||||
let addr = listener.local_addr();
|
||||
let proto = if is_tls { "https" } else { "http" };
|
||||
if let Ok(addr) = addr {
|
||||
@@ -283,53 +271,15 @@ impl Options {
|
||||
}
|
||||
}).collect();
|
||||
|
||||
info!("Listening on {adresses:?} with resources {resources:?}", resources = &config.resources);
|
||||
info!("Listening on {addresses:?} with resources {resources:?}", resources = &config.resources);
|
||||
|
||||
for resource in &config.resources {
|
||||
router = match resource {
|
||||
mas_config::HttpResource::Health => {
|
||||
router.merge(mas_handlers::healthcheck_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Prometheus => {
|
||||
router.route_service("/metrics", crate::telemetry::prometheus_service())
|
||||
}
|
||||
mas_config::HttpResource::Discovery => {
|
||||
router.merge(mas_handlers::discovery_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Human => {
|
||||
router.merge(mas_handlers::human_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Static => {
|
||||
router.nest(mas_router::StaticAsset::route(), static_files.clone())
|
||||
}
|
||||
mas_config::HttpResource::OAuth => {
|
||||
router.merge(mas_handlers::api_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Compat => {
|
||||
router.merge(mas_handlers::compat_router(state.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let router = router.layer(ServerLayer::default());
|
||||
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
|
||||
|
||||
async move {
|
||||
let tls_config: OptionFuture<_> = config
|
||||
.tls
|
||||
.map(|tls_config| async move {
|
||||
let (key, chain) = tls_config.load().await?;
|
||||
let key = rustls::PrivateKey(key);
|
||||
let chain = chain.into_iter().map(rustls::Certificate).collect();
|
||||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(chain, key)
|
||||
.context("failed to build TLS server config")?;
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
anyhow::Ok(Arc::new(config))
|
||||
})
|
||||
.into();
|
||||
let tls_config = tls_config.await.transpose()?;
|
||||
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
||||
let tls_config = crate::server::build_tls_server_config(tls_config).await?;
|
||||
Some(Arc::new(tls_config))
|
||||
} else { None };
|
||||
|
||||
futures_util::stream::iter(listeners)
|
||||
.map(Ok)
|
||||
|
@@ -28,6 +28,7 @@ use tracing_subscriber::{
|
||||
};
|
||||
|
||||
mod commands;
|
||||
mod server;
|
||||
mod telemetry;
|
||||
|
||||
#[tokio::main]
|
||||
|
153
crates/cli/src/server.rs
Normal file
153
crates/cli/src/server.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
|
||||
os::unix::net::UnixListener,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{body::HttpBody, Router};
|
||||
use listenfd::ListenFd;
|
||||
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
|
||||
use mas_handlers::AppState;
|
||||
use mas_listener::unix_or_tcp::UnixOrTcpListener;
|
||||
use mas_router::Route;
|
||||
use rustls::ServerConfig;
|
||||
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Send,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
{
|
||||
let mut router = Router::with_state_arc(state.clone());
|
||||
|
||||
for resource in resources {
|
||||
router = match resource {
|
||||
mas_config::HttpResource::Health => {
|
||||
router.merge(mas_handlers::healthcheck_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Prometheus => {
|
||||
router.route_service("/metrics", crate::telemetry::prometheus_service())
|
||||
}
|
||||
mas_config::HttpResource::Discovery => {
|
||||
router.merge(mas_handlers::discovery_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Human => {
|
||||
router.merge(mas_handlers::human_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Static { web_root } => {
|
||||
let handler = mas_static_files::service(web_root);
|
||||
router.nest(mas_router::StaticAsset::route(), handler)
|
||||
}
|
||||
mas_config::HttpResource::OAuth => {
|
||||
router.merge(mas_handlers::api_router(state.clone()))
|
||||
}
|
||||
mas_config::HttpResource::Compat => {
|
||||
router.merge(mas_handlers::compat_router(state.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
router
|
||||
}
|
||||
|
||||
pub async fn build_tls_server_config(
|
||||
config: &HttpTlsConfig,
|
||||
) -> Result<ServerConfig, anyhow::Error> {
|
||||
let (key, chain) = config.load().await?;
|
||||
let key = rustls::PrivateKey(key);
|
||||
let chain = chain.into_iter().map(rustls::Certificate).collect();
|
||||
|
||||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(chain, key)
|
||||
.context("failed to build TLS server config")?;
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn build_listeners(
|
||||
fd_manager: &mut ListenFd,
|
||||
configs: &[HttpBindConfig],
|
||||
) -> Result<Vec<UnixOrTcpListener>, anyhow::Error> {
|
||||
let mut listeners = Vec::with_capacity(configs.len());
|
||||
|
||||
for bind in configs {
|
||||
let listener = match bind {
|
||||
HttpBindConfig::Listen { host, port } => {
|
||||
let addrs = match host.as_deref() {
|
||||
Some(host) => (host, *port)
|
||||
.to_socket_addrs()
|
||||
.context("could not parse listener host")?
|
||||
.collect(),
|
||||
|
||||
None => vec![
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port),
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port),
|
||||
],
|
||||
};
|
||||
|
||||
let listener = TcpListener::bind(&addrs[..]).context("could not bind address")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
listener.try_into()?
|
||||
}
|
||||
|
||||
HttpBindConfig::Address { address } => {
|
||||
let addr: SocketAddr = address
|
||||
.parse()
|
||||
.context("could not parse listener address")?;
|
||||
let listener = TcpListener::bind(addr).context("could not bind address")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
listener.try_into()?
|
||||
}
|
||||
|
||||
HttpBindConfig::Unix { socket } => {
|
||||
let listener = UnixListener::bind(socket).context("could not bind socket")?;
|
||||
listener.try_into()?
|
||||
}
|
||||
|
||||
HttpBindConfig::FileDescriptor {
|
||||
fd,
|
||||
kind: UnixOrTcp::Tcp,
|
||||
} => {
|
||||
let listener = fd_manager
|
||||
.take_tcp_listener(*fd)?
|
||||
.context("no listener found on file descriptor")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
listener.try_into()?
|
||||
}
|
||||
|
||||
HttpBindConfig::FileDescriptor {
|
||||
fd,
|
||||
kind: UnixOrTcp::Unix,
|
||||
} => {
|
||||
let listener = fd_manager
|
||||
.take_unix_listener(*fd)?
|
||||
.context("no unix socket found on file descriptor")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
listener.try_into()?
|
||||
}
|
||||
};
|
||||
|
||||
listeners.push(listener);
|
||||
}
|
||||
|
||||
Ok(listeners)
|
||||
}
|
Reference in New Issue
Block a user