From ee43f08cf7b400a6694f9e1a179789bb81eb1f7d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 12 Oct 2022 14:36:19 +0200 Subject: [PATCH] Rewrite the listeners crate Now with a way better graceful shutdown! With proper handshakes! --- Cargo.lock | 22 +- crates/cli/Cargo.toml | 1 + crates/cli/src/commands/server.rs | 156 ++++----- crates/cli/src/server.rs | 18 +- crates/config/src/sections/http.rs | 10 +- crates/listener/Cargo.toml | 21 +- crates/listener/examples/demo/main.rs | 49 +++ crates/listener/src/info.rs | 323 ------------------ crates/listener/src/lib.rs | 36 +- crates/listener/src/maybe_tls.rs | 246 ++++++------- .../listener/src/proxy_protocol/acceptor.rs | 70 ++-- crates/listener/src/proxy_protocol/maybe.rs | 203 +++-------- crates/listener/src/proxy_protocol/mod.rs | 6 +- crates/listener/src/proxy_protocol/stream.rs | 169 --------- crates/listener/src/proxy_protocol/v1.rs | 52 +-- crates/listener/src/rewind.rs | 151 ++++++++ crates/listener/src/server.rs | 301 ++++++++++++++++ crates/listener/src/shutdown.rs | 178 ++++++++++ crates/listener/src/unix_or_tcp.rs | 96 +++--- 19 files changed, 1092 insertions(+), 1016 deletions(-) create mode 100644 crates/listener/examples/demo/main.rs delete mode 100644 crates/listener/src/info.rs delete mode 100644 crates/listener/src/proxy_protocol/stream.rs create mode 100644 crates/listener/src/rewind.rs create mode 100644 crates/listener/src/server.rs create mode 100644 crates/listener/src/shutdown.rs diff --git a/Cargo.lock b/Cargo.lock index 04f88ab9..af9ee3f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2425,6 +2425,7 @@ dependencies = [ "futures-util", "hyper", "indoc", + "itertools", "listenfd", "mas-config", "mas-email", @@ -2677,15 +2678,21 @@ dependencies = [ name = "mas-listener" version = "0.1.0" dependencies = [ + "anyhow", + "bytes 1.2.1", "futures-util", + "http-body", "hyper", + "libc", "pin-project-lite", "thiserror", "tokio", "tokio-rustls 0.23.4", - "tower", + "tokio-test", "tower-http", + "tower-service", "tracing", + "tracing-subscriber", ] [[package]] @@ -4872,6 +4879,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes 1.2.1", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.6.10" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index a488c9a6..40fd40b6 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -22,6 +22,7 @@ watchman_client = "0.8.0" atty = "0.2.14" listenfd = "1.0.0" rustls = "0.20.6" +itertools = "0.10.5" tracing = "0.1.36" tracing-appender = "0.2.2" diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 80048829..ac82a37a 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -16,26 +16,19 @@ use std::{sync::Arc, time::Duration}; use anyhow::Context; use clap::Parser; -use futures_util::{ - future::FutureExt, - stream::{StreamExt, TryStreamExt}, -}; -use hyper::Server; +use futures_util::stream::{StreamExt, TryStreamExt}; +use itertools::Itertools; use mas_config::RootConfig; use mas_email::Mailer; use mas_handlers::{AppState, MatrixHomeserver}; use mas_http::ServerLayer; -use mas_listener::{ - info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection}, - maybe_tls::MaybeTlsAcceptor, - proxy_protocol::MaybeProxyAcceptor, -}; +use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_storage::MIGRATOR; use mas_tasks::TaskQueue; use mas_templates::Templates; -use tokio::io::AsyncRead; +use tokio::{io::AsyncRead, signal::unix::SignalKind}; use tracing::{error, info, log::warn}; #[derive(Parser, Debug, Default)] @@ -49,32 +42,6 @@ pub(super) struct Options { watch: bool, } -#[cfg(not(unix))] -async fn shutdown_signal() { - // Wait for the CTRL+C signal - tokio::signal::ctrl_c() - .await - .expect("failed to install Ctrl+C signal handler"); - - tracing::info!("Got Ctrl+C, shutting down"); -} - -#[cfg(unix)] -async fn shutdown_signal() { - use tokio::signal::unix::{signal, SignalKind}; - - // Wait for SIGTERM and SIGINT signals - // This might panic but should be fine - let mut term = - signal(SignalKind::terminate()).expect("failed to install SIGTERM signal handler"); - let mut int = signal(SignalKind::interrupt()).expect("failed to install SIGINT signal handler"); - - tokio::select! { - _ = term.recv() => tracing::info!("Got SIGTERM, shutting down"), - _ = int.recv() => tracing::info!("Got SIGINT, shutting down"), - }; -} - /// Watch for changes in the templates folders async fn watch_templates( client: &watchman_client::Client, @@ -247,68 +214,75 @@ impl Options { policy_factory, }); - let signal = shutdown_signal().shared(); - let shutdown_signal = signal.clone(); - let mut fd_manager = listenfd::ListenFd::from_env(); - let listeners = listeners_config.into_iter().map(|listener_config| { - // Let's first grab all the listeners in a synchronous manner - let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds); - Ok((listener_config, listeners?)) - }); + let servers: Vec> = listeners_config + .into_iter() + .map(|config| { + // Let's first grab all the listeners + let listeners = crate::server::build_listeners(&mut fd_manager, &config.binds)?; - // Now that we have the listeners ready, we can do the rest concurrently - futures_util::stream::iter(listeners) - .try_for_each_concurrent(None, move |(config, listeners)| { - let signal = signal.clone(); + // Load the TLS config + let tls_config = if let Some(tls_config) = config.tls.as_ref() { + let tls_config = crate::server::build_tls_server_config(tls_config)?; + Some(Arc::new(tls_config)) + } else { + None + }; + // and build the router + let router = crate::server::build_router(&state, &config.resources) + .layer(ServerLayer::new(config.name.clone())); + + // Display some informations about where we'll be serving connections let is_tls = config.tls.is_some(); - let addresses: Vec = listeners.iter().map(|listener| { - let addr = listener.local_addr(); - let proto = if is_tls { "https" } else { "http" }; - if let Ok(addr) = addr { - format!("{proto}://{addr:?}") - } else { - warn!("Could not get local address for listener, something might be wrong!"); - format!("{proto}://???") + let addresses: Vec = listeners + .iter() + .map(|listener| { + let addr = listener.local_addr(); + let proto = if is_tls { "https" } else { "http" }; + if let Ok(addr) = addr { + format!("{proto}://{addr:?}") + } else { + warn!( + "Could not get local address for listener, something might be wrong!" + ); + format!("{proto}://???") + } + }) + .collect(); + + let additional = if config.proxy_protocol { + "(with Proxy Protocol)" + } else { + "" + }; + + info!( + "Listening on {addresses:?} with resources {resources:?} {additional}", + resources = &config.resources + ); + + anyhow::Ok(listeners.into_iter().map(move |listener| { + let mut server = Server::new(listener, router.clone()); + if let Some(tls_config) = &tls_config { + server = server.with_tls(tls_config.clone()); } - }).collect(); - - let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" }; - - info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources); - - let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone())); - let make_service = IntoMakeServiceWithConnection::new(router); - - async move { - let tls_config = if let Some(tls_config) = config.tls.as_ref() { - let tls_config = crate::server::build_tls_server_config(tls_config).await?; - Some(Arc::new(tls_config)) - } else { None }; - - futures_util::stream::iter(listeners) - .map(Ok) - .try_for_each_concurrent(None, move |listener| { - let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener); - let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol); - let listener = ConnectionInfoAcceptor::new(listener); - - Server::builder(listener) - .serve(make_service.clone()) - .with_graceful_shutdown(signal.clone()) - }) - .await?; - - anyhow::Ok(()) - } + if config.proxy_protocol { + server = server.with_proxy(); + } + server + })) }) - .await?; + .flatten_ok() + .collect::, _>>()?; - // This ensures we're running, even if no listener are setup - // This is useful for only running the task runner - shutdown_signal.await; + let shutdown = ShutdownStream::default() + .with_timeout(Duration::from_secs(60)) + .with_signal(SignalKind::terminate())? + .with_signal(SignalKind::interrupt())?; + + mas_listener::server::run_servers(servers, shutdown).await; Ok(()) } diff --git a/crates/cli/src/server.rs b/crates/cli/src/server.rs index 7ca43d0b..e7c37b7d 100644 --- a/crates/cli/src/server.rs +++ b/crates/cli/src/server.rs @@ -23,10 +23,9 @@ use axum::{body::HttpBody, Extension, Router}; use listenfd::ListenFd; use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp}; use mas_handlers::AppState; -use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener}; +use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo}; use mas_router::Route; use rustls::ServerConfig; -use tokio::sync::OnceCell; #[allow(clippy::trait_duplication_in_bounds)] pub fn build_router(state: &Arc, resources: &[HttpResource]) -> Router @@ -64,12 +63,9 @@ where // TODO: do a better handler here mas_config::HttpResource::ConnectionInfo => router.route( "/connection-info", - axum::routing::get( - |connection: Extension>>| async move { - let connection = connection.get().unwrap(); - format!("{connection:?}") - }, - ), + axum::routing::get(|connection: Extension| async move { + format!("{connection:?}") + }), ), } } @@ -77,10 +73,8 @@ where router } -pub async fn build_tls_server_config( - config: &HttpTlsConfig, -) -> Result { - let (key, chain) = config.load().await?; +pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result { + let (key, chain) = config.load()?; let key = rustls::PrivateKey(key); let chain = chain.into_iter().map(rustls::Certificate).collect(); diff --git a/crates/config/src/sections/http.rs b/crates/config/src/sections/http.rs index c5c971bf..812bc571 100644 --- a/crates/config/src/sections/http.rs +++ b/crates/config/src/sections/http.rs @@ -163,11 +163,11 @@ impl TlsConfig { /// - a password was provided but the key was not encrypted /// - decoding the certificate chain as PEM /// - the certificate chain is empty - pub async fn load(&self) -> Result<(Vec, Vec>), anyhow::Error> { + pub fn load(&self) -> Result<(Vec, Vec>), anyhow::Error> { let password = match &self.password { Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())), Some(PasswordOrFile::PasswordFile(path)) => { - Some(Cow::Owned(tokio::fs::read_to_string(path).await?)) + Some(Cow::Owned(std::fs::read_to_string(path)?)) } None => None, }; @@ -185,7 +185,7 @@ impl TlsConfig { KeyOrFile::KeyFile(path) => { // When reading from disk, it might be either PEM or DER. `PrivateKey::load*` // will try both. - let key = tokio::fs::read(path).await?; + let key = std::fs::read(path)?; if let Some(password) = password { PrivateKey::load_encrypted(&key, password.as_bytes())? } else { @@ -202,9 +202,7 @@ impl TlsConfig { let certificate_chain_pem = match &self.certificate { CertificateOrFile::Certificate(pem) => Cow::Borrowed(pem.as_str()), - CertificateOrFile::CertificateFile(path) => { - Cow::Owned(tokio::fs::read_to_string(path).await?) - } + CertificateOrFile::CertificateFile(path) => Cow::Owned(std::fs::read_to_string(path)?), }; let mut certificate_chain_reader = Cursor::new(certificate_chain_pem.as_bytes()); diff --git a/crates/listener/Cargo.toml b/crates/listener/Cargo.toml index ee1bb069..93301c61 100644 --- a/crates/listener/Cargo.toml +++ b/crates/listener/Cargo.toml @@ -6,12 +6,25 @@ edition = "2021" license = "Apache-2.0" [dependencies] +bytes = "1.2.1" futures-util = "0.3.24" -hyper = { version = "0.14.20", features = ["server"] } +http-body = "0.4.2" +hyper = { version = "0.14.20", features = ["server", "http1", "http2"] } pin-project-lite = "0.2.9" thiserror = "1.0.37" -tokio = { version = "1.21.2", features = ["net"] } +tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time"] } tokio-rustls = "0.23.4" -tower = "0.4.13" tower-http = { version = "0.3.4", features = ["add-extension"] } -tracing = "0.1.36" +tower-service = "0.3.2" +tracing = "0.1.37" +libc = "0.2.135" + +[dev-dependencies] +tokio-test = "0.4.2" +anyhow = "1.0.65" +tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time", "rt-multi-thread"] } +tracing-subscriber = "0.3.16" + +[[example]] +name = "demo" +path = "examples/demo/main.rs" diff --git a/crates/listener/examples/demo/main.rs b/crates/listener/examples/demo/main.rs new file mode 100644 index 00000000..6e503507 --- /dev/null +++ b/crates/listener/examples/demo/main.rs @@ -0,0 +1,49 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + convert::Infallible, + net::{Ipv4Addr, TcpListener}, + time::Duration, +}; + +use hyper::{service::service_fn, Request, Response}; +use tokio::signal::unix::SignalKind; +use tokio_streams_util::{server::Server, shutdown::ShutdownStream, ConnectionInfo}; + +async fn handler(req: Request) -> Result, Infallible> { + tracing::info!("Handling request"); + tokio::time::sleep(Duration::from_secs(3)).await; + let info = req.extensions().get::().unwrap(); + let body = format!("{info:?}"); + Ok(Response::new(body)) +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt::init(); + + let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 3000))?; + let service = service_fn(handler); + let server = Server::try_new(listener, service)?; + + tracing::info!("Listening on 127.0.0.1:3000"); + + let shutdown = ShutdownStream::default() + .with_signal(SignalKind::interrupt())? + .with_signal(SignalKind::terminate())?; + server.run(shutdown).await; + + Ok(()) +} diff --git a/crates/listener/src/info.rs b/crates/listener/src/info.rs deleted file mode 100644 index 4757f62e..00000000 --- a/crates/listener/src/info.rs +++ /dev/null @@ -1,323 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{ - future::Ready, - net::SocketAddr, - ops::Deref, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use hyper::server::accept::Accept; -use thiserror::Error; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::OnceCell, -}; -use tower::Service; -use tower_http::add_extension::AddExtension; - -use crate::{ - maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo, TlsStreamInfoError}, - proxy_protocol::{ - MaybeProxyAcceptor, MaybeProxyStream, ProxyHandshakeNotDone, ProxyProtocolV1Info, - }, - unix_or_tcp::{UnixOrTcpConnection, UnixOrTcpListener}, -}; - -// TODO: this is a mess, clean that up - -#[derive(Error, Debug)] -#[non_exhaustive] -pub enum FromStreamError { - #[error(transparent)] - Proxy(#[from] ProxyHandshakeNotDone), - - #[error(transparent)] - Tls(#[from] TlsStreamInfoError), - - #[error("Could not grab a reference to the underlying stream")] - GetRef, - - #[error("Could not get address info from underlying stream")] - IoError(#[from] std::io::Error), -} - -#[derive(Debug, Clone)] -pub struct Connection { - proxy: Option, - tls: Option, - - // We're not saving the UNIX domain socket address here because it can't be cloned, which is - // required for injecting the connection information as an extension - local_tcp_addr: Option, - peer_tcp_addr: Option, -} - -#[derive(Error, Debug)] -#[non_exhaustive] -pub enum GrabAddressError { - #[error("Proxy protocol was initiated with an unknown protocol")] - ProxyUnknown, - - #[error("Proxy protocol was initiated with UDP")] - ProxyUdp, - - #[error("Underlying listener is a UNIX socket")] - UnixListener, -} - -impl MaybeProxyAcceptor> { - pub fn can_have_peer_address(&self) -> bool { - self.is_proxied() || self.is_tcp() - } -} - -impl MaybeProxyStream> { - /// Get informations about this connection - /// - /// # Errors - /// - /// Returns an error if the proxy protocol or the TLS handhakes are not done - /// yet - pub fn connection_info(&self) -> Result { - Connection::from_stream(self) - } -} - -impl Connection { - /// Get informations about this connection - /// - /// # Errors - /// - /// Returns an error if the proxy protocol or the TLS handhakes are not done - /// yet - pub fn from_stream( - stream: &MaybeProxyStream>, - ) -> Result { - let proxy = stream.proxy_info()?.cloned(); - let tls = stream.tls_info()?; - let original = stream.get_ref().ok_or(FromStreamError::GetRef)?; - let local_tcp_addr = original.local_addr()?.into_net(); - let peer_tcp_addr = original.peer_addr()?.into_net(); - - Ok(Self { - proxy, - tls, - local_tcp_addr, - peer_tcp_addr, - }) - } - - #[must_use] - pub const fn is_proxied(&self) -> bool { - self.proxy.is_some() - } - - #[must_use] - pub const fn is_tls(&self) -> bool { - self.tls.is_some() - } - - /// Get the outmost peer address, either from the TCP listener or from the - /// proxy protocol infos. - /// - /// # Errors - /// - /// Returns an error if the info from the proxy protocol was not for a TCP - /// connection, or if the proxy protocol is not being used, the underlying - /// listener was a UNIX domain socket - pub fn peer_addr(&self) -> Result<&SocketAddr, GrabAddressError> { - if let Some(proxy) = self.proxy.as_ref() { - if proxy.is_udp() { - return Err(GrabAddressError::ProxyUdp); - } - - proxy.source().ok_or(GrabAddressError::ProxyUnknown) - } else { - self.peer_tcp_addr - .as_ref() - .ok_or(GrabAddressError::UnixListener) - } - } - - /// Get the outmost local address, either from the TCP listener or from the - /// proxy protocol infos. - /// - /// # Errors - /// - /// Returns an error if the info from the proxy protocol was not for a TCP - /// connection, or if the proxy protocol is not being used, the underlying - /// listener was a UNIX domain socket - pub fn local_addr(&self) -> Result<&SocketAddr, GrabAddressError> { - if let Some(proxy) = self.proxy.as_ref() { - if proxy.is_udp() { - return Err(GrabAddressError::ProxyUdp); - } - - proxy.destination().ok_or(GrabAddressError::ProxyUnknown) - } else { - self.local_tcp_addr - .as_ref() - .ok_or(GrabAddressError::UnixListener) - } - } -} - -pin_project_lite::pin_project! { - pub struct ConnectionInfoAcceptor { - #[pin] - acceptor: MaybeProxyAcceptor>, - } -} - -impl ConnectionInfoAcceptor { - pub const fn new(acceptor: MaybeProxyAcceptor>) -> Self { - Self { acceptor } - } -} - -impl Accept for ConnectionInfoAcceptor { - type Conn = ConnectionInfoStream; - type Error = std::io::Error; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll>> { - let proj = self.project(); - let ret = match futures_util::ready!(proj.acceptor.poll_accept(cx)) { - Some(Ok(conn)) => Some(Ok(ConnectionInfoStream::new(conn))), - Some(Err(e)) => Some(Err(e)), - None => None, - }; - Poll::Ready(ret) - } -} - -pin_project_lite::pin_project! { - pub struct ConnectionInfoStream { - connection: Arc>, - #[pin] - stream: MaybeProxyStream>, - } -} - -impl ConnectionInfoStream { - pub fn new(stream: MaybeProxyStream>) -> Self { - Self { - connection: Arc::new(OnceCell::const_new()), - stream, - } - } -} - -impl Deref for ConnectionInfoStream { - type Target = MaybeProxyStream>; - fn deref(&self) -> &Self::Target { - &self.stream - } -} - -impl AsyncRead for ConnectionInfoStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - let this = self.get_mut(); - futures_util::ready!(Pin::new(&mut this.stream).poll_read(cx, buf))?; - - if !this.stream.is_tls_handshaking() - && !this.stream.is_proxy_handshaking() - && !this.connection.initialized() - { - this.connection - .set(this.stream.connection_info().unwrap()) - .unwrap(); - } - - Poll::Ready(Ok(())) - } -} - -impl AsyncWrite for ConnectionInfoStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let proj = self.project(); - proj.stream.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let proj = self.project(); - proj.stream.poll_flush(cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let proj = self.project(); - proj.stream.poll_shutdown(cx) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - let proj = self.project(); - proj.stream.poll_write_vectored(cx, bufs) - } - - fn is_write_vectored(&self) -> bool { - self.stream.is_write_vectored() - } -} - -#[derive(Debug, Clone)] -pub struct IntoMakeServiceWithConnection { - svc: S, -} - -impl IntoMakeServiceWithConnection { - pub const fn new(svc: S) -> Self { - Self { svc } - } -} - -impl Service<&ConnectionInfoStream> for IntoMakeServiceWithConnection -where - S: Clone, -{ - type Response = AddExtension>>; - type Error = FromStreamError; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, target: &ConnectionInfoStream) -> Self::Future { - std::future::ready(Ok(AddExtension::new( - self.svc.clone(), - target.connection.clone(), - ))) - } -} diff --git a/crates/listener/src/lib.rs b/crates/listener/src/lib.rs index f4dbe4d7..091d5809 100644 --- a/crates/listener/src/lib.rs +++ b/crates/listener/src/lib.rs @@ -22,7 +22,41 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] -pub mod info; +use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info}; + pub mod maybe_tls; pub mod proxy_protocol; +pub mod rewind; +pub mod server; +pub mod shutdown; pub mod unix_or_tcp; + +#[derive(Debug, Clone)] +pub struct ConnectionInfo { + tls: Option, + proxy: Option, + net_peer_addr: Option, +} + +impl ConnectionInfo { + /// Returns informations about the TLS connection. Returns [`None`] if the + /// connection was not TLS. + #[must_use] + pub fn get_tls_ref(&self) -> Option<&TlsStreamInfo> { + self.tls.as_ref() + } + + /// Returns informations about the proxy protocol connection. Returns + /// [`None`] if the connection was not using the proxy protocol. + #[must_use] + pub fn get_proxy_ref(&self) -> Option<&ProxyProtocolV1Info> { + self.proxy.as_ref() + } + + /// Returns the remote peer address. Returns [`None`] if the connection was + /// established via a UNIX domain socket. + #[must_use] + pub fn get_peer_addr(&self) -> Option { + self.net_peer_addr + } +} diff --git a/crates/listener/src/maybe_tls.rs b/crates/listener/src/maybe_tls.rs index 55a1cca6..2d646558 100644 --- a/crates/listener/src/maybe_tls.rs +++ b/crates/listener/src/maybe_tls.rs @@ -13,18 +13,15 @@ // limitations under the License. use std::{ - ops::Deref, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use futures_util::{ready, Future}; -use hyper::server::accept::Accept; -use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_rustls::rustls::{ - Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite, +use tokio_rustls::{ + rustls::{Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite}, + TlsAcceptor, }; #[derive(Debug, Clone)] @@ -33,105 +30,78 @@ pub struct TlsStreamInfo { pub protocol_version: ProtocolVersion, pub negotiated_cipher_suite: SupportedCipherSuite, pub sni_hostname: Option, - pub apln_protocol: Option>, + pub alpn_protocol: Option>, pub peer_certificates: Option>, } -#[derive(Debug, Error)] -#[non_exhaustive] -pub enum TlsStreamInfoError { - #[error("TLS handshake is not done yet")] - HandshakingNotDone, - - #[error("Some fields were not available in the TLS connection")] - FieldsNotAvailable, +impl TlsStreamInfo { + #[must_use] + pub fn is_alpn_h2(&self) -> bool { + matches!(self.alpn_protocol.as_deref(), Some(b"h2")) + } } -pub enum MaybeTlsStream { - Handshaking(tokio_rustls::Accept), - Streaming(tokio_rustls::server::TlsStream), - Insecure(T), +pin_project_lite::pin_project! { + #[project = MaybeTlsStreamProj] + pub enum MaybeTlsStream { + Secure { + #[pin] + stream: tokio_rustls::server::TlsStream + }, + Insecure { + #[pin] + stream: T, + }, + } } impl MaybeTlsStream { - pub fn new(stream: T, config: Option>) -> Self - where - T: AsyncRead + AsyncWrite + Unpin, - { - if let Some(config) = config { - let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); - MaybeTlsStream::Handshaking(accept) - } else { - MaybeTlsStream::Insecure(stream) - } - } - /// Get a reference to the underlying IO stream /// /// Returns [`None`] if the stream closed before the TLS handshake finished. /// It is guaranteed to return [`Some`] value after the handshake finished, /// or if it is a non-TLS connection. - pub fn get_ref(&self) -> Option<&T> { + pub fn get_ref(&self) -> &T { match self { - Self::Handshaking(accept) => accept.get_ref(), - Self::Streaming(stream) => { - let (inner, _) = stream.get_ref(); - Some(inner) - } - Self::Insecure(inner) => Some(inner), + Self::Secure { stream } => stream.get_ref().0, + Self::Insecure { stream } => stream, } } /// Get a ref to the [`ServerConnection`] of the establish TLS stream. /// - /// Returns [`None`] if the connection is still handshaking and for non-TLS - /// connections. + /// Returns [`None`] for non-TLS connections. pub fn get_tls_connection(&self) -> Option<&ServerConnection> { match self { - Self::Streaming(stream) => { - let (_, conn) = stream.get_ref(); - Some(conn) - } - Self::Handshaking(_) | Self::Insecure(_) => None, + Self::Secure { stream } => Some(stream.get_ref().1), + Self::Insecure { .. } => None, } } /// Gather informations about the TLS connection. Returns `None` if the /// stream is not a TLS stream. - /// - /// # Errors - /// - /// Returns an error if the TLS handshake is not yet done - pub fn tls_info(&self) -> Result, TlsStreamInfoError> { - let conn = match self { - Self::Streaming(stream) => stream.get_ref().1, - Self::Handshaking(_) => return Err(TlsStreamInfoError::HandshakingNotDone), - Self::Insecure(_) => return Ok(None), - }; + pub fn tls_info(&self) -> Option { + let conn = self.get_tls_connection()?; - // NOTE: we're getting the protocol version and cipher suite *after* the - // handshake, so this should never lead to an error + // SAFETY: we're getting the protocol version and cipher suite *after* the + // handshake, so this should never lead to a panic let protocol_version = conn .protocol_version() - .ok_or(TlsStreamInfoError::FieldsNotAvailable)?; + .expect("TLS handshake is not done yet"); let negotiated_cipher_suite = conn .negotiated_cipher_suite() - .ok_or(TlsStreamInfoError::FieldsNotAvailable)?; + .expect("TLS handshake is not done yet"); let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned); - let apln_protocol = conn.alpn_protocol().map(ToOwned::to_owned); + let alpn_protocol = conn.alpn_protocol().map(ToOwned::to_owned); let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned); - Ok(Some(TlsStreamInfo { + Some(TlsStreamInfo { protocol_version, negotiated_cipher_suite, sni_hostname, - apln_protocol, + alpn_protocol, peer_certificates, - })) - } - - pub const fn is_tls_handshaking(&self) -> bool { - matches!(self, Self::Handshaking(_)) + }) } } @@ -144,20 +114,9 @@ where cx: &mut Context, buf: &mut ReadBuf, ) -> Poll> { - let pin = self.get_mut(); - match pin { - MaybeTlsStream::Handshaking(ref mut accept) => { - match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_read(cx, buf); - *pin = MaybeTlsStream::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - } - } - MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), - MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_read(cx, buf), + match self.project() { + MaybeTlsStreamProj::Secure { stream } => stream.poll_read(cx, buf), + MaybeTlsStreamProj::Insecure { stream } => stream.poll_read(cx, buf), } } } @@ -171,104 +130,89 @@ where cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let pin = self.get_mut(); - match pin { - MaybeTlsStream::Handshaking(ref mut accept) => { - match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_write(cx, buf); - *pin = MaybeTlsStream::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - } - } - MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), - MaybeTlsStream::Insecure(ref mut fallback) => Pin::new(fallback).poll_write(cx, buf), + match self.project() { + MaybeTlsStreamProj::Secure { stream } => stream.poll_write(cx, buf), + MaybeTlsStreamProj::Insecure { stream } => stream.poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + MaybeTlsStreamProj::Secure { stream } => stream.poll_write_vectored(cx, bufs), + MaybeTlsStreamProj::Insecure { stream } => stream.poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::Secure { stream } => stream.is_write_vectored(), + Self::Insecure { stream } => stream.is_write_vectored(), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())), - MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), - MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_flush(cx), + match self.project() { + MaybeTlsStreamProj::Secure { stream } => stream.poll_flush(cx), + MaybeTlsStreamProj::Insecure { stream } => stream.poll_flush(cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())), - MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), - MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_shutdown(cx), + match self.project() { + MaybeTlsStreamProj::Secure { stream } => stream.poll_shutdown(cx), + MaybeTlsStreamProj::Insecure { stream } => stream.poll_shutdown(cx), } } } -pub struct MaybeTlsAcceptor { +#[derive(Clone)] +pub struct MaybeTlsAcceptor { tls_config: Option>, - incoming: T, } -impl MaybeTlsAcceptor { - pub fn new(tls_config: Option>, incoming: T) -> Self { - Self { - tls_config, - incoming, - } +impl MaybeTlsAcceptor { + #[must_use] + pub fn new(tls_config: Option>) -> Self { + Self { tls_config } } - pub fn new_secure(tls_config: Arc, incoming: T) -> Self { + #[must_use] + pub fn new_secure(tls_config: Arc) -> Self { Self { tls_config: Some(tls_config), - incoming, } } - pub fn new_insecure(incoming: T) -> Self { - Self { - tls_config: None, - incoming, - } + #[must_use] + pub fn new_insecure() -> Self { + Self { tls_config: None } } + #[must_use] pub const fn is_secure(&self) -> bool { self.tls_config.is_some() } -} -impl Accept for MaybeTlsAcceptor -where - T: Accept + Unpin, - T::Conn: AsyncRead + AsyncWrite + Unpin, - T::Error: Into, -{ - type Conn = MaybeTlsStream; - type Error = std::io::Error; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let pin = self.get_mut(); - - let ret = match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { - Some(Ok(sock)) => { - let config = pin.tls_config.clone(); - Some(Ok(MaybeTlsStream::new(sock, config))) + /// Accept a connection and do the TLS handshake + /// + /// # Errors + /// + /// Returns an error if the TLS handshake failed + pub async fn accept(&self, stream: T) -> Result, std::io::Error> + where + T: AsyncRead + AsyncWrite + Unpin, + { + match &self.tls_config { + Some(config) => { + let acceptor = TlsAcceptor::from(config.clone()); + let stream = acceptor.accept(stream).await?; + Ok(MaybeTlsStream::Secure { stream }) } - - Some(Err(e)) => Some(Err(e.into())), - None => None, - }; - - Poll::Ready(ret) - } -} - -impl Deref for MaybeTlsAcceptor { - type Target = T; - fn deref(&self) -> &Self::Target { - &self.incoming + None => Ok(MaybeTlsStream::Insecure { stream }), + } } } diff --git a/crates/listener/src/proxy_protocol/acceptor.rs b/crates/listener/src/proxy_protocol/acceptor.rs index d10bd392..e6a6da9a 100644 --- a/crates/listener/src/proxy_protocol/acceptor.rs +++ b/crates/listener/src/proxy_protocol/acceptor.rs @@ -12,41 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. -use futures_util::ready; -use hyper::server::accept::Accept; +use bytes::BytesMut; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt}; -use super::ProxyStream; +use super::ProxyProtocolV1Info; +use crate::rewind::Rewind; -pin_project_lite::pin_project! { - pub struct ProxyAcceptor { - #[pin] - inner: A, - } +#[derive(Clone, Debug)] +pub struct ProxyAcceptor { + _private: (), } -impl ProxyAcceptor { - pub const fn new(inner: A) -> Self { - Self { inner } - } +#[derive(Debug, Error)] +#[error(transparent)] +pub enum ProxyAcceptError { + Parse(#[from] super::v1::ParseError), + Read(#[from] std::io::Error), } -impl Accept for ProxyAcceptor -where - A: Accept, -{ - type Conn = ProxyStream; - type Error = A::Error; +impl ProxyAcceptor { + #[must_use] + pub const fn new() -> Self { + Self { _private: () } + } - fn poll_accept( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { - let res = match ready!(self.project().inner.poll_accept(cx)) { - Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))), - Some(Err(e)) => Some(Err(e)), - None => None, + /// Accept a proxy-protocol stream + /// + /// # Errors + /// + /// Returns an error on read error on the underlying stream, or when the + /// proxy protocol preamble couldn't be parsed + pub async fn accept( + &self, + mut stream: T, + ) -> Result<(ProxyProtocolV1Info, Rewind), ProxyAcceptError> + where + T: AsyncRead + Unpin, + { + let mut buf = BytesMut::new(); + let info = loop { + stream.read_buf(&mut buf).await?; + + match ProxyProtocolV1Info::parse(&mut buf) { + Ok(info) => break info, + Err(e) if e.not_enough_bytes() => {} + Err(e) => return Err(e.into()), + } }; - std::task::Poll::Ready(res) + let stream = Rewind::new_buffered(stream, buf.into()); + + Ok((info, stream)) } } diff --git a/crates/listener/src/proxy_protocol/maybe.rs b/crates/listener/src/proxy_protocol/maybe.rs index b6e74377..67699f5e 100644 --- a/crates/listener/src/proxy_protocol/maybe.rs +++ b/crates/listener/src/proxy_protocol/maybe.rs @@ -12,179 +12,66 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - ops::Deref, - pin::Pin, - task::{Context, Poll}, -}; +use tokio::io::AsyncRead; -use futures_util::ready; -use hyper::server::accept::Accept; -use tokio::io::{AsyncRead, AsyncWrite}; +use super::{acceptor::ProxyAcceptError, ProxyAcceptor, ProxyProtocolV1Info}; +use crate::rewind::Rewind; -use super::{stream::HandshakeNotDone, ProxyProtocolV1Info, ProxyStream}; - -pin_project_lite::pin_project! { - pub struct MaybeProxyAcceptor { - proxied: bool, - - #[pin] - inner: A, - } +#[derive(Clone)] +pub struct MaybeProxyAcceptor { + acceptor: Option, } -impl MaybeProxyAcceptor { +impl MaybeProxyAcceptor { #[must_use] - pub const fn new(inner: A, proxied: bool) -> Self { - Self { proxied, inner } - } - - pub const fn is_proxied(&self) -> bool { - self.proxied - } -} - -impl Deref for MaybeProxyAcceptor { - type Target = A; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl Accept for MaybeProxyAcceptor -where - A: Accept, -{ - type Conn = MaybeProxyStream; - type Error = A::Error; - - fn poll_accept( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { - let proj = self.project(); - let res = match ready!(proj.inner.poll_accept(cx)) { - Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))), - Some(Err(e)) => Some(Err(e)), - None => None, + pub const fn new(proxied: bool) -> Self { + let acceptor = if proxied { + Some(ProxyAcceptor::new()) + } else { + None }; - std::task::Poll::Ready(res) + Self { acceptor } } -} -pin_project_lite::pin_project! { - #[project = MaybeProxyStreamProj] - pub enum MaybeProxyStream { - Proxied { #[pin] stream: ProxyStream }, - NotProxied { #[pin] stream: S }, - } -} - -impl MaybeProxyStream { - pub const fn new(stream: S, proxied: bool) -> Self { - if proxied { - Self::Proxied { - stream: ProxyStream::new(stream), - } - } else { - Self::NotProxied { stream } + #[must_use] + pub const fn new_proxied(acceptor: ProxyAcceptor) -> Self { + Self { + acceptor: Some(acceptor), } } - /// Get informations from the proxied connection, if it was procied + #[must_use] + pub const fn new_unproxied() -> Self { + Self { acceptor: None } + } + + #[must_use] + pub const fn is_proxied(&self) -> bool { + self.acceptor.is_some() + } + + /// Accept a connection and do the proxy protocol handshake /// /// # Errors /// - /// Returns an error if the stream did not complete the handshake yet - pub fn proxy_info(&self) -> Result, HandshakeNotDone> { - match self { - Self::Proxied { stream } => Ok(Some(stream.proxy_info()?)), - Self::NotProxied { .. } => Ok(None), - } - } - - pub const fn is_proxy_handshaking(&self) -> bool { - match self { - Self::Proxied { stream } => stream.is_handshaking(), - Self::NotProxied { .. } => false, - } - } -} - -impl Deref for MaybeProxyStream { - type Target = S; - fn deref(&self) -> &Self::Target { - match self { - Self::Proxied { stream } => &**stream, - Self::NotProxied { stream } => stream, - } - } -} - -impl AsyncRead for MaybeProxyStream -where - S: AsyncRead, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - match self.project() { - MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf), - MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf), - } - } -} - -impl AsyncWrite for MaybeProxyStream -where - S: AsyncWrite, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.project() { - MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf), - MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx), - MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx), - } - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.project() { - MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx), - MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx), - } - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - match self.project() { - MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs), - MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs), - } - } - - fn is_write_vectored(&self) -> bool { - match self { - MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(), - MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(), + /// Returns an error if the proxy protocol handshake failed + pub async fn accept( + &self, + stream: T, + ) -> Result<(Option, Rewind), ProxyAcceptError> + where + T: AsyncRead + Unpin, + { + match &self.acceptor { + Some(acceptor) => { + let (info, stream) = acceptor.accept(stream).await?; + Ok((Some(info), stream)) + } + None => { + let stream = Rewind::new(stream); + Ok((None, stream)) + } } } } diff --git a/crates/listener/src/proxy_protocol/mod.rs b/crates/listener/src/proxy_protocol/mod.rs index 7549e70d..4f97a9be 100644 --- a/crates/listener/src/proxy_protocol/mod.rs +++ b/crates/listener/src/proxy_protocol/mod.rs @@ -14,12 +14,10 @@ mod acceptor; mod maybe; -mod stream; mod v1; pub use self::{ - acceptor::ProxyAcceptor, - maybe::{MaybeProxyAcceptor, MaybeProxyStream}, - stream::{HandshakeNotDone as ProxyHandshakeNotDone, ProxyStream}, + acceptor::{ProxyAcceptError, ProxyAcceptor}, + maybe::MaybeProxyAcceptor, v1::ProxyProtocolV1Info, }; diff --git a/crates/listener/src/proxy_protocol/stream.rs b/crates/listener/src/proxy_protocol/stream.rs deleted file mode 100644 index 9126883d..00000000 --- a/crates/listener/src/proxy_protocol/stream.rs +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::ops::Deref; - -use futures_util::ready; -use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -use super::ProxyProtocolV1Info; - -// Max theorical size we need is 108 for proxy protocol v1 -const BUF_SIZE: usize = 256; - -#[derive(Debug)] -enum ProxyStreamState { - Handshaking { - buffer: [u8; BUF_SIZE], - index: usize, - }, - Established(ProxyProtocolV1Info), -} - -impl ProxyStreamState { - pub const fn is_handshaking(&self) -> bool { - matches!(self, Self::Handshaking { .. }) - } -} - -pin_project_lite::pin_project! { - #[derive(Debug)] - pub struct ProxyStream { - state: ProxyStreamState, - - #[pin] - inner: S, - } -} - -impl ProxyStream { - pub const fn new(inner: S) -> Self { - Self { - state: ProxyStreamState::Handshaking { - buffer: [0; BUF_SIZE], - index: 0, - }, - inner, - } - } -} - -#[derive(Debug, Error, Clone, Copy)] -#[error("Proxy protocol handshake is not complete")] -pub struct HandshakeNotDone; - -impl Deref for ProxyStream { - type Target = S; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl ProxyStream { - /// Get informations from the proxied connection - /// - /// # Errors - /// - /// Returns an error if the stream did not complete the handshake yet - pub fn proxy_info(&self) -> Result<&ProxyProtocolV1Info, HandshakeNotDone> { - match &self.state { - ProxyStreamState::Handshaking { .. } => Err(HandshakeNotDone), - ProxyStreamState::Established(info) => Ok(info), - } - } - - /// Returns `true` if the proxy protocol is still handshaking - pub const fn is_handshaking(&self) -> bool { - self.state.is_handshaking() - } -} - -impl AsyncRead for ProxyStream -where - S: AsyncRead, -{ - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - let proj = self.project(); - match proj.state { - ProxyStreamState::Handshaking { buffer, index } => { - let mut buffer = ReadBuf::new(&mut buffer[..]); - buffer.advance(*index); - ready!(proj.inner.poll_read(cx, &mut buffer))?; - let filled = buffer.filled(); - *index = filled.len(); - - match ProxyProtocolV1Info::parse(filled) { - Ok((info, rest)) => { - if buf.remaining() < rest.len() { - // This is highly unlikely, but is better than panicking later. - // If it ever happens, we could introduce a "buffer draining" state - // which drains the inner buffer repeatedly until it's empty - return std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "underlying buffer is too small", - ))); - } - buf.put_slice(rest); - *proj.state = ProxyStreamState::Established(info); - std::task::Poll::Ready(Ok(())) - } - Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())), - Err(e) => std::task::Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - e, - ))), - } - } - ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf), - } - } -} - -impl AsyncWrite for ProxyStream -where - S: AsyncWrite, -{ - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - let proj = self.project(); - match proj.state { - // Hold off writes until the handshake is done - // XXX: is this the right way to do it? - ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending, - ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf), - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.project().inner.poll_flush(cx) - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.project().inner.poll_shutdown(cx) - } -} diff --git a/crates/listener/src/proxy_protocol/v1.rs b/crates/listener/src/proxy_protocol/v1.rs index 32412c01..78654124 100644 --- a/crates/listener/src/proxy_protocol/v1.rs +++ b/crates/listener/src/proxy_protocol/v1.rs @@ -18,6 +18,7 @@ use std::{ str::Utf8Error, }; +use bytes::Buf; use thiserror::Error; #[derive(Debug, Clone)] @@ -35,7 +36,7 @@ pub enum ProxyProtocolV1Info { #[derive(Error, Debug)] #[error("Invalid proxy protocol header")] -pub(super) enum ParseError { +pub enum ParseError { #[error("Not enough bytes provided")] NotEnoughBytes, NoCrLf, @@ -60,17 +61,21 @@ impl ParseError { impl ProxyProtocolV1Info { #[allow(clippy::too_many_lines)] - pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> { + pub(super) fn parse(mut buf: B) -> Result + where + B: Buf + AsRef<[u8]>, + { use ParseError as E; // First, check if we *possibly* have enough bytes. // Minimum is 15: "PROXY UNKNOWN\r\n" - if bytes.len() < 15 { + if buf.remaining() < 15 { return Err(E::NotEnoughBytes); } // Let's check in the first 108 bytes if we find a CRLF - let crlf = if let Some(crlf) = bytes + let crlf = if let Some(crlf) = buf + .as_ref() .windows(2) .take(108) .position(|needle| needle == [0x0D, 0x0A]) @@ -78,7 +83,7 @@ impl ProxyProtocolV1Info { crlf } else { // If not, it might be because we don't have enough bytes - return if bytes.len() < 108 { + return if buf.remaining() < 108 { Err(E::NotEnoughBytes) } else { // Else it's just invalid @@ -86,10 +91,8 @@ impl ProxyProtocolV1Info { }; }; - // Keep the rest of the buffer to pass it to the underlying protocol - let rest = &bytes[crlf + 2..]; // Trim to everything before the CRLF - let bytes = &bytes[..crlf]; + let bytes = &buf.as_ref()[..crlf]; let mut it = bytes.splitn(6, |c| c == &b' '); // Check for the preamble @@ -187,7 +190,9 @@ impl ProxyProtocolV1Info { None => return Err(E::NoProtocol), }; - Ok((result, rest)) + buf.advance(crlf + 2); + + Ok(result) } #[must_use] @@ -258,39 +263,42 @@ mod tests { #[test] fn test_parse() { - let (info, rest) = ProxyProtocolV1Info::parse( - b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world", - ) - .unwrap(); - assert_eq!(rest, b"hello world"); + let mut buf = + b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world" + .as_slice(); + let info = ProxyProtocolV1Info::parse(&mut buf).unwrap(); + assert_eq!(buf, b"hello world"); assert!(info.is_tcp()); assert!(!info.is_udp()); assert!(!info.is_unknown()); assert!(info.is_ipv4()); assert!(!info.is_ipv6()); - let (info, rest) = ProxyProtocolV1Info::parse( + let mut buf = b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world" - ).unwrap(); - assert_eq!(rest, b"hello world"); + .as_slice(); + let info = ProxyProtocolV1Info::parse(&mut buf).unwrap(); + assert_eq!(buf, b"hello world"); assert!(info.is_tcp()); assert!(!info.is_udp()); assert!(!info.is_unknown()); assert!(!info.is_ipv4()); assert!(info.is_ipv6()); - let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap(); - assert_eq!(rest, b"hello world"); + let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice(); + let info = ProxyProtocolV1Info::parse(&mut buf).unwrap(); + assert_eq!(buf, b"hello world"); assert!(!info.is_tcp()); assert!(!info.is_udp()); assert!(info.is_unknown()); assert!(!info.is_ipv4()); assert!(!info.is_ipv6()); - let (info, rest) = ProxyProtocolV1Info::parse( + let mut buf = b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world" - ).unwrap(); - assert_eq!(rest, b"hello world"); + .as_slice(); + let info = ProxyProtocolV1Info::parse(&mut buf).unwrap(); + assert_eq!(buf, b"hello world"); assert!(!info.is_tcp()); assert!(!info.is_udp()); assert!(info.is_unknown()); diff --git a/crates/listener/src/rewind.rs b/crates/listener/src/rewind.rs new file mode 100644 index 00000000..72d661df --- /dev/null +++ b/crates/listener/src/rewind.rs @@ -0,0 +1,151 @@ +// Taken from hyper@0.14.20, src/common/io/rewind.rs + +use std::{ + cmp, io, + marker::Unpin, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{Buf, Bytes}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// Combine a buffer with an IO, rewinding reads to use the buffer. +#[derive(Debug)] +pub struct Rewind { + pre: Option, + inner: T, +} + +impl Rewind { + pub(crate) fn new(io: T) -> Self { + Rewind { + pre: None, + inner: io, + } + } + + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { + Rewind { + pre: Some(buf), + inner: io, + } + } + + #[cfg(test)] + pub(crate) fn rewind(&mut self, bs: Bytes) { + debug_assert!(self.pre.is_none()); + self.pre = Some(bs); + } +} + +impl AsyncRead for Rewind +where + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut prefix) = self.pre.take() { + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = cmp::min(prefix.len(), buf.remaining()); + // TODO: There should be a way to do following two lines cleaner... + buf.put_slice(&prefix[..copy_len]); + prefix.advance(copy_len); + // Put back what's left + if !prefix.is_empty() { + self.pre = Some(prefix); + } + + return Poll::Ready(Ok(())); + } + } + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for Rewind +where + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +#[cfg(test)] +mod tests { + // FIXME: re-implement tests with `async/await`, this import should + // trigger a warning to remind us + use bytes::Bytes; + use tokio::io::AsyncReadExt; + + use super::Rewind; + + #[tokio::test] + async fn partial_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + // Read off some bytes, ensure we filled o1 + let mut buf = [0; 2]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // At this point we should have read everything that was in the MockStream + assert_eq!(&buf, &underlying); + } + + #[tokio::test] + async fn full_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + } +} diff --git a/crates/listener/src/server.rs b/crates/listener/src/server.rs new file mode 100644 index 00000000..0c34e711 --- /dev/null +++ b/crates/listener/src/server.rs @@ -0,0 +1,301 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use futures_util::{Stream, StreamExt}; +use http_body::Body; +use hyper::{Request, Response}; +use thiserror::Error; +use tokio_rustls::rustls::ServerConfig; +use tower_http::add_extension::AddExtension; +use tower_service::Service; + +use crate::{ + maybe_tls::{MaybeTlsAcceptor, TlsStreamInfo}, + proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError}, + unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener}, + ConnectionInfo, +}; + +pub struct Server { + tls: Option>, + proxy: bool, + listener: UnixOrTcpListener, + service: S, +} + +impl Server { + /// # Errors + /// + /// Returns an error if the listener couldn't be converted via [`TryInfo`] + pub fn try_new(listener: L, service: S) -> Result + where + L: TryInto, + { + Ok(Self { + tls: None, + proxy: false, + listener: listener.try_into()?, + service, + }) + } + + #[must_use] + pub fn new(listener: impl Into, service: S) -> Self { + Self { + tls: None, + proxy: false, + listener: listener.into(), + service, + } + } + + #[must_use] + pub const fn with_proxy(mut self) -> Self { + self.proxy = true; + self + } + + #[must_use] + pub fn with_tls(mut self, config: Arc) -> Self { + self.tls = Some(config); + self + } + + /// Run a single server + pub async fn run(self, shutdown: SD) + where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into>, + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into>, + SD: Stream + Unpin, + SD::Item: std::fmt::Display, + { + run_servers(std::iter::once(self), shutdown).await; + } +} + +#[derive(Debug, Error)] +#[non_exhaustive] +enum AcceptError { + #[error("failed to accept connection from the underlying socket")] + Socket { + #[source] + source: std::io::Error, + }, + + #[error("failed to complete the TLS handshake")] + TlsHandshake { + #[source] + source: std::io::Error, + }, + + #[error("failed to complete the proxy protocol handshake")] + ProxyHandshake { + #[source] + source: ProxyAcceptError, + }, + + #[error(transparent)] + Hyper(#[from] hyper::Error), +} + +impl AcceptError { + fn socket(source: std::io::Error) -> Self { + Self::Socket { source } + } + + fn tls_handshake(source: std::io::Error) -> Self { + Self::TlsHandshake { source } + } + + fn proxy_handshake(source: ProxyAcceptError) -> Self { + Self::ProxyHandshake { source } + } +} + +async fn accept( + maybe_proxy_acceptor: &MaybeProxyAcceptor, + maybe_tls_acceptor: &MaybeTlsAcceptor, + peer_addr: SocketAddr, + stream: UnixOrTcpConnection, + service: S, +) -> Result<(), AcceptError> +where + S: Service, Response = Response>, + S::Future: Send + 'static, + S::Error: Into>, + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ + let (proxy, stream) = maybe_proxy_acceptor + .accept(stream) + .await + .map_err(AcceptError::proxy_handshake)?; + + let stream = maybe_tls_acceptor + .accept(stream) + .await + .map_err(AcceptError::tls_handshake)?; + + let tls = stream.tls_info(); + + // Figure out if it's HTTP/2 based on the negociated ALPN info + let is_h2 = tls.as_ref().map_or(false, TlsStreamInfo::is_alpn_h2); + + let info = ConnectionInfo { + tls, + proxy, + net_peer_addr: peer_addr.into_net(), + }; + + let service = AddExtension::new(service, info); + + if is_h2 { + hyper::server::conn::Http::new() + .http2_only(true) + .serve_connection(stream, service) + .await?; + } else { + hyper::server::conn::Http::new() + .http1_only(true) + .http1_keep_alive(false) + .serve_connection(stream, service) + .await?; + }; + + Ok(()) +} + +pub async fn run_servers(listeners: impl IntoIterator>, mut shutdown: SD) +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into>, + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into>, + SD: Stream + Unpin, + SD::Item: std::fmt::Display, +{ + let listeners: Vec<_> = listeners + .into_iter() + .map(|server| { + let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy); + let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls); + let service = server.service; + let listener = server.listener; + (maybe_proxy_acceptor, maybe_tls_acceptor, service, listener) + }) + .collect(); + + let mut set = tokio::task::JoinSet::new(); + + loop { + let mut accept_all: futures_util::stream::FuturesUnordered<_> = listeners + .iter() + .map( + |(maybe_proxy_acceptor, maybe_tls_acceptor, service, listener)| async move { + listener + .accept() + .await + .map_err(AcceptError::socket) + .map(|(addr, conn)| { + ( + maybe_proxy_acceptor.clone(), + maybe_tls_acceptor.clone(), + service.clone(), + addr, + conn, + ) + }) + }, + ) + .collect(); + + tokio::select! { + biased; + + // First look for the shutdown signal + res = shutdown.next() => { + let why = res.map_or_else(|| String::from("???"), |why| format!("{why}")); + tracing::info!("Received shutdown signal ({why})"); + + break; + }, + + // Poll on the JoinSet, clearing finished task + res = set.join_next(), if !set.is_empty() => { + match res { + Some(Ok(Ok(()))) => tracing::trace!("Task was successful"), + Some(Ok(Err(e))) => tracing::error!("{e}"), + Some(Err(e)) => tracing::error!("Join error: {e}"), + None => tracing::error!("Join set was polled even though it was empty"), + } + }, + + // Then look for connections to accept + res = accept_all.next(), if !accept_all.is_empty() => { + // SAFETY: We shouldn't reach this branch if the unordered future set is empty + let res = if let Some(res) = res { res } else { unreachable!() }; + + // Spawn the connection in the set, so we don't have to wait for the handshake to + // accept the next connection. This allows us to keep track of active connections + // and waiting on them for a graceful shutdown + set.spawn(async move { + let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res?; + accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await + }); + }, + }; + } + + if !set.is_empty() { + tracing::info!( + "There are {active} active connections, performing a graceful shutdown. Send the shutdown signal again to force.", + active = set.len() + ); + + loop { + tokio::select! { + biased; + + res = set.join_next() => { + match res { + Some(Ok(Ok(()))) => tracing::trace!("Task was successful"), + Some(Ok(Err(e))) => tracing::error!("{e}"), + Some(Err(e)) => tracing::error!("Join error: {e}"), + // No more tasks, going out + None => break, + } + }, + + + res = shutdown.next() => { + let why = res.map_or_else(|| String::from("???"), |why| format!("{why}")); + tracing::warn!("Received shutdown signal again ({why}), forcing shutdown ({active} active connections)", active = set.len()); + break; + }, + } + } + } + + set.shutdown().await; + tracing::info!("Shutdown complete"); +} diff --git a/crates/listener/src/shutdown.rs b/crates/listener/src/shutdown.rs new file mode 100644 index 00000000..cf7eb227 --- /dev/null +++ b/crates/listener/src/shutdown.rs @@ -0,0 +1,178 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{fmt::Display, pin::Pin, task::Poll, time::Duration}; + +use futures_util::{ready, Future, Stream}; +use tokio::{ + signal::unix::{signal, Signal, SignalKind}, + time::Sleep, +}; + +#[derive(Debug, Clone, Copy)] +pub enum ShutdownReason { + Signal(SignalKind), + Timeout, +} + +fn signal_to_str(kind: SignalKind) -> &'static str { + match kind.as_raw_value() { + libc::SIGALRM => "SIGALRM", + libc::SIGCHLD => "SIGCHLD", + libc::SIGHUP => "SIGHUP", + libc::SIGINT => "SIGINT", + libc::SIGIO => "SIGIO", + libc::SIGPIPE => "SIGPIPE", + libc::SIGQUIT => "SIGQUIT", + libc::SIGTERM => "SIGTERM", + libc::SIGUSR1 => "SIGUSR1", + libc::SIGUSR2 => "SIGUSR2", + _ => "SIG???", + } +} + +impl Display for ShutdownReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Signal(s) => signal_to_str(*s).fmt(f), + Self::Timeout => "timeout".fmt(f), + } + } +} + +#[derive(Default)] +pub enum ShutdownStreamState { + #[default] + Waiting, + + Graceful { + sleep: Option>>, + }, + + Done, +} + +impl ShutdownStreamState { + fn is_graceful(&self) -> bool { + matches!(self, Self::Graceful { .. }) + } + + fn is_done(&self) -> bool { + matches!(self, Self::Done) + } + + fn get_sleep_mut(&mut self) -> Option<&mut Pin>> { + match self { + Self::Graceful { sleep } => sleep.as_mut(), + _ => None, + } + } +} + +/// A stream which is used to drive a graceful shutdown. +/// +/// It will emit 2 items: one when a first signal is caught, the other when +/// either another signal is caught, or after a timeout. +#[derive(Default)] +pub struct ShutdownStream { + state: ShutdownStreamState, + signals: Vec<(SignalKind, Signal)>, + timeout: Option, +} + +impl ShutdownStream { + /// Create a default shutdown stream, which listens on SIGINT and SIGTERM, + /// with a 60s timeout + /// + /// # Errors + /// + /// Returns an error if signal handlers could not be installed + pub fn new() -> Result { + let ret = Self::default() + .with_timeout(Duration::from_secs(60)) + .with_signal(SignalKind::interrupt())? + .with_signal(SignalKind::terminate())?; + + Ok(ret) + } + + /// Add a signal to register + /// + /// # Errors + /// + /// Returns an error if the signal handler could not be installed + pub fn with_signal(mut self, kind: SignalKind) -> Result { + let signal = signal(kind)?; + self.signals.push((kind, signal)); + Ok(self) + } + + #[must_use] + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } +} + +impl Stream for ShutdownStream { + type Item = ShutdownReason; + + fn size_hint(&self) -> (usize, Option) { + match self.state { + ShutdownStreamState::Waiting => (2, Some(2)), + ShutdownStreamState::Graceful { .. } => (1, Some(1)), + ShutdownStreamState::Done => (0, Some(0)), + } + } + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + + if this.state.is_done() { + return Poll::Ready(None); + } + + for (kind, signal) in &mut this.signals { + match signal.poll_recv(cx) { + Poll::Ready(_) => { + // We got a signal + if this.state.is_graceful() { + // If we was gracefully shutting down, mark it as done + this.state = ShutdownStreamState::Done; + } else { + // Else start the timeout + let sleep = this + .timeout + .map(|duration| Box::pin(tokio::time::sleep(duration))); + this.state = ShutdownStreamState::Graceful { sleep }; + } + + return Poll::Ready(Some(ShutdownReason::Signal(*kind))); + } + Poll::Pending => {} + } + } + + if let Some(timeout) = this.state.get_sleep_mut() { + ready!(timeout.as_mut().poll(cx)); + this.state = ShutdownStreamState::Done; + Poll::Ready(Some(ShutdownReason::Timeout)) + } else { + Poll::Pending + } + } +} diff --git a/crates/listener/src/unix_or_tcp.rs b/crates/listener/src/unix_or_tcp.rs index e66a1ad0..3469ce53 100644 --- a/crates/listener/src/unix_or_tcp.rs +++ b/crates/listener/src/unix_or_tcp.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! A listener which can listen on either TCP sockets or on UNIX domain sockets + // TODO: Unlink the UNIX socket on drop? use std::{ @@ -19,8 +21,6 @@ use std::{ task::{Context, Poll}, }; -use futures_util::ready; -use hyper::server::accept::Accept; use tokio::{ io::{AsyncRead, AsyncWrite}, net::{TcpListener, TcpStream, UnixListener, UnixStream}, @@ -107,6 +107,7 @@ impl TryFrom for UnixOrTcpListener { type Error = std::io::Error; fn try_from(listener: std::os::unix::net::UnixListener) -> Result { + listener.set_nonblocking(true)?; Ok(Self::Unix(UnixListener::from_std(listener)?)) } } @@ -115,6 +116,7 @@ impl TryFrom for UnixOrTcpListener { type Error = std::io::Error; fn try_from(listener: std::net::TcpListener) -> Result { + listener.set_nonblocking(true)?; Ok(Self::Tcp(TcpListener::from_std(listener)?)) } } @@ -140,6 +142,24 @@ impl UnixOrTcpListener { pub const fn is_tcp(&self) -> bool { matches!(self, Self::Tcp(_)) } + + /// Accept an incoming connection + /// + /// # Errors + /// + /// Returns an error if the underlying socket couldn't accept the connection + pub async fn accept(&self) -> Result<(SocketAddr, UnixOrTcpConnection), std::io::Error> { + match self { + Self::Unix(listener) => { + let (stream, remote_addr) = listener.accept().await?; + Ok((remote_addr.into(), UnixOrTcpConnection::Unix { stream })) + } + Self::Tcp(listener) => { + let (stream, remote_addr) = listener.accept().await?; + Ok((remote_addr.into(), UnixOrTcpConnection::Tcp { stream })) + } + } + } } pin_project_lite::pin_project! { @@ -157,6 +177,12 @@ pin_project_lite::pin_project! { } } +impl From for UnixOrTcpConnection { + fn from(stream: TcpStream) -> Self { + Self::Tcp { stream } + } +} + impl UnixOrTcpConnection { /// Get the local address of the stream /// @@ -166,8 +192,8 @@ impl UnixOrTcpConnection { /// [`UnixStream`] couldn't provide the local address pub fn local_addr(&self) -> Result { match self { - Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from), - Self::Tcp { stream, .. } => stream.local_addr().map(SocketAddr::from), + Self::Unix { stream } => stream.local_addr().map(SocketAddr::from), + Self::Tcp { stream } => stream.local_addr().map(SocketAddr::from), } } @@ -179,36 +205,12 @@ impl UnixOrTcpConnection { /// [`UnixStream`] couldn't provide the remote address pub fn peer_addr(&self) -> Result { match self { - Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from), - Self::Tcp { stream, .. } => stream.peer_addr().map(SocketAddr::from), + Self::Unix { stream } => stream.peer_addr().map(SocketAddr::from), + Self::Tcp { stream } => stream.peer_addr().map(SocketAddr::from), } } } -impl Accept for UnixOrTcpListener { - type Error = std::io::Error; - type Conn = UnixOrTcpConnection; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> std::task::Poll>> { - let conn = match &*self { - Self::Unix(listener) => { - let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?; - UnixOrTcpConnection::Unix { stream } - } - - Self::Tcp(listener) => { - let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?; - UnixOrTcpConnection::Tcp { stream } - } - }; - - Poll::Ready(Some(Ok(conn))) - } -} - impl AsyncRead for UnixOrTcpConnection { fn poll_read( self: Pin<&mut Self>, @@ -234,23 +236,6 @@ impl AsyncWrite for UnixOrTcpConnection { } } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx), - UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx), - } - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.project() { - UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx), - UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx), - } - } - fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -268,4 +253,21 @@ impl AsyncWrite for UnixOrTcpConnection { UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(), } } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx), + UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx), + UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx), + } + } }