You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Rewrite the listeners crate
Now with a way better graceful shutdown! With proper handshakes!
This commit is contained in:
22
Cargo.lock
generated
22
Cargo.lock
generated
@ -2425,6 +2425,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"hyper",
|
||||
"indoc",
|
||||
"itertools",
|
||||
"listenfd",
|
||||
"mas-config",
|
||||
"mas-email",
|
||||
@ -2677,15 +2678,21 @@ dependencies = [
|
||||
name = "mas-listener"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes 1.2.1",
|
||||
"futures-util",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"libc",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.4",
|
||||
"tower",
|
||||
"tokio-test",
|
||||
"tower-http",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4872,6 +4879,19 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-test"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"bytes 1.2.1",
|
||||
"futures-core",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.6.10"
|
||||
|
@ -22,6 +22,7 @@ watchman_client = "0.8.0"
|
||||
atty = "0.2.14"
|
||||
listenfd = "1.0.0"
|
||||
rustls = "0.20.6"
|
||||
itertools = "0.10.5"
|
||||
|
||||
tracing = "0.1.36"
|
||||
tracing-appender = "0.2.2"
|
||||
|
@ -16,26 +16,19 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use futures_util::{
|
||||
future::FutureExt,
|
||||
stream::{StreamExt, TryStreamExt},
|
||||
};
|
||||
use hyper::Server;
|
||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||
use itertools::Itertools;
|
||||
use mas_config::RootConfig;
|
||||
use mas_email::Mailer;
|
||||
use mas_handlers::{AppState, MatrixHomeserver};
|
||||
use mas_http::ServerLayer;
|
||||
use mas_listener::{
|
||||
info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection},
|
||||
maybe_tls::MaybeTlsAcceptor,
|
||||
proxy_protocol::MaybeProxyAcceptor,
|
||||
};
|
||||
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::MIGRATOR;
|
||||
use mas_tasks::TaskQueue;
|
||||
use mas_templates::Templates;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::{io::AsyncRead, signal::unix::SignalKind};
|
||||
use tracing::{error, info, log::warn};
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
@ -49,32 +42,6 @@ pub(super) struct Options {
|
||||
watch: bool,
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
async fn shutdown_signal() {
|
||||
// Wait for the CTRL+C signal
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C signal handler");
|
||||
|
||||
tracing::info!("Got Ctrl+C, shutting down");
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn shutdown_signal() {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
// Wait for SIGTERM and SIGINT signals
|
||||
// This might panic but should be fine
|
||||
let mut term =
|
||||
signal(SignalKind::terminate()).expect("failed to install SIGTERM signal handler");
|
||||
let mut int = signal(SignalKind::interrupt()).expect("failed to install SIGINT signal handler");
|
||||
|
||||
tokio::select! {
|
||||
_ = term.recv() => tracing::info!("Got SIGTERM, shutting down"),
|
||||
_ = int.recv() => tracing::info!("Got SIGINT, shutting down"),
|
||||
};
|
||||
}
|
||||
|
||||
/// Watch for changes in the templates folders
|
||||
async fn watch_templates(
|
||||
client: &watchman_client::Client,
|
||||
@ -247,68 +214,75 @@ impl Options {
|
||||
policy_factory,
|
||||
});
|
||||
|
||||
let signal = shutdown_signal().shared();
|
||||
let shutdown_signal = signal.clone();
|
||||
|
||||
let mut fd_manager = listenfd::ListenFd::from_env();
|
||||
let listeners = listeners_config.into_iter().map(|listener_config| {
|
||||
// Let's first grab all the listeners in a synchronous manner
|
||||
let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds);
|
||||
|
||||
Ok((listener_config, listeners?))
|
||||
});
|
||||
let servers: Vec<Server<_>> = listeners_config
|
||||
.into_iter()
|
||||
.map(|config| {
|
||||
// Let's first grab all the listeners
|
||||
let listeners = crate::server::build_listeners(&mut fd_manager, &config.binds)?;
|
||||
|
||||
// Now that we have the listeners ready, we can do the rest concurrently
|
||||
futures_util::stream::iter(listeners)
|
||||
.try_for_each_concurrent(None, move |(config, listeners)| {
|
||||
let signal = signal.clone();
|
||||
// Load the TLS config
|
||||
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
||||
let tls_config = crate::server::build_tls_server_config(tls_config)?;
|
||||
Some(Arc::new(tls_config))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// and build the router
|
||||
let router = crate::server::build_router(&state, &config.resources)
|
||||
.layer(ServerLayer::new(config.name.clone()));
|
||||
|
||||
// Display some informations about where we'll be serving connections
|
||||
let is_tls = config.tls.is_some();
|
||||
let addresses: Vec<String> = listeners.iter().map(|listener| {
|
||||
let addresses: Vec<String> = listeners
|
||||
.iter()
|
||||
.map(|listener| {
|
||||
let addr = listener.local_addr();
|
||||
let proto = if is_tls { "https" } else { "http" };
|
||||
if let Ok(addr) = addr {
|
||||
format!("{proto}://{addr:?}")
|
||||
} else {
|
||||
warn!("Could not get local address for listener, something might be wrong!");
|
||||
warn!(
|
||||
"Could not get local address for listener, something might be wrong!"
|
||||
);
|
||||
format!("{proto}://???")
|
||||
}
|
||||
}).collect();
|
||||
|
||||
let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" };
|
||||
|
||||
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
|
||||
|
||||
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
|
||||
let make_service = IntoMakeServiceWithConnection::new(router);
|
||||
|
||||
async move {
|
||||
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
||||
let tls_config = crate::server::build_tls_server_config(tls_config).await?;
|
||||
Some(Arc::new(tls_config))
|
||||
} else { None };
|
||||
|
||||
futures_util::stream::iter(listeners)
|
||||
.map(Ok)
|
||||
.try_for_each_concurrent(None, move |listener| {
|
||||
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
|
||||
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
|
||||
let listener = ConnectionInfoAcceptor::new(listener);
|
||||
|
||||
Server::builder(listener)
|
||||
.serve(make_service.clone())
|
||||
.with_graceful_shutdown(signal.clone())
|
||||
})
|
||||
.await?;
|
||||
.collect();
|
||||
|
||||
anyhow::Ok(())
|
||||
let additional = if config.proxy_protocol {
|
||||
"(with Proxy Protocol)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
info!(
|
||||
"Listening on {addresses:?} with resources {resources:?} {additional}",
|
||||
resources = &config.resources
|
||||
);
|
||||
|
||||
anyhow::Ok(listeners.into_iter().map(move |listener| {
|
||||
let mut server = Server::new(listener, router.clone());
|
||||
if let Some(tls_config) = &tls_config {
|
||||
server = server.with_tls(tls_config.clone());
|
||||
}
|
||||
if config.proxy_protocol {
|
||||
server = server.with_proxy();
|
||||
}
|
||||
server
|
||||
}))
|
||||
})
|
||||
.await?;
|
||||
.flatten_ok()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// This ensures we're running, even if no listener are setup
|
||||
// This is useful for only running the task runner
|
||||
shutdown_signal.await;
|
||||
let shutdown = ShutdownStream::default()
|
||||
.with_timeout(Duration::from_secs(60))
|
||||
.with_signal(SignalKind::terminate())?
|
||||
.with_signal(SignalKind::interrupt())?;
|
||||
|
||||
mas_listener::server::run_servers(servers, shutdown).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -23,10 +23,9 @@ use axum::{body::HttpBody, Extension, Router};
|
||||
use listenfd::ListenFd;
|
||||
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
|
||||
use mas_handlers::AppState;
|
||||
use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener};
|
||||
use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo};
|
||||
use mas_router::Route;
|
||||
use rustls::ServerConfig;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
|
||||
@ -64,12 +63,9 @@ where
|
||||
// TODO: do a better handler here
|
||||
mas_config::HttpResource::ConnectionInfo => router.route(
|
||||
"/connection-info",
|
||||
axum::routing::get(
|
||||
|connection: Extension<Arc<OnceCell<Connection>>>| async move {
|
||||
let connection = connection.get().unwrap();
|
||||
axum::routing::get(|connection: Extension<ConnectionInfo>| async move {
|
||||
format!("{connection:?}")
|
||||
},
|
||||
),
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
@ -77,10 +73,8 @@ where
|
||||
router
|
||||
}
|
||||
|
||||
pub async fn build_tls_server_config(
|
||||
config: &HttpTlsConfig,
|
||||
) -> Result<ServerConfig, anyhow::Error> {
|
||||
let (key, chain) = config.load().await?;
|
||||
pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result<ServerConfig, anyhow::Error> {
|
||||
let (key, chain) = config.load()?;
|
||||
let key = rustls::PrivateKey(key);
|
||||
let chain = chain.into_iter().map(rustls::Certificate).collect();
|
||||
|
||||
|
@ -163,11 +163,11 @@ impl TlsConfig {
|
||||
/// - a password was provided but the key was not encrypted
|
||||
/// - decoding the certificate chain as PEM
|
||||
/// - the certificate chain is empty
|
||||
pub async fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
|
||||
pub fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
|
||||
let password = match &self.password {
|
||||
Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())),
|
||||
Some(PasswordOrFile::PasswordFile(path)) => {
|
||||
Some(Cow::Owned(tokio::fs::read_to_string(path).await?))
|
||||
Some(Cow::Owned(std::fs::read_to_string(path)?))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
@ -185,7 +185,7 @@ impl TlsConfig {
|
||||
KeyOrFile::KeyFile(path) => {
|
||||
// When reading from disk, it might be either PEM or DER. `PrivateKey::load*`
|
||||
// will try both.
|
||||
let key = tokio::fs::read(path).await?;
|
||||
let key = std::fs::read(path)?;
|
||||
if let Some(password) = password {
|
||||
PrivateKey::load_encrypted(&key, password.as_bytes())?
|
||||
} else {
|
||||
@ -202,9 +202,7 @@ impl TlsConfig {
|
||||
|
||||
let certificate_chain_pem = match &self.certificate {
|
||||
CertificateOrFile::Certificate(pem) => Cow::Borrowed(pem.as_str()),
|
||||
CertificateOrFile::CertificateFile(path) => {
|
||||
Cow::Owned(tokio::fs::read_to_string(path).await?)
|
||||
}
|
||||
CertificateOrFile::CertificateFile(path) => Cow::Owned(std::fs::read_to_string(path)?),
|
||||
};
|
||||
|
||||
let mut certificate_chain_reader = Cursor::new(certificate_chain_pem.as_bytes());
|
||||
|
@ -6,12 +6,25 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
bytes = "1.2.1"
|
||||
futures-util = "0.3.24"
|
||||
hyper = { version = "0.14.20", features = ["server"] }
|
||||
http-body = "0.4.2"
|
||||
hyper = { version = "0.14.20", features = ["server", "http1", "http2"] }
|
||||
pin-project-lite = "0.2.9"
|
||||
thiserror = "1.0.37"
|
||||
tokio = { version = "1.21.2", features = ["net"] }
|
||||
tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time"] }
|
||||
tokio-rustls = "0.23.4"
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.3.4", features = ["add-extension"] }
|
||||
tracing = "0.1.36"
|
||||
tower-service = "0.3.2"
|
||||
tracing = "0.1.37"
|
||||
libc = "0.2.135"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4.2"
|
||||
anyhow = "1.0.65"
|
||||
tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time", "rt-multi-thread"] }
|
||||
tracing-subscriber = "0.3.16"
|
||||
|
||||
[[example]]
|
||||
name = "demo"
|
||||
path = "examples/demo/main.rs"
|
||||
|
49
crates/listener/examples/demo/main.rs
Normal file
49
crates/listener/examples/demo/main.rs
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
net::{Ipv4Addr, TcpListener},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use hyper::{service::service_fn, Request, Response};
|
||||
use tokio::signal::unix::SignalKind;
|
||||
use tokio_streams_util::{server::Server, shutdown::ShutdownStream, ConnectionInfo};
|
||||
|
||||
async fn handler(req: Request<hyper::Body>) -> Result<Response<String>, Infallible> {
|
||||
tracing::info!("Handling request");
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
let info = req.extensions().get::<ConnectionInfo>().unwrap();
|
||||
let body = format!("{info:?}");
|
||||
Ok(Response::new(body))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 3000))?;
|
||||
let service = service_fn(handler);
|
||||
let server = Server::try_new(listener, service)?;
|
||||
|
||||
tracing::info!("Listening on 127.0.0.1:3000");
|
||||
|
||||
let shutdown = ShutdownStream::default()
|
||||
.with_signal(SignalKind::interrupt())?
|
||||
.with_signal(SignalKind::terminate())?;
|
||||
server.run(shutdown).await;
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,323 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
future::Ready,
|
||||
net::SocketAddr,
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use hyper::server::accept::Accept;
|
||||
use thiserror::Error;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
sync::OnceCell,
|
||||
};
|
||||
use tower::Service;
|
||||
use tower_http::add_extension::AddExtension;
|
||||
|
||||
use crate::{
|
||||
maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo, TlsStreamInfoError},
|
||||
proxy_protocol::{
|
||||
MaybeProxyAcceptor, MaybeProxyStream, ProxyHandshakeNotDone, ProxyProtocolV1Info,
|
||||
},
|
||||
unix_or_tcp::{UnixOrTcpConnection, UnixOrTcpListener},
|
||||
};
|
||||
|
||||
// TODO: this is a mess, clean that up
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum FromStreamError {
|
||||
#[error(transparent)]
|
||||
Proxy(#[from] ProxyHandshakeNotDone),
|
||||
|
||||
#[error(transparent)]
|
||||
Tls(#[from] TlsStreamInfoError),
|
||||
|
||||
#[error("Could not grab a reference to the underlying stream")]
|
||||
GetRef,
|
||||
|
||||
#[error("Could not get address info from underlying stream")]
|
||||
IoError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Connection {
|
||||
proxy: Option<ProxyProtocolV1Info>,
|
||||
tls: Option<TlsStreamInfo>,
|
||||
|
||||
// We're not saving the UNIX domain socket address here because it can't be cloned, which is
|
||||
// required for injecting the connection information as an extension
|
||||
local_tcp_addr: Option<SocketAddr>,
|
||||
peer_tcp_addr: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum GrabAddressError {
|
||||
#[error("Proxy protocol was initiated with an unknown protocol")]
|
||||
ProxyUnknown,
|
||||
|
||||
#[error("Proxy protocol was initiated with UDP")]
|
||||
ProxyUdp,
|
||||
|
||||
#[error("Underlying listener is a UNIX socket")]
|
||||
UnixListener,
|
||||
}
|
||||
|
||||
impl MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>> {
|
||||
pub fn can_have_peer_address(&self) -> bool {
|
||||
self.is_proxied() || self.is_tcp()
|
||||
}
|
||||
}
|
||||
|
||||
impl MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>> {
|
||||
/// Get informations about this connection
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the proxy protocol or the TLS handhakes are not done
|
||||
/// yet
|
||||
pub fn connection_info(&self) -> Result<Connection, FromStreamError> {
|
||||
Connection::from_stream(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Get informations about this connection
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the proxy protocol or the TLS handhakes are not done
|
||||
/// yet
|
||||
pub fn from_stream(
|
||||
stream: &MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
|
||||
) -> Result<Self, FromStreamError> {
|
||||
let proxy = stream.proxy_info()?.cloned();
|
||||
let tls = stream.tls_info()?;
|
||||
let original = stream.get_ref().ok_or(FromStreamError::GetRef)?;
|
||||
let local_tcp_addr = original.local_addr()?.into_net();
|
||||
let peer_tcp_addr = original.peer_addr()?.into_net();
|
||||
|
||||
Ok(Self {
|
||||
proxy,
|
||||
tls,
|
||||
local_tcp_addr,
|
||||
peer_tcp_addr,
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_proxied(&self) -> bool {
|
||||
self.proxy.is_some()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_tls(&self) -> bool {
|
||||
self.tls.is_some()
|
||||
}
|
||||
|
||||
/// Get the outmost peer address, either from the TCP listener or from the
|
||||
/// proxy protocol infos.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the info from the proxy protocol was not for a TCP
|
||||
/// connection, or if the proxy protocol is not being used, the underlying
|
||||
/// listener was a UNIX domain socket
|
||||
pub fn peer_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
|
||||
if let Some(proxy) = self.proxy.as_ref() {
|
||||
if proxy.is_udp() {
|
||||
return Err(GrabAddressError::ProxyUdp);
|
||||
}
|
||||
|
||||
proxy.source().ok_or(GrabAddressError::ProxyUnknown)
|
||||
} else {
|
||||
self.peer_tcp_addr
|
||||
.as_ref()
|
||||
.ok_or(GrabAddressError::UnixListener)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the outmost local address, either from the TCP listener or from the
|
||||
/// proxy protocol infos.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the info from the proxy protocol was not for a TCP
|
||||
/// connection, or if the proxy protocol is not being used, the underlying
|
||||
/// listener was a UNIX domain socket
|
||||
pub fn local_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
|
||||
if let Some(proxy) = self.proxy.as_ref() {
|
||||
if proxy.is_udp() {
|
||||
return Err(GrabAddressError::ProxyUdp);
|
||||
}
|
||||
|
||||
proxy.destination().ok_or(GrabAddressError::ProxyUnknown)
|
||||
} else {
|
||||
self.local_tcp_addr
|
||||
.as_ref()
|
||||
.ok_or(GrabAddressError::UnixListener)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct ConnectionInfoAcceptor {
|
||||
#[pin]
|
||||
acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionInfoAcceptor {
|
||||
pub const fn new(acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>) -> Self {
|
||||
Self { acceptor }
|
||||
}
|
||||
}
|
||||
|
||||
impl Accept for ConnectionInfoAcceptor {
|
||||
type Conn = ConnectionInfoStream;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let proj = self.project();
|
||||
let ret = match futures_util::ready!(proj.acceptor.poll_accept(cx)) {
|
||||
Some(Ok(conn)) => Some(Ok(ConnectionInfoStream::new(conn))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
};
|
||||
Poll::Ready(ret)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct ConnectionInfoStream {
|
||||
connection: Arc<OnceCell<Connection>>,
|
||||
#[pin]
|
||||
stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionInfoStream {
|
||||
pub fn new(stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>) -> Self {
|
||||
Self {
|
||||
connection: Arc::new(OnceCell::const_new()),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ConnectionInfoStream {
|
||||
type Target = MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for ConnectionInfoStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
futures_util::ready!(Pin::new(&mut this.stream).poll_read(cx, buf))?;
|
||||
|
||||
if !this.stream.is_tls_handshaking()
|
||||
&& !this.stream.is_proxy_handshaking()
|
||||
&& !this.connection.initialized()
|
||||
{
|
||||
this.connection
|
||||
.set(this.stream.connection_info().unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for ConnectionInfoStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.stream.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IntoMakeServiceWithConnection<S> {
|
||||
svc: S,
|
||||
}
|
||||
|
||||
impl<S> IntoMakeServiceWithConnection<S> {
|
||||
pub const fn new(svc: S) -> Self {
|
||||
Self { svc }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Service<&ConnectionInfoStream> for IntoMakeServiceWithConnection<S>
|
||||
where
|
||||
S: Clone,
|
||||
{
|
||||
type Response = AddExtension<S, Arc<OnceCell<Connection>>>;
|
||||
type Error = FromStreamError;
|
||||
type Future = Ready<Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, target: &ConnectionInfoStream) -> Self::Future {
|
||||
std::future::ready(Ok(AddExtension::new(
|
||||
self.svc.clone(),
|
||||
target.connection.clone(),
|
||||
)))
|
||||
}
|
||||
}
|
@ -22,7 +22,41 @@
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
pub mod info;
|
||||
use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info};
|
||||
|
||||
pub mod maybe_tls;
|
||||
pub mod proxy_protocol;
|
||||
pub mod rewind;
|
||||
pub mod server;
|
||||
pub mod shutdown;
|
||||
pub mod unix_or_tcp;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectionInfo {
|
||||
tls: Option<TlsStreamInfo>,
|
||||
proxy: Option<ProxyProtocolV1Info>,
|
||||
net_peer_addr: Option<std::net::SocketAddr>,
|
||||
}
|
||||
|
||||
impl ConnectionInfo {
|
||||
/// Returns informations about the TLS connection. Returns [`None`] if the
|
||||
/// connection was not TLS.
|
||||
#[must_use]
|
||||
pub fn get_tls_ref(&self) -> Option<&TlsStreamInfo> {
|
||||
self.tls.as_ref()
|
||||
}
|
||||
|
||||
/// Returns informations about the proxy protocol connection. Returns
|
||||
/// [`None`] if the connection was not using the proxy protocol.
|
||||
#[must_use]
|
||||
pub fn get_proxy_ref(&self) -> Option<&ProxyProtocolV1Info> {
|
||||
self.proxy.as_ref()
|
||||
}
|
||||
|
||||
/// Returns the remote peer address. Returns [`None`] if the connection was
|
||||
/// established via a UNIX domain socket.
|
||||
#[must_use]
|
||||
pub fn get_peer_addr(&self) -> Option<std::net::SocketAddr> {
|
||||
self.net_peer_addr
|
||||
}
|
||||
}
|
||||
|
@ -13,18 +13,15 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::{ready, Future};
|
||||
use hyper::server::accept::Accept;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::rustls::{
|
||||
Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite,
|
||||
use tokio_rustls::{
|
||||
rustls::{Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite},
|
||||
TlsAcceptor,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -33,105 +30,78 @@ pub struct TlsStreamInfo {
|
||||
pub protocol_version: ProtocolVersion,
|
||||
pub negotiated_cipher_suite: SupportedCipherSuite,
|
||||
pub sni_hostname: Option<String>,
|
||||
pub apln_protocol: Option<Vec<u8>>,
|
||||
pub alpn_protocol: Option<Vec<u8>>,
|
||||
pub peer_certificates: Option<Vec<Certificate>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum TlsStreamInfoError {
|
||||
#[error("TLS handshake is not done yet")]
|
||||
HandshakingNotDone,
|
||||
|
||||
#[error("Some fields were not available in the TLS connection")]
|
||||
FieldsNotAvailable,
|
||||
impl TlsStreamInfo {
|
||||
#[must_use]
|
||||
pub fn is_alpn_h2(&self) -> bool {
|
||||
matches!(self.alpn_protocol.as_deref(), Some(b"h2"))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[project = MaybeTlsStreamProj]
|
||||
pub enum MaybeTlsStream<T> {
|
||||
Handshaking(tokio_rustls::Accept<T>),
|
||||
Streaming(tokio_rustls::server::TlsStream<T>),
|
||||
Insecure(T),
|
||||
Secure {
|
||||
#[pin]
|
||||
stream: tokio_rustls::server::TlsStream<T>
|
||||
},
|
||||
Insecure {
|
||||
#[pin]
|
||||
stream: T,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MaybeTlsStream<T> {
|
||||
pub fn new(stream: T, config: Option<Arc<ServerConfig>>) -> Self
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
if let Some(config) = config {
|
||||
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
|
||||
MaybeTlsStream::Handshaking(accept)
|
||||
} else {
|
||||
MaybeTlsStream::Insecure(stream)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying IO stream
|
||||
///
|
||||
/// Returns [`None`] if the stream closed before the TLS handshake finished.
|
||||
/// It is guaranteed to return [`Some`] value after the handshake finished,
|
||||
/// or if it is a non-TLS connection.
|
||||
pub fn get_ref(&self) -> Option<&T> {
|
||||
pub fn get_ref(&self) -> &T {
|
||||
match self {
|
||||
Self::Handshaking(accept) => accept.get_ref(),
|
||||
Self::Streaming(stream) => {
|
||||
let (inner, _) = stream.get_ref();
|
||||
Some(inner)
|
||||
}
|
||||
Self::Insecure(inner) => Some(inner),
|
||||
Self::Secure { stream } => stream.get_ref().0,
|
||||
Self::Insecure { stream } => stream,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a ref to the [`ServerConnection`] of the establish TLS stream.
|
||||
///
|
||||
/// Returns [`None`] if the connection is still handshaking and for non-TLS
|
||||
/// connections.
|
||||
/// Returns [`None`] for non-TLS connections.
|
||||
pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
|
||||
match self {
|
||||
Self::Streaming(stream) => {
|
||||
let (_, conn) = stream.get_ref();
|
||||
Some(conn)
|
||||
}
|
||||
Self::Handshaking(_) | Self::Insecure(_) => None,
|
||||
Self::Secure { stream } => Some(stream.get_ref().1),
|
||||
Self::Insecure { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gather informations about the TLS connection. Returns `None` if the
|
||||
/// stream is not a TLS stream.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the TLS handshake is not yet done
|
||||
pub fn tls_info(&self) -> Result<Option<TlsStreamInfo>, TlsStreamInfoError> {
|
||||
let conn = match self {
|
||||
Self::Streaming(stream) => stream.get_ref().1,
|
||||
Self::Handshaking(_) => return Err(TlsStreamInfoError::HandshakingNotDone),
|
||||
Self::Insecure(_) => return Ok(None),
|
||||
};
|
||||
pub fn tls_info(&self) -> Option<TlsStreamInfo> {
|
||||
let conn = self.get_tls_connection()?;
|
||||
|
||||
// NOTE: we're getting the protocol version and cipher suite *after* the
|
||||
// handshake, so this should never lead to an error
|
||||
// SAFETY: we're getting the protocol version and cipher suite *after* the
|
||||
// handshake, so this should never lead to a panic
|
||||
let protocol_version = conn
|
||||
.protocol_version()
|
||||
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
|
||||
.expect("TLS handshake is not done yet");
|
||||
let negotiated_cipher_suite = conn
|
||||
.negotiated_cipher_suite()
|
||||
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
|
||||
.expect("TLS handshake is not done yet");
|
||||
|
||||
let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned);
|
||||
let apln_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
|
||||
let alpn_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
|
||||
let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned);
|
||||
Ok(Some(TlsStreamInfo {
|
||||
Some(TlsStreamInfo {
|
||||
protocol_version,
|
||||
negotiated_cipher_suite,
|
||||
sni_hostname,
|
||||
apln_protocol,
|
||||
alpn_protocol,
|
||||
peer_certificates,
|
||||
}))
|
||||
}
|
||||
|
||||
pub const fn is_tls_handshaking(&self) -> bool {
|
||||
matches!(self, Self::Handshaking(_))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -144,20 +114,9 @@ where
|
||||
cx: &mut Context,
|
||||
buf: &mut ReadBuf,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let pin = self.get_mut();
|
||||
match pin {
|
||||
MaybeTlsStream::Handshaking(ref mut accept) => {
|
||||
match ready!(Pin::new(accept).poll(cx)) {
|
||||
Ok(mut stream) => {
|
||||
let result = Pin::new(&mut stream).poll_read(cx, buf);
|
||||
*pin = MaybeTlsStream::Streaming(stream);
|
||||
result
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
match self.project() {
|
||||
MaybeTlsStreamProj::Secure { stream } => stream.poll_read(cx, buf),
|
||||
MaybeTlsStreamProj::Insecure { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -171,104 +130,89 @@ where
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
let pin = self.get_mut();
|
||||
match pin {
|
||||
MaybeTlsStream::Handshaking(ref mut accept) => {
|
||||
match ready!(Pin::new(accept).poll(cx)) {
|
||||
Ok(mut stream) => {
|
||||
let result = Pin::new(&mut stream).poll_write(cx, buf);
|
||||
*pin = MaybeTlsStream::Streaming(stream);
|
||||
result
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(err)),
|
||||
match self.project() {
|
||||
MaybeTlsStreamProj::Secure { stream } => stream.poll_write(cx, buf),
|
||||
MaybeTlsStreamProj::Insecure { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
MaybeTlsStream::Insecure(ref mut fallback) => Pin::new(fallback).poll_write(cx, buf),
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeTlsStreamProj::Secure { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
MaybeTlsStreamProj::Insecure { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match self {
|
||||
Self::Secure { stream } => stream.is_write_vectored(),
|
||||
Self::Insecure { stream } => stream.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
||||
match self.project() {
|
||||
MaybeTlsStreamProj::Secure { stream } => stream.poll_flush(cx),
|
||||
MaybeTlsStreamProj::Insecure { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
match self.project() {
|
||||
MaybeTlsStreamProj::Secure { stream } => stream.poll_shutdown(cx),
|
||||
MaybeTlsStreamProj::Insecure { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MaybeTlsAcceptor<T> {
|
||||
#[derive(Clone)]
|
||||
pub struct MaybeTlsAcceptor {
|
||||
tls_config: Option<Arc<ServerConfig>>,
|
||||
incoming: T,
|
||||
}
|
||||
|
||||
impl<T> MaybeTlsAcceptor<T> {
|
||||
pub fn new(tls_config: Option<Arc<ServerConfig>>, incoming: T) -> Self {
|
||||
Self {
|
||||
tls_config,
|
||||
incoming,
|
||||
}
|
||||
impl MaybeTlsAcceptor {
|
||||
#[must_use]
|
||||
pub fn new(tls_config: Option<Arc<ServerConfig>>) -> Self {
|
||||
Self { tls_config }
|
||||
}
|
||||
|
||||
pub fn new_secure(tls_config: Arc<ServerConfig>, incoming: T) -> Self {
|
||||
#[must_use]
|
||||
pub fn new_secure(tls_config: Arc<ServerConfig>) -> Self {
|
||||
Self {
|
||||
tls_config: Some(tls_config),
|
||||
incoming,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_insecure(incoming: T) -> Self {
|
||||
Self {
|
||||
tls_config: None,
|
||||
incoming,
|
||||
}
|
||||
#[must_use]
|
||||
pub fn new_insecure() -> Self {
|
||||
Self { tls_config: None }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_secure(&self) -> bool {
|
||||
self.tls_config.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Accept for MaybeTlsAcceptor<T>
|
||||
/// Accept a connection and do the TLS handshake
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the TLS handshake failed
|
||||
pub async fn accept<T>(&self, stream: T) -> Result<MaybeTlsStream<T>, std::io::Error>
|
||||
where
|
||||
T: Accept + Unpin,
|
||||
T::Conn: AsyncRead + AsyncWrite + Unpin,
|
||||
T::Error: Into<std::io::Error>,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Conn = MaybeTlsStream<T::Conn>;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let pin = self.get_mut();
|
||||
|
||||
let ret = match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
|
||||
Some(Ok(sock)) => {
|
||||
let config = pin.tls_config.clone();
|
||||
Some(Ok(MaybeTlsStream::new(sock, config)))
|
||||
match &self.tls_config {
|
||||
Some(config) => {
|
||||
let acceptor = TlsAcceptor::from(config.clone());
|
||||
let stream = acceptor.accept(stream).await?;
|
||||
Ok(MaybeTlsStream::Secure { stream })
|
||||
}
|
||||
|
||||
Some(Err(e)) => Some(Err(e.into())),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Poll::Ready(ret)
|
||||
None => Ok(MaybeTlsStream::Insecure { stream }),
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for MaybeTlsAcceptor<T> {
|
||||
type Target = T;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.incoming
|
||||
}
|
||||
}
|
||||
|
@ -12,41 +12,57 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
use bytes::BytesMut;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
use super::ProxyStream;
|
||||
use super::ProxyProtocolV1Info;
|
||||
use crate::rewind::Rewind;
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct ProxyAcceptor<A> {
|
||||
#[pin]
|
||||
inner: A,
|
||||
}
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ProxyAcceptor {
|
||||
_private: (),
|
||||
}
|
||||
|
||||
impl<A> ProxyAcceptor<A> {
|
||||
pub const fn new(inner: A) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub enum ProxyAcceptError {
|
||||
Parse(#[from] super::v1::ParseError),
|
||||
Read(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl<A> Accept for ProxyAcceptor<A>
|
||||
impl ProxyAcceptor {
|
||||
#[must_use]
|
||||
pub const fn new() -> Self {
|
||||
Self { _private: () }
|
||||
}
|
||||
|
||||
/// Accept a proxy-protocol stream
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error on read error on the underlying stream, or when the
|
||||
/// proxy protocol preamble couldn't be parsed
|
||||
pub async fn accept<T>(
|
||||
&self,
|
||||
mut stream: T,
|
||||
) -> Result<(ProxyProtocolV1Info, Rewind<T>), ProxyAcceptError>
|
||||
where
|
||||
A: Accept,
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
type Conn = ProxyStream<A::Conn>;
|
||||
type Error = A::Error;
|
||||
let mut buf = BytesMut::new();
|
||||
let info = loop {
|
||||
stream.read_buf(&mut buf).await?;
|
||||
|
||||
fn poll_accept(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let res = match ready!(self.project().inner.poll_accept(cx)) {
|
||||
Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
match ProxyProtocolV1Info::parse(&mut buf) {
|
||||
Ok(info) => break info,
|
||||
Err(e) if e.not_enough_bytes() => {}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
};
|
||||
|
||||
std::task::Poll::Ready(res)
|
||||
let stream = Rewind::new_buffered(stream, buf.into());
|
||||
|
||||
Ok((info, stream))
|
||||
}
|
||||
}
|
||||
|
@ -12,179 +12,66 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::io::AsyncRead;
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use super::{acceptor::ProxyAcceptError, ProxyAcceptor, ProxyProtocolV1Info};
|
||||
use crate::rewind::Rewind;
|
||||
|
||||
use super::{stream::HandshakeNotDone, ProxyProtocolV1Info, ProxyStream};
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct MaybeProxyAcceptor<A> {
|
||||
proxied: bool,
|
||||
|
||||
#[pin]
|
||||
inner: A,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
pub struct MaybeProxyAcceptor {
|
||||
acceptor: Option<ProxyAcceptor>,
|
||||
}
|
||||
|
||||
impl<A> MaybeProxyAcceptor<A> {
|
||||
impl MaybeProxyAcceptor {
|
||||
#[must_use]
|
||||
pub const fn new(inner: A, proxied: bool) -> Self {
|
||||
Self { proxied, inner }
|
||||
}
|
||||
|
||||
pub const fn is_proxied(&self) -> bool {
|
||||
self.proxied
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Deref for MaybeProxyAcceptor<A> {
|
||||
type Target = A;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Accept for MaybeProxyAcceptor<A>
|
||||
where
|
||||
A: Accept,
|
||||
{
|
||||
type Conn = MaybeProxyStream<A::Conn>;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let proj = self.project();
|
||||
let res = match ready!(proj.inner.poll_accept(cx)) {
|
||||
Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
pub const fn new(proxied: bool) -> Self {
|
||||
let acceptor = if proxied {
|
||||
Some(ProxyAcceptor::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
std::task::Poll::Ready(res)
|
||||
Self { acceptor }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn new_proxied(acceptor: ProxyAcceptor) -> Self {
|
||||
Self {
|
||||
acceptor: Some(acceptor),
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[project = MaybeProxyStreamProj]
|
||||
pub enum MaybeProxyStream<S> {
|
||||
Proxied { #[pin] stream: ProxyStream<S> },
|
||||
NotProxied { #[pin] stream: S },
|
||||
}
|
||||
#[must_use]
|
||||
pub const fn new_unproxied() -> Self {
|
||||
Self { acceptor: None }
|
||||
}
|
||||
|
||||
impl<S> MaybeProxyStream<S> {
|
||||
pub const fn new(stream: S, proxied: bool) -> Self {
|
||||
if proxied {
|
||||
Self::Proxied {
|
||||
stream: ProxyStream::new(stream),
|
||||
}
|
||||
} else {
|
||||
Self::NotProxied { stream }
|
||||
}
|
||||
#[must_use]
|
||||
pub const fn is_proxied(&self) -> bool {
|
||||
self.acceptor.is_some()
|
||||
}
|
||||
|
||||
/// Get informations from the proxied connection, if it was procied
|
||||
/// Accept a connection and do the proxy protocol handshake
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the stream did not complete the handshake yet
|
||||
pub fn proxy_info(&self) -> Result<Option<&ProxyProtocolV1Info>, HandshakeNotDone> {
|
||||
match self {
|
||||
Self::Proxied { stream } => Ok(Some(stream.proxy_info()?)),
|
||||
Self::NotProxied { .. } => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_proxy_handshaking(&self) -> bool {
|
||||
match self {
|
||||
Self::Proxied { stream } => stream.is_handshaking(),
|
||||
Self::NotProxied { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Deref for MaybeProxyStream<S> {
|
||||
type Target = S;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::Proxied { stream } => &**stream,
|
||||
Self::NotProxied { stream } => stream,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for MaybeProxyStream<S>
|
||||
/// Returns an error if the proxy protocol handshake failed
|
||||
pub async fn accept<T>(
|
||||
&self,
|
||||
stream: T,
|
||||
) -> Result<(Option<ProxyProtocolV1Info>, Rewind<T>), ProxyAcceptError>
|
||||
where
|
||||
S: AsyncRead,
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf),
|
||||
match &self.acceptor {
|
||||
Some(acceptor) => {
|
||||
let (info, stream) = acceptor.accept(stream).await?;
|
||||
Ok((Some(info), stream))
|
||||
}
|
||||
None => {
|
||||
let stream = Rewind::new(stream);
|
||||
Ok((None, stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for MaybeProxyStream<S>
|
||||
where
|
||||
S: AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match self {
|
||||
MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(),
|
||||
MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,12 +14,10 @@
|
||||
|
||||
mod acceptor;
|
||||
mod maybe;
|
||||
mod stream;
|
||||
mod v1;
|
||||
|
||||
pub use self::{
|
||||
acceptor::ProxyAcceptor,
|
||||
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
|
||||
stream::{HandshakeNotDone as ProxyHandshakeNotDone, ProxyStream},
|
||||
acceptor::{ProxyAcceptError, ProxyAcceptor},
|
||||
maybe::MaybeProxyAcceptor,
|
||||
v1::ProxyProtocolV1Info,
|
||||
};
|
||||
|
@ -1,169 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use futures_util::ready;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use super::ProxyProtocolV1Info;
|
||||
|
||||
// Max theorical size we need is 108 for proxy protocol v1
|
||||
const BUF_SIZE: usize = 256;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ProxyStreamState {
|
||||
Handshaking {
|
||||
buffer: [u8; BUF_SIZE],
|
||||
index: usize,
|
||||
},
|
||||
Established(ProxyProtocolV1Info),
|
||||
}
|
||||
|
||||
impl ProxyStreamState {
|
||||
pub const fn is_handshaking(&self) -> bool {
|
||||
matches!(self, Self::Handshaking { .. })
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyStream<S> {
|
||||
state: ProxyStreamState,
|
||||
|
||||
#[pin]
|
||||
inner: S,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyStream<S> {
|
||||
pub const fn new(inner: S) -> Self {
|
||||
Self {
|
||||
state: ProxyStreamState::Handshaking {
|
||||
buffer: [0; BUF_SIZE],
|
||||
index: 0,
|
||||
},
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, Clone, Copy)]
|
||||
#[error("Proxy protocol handshake is not complete")]
|
||||
pub struct HandshakeNotDone;
|
||||
|
||||
impl<S> Deref for ProxyStream<S> {
|
||||
type Target = S;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyStream<S> {
|
||||
/// Get informations from the proxied connection
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the stream did not complete the handshake yet
|
||||
pub fn proxy_info(&self) -> Result<&ProxyProtocolV1Info, HandshakeNotDone> {
|
||||
match &self.state {
|
||||
ProxyStreamState::Handshaking { .. } => Err(HandshakeNotDone),
|
||||
ProxyStreamState::Established(info) => Ok(info),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the proxy protocol is still handshaking
|
||||
pub const fn is_handshaking(&self) -> bool {
|
||||
self.state.is_handshaking()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for ProxyStream<S>
|
||||
where
|
||||
S: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let proj = self.project();
|
||||
match proj.state {
|
||||
ProxyStreamState::Handshaking { buffer, index } => {
|
||||
let mut buffer = ReadBuf::new(&mut buffer[..]);
|
||||
buffer.advance(*index);
|
||||
ready!(proj.inner.poll_read(cx, &mut buffer))?;
|
||||
let filled = buffer.filled();
|
||||
*index = filled.len();
|
||||
|
||||
match ProxyProtocolV1Info::parse(filled) {
|
||||
Ok((info, rest)) => {
|
||||
if buf.remaining() < rest.len() {
|
||||
// This is highly unlikely, but is better than panicking later.
|
||||
// If it ever happens, we could introduce a "buffer draining" state
|
||||
// which drains the inner buffer repeatedly until it's empty
|
||||
return std::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"underlying buffer is too small",
|
||||
)));
|
||||
}
|
||||
buf.put_slice(rest);
|
||||
*proj.state = ProxyStreamState::Established(info);
|
||||
std::task::Poll::Ready(Ok(()))
|
||||
}
|
||||
Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())),
|
||||
Err(e) => std::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e,
|
||||
))),
|
||||
}
|
||||
}
|
||||
ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for ProxyStream<S>
|
||||
where
|
||||
S: AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let proj = self.project();
|
||||
match proj.state {
|
||||
// Hold off writes until the handshake is done
|
||||
// XXX: is this the right way to do it?
|
||||
ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending,
|
||||
ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
}
|
@ -18,6 +18,7 @@ use std::{
|
||||
str::Utf8Error,
|
||||
};
|
||||
|
||||
use bytes::Buf;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -35,7 +36,7 @@ pub enum ProxyProtocolV1Info {
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Invalid proxy protocol header")]
|
||||
pub(super) enum ParseError {
|
||||
pub enum ParseError {
|
||||
#[error("Not enough bytes provided")]
|
||||
NotEnoughBytes,
|
||||
NoCrLf,
|
||||
@ -60,17 +61,21 @@ impl ParseError {
|
||||
|
||||
impl ProxyProtocolV1Info {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> {
|
||||
pub(super) fn parse<B>(mut buf: B) -> Result<Self, ParseError>
|
||||
where
|
||||
B: Buf + AsRef<[u8]>,
|
||||
{
|
||||
use ParseError as E;
|
||||
// First, check if we *possibly* have enough bytes.
|
||||
// Minimum is 15: "PROXY UNKNOWN\r\n"
|
||||
|
||||
if bytes.len() < 15 {
|
||||
if buf.remaining() < 15 {
|
||||
return Err(E::NotEnoughBytes);
|
||||
}
|
||||
|
||||
// Let's check in the first 108 bytes if we find a CRLF
|
||||
let crlf = if let Some(crlf) = bytes
|
||||
let crlf = if let Some(crlf) = buf
|
||||
.as_ref()
|
||||
.windows(2)
|
||||
.take(108)
|
||||
.position(|needle| needle == [0x0D, 0x0A])
|
||||
@ -78,7 +83,7 @@ impl ProxyProtocolV1Info {
|
||||
crlf
|
||||
} else {
|
||||
// If not, it might be because we don't have enough bytes
|
||||
return if bytes.len() < 108 {
|
||||
return if buf.remaining() < 108 {
|
||||
Err(E::NotEnoughBytes)
|
||||
} else {
|
||||
// Else it's just invalid
|
||||
@ -86,10 +91,8 @@ impl ProxyProtocolV1Info {
|
||||
};
|
||||
};
|
||||
|
||||
// Keep the rest of the buffer to pass it to the underlying protocol
|
||||
let rest = &bytes[crlf + 2..];
|
||||
// Trim to everything before the CRLF
|
||||
let bytes = &bytes[..crlf];
|
||||
let bytes = &buf.as_ref()[..crlf];
|
||||
|
||||
let mut it = bytes.splitn(6, |c| c == &b' ');
|
||||
// Check for the preamble
|
||||
@ -187,7 +190,9 @@ impl ProxyProtocolV1Info {
|
||||
None => return Err(E::NoProtocol),
|
||||
};
|
||||
|
||||
Ok((result, rest))
|
||||
buf.advance(crlf + 2);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@ -258,39 +263,42 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_parse() {
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world",
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
let mut buf =
|
||||
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world"
|
||||
.as_slice();
|
||||
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||
assert_eq!(buf, b"hello world");
|
||||
assert!(info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(!info.is_unknown());
|
||||
assert!(info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
let mut buf =
|
||||
b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||
).unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
.as_slice();
|
||||
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||
assert_eq!(buf, b"hello world");
|
||||
assert!(info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(!info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
|
||||
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||
assert_eq!(buf, b"hello world");
|
||||
assert!(!info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
let mut buf =
|
||||
b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||
).unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
.as_slice();
|
||||
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||
assert_eq!(buf, b"hello world");
|
||||
assert!(!info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(info.is_unknown());
|
||||
|
151
crates/listener/src/rewind.rs
Normal file
151
crates/listener/src/rewind.rs
Normal file
@ -0,0 +1,151 @@
|
||||
// Taken from hyper@0.14.20, src/common/io/rewind.rs
|
||||
|
||||
use std::{
|
||||
cmp, io,
|
||||
marker::Unpin,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
/// Combine a buffer with an IO, rewinding reads to use the buffer.
|
||||
#[derive(Debug)]
|
||||
pub struct Rewind<T> {
|
||||
pre: Option<Bytes>,
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T> Rewind<T> {
|
||||
pub(crate) fn new(io: T) -> Self {
|
||||
Rewind {
|
||||
pre: None,
|
||||
inner: io,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
|
||||
Rewind {
|
||||
pre: Some(buf),
|
||||
inner: io,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn rewind(&mut self, bs: Bytes) {
|
||||
debug_assert!(self.pre.is_none());
|
||||
self.pre = Some(bs);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncRead for Rewind<T>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if let Some(mut prefix) = self.pre.take() {
|
||||
// If there are no remaining bytes, let the bytes get dropped.
|
||||
if !prefix.is_empty() {
|
||||
let copy_len = cmp::min(prefix.len(), buf.remaining());
|
||||
// TODO: There should be a way to do following two lines cleaner...
|
||||
buf.put_slice(&prefix[..copy_len]);
|
||||
prefix.advance(copy_len);
|
||||
// Put back what's left
|
||||
if !prefix.is_empty() {
|
||||
self.pre = Some(prefix);
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncWrite for Rewind<T>
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// FIXME: re-implement tests with `async/await`, this import should
|
||||
// trigger a warning to remind us
|
||||
use bytes::Bytes;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use super::Rewind;
|
||||
|
||||
#[tokio::test]
|
||||
async fn partial_rewind() {
|
||||
let underlying = [104, 101, 108, 108, 111];
|
||||
|
||||
let mock = tokio_test::io::Builder::new().read(&underlying).build();
|
||||
|
||||
let mut stream = Rewind::new(mock);
|
||||
|
||||
// Read off some bytes, ensure we filled o1
|
||||
let mut buf = [0; 2];
|
||||
stream.read_exact(&mut buf).await.expect("read1");
|
||||
|
||||
// Rewind the stream so that it is as if we never read in the first place.
|
||||
stream.rewind(Bytes::copy_from_slice(&buf[..]));
|
||||
|
||||
let mut buf = [0; 5];
|
||||
stream.read_exact(&mut buf).await.expect("read1");
|
||||
|
||||
// At this point we should have read everything that was in the MockStream
|
||||
assert_eq!(&buf, &underlying);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn full_rewind() {
|
||||
let underlying = [104, 101, 108, 108, 111];
|
||||
|
||||
let mock = tokio_test::io::Builder::new().read(&underlying).build();
|
||||
|
||||
let mut stream = Rewind::new(mock);
|
||||
|
||||
let mut buf = [0; 5];
|
||||
stream.read_exact(&mut buf).await.expect("read1");
|
||||
|
||||
// Rewind the stream so that it is as if we never read in the first place.
|
||||
stream.rewind(Bytes::copy_from_slice(&buf[..]));
|
||||
|
||||
let mut buf = [0; 5];
|
||||
stream.read_exact(&mut buf).await.expect("read1");
|
||||
}
|
||||
}
|
301
crates/listener/src/server.rs
Normal file
301
crates/listener/src/server.rs
Normal file
@ -0,0 +1,301 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_util::{Stream, StreamExt};
|
||||
use http_body::Body;
|
||||
use hyper::{Request, Response};
|
||||
use thiserror::Error;
|
||||
use tokio_rustls::rustls::ServerConfig;
|
||||
use tower_http::add_extension::AddExtension;
|
||||
use tower_service::Service;
|
||||
|
||||
use crate::{
|
||||
maybe_tls::{MaybeTlsAcceptor, TlsStreamInfo},
|
||||
proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
|
||||
unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
|
||||
ConnectionInfo,
|
||||
};
|
||||
|
||||
pub struct Server<S> {
|
||||
tls: Option<Arc<ServerConfig>>,
|
||||
proxy: bool,
|
||||
listener: UnixOrTcpListener,
|
||||
service: S,
|
||||
}
|
||||
|
||||
impl<S> Server<S> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the listener couldn't be converted via [`TryInfo`]
|
||||
pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
|
||||
where
|
||||
L: TryInto<UnixOrTcpListener>,
|
||||
{
|
||||
Ok(Self {
|
||||
tls: None,
|
||||
proxy: false,
|
||||
listener: listener.try_into()?,
|
||||
service,
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
|
||||
Self {
|
||||
tls: None,
|
||||
proxy: false,
|
||||
listener: listener.into(),
|
||||
service,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_proxy(mut self) -> Self {
|
||||
self.proxy = true;
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
|
||||
self.tls = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Run a single server
|
||||
pub async fn run<B, SD>(self, shutdown: SD)
|
||||
where
|
||||
S: Service<Request<hyper::Body>, Response = Response<B>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
B: Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
SD: Stream + Unpin,
|
||||
SD::Item: std::fmt::Display,
|
||||
{
|
||||
run_servers(std::iter::once(self), shutdown).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[non_exhaustive]
|
||||
enum AcceptError {
|
||||
#[error("failed to accept connection from the underlying socket")]
|
||||
Socket {
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
#[error("failed to complete the TLS handshake")]
|
||||
TlsHandshake {
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
#[error("failed to complete the proxy protocol handshake")]
|
||||
ProxyHandshake {
|
||||
#[source]
|
||||
source: ProxyAcceptError,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Hyper(#[from] hyper::Error),
|
||||
}
|
||||
|
||||
impl AcceptError {
|
||||
fn socket(source: std::io::Error) -> Self {
|
||||
Self::Socket { source }
|
||||
}
|
||||
|
||||
fn tls_handshake(source: std::io::Error) -> Self {
|
||||
Self::TlsHandshake { source }
|
||||
}
|
||||
|
||||
fn proxy_handshake(source: ProxyAcceptError) -> Self {
|
||||
Self::ProxyHandshake { source }
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept<S, B>(
|
||||
maybe_proxy_acceptor: &MaybeProxyAcceptor,
|
||||
maybe_tls_acceptor: &MaybeTlsAcceptor,
|
||||
peer_addr: SocketAddr,
|
||||
stream: UnixOrTcpConnection,
|
||||
service: S,
|
||||
) -> Result<(), AcceptError>
|
||||
where
|
||||
S: Service<Request<hyper::Body>, Response = Response<B>>,
|
||||
S::Future: Send + 'static,
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
B: Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
let (proxy, stream) = maybe_proxy_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(AcceptError::proxy_handshake)?;
|
||||
|
||||
let stream = maybe_tls_acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(AcceptError::tls_handshake)?;
|
||||
|
||||
let tls = stream.tls_info();
|
||||
|
||||
// Figure out if it's HTTP/2 based on the negociated ALPN info
|
||||
let is_h2 = tls.as_ref().map_or(false, TlsStreamInfo::is_alpn_h2);
|
||||
|
||||
let info = ConnectionInfo {
|
||||
tls,
|
||||
proxy,
|
||||
net_peer_addr: peer_addr.into_net(),
|
||||
};
|
||||
|
||||
let service = AddExtension::new(service, info);
|
||||
|
||||
if is_h2 {
|
||||
hyper::server::conn::Http::new()
|
||||
.http2_only(true)
|
||||
.serve_connection(stream, service)
|
||||
.await?;
|
||||
} else {
|
||||
hyper::server::conn::Http::new()
|
||||
.http1_only(true)
|
||||
.http1_keep_alive(false)
|
||||
.serve_connection(stream, service)
|
||||
.await?;
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_servers<S, B, SD>(listeners: impl IntoIterator<Item = Server<S>>, mut shutdown: SD)
|
||||
where
|
||||
S: Service<Request<hyper::Body>, Response = Response<B>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
B: Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
SD: Stream + Unpin,
|
||||
SD::Item: std::fmt::Display,
|
||||
{
|
||||
let listeners: Vec<_> = listeners
|
||||
.into_iter()
|
||||
.map(|server| {
|
||||
let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
|
||||
let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
|
||||
let service = server.service;
|
||||
let listener = server.listener;
|
||||
(maybe_proxy_acceptor, maybe_tls_acceptor, service, listener)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
|
||||
loop {
|
||||
let mut accept_all: futures_util::stream::FuturesUnordered<_> = listeners
|
||||
.iter()
|
||||
.map(
|
||||
|(maybe_proxy_acceptor, maybe_tls_acceptor, service, listener)| async move {
|
||||
listener
|
||||
.accept()
|
||||
.await
|
||||
.map_err(AcceptError::socket)
|
||||
.map(|(addr, conn)| {
|
||||
(
|
||||
maybe_proxy_acceptor.clone(),
|
||||
maybe_tls_acceptor.clone(),
|
||||
service.clone(),
|
||||
addr,
|
||||
conn,
|
||||
)
|
||||
})
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
// First look for the shutdown signal
|
||||
res = shutdown.next() => {
|
||||
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
|
||||
tracing::info!("Received shutdown signal ({why})");
|
||||
|
||||
break;
|
||||
},
|
||||
|
||||
// Poll on the JoinSet, clearing finished task
|
||||
res = set.join_next(), if !set.is_empty() => {
|
||||
match res {
|
||||
Some(Ok(Ok(()))) => tracing::trace!("Task was successful"),
|
||||
Some(Ok(Err(e))) => tracing::error!("{e}"),
|
||||
Some(Err(e)) => tracing::error!("Join error: {e}"),
|
||||
None => tracing::error!("Join set was polled even though it was empty"),
|
||||
}
|
||||
},
|
||||
|
||||
// Then look for connections to accept
|
||||
res = accept_all.next(), if !accept_all.is_empty() => {
|
||||
// SAFETY: We shouldn't reach this branch if the unordered future set is empty
|
||||
let res = if let Some(res) = res { res } else { unreachable!() };
|
||||
|
||||
// Spawn the connection in the set, so we don't have to wait for the handshake to
|
||||
// accept the next connection. This allows us to keep track of active connections
|
||||
// and waiting on them for a graceful shutdown
|
||||
set.spawn(async move {
|
||||
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res?;
|
||||
accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if !set.is_empty() {
|
||||
tracing::info!(
|
||||
"There are {active} active connections, performing a graceful shutdown. Send the shutdown signal again to force.",
|
||||
active = set.len()
|
||||
);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
res = set.join_next() => {
|
||||
match res {
|
||||
Some(Ok(Ok(()))) => tracing::trace!("Task was successful"),
|
||||
Some(Ok(Err(e))) => tracing::error!("{e}"),
|
||||
Some(Err(e)) => tracing::error!("Join error: {e}"),
|
||||
// No more tasks, going out
|
||||
None => break,
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
res = shutdown.next() => {
|
||||
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
|
||||
tracing::warn!("Received shutdown signal again ({why}), forcing shutdown ({active} active connections)", active = set.len());
|
||||
break;
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set.shutdown().await;
|
||||
tracing::info!("Shutdown complete");
|
||||
}
|
178
crates/listener/src/shutdown.rs
Normal file
178
crates/listener/src/shutdown.rs
Normal file
@ -0,0 +1,178 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{fmt::Display, pin::Pin, task::Poll, time::Duration};
|
||||
|
||||
use futures_util::{ready, Future, Stream};
|
||||
use tokio::{
|
||||
signal::unix::{signal, Signal, SignalKind},
|
||||
time::Sleep,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ShutdownReason {
|
||||
Signal(SignalKind),
|
||||
Timeout,
|
||||
}
|
||||
|
||||
fn signal_to_str(kind: SignalKind) -> &'static str {
|
||||
match kind.as_raw_value() {
|
||||
libc::SIGALRM => "SIGALRM",
|
||||
libc::SIGCHLD => "SIGCHLD",
|
||||
libc::SIGHUP => "SIGHUP",
|
||||
libc::SIGINT => "SIGINT",
|
||||
libc::SIGIO => "SIGIO",
|
||||
libc::SIGPIPE => "SIGPIPE",
|
||||
libc::SIGQUIT => "SIGQUIT",
|
||||
libc::SIGTERM => "SIGTERM",
|
||||
libc::SIGUSR1 => "SIGUSR1",
|
||||
libc::SIGUSR2 => "SIGUSR2",
|
||||
_ => "SIG???",
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ShutdownReason {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Signal(s) => signal_to_str(*s).fmt(f),
|
||||
Self::Timeout => "timeout".fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub enum ShutdownStreamState {
|
||||
#[default]
|
||||
Waiting,
|
||||
|
||||
Graceful {
|
||||
sleep: Option<Pin<Box<Sleep>>>,
|
||||
},
|
||||
|
||||
Done,
|
||||
}
|
||||
|
||||
impl ShutdownStreamState {
|
||||
fn is_graceful(&self) -> bool {
|
||||
matches!(self, Self::Graceful { .. })
|
||||
}
|
||||
|
||||
fn is_done(&self) -> bool {
|
||||
matches!(self, Self::Done)
|
||||
}
|
||||
|
||||
fn get_sleep_mut(&mut self) -> Option<&mut Pin<Box<Sleep>>> {
|
||||
match self {
|
||||
Self::Graceful { sleep } => sleep.as_mut(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream which is used to drive a graceful shutdown.
|
||||
///
|
||||
/// It will emit 2 items: one when a first signal is caught, the other when
|
||||
/// either another signal is caught, or after a timeout.
|
||||
#[derive(Default)]
|
||||
pub struct ShutdownStream {
|
||||
state: ShutdownStreamState,
|
||||
signals: Vec<(SignalKind, Signal)>,
|
||||
timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl ShutdownStream {
|
||||
/// Create a default shutdown stream, which listens on SIGINT and SIGTERM,
|
||||
/// with a 60s timeout
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if signal handlers could not be installed
|
||||
pub fn new() -> Result<Self, std::io::Error> {
|
||||
let ret = Self::default()
|
||||
.with_timeout(Duration::from_secs(60))
|
||||
.with_signal(SignalKind::interrupt())?
|
||||
.with_signal(SignalKind::terminate())?;
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Add a signal to register
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the signal handler could not be installed
|
||||
pub fn with_signal(mut self, kind: SignalKind) -> Result<Self, std::io::Error> {
|
||||
let signal = signal(kind)?;
|
||||
self.signals.push((kind, signal));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.timeout = Some(timeout);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for ShutdownStream {
|
||||
type Item = ShutdownReason;
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
match self.state {
|
||||
ShutdownStreamState::Waiting => (2, Some(2)),
|
||||
ShutdownStreamState::Graceful { .. } => (1, Some(1)),
|
||||
ShutdownStreamState::Done => (0, Some(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_next(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if this.state.is_done() {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
for (kind, signal) in &mut this.signals {
|
||||
match signal.poll_recv(cx) {
|
||||
Poll::Ready(_) => {
|
||||
// We got a signal
|
||||
if this.state.is_graceful() {
|
||||
// If we was gracefully shutting down, mark it as done
|
||||
this.state = ShutdownStreamState::Done;
|
||||
} else {
|
||||
// Else start the timeout
|
||||
let sleep = this
|
||||
.timeout
|
||||
.map(|duration| Box::pin(tokio::time::sleep(duration)));
|
||||
this.state = ShutdownStreamState::Graceful { sleep };
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(ShutdownReason::Signal(*kind)));
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(timeout) = this.state.get_sleep_mut() {
|
||||
ready!(timeout.as_mut().poll(cx));
|
||||
this.state = ShutdownStreamState::Done;
|
||||
Poll::Ready(Some(ShutdownReason::Timeout))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! A listener which can listen on either TCP sockets or on UNIX domain sockets
|
||||
|
||||
// TODO: Unlink the UNIX socket on drop?
|
||||
|
||||
use std::{
|
||||
@ -19,8 +21,6 @@ use std::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::{TcpListener, TcpStream, UnixListener, UnixStream},
|
||||
@ -107,6 +107,7 @@ impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
|
||||
listener.set_nonblocking(true)?;
|
||||
Ok(Self::Unix(UnixListener::from_std(listener)?))
|
||||
}
|
||||
}
|
||||
@ -115,6 +116,7 @@ impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
|
||||
listener.set_nonblocking(true)?;
|
||||
Ok(Self::Tcp(TcpListener::from_std(listener)?))
|
||||
}
|
||||
}
|
||||
@ -140,6 +142,24 @@ impl UnixOrTcpListener {
|
||||
pub const fn is_tcp(&self) -> bool {
|
||||
matches!(self, Self::Tcp(_))
|
||||
}
|
||||
|
||||
/// Accept an incoming connection
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the underlying socket couldn't accept the connection
|
||||
pub async fn accept(&self) -> Result<(SocketAddr, UnixOrTcpConnection), std::io::Error> {
|
||||
match self {
|
||||
Self::Unix(listener) => {
|
||||
let (stream, remote_addr) = listener.accept().await?;
|
||||
Ok((remote_addr.into(), UnixOrTcpConnection::Unix { stream }))
|
||||
}
|
||||
Self::Tcp(listener) => {
|
||||
let (stream, remote_addr) = listener.accept().await?;
|
||||
Ok((remote_addr.into(), UnixOrTcpConnection::Tcp { stream }))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
@ -157,6 +177,12 @@ pin_project_lite::pin_project! {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TcpStream> for UnixOrTcpConnection {
|
||||
fn from(stream: TcpStream) -> Self {
|
||||
Self::Tcp { stream }
|
||||
}
|
||||
}
|
||||
|
||||
impl UnixOrTcpConnection {
|
||||
/// Get the local address of the stream
|
||||
///
|
||||
@ -166,8 +192,8 @@ impl UnixOrTcpConnection {
|
||||
/// [`UnixStream`] couldn't provide the local address
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
||||
Self::Tcp { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
||||
Self::Unix { stream } => stream.local_addr().map(SocketAddr::from),
|
||||
Self::Tcp { stream } => stream.local_addr().map(SocketAddr::from),
|
||||
}
|
||||
}
|
||||
|
||||
@ -179,36 +205,12 @@ impl UnixOrTcpConnection {
|
||||
/// [`UnixStream`] couldn't provide the remote address
|
||||
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
||||
Self::Tcp { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
||||
Self::Unix { stream } => stream.peer_addr().map(SocketAddr::from),
|
||||
Self::Tcp { stream } => stream.peer_addr().map(SocketAddr::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Accept for UnixOrTcpListener {
|
||||
type Error = std::io::Error;
|
||||
type Conn = UnixOrTcpConnection;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let conn = match &*self {
|
||||
Self::Unix(listener) => {
|
||||
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
|
||||
UnixOrTcpConnection::Unix { stream }
|
||||
}
|
||||
|
||||
Self::Tcp(listener) => {
|
||||
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
|
||||
UnixOrTcpConnection::Tcp { stream }
|
||||
}
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(conn)))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for UnixOrTcpConnection {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
@ -234,23 +236,6 @@ impl AsyncWrite for UnixOrTcpConnection {
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
@ -268,4 +253,21 @@ impl AsyncWrite for UnixOrTcpConnection {
|
||||
UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user