1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Rewrite the listeners crate

Now with a way better graceful shutdown! With proper handshakes!
This commit is contained in:
Quentin Gliech
2022-10-12 14:36:19 +02:00
parent 485778beb3
commit ee43f08cf7
19 changed files with 1092 additions and 1016 deletions

22
Cargo.lock generated
View File

@ -2425,6 +2425,7 @@ dependencies = [
"futures-util",
"hyper",
"indoc",
"itertools",
"listenfd",
"mas-config",
"mas-email",
@ -2677,15 +2678,21 @@ dependencies = [
name = "mas-listener"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes 1.2.1",
"futures-util",
"http-body",
"hyper",
"libc",
"pin-project-lite",
"thiserror",
"tokio",
"tokio-rustls 0.23.4",
"tower",
"tokio-test",
"tower-http",
"tower-service",
"tracing",
"tracing-subscriber",
]
[[package]]
@ -4872,6 +4879,19 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-test"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
dependencies = [
"async-stream",
"bytes 1.2.1",
"futures-core",
"tokio",
"tokio-stream",
]
[[package]]
name = "tokio-util"
version = "0.6.10"

View File

@ -22,6 +22,7 @@ watchman_client = "0.8.0"
atty = "0.2.14"
listenfd = "1.0.0"
rustls = "0.20.6"
itertools = "0.10.5"
tracing = "0.1.36"
tracing-appender = "0.2.2"

View File

@ -16,26 +16,19 @@ use std::{sync::Arc, time::Duration};
use anyhow::Context;
use clap::Parser;
use futures_util::{
future::FutureExt,
stream::{StreamExt, TryStreamExt},
};
use hyper::Server;
use futures_util::stream::{StreamExt, TryStreamExt};
use itertools::Itertools;
use mas_config::RootConfig;
use mas_email::Mailer;
use mas_handlers::{AppState, MatrixHomeserver};
use mas_http::ServerLayer;
use mas_listener::{
info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection},
maybe_tls::MaybeTlsAcceptor,
proxy_protocol::MaybeProxyAcceptor,
};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR;
use mas_tasks::TaskQueue;
use mas_templates::Templates;
use tokio::io::AsyncRead;
use tokio::{io::AsyncRead, signal::unix::SignalKind};
use tracing::{error, info, log::warn};
#[derive(Parser, Debug, Default)]
@ -49,32 +42,6 @@ pub(super) struct Options {
watch: bool,
}
#[cfg(not(unix))]
async fn shutdown_signal() {
// Wait for the CTRL+C signal
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C signal handler");
tracing::info!("Got Ctrl+C, shutting down");
}
#[cfg(unix)]
async fn shutdown_signal() {
use tokio::signal::unix::{signal, SignalKind};
// Wait for SIGTERM and SIGINT signals
// This might panic but should be fine
let mut term =
signal(SignalKind::terminate()).expect("failed to install SIGTERM signal handler");
let mut int = signal(SignalKind::interrupt()).expect("failed to install SIGINT signal handler");
tokio::select! {
_ = term.recv() => tracing::info!("Got SIGTERM, shutting down"),
_ = int.recv() => tracing::info!("Got SIGINT, shutting down"),
};
}
/// Watch for changes in the templates folders
async fn watch_templates(
client: &watchman_client::Client,
@ -247,68 +214,75 @@ impl Options {
policy_factory,
});
let signal = shutdown_signal().shared();
let shutdown_signal = signal.clone();
let mut fd_manager = listenfd::ListenFd::from_env();
let listeners = listeners_config.into_iter().map(|listener_config| {
// Let's first grab all the listeners in a synchronous manner
let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds);
Ok((listener_config, listeners?))
});
let servers: Vec<Server<_>> = listeners_config
.into_iter()
.map(|config| {
// Let's first grab all the listeners
let listeners = crate::server::build_listeners(&mut fd_manager, &config.binds)?;
// Now that we have the listeners ready, we can do the rest concurrently
futures_util::stream::iter(listeners)
.try_for_each_concurrent(None, move |(config, listeners)| {
let signal = signal.clone();
// Load the TLS config
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
let tls_config = crate::server::build_tls_server_config(tls_config)?;
Some(Arc::new(tls_config))
} else {
None
};
// and build the router
let router = crate::server::build_router(&state, &config.resources)
.layer(ServerLayer::new(config.name.clone()));
// Display some informations about where we'll be serving connections
let is_tls = config.tls.is_some();
let addresses: Vec<String> = listeners.iter().map(|listener| {
let addresses: Vec<String> = listeners
.iter()
.map(|listener| {
let addr = listener.local_addr();
let proto = if is_tls { "https" } else { "http" };
if let Ok(addr) = addr {
format!("{proto}://{addr:?}")
} else {
warn!("Could not get local address for listener, something might be wrong!");
warn!(
"Could not get local address for listener, something might be wrong!"
);
format!("{proto}://???")
}
}).collect();
let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" };
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
let make_service = IntoMakeServiceWithConnection::new(router);
async move {
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
let tls_config = crate::server::build_tls_server_config(tls_config).await?;
Some(Arc::new(tls_config))
} else { None };
futures_util::stream::iter(listeners)
.map(Ok)
.try_for_each_concurrent(None, move |listener| {
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
let listener = ConnectionInfoAcceptor::new(listener);
Server::builder(listener)
.serve(make_service.clone())
.with_graceful_shutdown(signal.clone())
})
.await?;
.collect();
anyhow::Ok(())
let additional = if config.proxy_protocol {
"(with Proxy Protocol)"
} else {
""
};
info!(
"Listening on {addresses:?} with resources {resources:?} {additional}",
resources = &config.resources
);
anyhow::Ok(listeners.into_iter().map(move |listener| {
let mut server = Server::new(listener, router.clone());
if let Some(tls_config) = &tls_config {
server = server.with_tls(tls_config.clone());
}
if config.proxy_protocol {
server = server.with_proxy();
}
server
}))
})
.await?;
.flatten_ok()
.collect::<Result<Vec<_>, _>>()?;
// This ensures we're running, even if no listener are setup
// This is useful for only running the task runner
shutdown_signal.await;
let shutdown = ShutdownStream::default()
.with_timeout(Duration::from_secs(60))
.with_signal(SignalKind::terminate())?
.with_signal(SignalKind::interrupt())?;
mas_listener::server::run_servers(servers, shutdown).await;
Ok(())
}

View File

@ -23,10 +23,9 @@ use axum::{body::HttpBody, Extension, Router};
use listenfd::ListenFd;
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
use mas_handlers::AppState;
use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener};
use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo};
use mas_router::Route;
use rustls::ServerConfig;
use tokio::sync::OnceCell;
#[allow(clippy::trait_duplication_in_bounds)]
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
@ -64,12 +63,9 @@ where
// TODO: do a better handler here
mas_config::HttpResource::ConnectionInfo => router.route(
"/connection-info",
axum::routing::get(
|connection: Extension<Arc<OnceCell<Connection>>>| async move {
let connection = connection.get().unwrap();
axum::routing::get(|connection: Extension<ConnectionInfo>| async move {
format!("{connection:?}")
},
),
}),
),
}
}
@ -77,10 +73,8 @@ where
router
}
pub async fn build_tls_server_config(
config: &HttpTlsConfig,
) -> Result<ServerConfig, anyhow::Error> {
let (key, chain) = config.load().await?;
pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result<ServerConfig, anyhow::Error> {
let (key, chain) = config.load()?;
let key = rustls::PrivateKey(key);
let chain = chain.into_iter().map(rustls::Certificate).collect();

View File

@ -163,11 +163,11 @@ impl TlsConfig {
/// - a password was provided but the key was not encrypted
/// - decoding the certificate chain as PEM
/// - the certificate chain is empty
pub async fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
pub fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
let password = match &self.password {
Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())),
Some(PasswordOrFile::PasswordFile(path)) => {
Some(Cow::Owned(tokio::fs::read_to_string(path).await?))
Some(Cow::Owned(std::fs::read_to_string(path)?))
}
None => None,
};
@ -185,7 +185,7 @@ impl TlsConfig {
KeyOrFile::KeyFile(path) => {
// When reading from disk, it might be either PEM or DER. `PrivateKey::load*`
// will try both.
let key = tokio::fs::read(path).await?;
let key = std::fs::read(path)?;
if let Some(password) = password {
PrivateKey::load_encrypted(&key, password.as_bytes())?
} else {
@ -202,9 +202,7 @@ impl TlsConfig {
let certificate_chain_pem = match &self.certificate {
CertificateOrFile::Certificate(pem) => Cow::Borrowed(pem.as_str()),
CertificateOrFile::CertificateFile(path) => {
Cow::Owned(tokio::fs::read_to_string(path).await?)
}
CertificateOrFile::CertificateFile(path) => Cow::Owned(std::fs::read_to_string(path)?),
};
let mut certificate_chain_reader = Cursor::new(certificate_chain_pem.as_bytes());

View File

@ -6,12 +6,25 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
bytes = "1.2.1"
futures-util = "0.3.24"
hyper = { version = "0.14.20", features = ["server"] }
http-body = "0.4.2"
hyper = { version = "0.14.20", features = ["server", "http1", "http2"] }
pin-project-lite = "0.2.9"
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["net"] }
tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time"] }
tokio-rustls = "0.23.4"
tower = "0.4.13"
tower-http = { version = "0.3.4", features = ["add-extension"] }
tracing = "0.1.36"
tower-service = "0.3.2"
tracing = "0.1.37"
libc = "0.2.135"
[dev-dependencies]
tokio-test = "0.4.2"
anyhow = "1.0.65"
tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time", "rt-multi-thread"] }
tracing-subscriber = "0.3.16"
[[example]]
name = "demo"
path = "examples/demo/main.rs"

View File

@ -0,0 +1,49 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
convert::Infallible,
net::{Ipv4Addr, TcpListener},
time::Duration,
};
use hyper::{service::service_fn, Request, Response};
use tokio::signal::unix::SignalKind;
use tokio_streams_util::{server::Server, shutdown::ShutdownStream, ConnectionInfo};
async fn handler(req: Request<hyper::Body>) -> Result<Response<String>, Infallible> {
tracing::info!("Handling request");
tokio::time::sleep(Duration::from_secs(3)).await;
let info = req.extensions().get::<ConnectionInfo>().unwrap();
let body = format!("{info:?}");
Ok(Response::new(body))
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt::init();
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 3000))?;
let service = service_fn(handler);
let server = Server::try_new(listener, service)?;
tracing::info!("Listening on 127.0.0.1:3000");
let shutdown = ShutdownStream::default()
.with_signal(SignalKind::interrupt())?
.with_signal(SignalKind::terminate())?;
server.run(shutdown).await;
Ok(())
}

View File

@ -1,323 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
future::Ready,
net::SocketAddr,
ops::Deref,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use hyper::server::accept::Accept;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::OnceCell,
};
use tower::Service;
use tower_http::add_extension::AddExtension;
use crate::{
maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo, TlsStreamInfoError},
proxy_protocol::{
MaybeProxyAcceptor, MaybeProxyStream, ProxyHandshakeNotDone, ProxyProtocolV1Info,
},
unix_or_tcp::{UnixOrTcpConnection, UnixOrTcpListener},
};
// TODO: this is a mess, clean that up
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum FromStreamError {
#[error(transparent)]
Proxy(#[from] ProxyHandshakeNotDone),
#[error(transparent)]
Tls(#[from] TlsStreamInfoError),
#[error("Could not grab a reference to the underlying stream")]
GetRef,
#[error("Could not get address info from underlying stream")]
IoError(#[from] std::io::Error),
}
#[derive(Debug, Clone)]
pub struct Connection {
proxy: Option<ProxyProtocolV1Info>,
tls: Option<TlsStreamInfo>,
// We're not saving the UNIX domain socket address here because it can't be cloned, which is
// required for injecting the connection information as an extension
local_tcp_addr: Option<SocketAddr>,
peer_tcp_addr: Option<SocketAddr>,
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum GrabAddressError {
#[error("Proxy protocol was initiated with an unknown protocol")]
ProxyUnknown,
#[error("Proxy protocol was initiated with UDP")]
ProxyUdp,
#[error("Underlying listener is a UNIX socket")]
UnixListener,
}
impl MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>> {
pub fn can_have_peer_address(&self) -> bool {
self.is_proxied() || self.is_tcp()
}
}
impl MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>> {
/// Get informations about this connection
///
/// # Errors
///
/// Returns an error if the proxy protocol or the TLS handhakes are not done
/// yet
pub fn connection_info(&self) -> Result<Connection, FromStreamError> {
Connection::from_stream(self)
}
}
impl Connection {
/// Get informations about this connection
///
/// # Errors
///
/// Returns an error if the proxy protocol or the TLS handhakes are not done
/// yet
pub fn from_stream(
stream: &MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
) -> Result<Self, FromStreamError> {
let proxy = stream.proxy_info()?.cloned();
let tls = stream.tls_info()?;
let original = stream.get_ref().ok_or(FromStreamError::GetRef)?;
let local_tcp_addr = original.local_addr()?.into_net();
let peer_tcp_addr = original.peer_addr()?.into_net();
Ok(Self {
proxy,
tls,
local_tcp_addr,
peer_tcp_addr,
})
}
#[must_use]
pub const fn is_proxied(&self) -> bool {
self.proxy.is_some()
}
#[must_use]
pub const fn is_tls(&self) -> bool {
self.tls.is_some()
}
/// Get the outmost peer address, either from the TCP listener or from the
/// proxy protocol infos.
///
/// # Errors
///
/// Returns an error if the info from the proxy protocol was not for a TCP
/// connection, or if the proxy protocol is not being used, the underlying
/// listener was a UNIX domain socket
pub fn peer_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
if let Some(proxy) = self.proxy.as_ref() {
if proxy.is_udp() {
return Err(GrabAddressError::ProxyUdp);
}
proxy.source().ok_or(GrabAddressError::ProxyUnknown)
} else {
self.peer_tcp_addr
.as_ref()
.ok_or(GrabAddressError::UnixListener)
}
}
/// Get the outmost local address, either from the TCP listener or from the
/// proxy protocol infos.
///
/// # Errors
///
/// Returns an error if the info from the proxy protocol was not for a TCP
/// connection, or if the proxy protocol is not being used, the underlying
/// listener was a UNIX domain socket
pub fn local_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
if let Some(proxy) = self.proxy.as_ref() {
if proxy.is_udp() {
return Err(GrabAddressError::ProxyUdp);
}
proxy.destination().ok_or(GrabAddressError::ProxyUnknown)
} else {
self.local_tcp_addr
.as_ref()
.ok_or(GrabAddressError::UnixListener)
}
}
}
pin_project_lite::pin_project! {
pub struct ConnectionInfoAcceptor {
#[pin]
acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>,
}
}
impl ConnectionInfoAcceptor {
pub const fn new(acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>) -> Self {
Self { acceptor }
}
}
impl Accept for ConnectionInfoAcceptor {
type Conn = ConnectionInfoStream;
type Error = std::io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let proj = self.project();
let ret = match futures_util::ready!(proj.acceptor.poll_accept(cx)) {
Some(Ok(conn)) => Some(Ok(ConnectionInfoStream::new(conn))),
Some(Err(e)) => Some(Err(e)),
None => None,
};
Poll::Ready(ret)
}
}
pin_project_lite::pin_project! {
pub struct ConnectionInfoStream {
connection: Arc<OnceCell<Connection>>,
#[pin]
stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
}
}
impl ConnectionInfoStream {
pub fn new(stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>) -> Self {
Self {
connection: Arc::new(OnceCell::const_new()),
stream,
}
}
}
impl Deref for ConnectionInfoStream {
type Target = MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl AsyncRead for ConnectionInfoStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
futures_util::ready!(Pin::new(&mut this.stream).poll_read(cx, buf))?;
if !this.stream.is_tls_handshaking()
&& !this.stream.is_proxy_handshaking()
&& !this.connection.initialized()
{
this.connection
.set(this.stream.connection_info().unwrap())
.unwrap();
}
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for ConnectionInfoStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let proj = self.project();
proj.stream.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
let proj = self.project();
proj.stream.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let proj = self.project();
proj.stream.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
let proj = self.project();
proj.stream.poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
}
#[derive(Debug, Clone)]
pub struct IntoMakeServiceWithConnection<S> {
svc: S,
}
impl<S> IntoMakeServiceWithConnection<S> {
pub const fn new(svc: S) -> Self {
Self { svc }
}
}
impl<S> Service<&ConnectionInfoStream> for IntoMakeServiceWithConnection<S>
where
S: Clone,
{
type Response = AddExtension<S, Arc<OnceCell<Connection>>>;
type Error = FromStreamError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, target: &ConnectionInfoStream) -> Self::Future {
std::future::ready(Ok(AddExtension::new(
self.svc.clone(),
target.connection.clone(),
)))
}
}

View File

@ -22,7 +22,41 @@
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
pub mod info;
use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info};
pub mod maybe_tls;
pub mod proxy_protocol;
pub mod rewind;
pub mod server;
pub mod shutdown;
pub mod unix_or_tcp;
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
tls: Option<TlsStreamInfo>,
proxy: Option<ProxyProtocolV1Info>,
net_peer_addr: Option<std::net::SocketAddr>,
}
impl ConnectionInfo {
/// Returns informations about the TLS connection. Returns [`None`] if the
/// connection was not TLS.
#[must_use]
pub fn get_tls_ref(&self) -> Option<&TlsStreamInfo> {
self.tls.as_ref()
}
/// Returns informations about the proxy protocol connection. Returns
/// [`None`] if the connection was not using the proxy protocol.
#[must_use]
pub fn get_proxy_ref(&self) -> Option<&ProxyProtocolV1Info> {
self.proxy.as_ref()
}
/// Returns the remote peer address. Returns [`None`] if the connection was
/// established via a UNIX domain socket.
#[must_use]
pub fn get_peer_addr(&self) -> Option<std::net::SocketAddr> {
self.net_peer_addr
}
}

View File

@ -13,18 +13,15 @@
// limitations under the License.
use std::{
ops::Deref,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures_util::{ready, Future};
use hyper::server::accept::Accept;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::{
Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite,
use tokio_rustls::{
rustls::{Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite},
TlsAcceptor,
};
#[derive(Debug, Clone)]
@ -33,105 +30,78 @@ pub struct TlsStreamInfo {
pub protocol_version: ProtocolVersion,
pub negotiated_cipher_suite: SupportedCipherSuite,
pub sni_hostname: Option<String>,
pub apln_protocol: Option<Vec<u8>>,
pub alpn_protocol: Option<Vec<u8>>,
pub peer_certificates: Option<Vec<Certificate>>,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TlsStreamInfoError {
#[error("TLS handshake is not done yet")]
HandshakingNotDone,
#[error("Some fields were not available in the TLS connection")]
FieldsNotAvailable,
impl TlsStreamInfo {
#[must_use]
pub fn is_alpn_h2(&self) -> bool {
matches!(self.alpn_protocol.as_deref(), Some(b"h2"))
}
}
pin_project_lite::pin_project! {
#[project = MaybeTlsStreamProj]
pub enum MaybeTlsStream<T> {
Handshaking(tokio_rustls::Accept<T>),
Streaming(tokio_rustls::server::TlsStream<T>),
Insecure(T),
Secure {
#[pin]
stream: tokio_rustls::server::TlsStream<T>
},
Insecure {
#[pin]
stream: T,
},
}
}
impl<T> MaybeTlsStream<T> {
pub fn new(stream: T, config: Option<Arc<ServerConfig>>) -> Self
where
T: AsyncRead + AsyncWrite + Unpin,
{
if let Some(config) = config {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
MaybeTlsStream::Handshaking(accept)
} else {
MaybeTlsStream::Insecure(stream)
}
}
/// Get a reference to the underlying IO stream
///
/// Returns [`None`] if the stream closed before the TLS handshake finished.
/// It is guaranteed to return [`Some`] value after the handshake finished,
/// or if it is a non-TLS connection.
pub fn get_ref(&self) -> Option<&T> {
pub fn get_ref(&self) -> &T {
match self {
Self::Handshaking(accept) => accept.get_ref(),
Self::Streaming(stream) => {
let (inner, _) = stream.get_ref();
Some(inner)
}
Self::Insecure(inner) => Some(inner),
Self::Secure { stream } => stream.get_ref().0,
Self::Insecure { stream } => stream,
}
}
/// Get a ref to the [`ServerConnection`] of the establish TLS stream.
///
/// Returns [`None`] if the connection is still handshaking and for non-TLS
/// connections.
/// Returns [`None`] for non-TLS connections.
pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
match self {
Self::Streaming(stream) => {
let (_, conn) = stream.get_ref();
Some(conn)
}
Self::Handshaking(_) | Self::Insecure(_) => None,
Self::Secure { stream } => Some(stream.get_ref().1),
Self::Insecure { .. } => None,
}
}
/// Gather informations about the TLS connection. Returns `None` if the
/// stream is not a TLS stream.
///
/// # Errors
///
/// Returns an error if the TLS handshake is not yet done
pub fn tls_info(&self) -> Result<Option<TlsStreamInfo>, TlsStreamInfoError> {
let conn = match self {
Self::Streaming(stream) => stream.get_ref().1,
Self::Handshaking(_) => return Err(TlsStreamInfoError::HandshakingNotDone),
Self::Insecure(_) => return Ok(None),
};
pub fn tls_info(&self) -> Option<TlsStreamInfo> {
let conn = self.get_tls_connection()?;
// NOTE: we're getting the protocol version and cipher suite *after* the
// handshake, so this should never lead to an error
// SAFETY: we're getting the protocol version and cipher suite *after* the
// handshake, so this should never lead to a panic
let protocol_version = conn
.protocol_version()
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
.expect("TLS handshake is not done yet");
let negotiated_cipher_suite = conn
.negotiated_cipher_suite()
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
.expect("TLS handshake is not done yet");
let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned);
let apln_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
let alpn_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned);
Ok(Some(TlsStreamInfo {
Some(TlsStreamInfo {
protocol_version,
negotiated_cipher_suite,
sni_hostname,
apln_protocol,
alpn_protocol,
peer_certificates,
}))
}
pub const fn is_tls_handshaking(&self) -> bool {
matches!(self, Self::Handshaking(_))
})
}
}
@ -144,20 +114,9 @@ where
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<std::io::Result<()>> {
let pin = self.get_mut();
match pin {
MaybeTlsStream::Handshaking(ref mut accept) => {
match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
*pin = MaybeTlsStream::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
}
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
match self.project() {
MaybeTlsStreamProj::Secure { stream } => stream.poll_read(cx, buf),
MaybeTlsStreamProj::Insecure { stream } => stream.poll_read(cx, buf),
}
}
}
@ -171,104 +130,89 @@ where
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let pin = self.get_mut();
match pin {
MaybeTlsStream::Handshaking(ref mut accept) => {
match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
*pin = MaybeTlsStream::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
match self.project() {
MaybeTlsStreamProj::Secure { stream } => stream.poll_write(cx, buf),
MaybeTlsStreamProj::Insecure { stream } => stream.poll_write(cx, buf),
}
}
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
MaybeTlsStream::Insecure(ref mut fallback) => Pin::new(fallback).poll_write(cx, buf),
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
MaybeTlsStreamProj::Secure { stream } => stream.poll_write_vectored(cx, bufs),
MaybeTlsStreamProj::Insecure { stream } => stream.poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
Self::Secure { stream } => stream.is_write_vectored(),
Self::Insecure { stream } => stream.is_write_vectored(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_flush(cx),
match self.project() {
MaybeTlsStreamProj::Secure { stream } => stream.poll_flush(cx),
MaybeTlsStreamProj::Insecure { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
match self.project() {
MaybeTlsStreamProj::Secure { stream } => stream.poll_shutdown(cx),
MaybeTlsStreamProj::Insecure { stream } => stream.poll_shutdown(cx),
}
}
}
pub struct MaybeTlsAcceptor<T> {
#[derive(Clone)]
pub struct MaybeTlsAcceptor {
tls_config: Option<Arc<ServerConfig>>,
incoming: T,
}
impl<T> MaybeTlsAcceptor<T> {
pub fn new(tls_config: Option<Arc<ServerConfig>>, incoming: T) -> Self {
Self {
tls_config,
incoming,
}
impl MaybeTlsAcceptor {
#[must_use]
pub fn new(tls_config: Option<Arc<ServerConfig>>) -> Self {
Self { tls_config }
}
pub fn new_secure(tls_config: Arc<ServerConfig>, incoming: T) -> Self {
#[must_use]
pub fn new_secure(tls_config: Arc<ServerConfig>) -> Self {
Self {
tls_config: Some(tls_config),
incoming,
}
}
pub fn new_insecure(incoming: T) -> Self {
Self {
tls_config: None,
incoming,
}
#[must_use]
pub fn new_insecure() -> Self {
Self { tls_config: None }
}
#[must_use]
pub const fn is_secure(&self) -> bool {
self.tls_config.is_some()
}
}
impl<T> Accept for MaybeTlsAcceptor<T>
/// Accept a connection and do the TLS handshake
///
/// # Errors
///
/// Returns an error if the TLS handshake failed
pub async fn accept<T>(&self, stream: T) -> Result<MaybeTlsStream<T>, std::io::Error>
where
T: Accept + Unpin,
T::Conn: AsyncRead + AsyncWrite + Unpin,
T::Error: Into<std::io::Error>,
T: AsyncRead + AsyncWrite + Unpin,
{
type Conn = MaybeTlsStream<T::Conn>;
type Error = std::io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
let ret = match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => {
let config = pin.tls_config.clone();
Some(Ok(MaybeTlsStream::new(sock, config)))
match &self.tls_config {
Some(config) => {
let acceptor = TlsAcceptor::from(config.clone());
let stream = acceptor.accept(stream).await?;
Ok(MaybeTlsStream::Secure { stream })
}
Some(Err(e)) => Some(Err(e.into())),
None => None,
};
Poll::Ready(ret)
None => Ok(MaybeTlsStream::Insecure { stream }),
}
}
impl<T> Deref for MaybeTlsAcceptor<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.incoming
}
}

View File

@ -12,41 +12,57 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures_util::ready;
use hyper::server::accept::Accept;
use bytes::BytesMut;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
use super::ProxyStream;
use super::ProxyProtocolV1Info;
use crate::rewind::Rewind;
pin_project_lite::pin_project! {
pub struct ProxyAcceptor<A> {
#[pin]
inner: A,
}
#[derive(Clone, Debug)]
pub struct ProxyAcceptor {
_private: (),
}
impl<A> ProxyAcceptor<A> {
pub const fn new(inner: A) -> Self {
Self { inner }
}
#[derive(Debug, Error)]
#[error(transparent)]
pub enum ProxyAcceptError {
Parse(#[from] super::v1::ParseError),
Read(#[from] std::io::Error),
}
impl<A> Accept for ProxyAcceptor<A>
impl ProxyAcceptor {
#[must_use]
pub const fn new() -> Self {
Self { _private: () }
}
/// Accept a proxy-protocol stream
///
/// # Errors
///
/// Returns an error on read error on the underlying stream, or when the
/// proxy protocol preamble couldn't be parsed
pub async fn accept<T>(
&self,
mut stream: T,
) -> Result<(ProxyProtocolV1Info, Rewind<T>), ProxyAcceptError>
where
A: Accept,
T: AsyncRead + Unpin,
{
type Conn = ProxyStream<A::Conn>;
type Error = A::Error;
let mut buf = BytesMut::new();
let info = loop {
stream.read_buf(&mut buf).await?;
fn poll_accept(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
let res = match ready!(self.project().inner.poll_accept(cx)) {
Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))),
Some(Err(e)) => Some(Err(e)),
None => None,
match ProxyProtocolV1Info::parse(&mut buf) {
Ok(info) => break info,
Err(e) if e.not_enough_bytes() => {}
Err(e) => return Err(e.into()),
}
};
std::task::Poll::Ready(res)
let stream = Rewind::new_buffered(stream, buf.into());
Ok((info, stream))
}
}

View File

@ -12,179 +12,66 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
ops::Deref,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::AsyncRead;
use futures_util::ready;
use hyper::server::accept::Accept;
use tokio::io::{AsyncRead, AsyncWrite};
use super::{acceptor::ProxyAcceptError, ProxyAcceptor, ProxyProtocolV1Info};
use crate::rewind::Rewind;
use super::{stream::HandshakeNotDone, ProxyProtocolV1Info, ProxyStream};
pin_project_lite::pin_project! {
pub struct MaybeProxyAcceptor<A> {
proxied: bool,
#[pin]
inner: A,
}
#[derive(Clone)]
pub struct MaybeProxyAcceptor {
acceptor: Option<ProxyAcceptor>,
}
impl<A> MaybeProxyAcceptor<A> {
impl MaybeProxyAcceptor {
#[must_use]
pub const fn new(inner: A, proxied: bool) -> Self {
Self { proxied, inner }
}
pub const fn is_proxied(&self) -> bool {
self.proxied
}
}
impl<A> Deref for MaybeProxyAcceptor<A> {
type Target = A;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<A> Accept for MaybeProxyAcceptor<A>
where
A: Accept,
{
type Conn = MaybeProxyStream<A::Conn>;
type Error = A::Error;
fn poll_accept(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
let proj = self.project();
let res = match ready!(proj.inner.poll_accept(cx)) {
Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))),
Some(Err(e)) => Some(Err(e)),
None => None,
pub const fn new(proxied: bool) -> Self {
let acceptor = if proxied {
Some(ProxyAcceptor::new())
} else {
None
};
std::task::Poll::Ready(res)
Self { acceptor }
}
#[must_use]
pub const fn new_proxied(acceptor: ProxyAcceptor) -> Self {
Self {
acceptor: Some(acceptor),
}
}
pin_project_lite::pin_project! {
#[project = MaybeProxyStreamProj]
pub enum MaybeProxyStream<S> {
Proxied { #[pin] stream: ProxyStream<S> },
NotProxied { #[pin] stream: S },
}
#[must_use]
pub const fn new_unproxied() -> Self {
Self { acceptor: None }
}
impl<S> MaybeProxyStream<S> {
pub const fn new(stream: S, proxied: bool) -> Self {
if proxied {
Self::Proxied {
stream: ProxyStream::new(stream),
}
} else {
Self::NotProxied { stream }
}
#[must_use]
pub const fn is_proxied(&self) -> bool {
self.acceptor.is_some()
}
/// Get informations from the proxied connection, if it was procied
/// Accept a connection and do the proxy protocol handshake
///
/// # Errors
///
/// Returns an error if the stream did not complete the handshake yet
pub fn proxy_info(&self) -> Result<Option<&ProxyProtocolV1Info>, HandshakeNotDone> {
match self {
Self::Proxied { stream } => Ok(Some(stream.proxy_info()?)),
Self::NotProxied { .. } => Ok(None),
}
}
pub const fn is_proxy_handshaking(&self) -> bool {
match self {
Self::Proxied { stream } => stream.is_handshaking(),
Self::NotProxied { .. } => false,
}
}
}
impl<S> Deref for MaybeProxyStream<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
match self {
Self::Proxied { stream } => &**stream,
Self::NotProxied { stream } => stream,
}
}
}
impl<S> AsyncRead for MaybeProxyStream<S>
/// Returns an error if the proxy protocol handshake failed
pub async fn accept<T>(
&self,
stream: T,
) -> Result<(Option<ProxyProtocolV1Info>, Rewind<T>), ProxyAcceptError>
where
S: AsyncRead,
T: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf),
match &self.acceptor {
Some(acceptor) => {
let (info, stream) = acceptor.accept(stream).await?;
Ok((Some(info), stream))
}
None => {
let stream = Rewind::new(stream);
Ok((None, stream))
}
}
}
impl<S> AsyncWrite for MaybeProxyStream<S>
where
S: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(),
MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(),
}
}
}

View File

@ -14,12 +14,10 @@
mod acceptor;
mod maybe;
mod stream;
mod v1;
pub use self::{
acceptor::ProxyAcceptor,
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
stream::{HandshakeNotDone as ProxyHandshakeNotDone, ProxyStream},
acceptor::{ProxyAcceptError, ProxyAcceptor},
maybe::MaybeProxyAcceptor,
v1::ProxyProtocolV1Info,
};

View File

@ -1,169 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use futures_util::ready;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::ProxyProtocolV1Info;
// Max theorical size we need is 108 for proxy protocol v1
const BUF_SIZE: usize = 256;
#[derive(Debug)]
enum ProxyStreamState {
Handshaking {
buffer: [u8; BUF_SIZE],
index: usize,
},
Established(ProxyProtocolV1Info),
}
impl ProxyStreamState {
pub const fn is_handshaking(&self) -> bool {
matches!(self, Self::Handshaking { .. })
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct ProxyStream<S> {
state: ProxyStreamState,
#[pin]
inner: S,
}
}
impl<S> ProxyStream<S> {
pub const fn new(inner: S) -> Self {
Self {
state: ProxyStreamState::Handshaking {
buffer: [0; BUF_SIZE],
index: 0,
},
inner,
}
}
}
#[derive(Debug, Error, Clone, Copy)]
#[error("Proxy protocol handshake is not complete")]
pub struct HandshakeNotDone;
impl<S> Deref for ProxyStream<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<S> ProxyStream<S> {
/// Get informations from the proxied connection
///
/// # Errors
///
/// Returns an error if the stream did not complete the handshake yet
pub fn proxy_info(&self) -> Result<&ProxyProtocolV1Info, HandshakeNotDone> {
match &self.state {
ProxyStreamState::Handshaking { .. } => Err(HandshakeNotDone),
ProxyStreamState::Established(info) => Ok(info),
}
}
/// Returns `true` if the proxy protocol is still handshaking
pub const fn is_handshaking(&self) -> bool {
self.state.is_handshaking()
}
}
impl<S> AsyncRead for ProxyStream<S>
where
S: AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let proj = self.project();
match proj.state {
ProxyStreamState::Handshaking { buffer, index } => {
let mut buffer = ReadBuf::new(&mut buffer[..]);
buffer.advance(*index);
ready!(proj.inner.poll_read(cx, &mut buffer))?;
let filled = buffer.filled();
*index = filled.len();
match ProxyProtocolV1Info::parse(filled) {
Ok((info, rest)) => {
if buf.remaining() < rest.len() {
// This is highly unlikely, but is better than panicking later.
// If it ever happens, we could introduce a "buffer draining" state
// which drains the inner buffer repeatedly until it's empty
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"underlying buffer is too small",
)));
}
buf.put_slice(rest);
*proj.state = ProxyStreamState::Established(info);
std::task::Poll::Ready(Ok(()))
}
Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())),
Err(e) => std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
e,
))),
}
}
ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf),
}
}
}
impl<S> AsyncWrite for ProxyStream<S>
where
S: AsyncWrite,
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let proj = self.project();
match proj.state {
// Hold off writes until the handshake is done
// XXX: is this the right way to do it?
ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending,
ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.project().inner.poll_shutdown(cx)
}
}

View File

@ -18,6 +18,7 @@ use std::{
str::Utf8Error,
};
use bytes::Buf;
use thiserror::Error;
#[derive(Debug, Clone)]
@ -35,7 +36,7 @@ pub enum ProxyProtocolV1Info {
#[derive(Error, Debug)]
#[error("Invalid proxy protocol header")]
pub(super) enum ParseError {
pub enum ParseError {
#[error("Not enough bytes provided")]
NotEnoughBytes,
NoCrLf,
@ -60,17 +61,21 @@ impl ParseError {
impl ProxyProtocolV1Info {
#[allow(clippy::too_many_lines)]
pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> {
pub(super) fn parse<B>(mut buf: B) -> Result<Self, ParseError>
where
B: Buf + AsRef<[u8]>,
{
use ParseError as E;
// First, check if we *possibly* have enough bytes.
// Minimum is 15: "PROXY UNKNOWN\r\n"
if bytes.len() < 15 {
if buf.remaining() < 15 {
return Err(E::NotEnoughBytes);
}
// Let's check in the first 108 bytes if we find a CRLF
let crlf = if let Some(crlf) = bytes
let crlf = if let Some(crlf) = buf
.as_ref()
.windows(2)
.take(108)
.position(|needle| needle == [0x0D, 0x0A])
@ -78,7 +83,7 @@ impl ProxyProtocolV1Info {
crlf
} else {
// If not, it might be because we don't have enough bytes
return if bytes.len() < 108 {
return if buf.remaining() < 108 {
Err(E::NotEnoughBytes)
} else {
// Else it's just invalid
@ -86,10 +91,8 @@ impl ProxyProtocolV1Info {
};
};
// Keep the rest of the buffer to pass it to the underlying protocol
let rest = &bytes[crlf + 2..];
// Trim to everything before the CRLF
let bytes = &bytes[..crlf];
let bytes = &buf.as_ref()[..crlf];
let mut it = bytes.splitn(6, |c| c == &b' ');
// Check for the preamble
@ -187,7 +190,9 @@ impl ProxyProtocolV1Info {
None => return Err(E::NoProtocol),
};
Ok((result, rest))
buf.advance(crlf + 2);
Ok(result)
}
#[must_use]
@ -258,39 +263,42 @@ mod tests {
#[test]
fn test_parse() {
let (info, rest) = ProxyProtocolV1Info::parse(
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world",
)
.unwrap();
assert_eq!(rest, b"hello world");
let mut buf =
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world"
.as_slice();
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
assert_eq!(buf, b"hello world");
assert!(info.is_tcp());
assert!(!info.is_udp());
assert!(!info.is_unknown());
assert!(info.is_ipv4());
assert!(!info.is_ipv6());
let (info, rest) = ProxyProtocolV1Info::parse(
let mut buf =
b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
).unwrap();
assert_eq!(rest, b"hello world");
.as_slice();
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
assert_eq!(buf, b"hello world");
assert!(info.is_tcp());
assert!(!info.is_udp());
assert!(!info.is_unknown());
assert!(!info.is_ipv4());
assert!(info.is_ipv6());
let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap();
assert_eq!(rest, b"hello world");
let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
assert_eq!(buf, b"hello world");
assert!(!info.is_tcp());
assert!(!info.is_udp());
assert!(info.is_unknown());
assert!(!info.is_ipv4());
assert!(!info.is_ipv6());
let (info, rest) = ProxyProtocolV1Info::parse(
let mut buf =
b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
).unwrap();
assert_eq!(rest, b"hello world");
.as_slice();
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
assert_eq!(buf, b"hello world");
assert!(!info.is_tcp());
assert!(!info.is_udp());
assert!(info.is_unknown());

View File

@ -0,0 +1,151 @@
// Taken from hyper@0.14.20, src/common/io/rewind.rs
use std::{
cmp, io,
marker::Unpin,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// Combine a buffer with an IO, rewinding reads to use the buffer.
#[derive(Debug)]
pub struct Rewind<T> {
pre: Option<Bytes>,
inner: T,
}
impl<T> Rewind<T> {
pub(crate) fn new(io: T) -> Self {
Rewind {
pre: None,
inner: io,
}
}
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
Rewind {
pre: Some(buf),
inner: io,
}
}
#[cfg(test)]
pub(crate) fn rewind(&mut self, bs: Bytes) {
debug_assert!(self.pre.is_none());
self.pre = Some(bs);
}
}
impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
prefix.advance(copy_len);
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}
return Poll::Ready(Ok(()));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<T> AsyncWrite for Rewind<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
#[cfg(test)]
mod tests {
// FIXME: re-implement tests with `async/await`, this import should
// trigger a warning to remind us
use bytes::Bytes;
use tokio::io::AsyncReadExt;
use super::Rewind;
#[tokio::test]
async fn partial_rewind() {
let underlying = [104, 101, 108, 108, 111];
let mock = tokio_test::io::Builder::new().read(&underlying).build();
let mut stream = Rewind::new(mock);
// Read off some bytes, ensure we filled o1
let mut buf = [0; 2];
stream.read_exact(&mut buf).await.expect("read1");
// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));
let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");
// At this point we should have read everything that was in the MockStream
assert_eq!(&buf, &underlying);
}
#[tokio::test]
async fn full_rewind() {
let underlying = [104, 101, 108, 108, 111];
let mock = tokio_test::io::Builder::new().read(&underlying).build();
let mut stream = Rewind::new(mock);
let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");
// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));
let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");
}
}

View File

@ -0,0 +1,301 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use futures_util::{Stream, StreamExt};
use http_body::Body;
use hyper::{Request, Response};
use thiserror::Error;
use tokio_rustls::rustls::ServerConfig;
use tower_http::add_extension::AddExtension;
use tower_service::Service;
use crate::{
maybe_tls::{MaybeTlsAcceptor, TlsStreamInfo},
proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
ConnectionInfo,
};
pub struct Server<S> {
tls: Option<Arc<ServerConfig>>,
proxy: bool,
listener: UnixOrTcpListener,
service: S,
}
impl<S> Server<S> {
/// # Errors
///
/// Returns an error if the listener couldn't be converted via [`TryInfo`]
pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
where
L: TryInto<UnixOrTcpListener>,
{
Ok(Self {
tls: None,
proxy: false,
listener: listener.try_into()?,
service,
})
}
#[must_use]
pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
Self {
tls: None,
proxy: false,
listener: listener.into(),
service,
}
}
#[must_use]
pub const fn with_proxy(mut self) -> Self {
self.proxy = true;
self
}
#[must_use]
pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
self.tls = Some(config);
self
}
/// Run a single server
pub async fn run<B, SD>(self, shutdown: SD)
where
S: Service<Request<hyper::Body>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
SD: Stream + Unpin,
SD::Item: std::fmt::Display,
{
run_servers(std::iter::once(self), shutdown).await;
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
enum AcceptError {
#[error("failed to accept connection from the underlying socket")]
Socket {
#[source]
source: std::io::Error,
},
#[error("failed to complete the TLS handshake")]
TlsHandshake {
#[source]
source: std::io::Error,
},
#[error("failed to complete the proxy protocol handshake")]
ProxyHandshake {
#[source]
source: ProxyAcceptError,
},
#[error(transparent)]
Hyper(#[from] hyper::Error),
}
impl AcceptError {
fn socket(source: std::io::Error) -> Self {
Self::Socket { source }
}
fn tls_handshake(source: std::io::Error) -> Self {
Self::TlsHandshake { source }
}
fn proxy_handshake(source: ProxyAcceptError) -> Self {
Self::ProxyHandshake { source }
}
}
async fn accept<S, B>(
maybe_proxy_acceptor: &MaybeProxyAcceptor,
maybe_tls_acceptor: &MaybeTlsAcceptor,
peer_addr: SocketAddr,
stream: UnixOrTcpConnection,
service: S,
) -> Result<(), AcceptError>
where
S: Service<Request<hyper::Body>, Response = Response<B>>,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let (proxy, stream) = maybe_proxy_acceptor
.accept(stream)
.await
.map_err(AcceptError::proxy_handshake)?;
let stream = maybe_tls_acceptor
.accept(stream)
.await
.map_err(AcceptError::tls_handshake)?;
let tls = stream.tls_info();
// Figure out if it's HTTP/2 based on the negociated ALPN info
let is_h2 = tls.as_ref().map_or(false, TlsStreamInfo::is_alpn_h2);
let info = ConnectionInfo {
tls,
proxy,
net_peer_addr: peer_addr.into_net(),
};
let service = AddExtension::new(service, info);
if is_h2 {
hyper::server::conn::Http::new()
.http2_only(true)
.serve_connection(stream, service)
.await?;
} else {
hyper::server::conn::Http::new()
.http1_only(true)
.http1_keep_alive(false)
.serve_connection(stream, service)
.await?;
};
Ok(())
}
pub async fn run_servers<S, B, SD>(listeners: impl IntoIterator<Item = Server<S>>, mut shutdown: SD)
where
S: Service<Request<hyper::Body>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
SD: Stream + Unpin,
SD::Item: std::fmt::Display,
{
let listeners: Vec<_> = listeners
.into_iter()
.map(|server| {
let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
let service = server.service;
let listener = server.listener;
(maybe_proxy_acceptor, maybe_tls_acceptor, service, listener)
})
.collect();
let mut set = tokio::task::JoinSet::new();
loop {
let mut accept_all: futures_util::stream::FuturesUnordered<_> = listeners
.iter()
.map(
|(maybe_proxy_acceptor, maybe_tls_acceptor, service, listener)| async move {
listener
.accept()
.await
.map_err(AcceptError::socket)
.map(|(addr, conn)| {
(
maybe_proxy_acceptor.clone(),
maybe_tls_acceptor.clone(),
service.clone(),
addr,
conn,
)
})
},
)
.collect();
tokio::select! {
biased;
// First look for the shutdown signal
res = shutdown.next() => {
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
tracing::info!("Received shutdown signal ({why})");
break;
},
// Poll on the JoinSet, clearing finished task
res = set.join_next(), if !set.is_empty() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Task was successful"),
Some(Ok(Err(e))) => tracing::error!("{e}"),
Some(Err(e)) => tracing::error!("Join error: {e}"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
// Then look for connections to accept
res = accept_all.next(), if !accept_all.is_empty() => {
// SAFETY: We shouldn't reach this branch if the unordered future set is empty
let res = if let Some(res) = res { res } else { unreachable!() };
// Spawn the connection in the set, so we don't have to wait for the handshake to
// accept the next connection. This allows us to keep track of active connections
// and waiting on them for a graceful shutdown
set.spawn(async move {
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res?;
accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
});
},
};
}
if !set.is_empty() {
tracing::info!(
"There are {active} active connections, performing a graceful shutdown. Send the shutdown signal again to force.",
active = set.len()
);
loop {
tokio::select! {
biased;
res = set.join_next() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Task was successful"),
Some(Ok(Err(e))) => tracing::error!("{e}"),
Some(Err(e)) => tracing::error!("Join error: {e}"),
// No more tasks, going out
None => break,
}
},
res = shutdown.next() => {
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
tracing::warn!("Received shutdown signal again ({why}), forcing shutdown ({active} active connections)", active = set.len());
break;
},
}
}
}
set.shutdown().await;
tracing::info!("Shutdown complete");
}

View File

@ -0,0 +1,178 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{fmt::Display, pin::Pin, task::Poll, time::Duration};
use futures_util::{ready, Future, Stream};
use tokio::{
signal::unix::{signal, Signal, SignalKind},
time::Sleep,
};
#[derive(Debug, Clone, Copy)]
pub enum ShutdownReason {
Signal(SignalKind),
Timeout,
}
fn signal_to_str(kind: SignalKind) -> &'static str {
match kind.as_raw_value() {
libc::SIGALRM => "SIGALRM",
libc::SIGCHLD => "SIGCHLD",
libc::SIGHUP => "SIGHUP",
libc::SIGINT => "SIGINT",
libc::SIGIO => "SIGIO",
libc::SIGPIPE => "SIGPIPE",
libc::SIGQUIT => "SIGQUIT",
libc::SIGTERM => "SIGTERM",
libc::SIGUSR1 => "SIGUSR1",
libc::SIGUSR2 => "SIGUSR2",
_ => "SIG???",
}
}
impl Display for ShutdownReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Signal(s) => signal_to_str(*s).fmt(f),
Self::Timeout => "timeout".fmt(f),
}
}
}
#[derive(Default)]
pub enum ShutdownStreamState {
#[default]
Waiting,
Graceful {
sleep: Option<Pin<Box<Sleep>>>,
},
Done,
}
impl ShutdownStreamState {
fn is_graceful(&self) -> bool {
matches!(self, Self::Graceful { .. })
}
fn is_done(&self) -> bool {
matches!(self, Self::Done)
}
fn get_sleep_mut(&mut self) -> Option<&mut Pin<Box<Sleep>>> {
match self {
Self::Graceful { sleep } => sleep.as_mut(),
_ => None,
}
}
}
/// A stream which is used to drive a graceful shutdown.
///
/// It will emit 2 items: one when a first signal is caught, the other when
/// either another signal is caught, or after a timeout.
#[derive(Default)]
pub struct ShutdownStream {
state: ShutdownStreamState,
signals: Vec<(SignalKind, Signal)>,
timeout: Option<Duration>,
}
impl ShutdownStream {
/// Create a default shutdown stream, which listens on SIGINT and SIGTERM,
/// with a 60s timeout
///
/// # Errors
///
/// Returns an error if signal handlers could not be installed
pub fn new() -> Result<Self, std::io::Error> {
let ret = Self::default()
.with_timeout(Duration::from_secs(60))
.with_signal(SignalKind::interrupt())?
.with_signal(SignalKind::terminate())?;
Ok(ret)
}
/// Add a signal to register
///
/// # Errors
///
/// Returns an error if the signal handler could not be installed
pub fn with_signal(mut self, kind: SignalKind) -> Result<Self, std::io::Error> {
let signal = signal(kind)?;
self.signals.push((kind, signal));
Ok(self)
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
impl Stream for ShutdownStream {
type Item = ShutdownReason;
fn size_hint(&self) -> (usize, Option<usize>) {
match self.state {
ShutdownStreamState::Waiting => (2, Some(2)),
ShutdownStreamState::Graceful { .. } => (1, Some(1)),
ShutdownStreamState::Done => (0, Some(0)),
}
}
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.state.is_done() {
return Poll::Ready(None);
}
for (kind, signal) in &mut this.signals {
match signal.poll_recv(cx) {
Poll::Ready(_) => {
// We got a signal
if this.state.is_graceful() {
// If we was gracefully shutting down, mark it as done
this.state = ShutdownStreamState::Done;
} else {
// Else start the timeout
let sleep = this
.timeout
.map(|duration| Box::pin(tokio::time::sleep(duration)));
this.state = ShutdownStreamState::Graceful { sleep };
}
return Poll::Ready(Some(ShutdownReason::Signal(*kind)));
}
Poll::Pending => {}
}
}
if let Some(timeout) = this.state.get_sleep_mut() {
ready!(timeout.as_mut().poll(cx));
this.state = ShutdownStreamState::Done;
Poll::Ready(Some(ShutdownReason::Timeout))
} else {
Poll::Pending
}
}
}

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! A listener which can listen on either TCP sockets or on UNIX domain sockets
// TODO: Unlink the UNIX socket on drop?
use std::{
@ -19,8 +21,6 @@ use std::{
task::{Context, Poll},
};
use futures_util::ready;
use hyper::server::accept::Accept;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream, UnixListener, UnixStream},
@ -107,6 +107,7 @@ impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
type Error = std::io::Error;
fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
listener.set_nonblocking(true)?;
Ok(Self::Unix(UnixListener::from_std(listener)?))
}
}
@ -115,6 +116,7 @@ impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
type Error = std::io::Error;
fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
listener.set_nonblocking(true)?;
Ok(Self::Tcp(TcpListener::from_std(listener)?))
}
}
@ -140,6 +142,24 @@ impl UnixOrTcpListener {
pub const fn is_tcp(&self) -> bool {
matches!(self, Self::Tcp(_))
}
/// Accept an incoming connection
///
/// # Errors
///
/// Returns an error if the underlying socket couldn't accept the connection
pub async fn accept(&self) -> Result<(SocketAddr, UnixOrTcpConnection), std::io::Error> {
match self {
Self::Unix(listener) => {
let (stream, remote_addr) = listener.accept().await?;
Ok((remote_addr.into(), UnixOrTcpConnection::Unix { stream }))
}
Self::Tcp(listener) => {
let (stream, remote_addr) = listener.accept().await?;
Ok((remote_addr.into(), UnixOrTcpConnection::Tcp { stream }))
}
}
}
}
pin_project_lite::pin_project! {
@ -157,6 +177,12 @@ pin_project_lite::pin_project! {
}
}
impl From<TcpStream> for UnixOrTcpConnection {
fn from(stream: TcpStream) -> Self {
Self::Tcp { stream }
}
}
impl UnixOrTcpConnection {
/// Get the local address of the stream
///
@ -166,8 +192,8 @@ impl UnixOrTcpConnection {
/// [`UnixStream`] couldn't provide the local address
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
Self::Tcp { stream, .. } => stream.local_addr().map(SocketAddr::from),
Self::Unix { stream } => stream.local_addr().map(SocketAddr::from),
Self::Tcp { stream } => stream.local_addr().map(SocketAddr::from),
}
}
@ -179,36 +205,12 @@ impl UnixOrTcpConnection {
/// [`UnixStream`] couldn't provide the remote address
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),
Self::Tcp { stream, .. } => stream.peer_addr().map(SocketAddr::from),
Self::Unix { stream } => stream.peer_addr().map(SocketAddr::from),
Self::Tcp { stream } => stream.peer_addr().map(SocketAddr::from),
}
}
}
impl Accept for UnixOrTcpListener {
type Error = std::io::Error;
type Conn = UnixOrTcpConnection;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
let conn = match &*self {
Self::Unix(listener) => {
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
UnixOrTcpConnection::Unix { stream }
}
Self::Tcp(listener) => {
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
UnixOrTcpConnection::Tcp { stream }
}
};
Poll::Ready(Some(Ok(conn)))
}
}
impl AsyncRead for UnixOrTcpConnection {
fn poll_read(
self: Pin<&mut Self>,
@ -234,23 +236,6 @@ impl AsyncWrite for UnixOrTcpConnection {
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
@ -268,4 +253,21 @@ impl AsyncWrite for UnixOrTcpConnection {
UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
}
}
}