You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Rewrite the listeners crate
Now with a way better graceful shutdown! With proper handshakes!
This commit is contained in:
22
Cargo.lock
generated
22
Cargo.lock
generated
@ -2425,6 +2425,7 @@ dependencies = [
|
|||||||
"futures-util",
|
"futures-util",
|
||||||
"hyper",
|
"hyper",
|
||||||
"indoc",
|
"indoc",
|
||||||
|
"itertools",
|
||||||
"listenfd",
|
"listenfd",
|
||||||
"mas-config",
|
"mas-config",
|
||||||
"mas-email",
|
"mas-email",
|
||||||
@ -2677,15 +2678,21 @@ dependencies = [
|
|||||||
name = "mas-listener"
|
name = "mas-listener"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"bytes 1.2.1",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
|
"libc",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls 0.23.4",
|
"tokio-rustls 0.23.4",
|
||||||
"tower",
|
"tokio-test",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4872,6 +4879,19 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-test"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
|
||||||
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
|
"bytes 1.2.1",
|
||||||
|
"futures-core",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.6.10"
|
version = "0.6.10"
|
||||||
|
@ -22,6 +22,7 @@ watchman_client = "0.8.0"
|
|||||||
atty = "0.2.14"
|
atty = "0.2.14"
|
||||||
listenfd = "1.0.0"
|
listenfd = "1.0.0"
|
||||||
rustls = "0.20.6"
|
rustls = "0.20.6"
|
||||||
|
itertools = "0.10.5"
|
||||||
|
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-appender = "0.2.2"
|
tracing-appender = "0.2.2"
|
||||||
|
@ -16,26 +16,19 @@ use std::{sync::Arc, time::Duration};
|
|||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use futures_util::{
|
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||||
future::FutureExt,
|
use itertools::Itertools;
|
||||||
stream::{StreamExt, TryStreamExt},
|
|
||||||
};
|
|
||||||
use hyper::Server;
|
|
||||||
use mas_config::RootConfig;
|
use mas_config::RootConfig;
|
||||||
use mas_email::Mailer;
|
use mas_email::Mailer;
|
||||||
use mas_handlers::{AppState, MatrixHomeserver};
|
use mas_handlers::{AppState, MatrixHomeserver};
|
||||||
use mas_http::ServerLayer;
|
use mas_http::ServerLayer;
|
||||||
use mas_listener::{
|
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||||
info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection},
|
|
||||||
maybe_tls::MaybeTlsAcceptor,
|
|
||||||
proxy_protocol::MaybeProxyAcceptor,
|
|
||||||
};
|
|
||||||
use mas_policy::PolicyFactory;
|
use mas_policy::PolicyFactory;
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
use mas_storage::MIGRATOR;
|
use mas_storage::MIGRATOR;
|
||||||
use mas_tasks::TaskQueue;
|
use mas_tasks::TaskQueue;
|
||||||
use mas_templates::Templates;
|
use mas_templates::Templates;
|
||||||
use tokio::io::AsyncRead;
|
use tokio::{io::AsyncRead, signal::unix::SignalKind};
|
||||||
use tracing::{error, info, log::warn};
|
use tracing::{error, info, log::warn};
|
||||||
|
|
||||||
#[derive(Parser, Debug, Default)]
|
#[derive(Parser, Debug, Default)]
|
||||||
@ -49,32 +42,6 @@ pub(super) struct Options {
|
|||||||
watch: bool,
|
watch: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
|
||||||
async fn shutdown_signal() {
|
|
||||||
// Wait for the CTRL+C signal
|
|
||||||
tokio::signal::ctrl_c()
|
|
||||||
.await
|
|
||||||
.expect("failed to install Ctrl+C signal handler");
|
|
||||||
|
|
||||||
tracing::info!("Got Ctrl+C, shutting down");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(unix)]
|
|
||||||
async fn shutdown_signal() {
|
|
||||||
use tokio::signal::unix::{signal, SignalKind};
|
|
||||||
|
|
||||||
// Wait for SIGTERM and SIGINT signals
|
|
||||||
// This might panic but should be fine
|
|
||||||
let mut term =
|
|
||||||
signal(SignalKind::terminate()).expect("failed to install SIGTERM signal handler");
|
|
||||||
let mut int = signal(SignalKind::interrupt()).expect("failed to install SIGINT signal handler");
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
_ = term.recv() => tracing::info!("Got SIGTERM, shutting down"),
|
|
||||||
_ = int.recv() => tracing::info!("Got SIGINT, shutting down"),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Watch for changes in the templates folders
|
/// Watch for changes in the templates folders
|
||||||
async fn watch_templates(
|
async fn watch_templates(
|
||||||
client: &watchman_client::Client,
|
client: &watchman_client::Client,
|
||||||
@ -247,68 +214,75 @@ impl Options {
|
|||||||
policy_factory,
|
policy_factory,
|
||||||
});
|
});
|
||||||
|
|
||||||
let signal = shutdown_signal().shared();
|
|
||||||
let shutdown_signal = signal.clone();
|
|
||||||
|
|
||||||
let mut fd_manager = listenfd::ListenFd::from_env();
|
let mut fd_manager = listenfd::ListenFd::from_env();
|
||||||
let listeners = listeners_config.into_iter().map(|listener_config| {
|
|
||||||
// Let's first grab all the listeners in a synchronous manner
|
|
||||||
let listeners = crate::server::build_listeners(&mut fd_manager, &listener_config.binds);
|
|
||||||
|
|
||||||
Ok((listener_config, listeners?))
|
let servers: Vec<Server<_>> = listeners_config
|
||||||
});
|
.into_iter()
|
||||||
|
.map(|config| {
|
||||||
|
// Let's first grab all the listeners
|
||||||
|
let listeners = crate::server::build_listeners(&mut fd_manager, &config.binds)?;
|
||||||
|
|
||||||
// Now that we have the listeners ready, we can do the rest concurrently
|
// Load the TLS config
|
||||||
futures_util::stream::iter(listeners)
|
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
||||||
.try_for_each_concurrent(None, move |(config, listeners)| {
|
let tls_config = crate::server::build_tls_server_config(tls_config)?;
|
||||||
let signal = signal.clone();
|
Some(Arc::new(tls_config))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// and build the router
|
||||||
|
let router = crate::server::build_router(&state, &config.resources)
|
||||||
|
.layer(ServerLayer::new(config.name.clone()));
|
||||||
|
|
||||||
|
// Display some informations about where we'll be serving connections
|
||||||
let is_tls = config.tls.is_some();
|
let is_tls = config.tls.is_some();
|
||||||
let addresses: Vec<String> = listeners.iter().map(|listener| {
|
let addresses: Vec<String> = listeners
|
||||||
|
.iter()
|
||||||
|
.map(|listener| {
|
||||||
let addr = listener.local_addr();
|
let addr = listener.local_addr();
|
||||||
let proto = if is_tls { "https" } else { "http" };
|
let proto = if is_tls { "https" } else { "http" };
|
||||||
if let Ok(addr) = addr {
|
if let Ok(addr) = addr {
|
||||||
format!("{proto}://{addr:?}")
|
format!("{proto}://{addr:?}")
|
||||||
} else {
|
} else {
|
||||||
warn!("Could not get local address for listener, something might be wrong!");
|
warn!(
|
||||||
|
"Could not get local address for listener, something might be wrong!"
|
||||||
|
);
|
||||||
format!("{proto}://???")
|
format!("{proto}://???")
|
||||||
}
|
}
|
||||||
}).collect();
|
|
||||||
|
|
||||||
let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" };
|
|
||||||
|
|
||||||
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
|
|
||||||
|
|
||||||
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
|
|
||||||
let make_service = IntoMakeServiceWithConnection::new(router);
|
|
||||||
|
|
||||||
async move {
|
|
||||||
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
|
||||||
let tls_config = crate::server::build_tls_server_config(tls_config).await?;
|
|
||||||
Some(Arc::new(tls_config))
|
|
||||||
} else { None };
|
|
||||||
|
|
||||||
futures_util::stream::iter(listeners)
|
|
||||||
.map(Ok)
|
|
||||||
.try_for_each_concurrent(None, move |listener| {
|
|
||||||
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
|
|
||||||
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
|
|
||||||
let listener = ConnectionInfoAcceptor::new(listener);
|
|
||||||
|
|
||||||
Server::builder(listener)
|
|
||||||
.serve(make_service.clone())
|
|
||||||
.with_graceful_shutdown(signal.clone())
|
|
||||||
})
|
})
|
||||||
.await?;
|
.collect();
|
||||||
|
|
||||||
anyhow::Ok(())
|
let additional = if config.proxy_protocol {
|
||||||
|
"(with Proxy Protocol)"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Listening on {addresses:?} with resources {resources:?} {additional}",
|
||||||
|
resources = &config.resources
|
||||||
|
);
|
||||||
|
|
||||||
|
anyhow::Ok(listeners.into_iter().map(move |listener| {
|
||||||
|
let mut server = Server::new(listener, router.clone());
|
||||||
|
if let Some(tls_config) = &tls_config {
|
||||||
|
server = server.with_tls(tls_config.clone());
|
||||||
}
|
}
|
||||||
|
if config.proxy_protocol {
|
||||||
|
server = server.with_proxy();
|
||||||
|
}
|
||||||
|
server
|
||||||
|
}))
|
||||||
})
|
})
|
||||||
.await?;
|
.flatten_ok()
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
// This ensures we're running, even if no listener are setup
|
let shutdown = ShutdownStream::default()
|
||||||
// This is useful for only running the task runner
|
.with_timeout(Duration::from_secs(60))
|
||||||
shutdown_signal.await;
|
.with_signal(SignalKind::terminate())?
|
||||||
|
.with_signal(SignalKind::interrupt())?;
|
||||||
|
|
||||||
|
mas_listener::server::run_servers(servers, shutdown).await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -23,10 +23,9 @@ use axum::{body::HttpBody, Extension, Router};
|
|||||||
use listenfd::ListenFd;
|
use listenfd::ListenFd;
|
||||||
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
|
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
|
||||||
use mas_handlers::AppState;
|
use mas_handlers::AppState;
|
||||||
use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener};
|
use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo};
|
||||||
use mas_router::Route;
|
use mas_router::Route;
|
||||||
use rustls::ServerConfig;
|
use rustls::ServerConfig;
|
||||||
use tokio::sync::OnceCell;
|
|
||||||
|
|
||||||
#[allow(clippy::trait_duplication_in_bounds)]
|
#[allow(clippy::trait_duplication_in_bounds)]
|
||||||
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
|
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
|
||||||
@ -64,12 +63,9 @@ where
|
|||||||
// TODO: do a better handler here
|
// TODO: do a better handler here
|
||||||
mas_config::HttpResource::ConnectionInfo => router.route(
|
mas_config::HttpResource::ConnectionInfo => router.route(
|
||||||
"/connection-info",
|
"/connection-info",
|
||||||
axum::routing::get(
|
axum::routing::get(|connection: Extension<ConnectionInfo>| async move {
|
||||||
|connection: Extension<Arc<OnceCell<Connection>>>| async move {
|
|
||||||
let connection = connection.get().unwrap();
|
|
||||||
format!("{connection:?}")
|
format!("{connection:?}")
|
||||||
},
|
}),
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -77,10 +73,8 @@ where
|
|||||||
router
|
router
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn build_tls_server_config(
|
pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result<ServerConfig, anyhow::Error> {
|
||||||
config: &HttpTlsConfig,
|
let (key, chain) = config.load()?;
|
||||||
) -> Result<ServerConfig, anyhow::Error> {
|
|
||||||
let (key, chain) = config.load().await?;
|
|
||||||
let key = rustls::PrivateKey(key);
|
let key = rustls::PrivateKey(key);
|
||||||
let chain = chain.into_iter().map(rustls::Certificate).collect();
|
let chain = chain.into_iter().map(rustls::Certificate).collect();
|
||||||
|
|
||||||
|
@ -163,11 +163,11 @@ impl TlsConfig {
|
|||||||
/// - a password was provided but the key was not encrypted
|
/// - a password was provided but the key was not encrypted
|
||||||
/// - decoding the certificate chain as PEM
|
/// - decoding the certificate chain as PEM
|
||||||
/// - the certificate chain is empty
|
/// - the certificate chain is empty
|
||||||
pub async fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
|
pub fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
|
||||||
let password = match &self.password {
|
let password = match &self.password {
|
||||||
Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())),
|
Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())),
|
||||||
Some(PasswordOrFile::PasswordFile(path)) => {
|
Some(PasswordOrFile::PasswordFile(path)) => {
|
||||||
Some(Cow::Owned(tokio::fs::read_to_string(path).await?))
|
Some(Cow::Owned(std::fs::read_to_string(path)?))
|
||||||
}
|
}
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
@ -185,7 +185,7 @@ impl TlsConfig {
|
|||||||
KeyOrFile::KeyFile(path) => {
|
KeyOrFile::KeyFile(path) => {
|
||||||
// When reading from disk, it might be either PEM or DER. `PrivateKey::load*`
|
// When reading from disk, it might be either PEM or DER. `PrivateKey::load*`
|
||||||
// will try both.
|
// will try both.
|
||||||
let key = tokio::fs::read(path).await?;
|
let key = std::fs::read(path)?;
|
||||||
if let Some(password) = password {
|
if let Some(password) = password {
|
||||||
PrivateKey::load_encrypted(&key, password.as_bytes())?
|
PrivateKey::load_encrypted(&key, password.as_bytes())?
|
||||||
} else {
|
} else {
|
||||||
@ -202,9 +202,7 @@ impl TlsConfig {
|
|||||||
|
|
||||||
let certificate_chain_pem = match &self.certificate {
|
let certificate_chain_pem = match &self.certificate {
|
||||||
CertificateOrFile::Certificate(pem) => Cow::Borrowed(pem.as_str()),
|
CertificateOrFile::Certificate(pem) => Cow::Borrowed(pem.as_str()),
|
||||||
CertificateOrFile::CertificateFile(path) => {
|
CertificateOrFile::CertificateFile(path) => Cow::Owned(std::fs::read_to_string(path)?),
|
||||||
Cow::Owned(tokio::fs::read_to_string(path).await?)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut certificate_chain_reader = Cursor::new(certificate_chain_pem.as_bytes());
|
let mut certificate_chain_reader = Cursor::new(certificate_chain_pem.as_bytes());
|
||||||
|
@ -6,12 +6,25 @@ edition = "2021"
|
|||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
bytes = "1.2.1"
|
||||||
futures-util = "0.3.24"
|
futures-util = "0.3.24"
|
||||||
hyper = { version = "0.14.20", features = ["server"] }
|
http-body = "0.4.2"
|
||||||
|
hyper = { version = "0.14.20", features = ["server", "http1", "http2"] }
|
||||||
pin-project-lite = "0.2.9"
|
pin-project-lite = "0.2.9"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokio = { version = "1.21.2", features = ["net"] }
|
tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time"] }
|
||||||
tokio-rustls = "0.23.4"
|
tokio-rustls = "0.23.4"
|
||||||
tower = "0.4.13"
|
|
||||||
tower-http = { version = "0.3.4", features = ["add-extension"] }
|
tower-http = { version = "0.3.4", features = ["add-extension"] }
|
||||||
tracing = "0.1.36"
|
tower-service = "0.3.2"
|
||||||
|
tracing = "0.1.37"
|
||||||
|
libc = "0.2.135"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio-test = "0.4.2"
|
||||||
|
anyhow = "1.0.65"
|
||||||
|
tokio = { version = "1.21.2", features = ["net", "rt", "macros", "signal", "time", "rt-multi-thread"] }
|
||||||
|
tracing-subscriber = "0.3.16"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "demo"
|
||||||
|
path = "examples/demo/main.rs"
|
||||||
|
49
crates/listener/examples/demo/main.rs
Normal file
49
crates/listener/examples/demo/main.rs
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
|
net::{Ipv4Addr, TcpListener},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use hyper::{service::service_fn, Request, Response};
|
||||||
|
use tokio::signal::unix::SignalKind;
|
||||||
|
use tokio_streams_util::{server::Server, shutdown::ShutdownStream, ConnectionInfo};
|
||||||
|
|
||||||
|
async fn handler(req: Request<hyper::Body>) -> Result<Response<String>, Infallible> {
|
||||||
|
tracing::info!("Handling request");
|
||||||
|
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||||
|
let info = req.extensions().get::<ConnectionInfo>().unwrap();
|
||||||
|
let body = format!("{info:?}");
|
||||||
|
Ok(Response::new(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 3000))?;
|
||||||
|
let service = service_fn(handler);
|
||||||
|
let server = Server::try_new(listener, service)?;
|
||||||
|
|
||||||
|
tracing::info!("Listening on 127.0.0.1:3000");
|
||||||
|
|
||||||
|
let shutdown = ShutdownStream::default()
|
||||||
|
.with_signal(SignalKind::interrupt())?
|
||||||
|
.with_signal(SignalKind::terminate())?;
|
||||||
|
server.run(shutdown).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,323 +0,0 @@
|
|||||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
use std::{
|
|
||||||
future::Ready,
|
|
||||||
net::SocketAddr,
|
|
||||||
ops::Deref,
|
|
||||||
pin::Pin,
|
|
||||||
sync::Arc,
|
|
||||||
task::{Context, Poll},
|
|
||||||
};
|
|
||||||
|
|
||||||
use hyper::server::accept::Accept;
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::{
|
|
||||||
io::{AsyncRead, AsyncWrite},
|
|
||||||
sync::OnceCell,
|
|
||||||
};
|
|
||||||
use tower::Service;
|
|
||||||
use tower_http::add_extension::AddExtension;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo, TlsStreamInfoError},
|
|
||||||
proxy_protocol::{
|
|
||||||
MaybeProxyAcceptor, MaybeProxyStream, ProxyHandshakeNotDone, ProxyProtocolV1Info,
|
|
||||||
},
|
|
||||||
unix_or_tcp::{UnixOrTcpConnection, UnixOrTcpListener},
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: this is a mess, clean that up
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum FromStreamError {
|
|
||||||
#[error(transparent)]
|
|
||||||
Proxy(#[from] ProxyHandshakeNotDone),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Tls(#[from] TlsStreamInfoError),
|
|
||||||
|
|
||||||
#[error("Could not grab a reference to the underlying stream")]
|
|
||||||
GetRef,
|
|
||||||
|
|
||||||
#[error("Could not get address info from underlying stream")]
|
|
||||||
IoError(#[from] std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Connection {
|
|
||||||
proxy: Option<ProxyProtocolV1Info>,
|
|
||||||
tls: Option<TlsStreamInfo>,
|
|
||||||
|
|
||||||
// We're not saving the UNIX domain socket address here because it can't be cloned, which is
|
|
||||||
// required for injecting the connection information as an extension
|
|
||||||
local_tcp_addr: Option<SocketAddr>,
|
|
||||||
peer_tcp_addr: Option<SocketAddr>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum GrabAddressError {
|
|
||||||
#[error("Proxy protocol was initiated with an unknown protocol")]
|
|
||||||
ProxyUnknown,
|
|
||||||
|
|
||||||
#[error("Proxy protocol was initiated with UDP")]
|
|
||||||
ProxyUdp,
|
|
||||||
|
|
||||||
#[error("Underlying listener is a UNIX socket")]
|
|
||||||
UnixListener,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>> {
|
|
||||||
pub fn can_have_peer_address(&self) -> bool {
|
|
||||||
self.is_proxied() || self.is_tcp()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>> {
|
|
||||||
/// Get informations about this connection
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns an error if the proxy protocol or the TLS handhakes are not done
|
|
||||||
/// yet
|
|
||||||
pub fn connection_info(&self) -> Result<Connection, FromStreamError> {
|
|
||||||
Connection::from_stream(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Connection {
|
|
||||||
/// Get informations about this connection
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns an error if the proxy protocol or the TLS handhakes are not done
|
|
||||||
/// yet
|
|
||||||
pub fn from_stream(
|
|
||||||
stream: &MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
|
|
||||||
) -> Result<Self, FromStreamError> {
|
|
||||||
let proxy = stream.proxy_info()?.cloned();
|
|
||||||
let tls = stream.tls_info()?;
|
|
||||||
let original = stream.get_ref().ok_or(FromStreamError::GetRef)?;
|
|
||||||
let local_tcp_addr = original.local_addr()?.into_net();
|
|
||||||
let peer_tcp_addr = original.peer_addr()?.into_net();
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
proxy,
|
|
||||||
tls,
|
|
||||||
local_tcp_addr,
|
|
||||||
peer_tcp_addr,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub const fn is_proxied(&self) -> bool {
|
|
||||||
self.proxy.is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub const fn is_tls(&self) -> bool {
|
|
||||||
self.tls.is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the outmost peer address, either from the TCP listener or from the
|
|
||||||
/// proxy protocol infos.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns an error if the info from the proxy protocol was not for a TCP
|
|
||||||
/// connection, or if the proxy protocol is not being used, the underlying
|
|
||||||
/// listener was a UNIX domain socket
|
|
||||||
pub fn peer_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
|
|
||||||
if let Some(proxy) = self.proxy.as_ref() {
|
|
||||||
if proxy.is_udp() {
|
|
||||||
return Err(GrabAddressError::ProxyUdp);
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.source().ok_or(GrabAddressError::ProxyUnknown)
|
|
||||||
} else {
|
|
||||||
self.peer_tcp_addr
|
|
||||||
.as_ref()
|
|
||||||
.ok_or(GrabAddressError::UnixListener)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the outmost local address, either from the TCP listener or from the
|
|
||||||
/// proxy protocol infos.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns an error if the info from the proxy protocol was not for a TCP
|
|
||||||
/// connection, or if the proxy protocol is not being used, the underlying
|
|
||||||
/// listener was a UNIX domain socket
|
|
||||||
pub fn local_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
|
|
||||||
if let Some(proxy) = self.proxy.as_ref() {
|
|
||||||
if proxy.is_udp() {
|
|
||||||
return Err(GrabAddressError::ProxyUdp);
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.destination().ok_or(GrabAddressError::ProxyUnknown)
|
|
||||||
} else {
|
|
||||||
self.local_tcp_addr
|
|
||||||
.as_ref()
|
|
||||||
.ok_or(GrabAddressError::UnixListener)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project_lite::pin_project! {
|
|
||||||
pub struct ConnectionInfoAcceptor {
|
|
||||||
#[pin]
|
|
||||||
acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConnectionInfoAcceptor {
|
|
||||||
pub const fn new(acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>) -> Self {
|
|
||||||
Self { acceptor }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Accept for ConnectionInfoAcceptor {
|
|
||||||
type Conn = ConnectionInfoStream;
|
|
||||||
type Error = std::io::Error;
|
|
||||||
|
|
||||||
fn poll_accept(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
|
||||||
let proj = self.project();
|
|
||||||
let ret = match futures_util::ready!(proj.acceptor.poll_accept(cx)) {
|
|
||||||
Some(Ok(conn)) => Some(Ok(ConnectionInfoStream::new(conn))),
|
|
||||||
Some(Err(e)) => Some(Err(e)),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
Poll::Ready(ret)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project_lite::pin_project! {
|
|
||||||
pub struct ConnectionInfoStream {
|
|
||||||
connection: Arc<OnceCell<Connection>>,
|
|
||||||
#[pin]
|
|
||||||
stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConnectionInfoStream {
|
|
||||||
pub fn new(stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>) -> Self {
|
|
||||||
Self {
|
|
||||||
connection: Arc::new(OnceCell::const_new()),
|
|
||||||
stream,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Deref for ConnectionInfoStream {
|
|
||||||
type Target = MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.stream
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncRead for ConnectionInfoStream {
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut tokio::io::ReadBuf<'_>,
|
|
||||||
) -> Poll<std::io::Result<()>> {
|
|
||||||
let this = self.get_mut();
|
|
||||||
futures_util::ready!(Pin::new(&mut this.stream).poll_read(cx, buf))?;
|
|
||||||
|
|
||||||
if !this.stream.is_tls_handshaking()
|
|
||||||
&& !this.stream.is_proxy_handshaking()
|
|
||||||
&& !this.connection.initialized()
|
|
||||||
{
|
|
||||||
this.connection
|
|
||||||
.set(this.stream.connection_info().unwrap())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncWrite for ConnectionInfoStream {
|
|
||||||
fn poll_write(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<Result<usize, std::io::Error>> {
|
|
||||||
let proj = self.project();
|
|
||||||
proj.stream.poll_write(cx, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
let proj = self.project();
|
|
||||||
proj.stream.poll_flush(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
let proj = self.project();
|
|
||||||
proj.stream.poll_shutdown(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_write_vectored(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
bufs: &[std::io::IoSlice<'_>],
|
|
||||||
) -> Poll<Result<usize, std::io::Error>> {
|
|
||||||
let proj = self.project();
|
|
||||||
proj.stream.poll_write_vectored(cx, bufs)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_write_vectored(&self) -> bool {
|
|
||||||
self.stream.is_write_vectored()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct IntoMakeServiceWithConnection<S> {
|
|
||||||
svc: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> IntoMakeServiceWithConnection<S> {
|
|
||||||
pub const fn new(svc: S) -> Self {
|
|
||||||
Self { svc }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> Service<&ConnectionInfoStream> for IntoMakeServiceWithConnection<S>
|
|
||||||
where
|
|
||||||
S: Clone,
|
|
||||||
{
|
|
||||||
type Response = AddExtension<S, Arc<OnceCell<Connection>>>;
|
|
||||||
type Error = FromStreamError;
|
|
||||||
type Future = Ready<Result<Self::Response, Self::Error>>;
|
|
||||||
|
|
||||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call(&mut self, target: &ConnectionInfoStream) -> Self::Future {
|
|
||||||
std::future::ready(Ok(AddExtension::new(
|
|
||||||
self.svc.clone(),
|
|
||||||
target.connection.clone(),
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
}
|
|
@ -22,7 +22,41 @@
|
|||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(clippy::module_name_repetitions)]
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
pub mod info;
|
use self::{maybe_tls::TlsStreamInfo, proxy_protocol::ProxyProtocolV1Info};
|
||||||
|
|
||||||
pub mod maybe_tls;
|
pub mod maybe_tls;
|
||||||
pub mod proxy_protocol;
|
pub mod proxy_protocol;
|
||||||
|
pub mod rewind;
|
||||||
|
pub mod server;
|
||||||
|
pub mod shutdown;
|
||||||
pub mod unix_or_tcp;
|
pub mod unix_or_tcp;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ConnectionInfo {
|
||||||
|
tls: Option<TlsStreamInfo>,
|
||||||
|
proxy: Option<ProxyProtocolV1Info>,
|
||||||
|
net_peer_addr: Option<std::net::SocketAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionInfo {
|
||||||
|
/// Returns informations about the TLS connection. Returns [`None`] if the
|
||||||
|
/// connection was not TLS.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_tls_ref(&self) -> Option<&TlsStreamInfo> {
|
||||||
|
self.tls.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns informations about the proxy protocol connection. Returns
|
||||||
|
/// [`None`] if the connection was not using the proxy protocol.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_proxy_ref(&self) -> Option<&ProxyProtocolV1Info> {
|
||||||
|
self.proxy.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the remote peer address. Returns [`None`] if the connection was
|
||||||
|
/// established via a UNIX domain socket.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_peer_addr(&self) -> Option<std::net::SocketAddr> {
|
||||||
|
self.net_peer_addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -13,18 +13,15 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
ops::Deref,
|
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures_util::{ready, Future};
|
|
||||||
use hyper::server::accept::Accept;
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use tokio_rustls::rustls::{
|
use tokio_rustls::{
|
||||||
Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite,
|
rustls::{Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite},
|
||||||
|
TlsAcceptor,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -33,105 +30,78 @@ pub struct TlsStreamInfo {
|
|||||||
pub protocol_version: ProtocolVersion,
|
pub protocol_version: ProtocolVersion,
|
||||||
pub negotiated_cipher_suite: SupportedCipherSuite,
|
pub negotiated_cipher_suite: SupportedCipherSuite,
|
||||||
pub sni_hostname: Option<String>,
|
pub sni_hostname: Option<String>,
|
||||||
pub apln_protocol: Option<Vec<u8>>,
|
pub alpn_protocol: Option<Vec<u8>>,
|
||||||
pub peer_certificates: Option<Vec<Certificate>>,
|
pub peer_certificates: Option<Vec<Certificate>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
impl TlsStreamInfo {
|
||||||
#[non_exhaustive]
|
#[must_use]
|
||||||
pub enum TlsStreamInfoError {
|
pub fn is_alpn_h2(&self) -> bool {
|
||||||
#[error("TLS handshake is not done yet")]
|
matches!(self.alpn_protocol.as_deref(), Some(b"h2"))
|
||||||
HandshakingNotDone,
|
}
|
||||||
|
|
||||||
#[error("Some fields were not available in the TLS connection")]
|
|
||||||
FieldsNotAvailable,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum MaybeTlsStream<T> {
|
pin_project_lite::pin_project! {
|
||||||
Handshaking(tokio_rustls::Accept<T>),
|
#[project = MaybeTlsStreamProj]
|
||||||
Streaming(tokio_rustls::server::TlsStream<T>),
|
pub enum MaybeTlsStream<T> {
|
||||||
Insecure(T),
|
Secure {
|
||||||
|
#[pin]
|
||||||
|
stream: tokio_rustls::server::TlsStream<T>
|
||||||
|
},
|
||||||
|
Insecure {
|
||||||
|
#[pin]
|
||||||
|
stream: T,
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> MaybeTlsStream<T> {
|
impl<T> MaybeTlsStream<T> {
|
||||||
pub fn new(stream: T, config: Option<Arc<ServerConfig>>) -> Self
|
|
||||||
where
|
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
|
||||||
{
|
|
||||||
if let Some(config) = config {
|
|
||||||
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
|
|
||||||
MaybeTlsStream::Handshaking(accept)
|
|
||||||
} else {
|
|
||||||
MaybeTlsStream::Insecure(stream)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the underlying IO stream
|
/// Get a reference to the underlying IO stream
|
||||||
///
|
///
|
||||||
/// Returns [`None`] if the stream closed before the TLS handshake finished.
|
/// Returns [`None`] if the stream closed before the TLS handshake finished.
|
||||||
/// It is guaranteed to return [`Some`] value after the handshake finished,
|
/// It is guaranteed to return [`Some`] value after the handshake finished,
|
||||||
/// or if it is a non-TLS connection.
|
/// or if it is a non-TLS connection.
|
||||||
pub fn get_ref(&self) -> Option<&T> {
|
pub fn get_ref(&self) -> &T {
|
||||||
match self {
|
match self {
|
||||||
Self::Handshaking(accept) => accept.get_ref(),
|
Self::Secure { stream } => stream.get_ref().0,
|
||||||
Self::Streaming(stream) => {
|
Self::Insecure { stream } => stream,
|
||||||
let (inner, _) = stream.get_ref();
|
|
||||||
Some(inner)
|
|
||||||
}
|
|
||||||
Self::Insecure(inner) => Some(inner),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a ref to the [`ServerConnection`] of the establish TLS stream.
|
/// Get a ref to the [`ServerConnection`] of the establish TLS stream.
|
||||||
///
|
///
|
||||||
/// Returns [`None`] if the connection is still handshaking and for non-TLS
|
/// Returns [`None`] for non-TLS connections.
|
||||||
/// connections.
|
|
||||||
pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
|
pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
|
||||||
match self {
|
match self {
|
||||||
Self::Streaming(stream) => {
|
Self::Secure { stream } => Some(stream.get_ref().1),
|
||||||
let (_, conn) = stream.get_ref();
|
Self::Insecure { .. } => None,
|
||||||
Some(conn)
|
|
||||||
}
|
|
||||||
Self::Handshaking(_) | Self::Insecure(_) => None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gather informations about the TLS connection. Returns `None` if the
|
/// Gather informations about the TLS connection. Returns `None` if the
|
||||||
/// stream is not a TLS stream.
|
/// stream is not a TLS stream.
|
||||||
///
|
pub fn tls_info(&self) -> Option<TlsStreamInfo> {
|
||||||
/// # Errors
|
let conn = self.get_tls_connection()?;
|
||||||
///
|
|
||||||
/// Returns an error if the TLS handshake is not yet done
|
|
||||||
pub fn tls_info(&self) -> Result<Option<TlsStreamInfo>, TlsStreamInfoError> {
|
|
||||||
let conn = match self {
|
|
||||||
Self::Streaming(stream) => stream.get_ref().1,
|
|
||||||
Self::Handshaking(_) => return Err(TlsStreamInfoError::HandshakingNotDone),
|
|
||||||
Self::Insecure(_) => return Ok(None),
|
|
||||||
};
|
|
||||||
|
|
||||||
// NOTE: we're getting the protocol version and cipher suite *after* the
|
// SAFETY: we're getting the protocol version and cipher suite *after* the
|
||||||
// handshake, so this should never lead to an error
|
// handshake, so this should never lead to a panic
|
||||||
let protocol_version = conn
|
let protocol_version = conn
|
||||||
.protocol_version()
|
.protocol_version()
|
||||||
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
|
.expect("TLS handshake is not done yet");
|
||||||
let negotiated_cipher_suite = conn
|
let negotiated_cipher_suite = conn
|
||||||
.negotiated_cipher_suite()
|
.negotiated_cipher_suite()
|
||||||
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
|
.expect("TLS handshake is not done yet");
|
||||||
|
|
||||||
let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned);
|
let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned);
|
||||||
let apln_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
|
let alpn_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
|
||||||
let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned);
|
let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned);
|
||||||
Ok(Some(TlsStreamInfo {
|
Some(TlsStreamInfo {
|
||||||
protocol_version,
|
protocol_version,
|
||||||
negotiated_cipher_suite,
|
negotiated_cipher_suite,
|
||||||
sni_hostname,
|
sni_hostname,
|
||||||
apln_protocol,
|
alpn_protocol,
|
||||||
peer_certificates,
|
peer_certificates,
|
||||||
}))
|
})
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn is_tls_handshaking(&self) -> bool {
|
|
||||||
matches!(self, Self::Handshaking(_))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,20 +114,9 @@ where
|
|||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
buf: &mut ReadBuf,
|
buf: &mut ReadBuf,
|
||||||
) -> Poll<std::io::Result<()>> {
|
) -> Poll<std::io::Result<()>> {
|
||||||
let pin = self.get_mut();
|
match self.project() {
|
||||||
match pin {
|
MaybeTlsStreamProj::Secure { stream } => stream.poll_read(cx, buf),
|
||||||
MaybeTlsStream::Handshaking(ref mut accept) => {
|
MaybeTlsStreamProj::Insecure { stream } => stream.poll_read(cx, buf),
|
||||||
match ready!(Pin::new(accept).poll(cx)) {
|
|
||||||
Ok(mut stream) => {
|
|
||||||
let result = Pin::new(&mut stream).poll_read(cx, buf);
|
|
||||||
*pin = MaybeTlsStream::Streaming(stream);
|
|
||||||
result
|
|
||||||
}
|
|
||||||
Err(err) => Poll::Ready(Err(err)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
|
||||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,104 +130,89 @@ where
|
|||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
buf: &[u8],
|
buf: &[u8],
|
||||||
) -> Poll<std::io::Result<usize>> {
|
) -> Poll<std::io::Result<usize>> {
|
||||||
let pin = self.get_mut();
|
match self.project() {
|
||||||
match pin {
|
MaybeTlsStreamProj::Secure { stream } => stream.poll_write(cx, buf),
|
||||||
MaybeTlsStream::Handshaking(ref mut accept) => {
|
MaybeTlsStreamProj::Insecure { stream } => stream.poll_write(cx, buf),
|
||||||
match ready!(Pin::new(accept).poll(cx)) {
|
|
||||||
Ok(mut stream) => {
|
|
||||||
let result = Pin::new(&mut stream).poll_write(cx, buf);
|
|
||||||
*pin = MaybeTlsStream::Streaming(stream);
|
|
||||||
result
|
|
||||||
}
|
|
||||||
Err(err) => Poll::Ready(Err(err)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
|
|
||||||
MaybeTlsStream::Insecure(ref mut fallback) => Pin::new(fallback).poll_write(cx, buf),
|
fn poll_write_vectored(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
bufs: &[std::io::IoSlice<'_>],
|
||||||
|
) -> Poll<Result<usize, std::io::Error>> {
|
||||||
|
match self.project() {
|
||||||
|
MaybeTlsStreamProj::Secure { stream } => stream.poll_write_vectored(cx, bufs),
|
||||||
|
MaybeTlsStreamProj::Insecure { stream } => stream.poll_write_vectored(cx, bufs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_write_vectored(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Secure { stream } => stream.is_write_vectored(),
|
||||||
|
Self::Insecure { stream } => stream.is_write_vectored(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
match self.get_mut() {
|
match self.project() {
|
||||||
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
|
MaybeTlsStreamProj::Secure { stream } => stream.poll_flush(cx),
|
||||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
MaybeTlsStreamProj::Insecure { stream } => stream.poll_flush(cx),
|
||||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
match self.get_mut() {
|
match self.project() {
|
||||||
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
|
MaybeTlsStreamProj::Secure { stream } => stream.poll_shutdown(cx),
|
||||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
MaybeTlsStreamProj::Insecure { stream } => stream.poll_shutdown(cx),
|
||||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MaybeTlsAcceptor<T> {
|
#[derive(Clone)]
|
||||||
|
pub struct MaybeTlsAcceptor {
|
||||||
tls_config: Option<Arc<ServerConfig>>,
|
tls_config: Option<Arc<ServerConfig>>,
|
||||||
incoming: T,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> MaybeTlsAcceptor<T> {
|
impl MaybeTlsAcceptor {
|
||||||
pub fn new(tls_config: Option<Arc<ServerConfig>>, incoming: T) -> Self {
|
#[must_use]
|
||||||
Self {
|
pub fn new(tls_config: Option<Arc<ServerConfig>>) -> Self {
|
||||||
tls_config,
|
Self { tls_config }
|
||||||
incoming,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_secure(tls_config: Arc<ServerConfig>, incoming: T) -> Self {
|
#[must_use]
|
||||||
|
pub fn new_secure(tls_config: Arc<ServerConfig>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tls_config: Some(tls_config),
|
tls_config: Some(tls_config),
|
||||||
incoming,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_insecure(incoming: T) -> Self {
|
#[must_use]
|
||||||
Self {
|
pub fn new_insecure() -> Self {
|
||||||
tls_config: None,
|
Self { tls_config: None }
|
||||||
incoming,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub const fn is_secure(&self) -> bool {
|
pub const fn is_secure(&self) -> bool {
|
||||||
self.tls_config.is_some()
|
self.tls_config.is_some()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Accept for MaybeTlsAcceptor<T>
|
/// Accept a connection and do the TLS handshake
|
||||||
where
|
///
|
||||||
T: Accept + Unpin,
|
/// # Errors
|
||||||
T::Conn: AsyncRead + AsyncWrite + Unpin,
|
///
|
||||||
T::Error: Into<std::io::Error>,
|
/// Returns an error if the TLS handshake failed
|
||||||
{
|
pub async fn accept<T>(&self, stream: T) -> Result<MaybeTlsStream<T>, std::io::Error>
|
||||||
type Conn = MaybeTlsStream<T::Conn>;
|
where
|
||||||
type Error = std::io::Error;
|
T: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
fn poll_accept(
|
match &self.tls_config {
|
||||||
self: Pin<&mut Self>,
|
Some(config) => {
|
||||||
cx: &mut Context<'_>,
|
let acceptor = TlsAcceptor::from(config.clone());
|
||||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
let stream = acceptor.accept(stream).await?;
|
||||||
let pin = self.get_mut();
|
Ok(MaybeTlsStream::Secure { stream })
|
||||||
|
}
|
||||||
let ret = match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
|
None => Ok(MaybeTlsStream::Insecure { stream }),
|
||||||
Some(Ok(sock)) => {
|
|
||||||
let config = pin.tls_config.clone();
|
|
||||||
Some(Ok(MaybeTlsStream::new(sock, config)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(Err(e)) => Some(Err(e.into())),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
Poll::Ready(ret)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Deref for MaybeTlsAcceptor<T> {
|
|
||||||
type Target = T;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.incoming
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,41 +12,57 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use futures_util::ready;
|
use bytes::BytesMut;
|
||||||
use hyper::server::accept::Accept;
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||||
|
|
||||||
use super::ProxyStream;
|
use super::ProxyProtocolV1Info;
|
||||||
|
use crate::rewind::Rewind;
|
||||||
|
|
||||||
pin_project_lite::pin_project! {
|
#[derive(Clone, Debug)]
|
||||||
pub struct ProxyAcceptor<A> {
|
pub struct ProxyAcceptor {
|
||||||
#[pin]
|
_private: (),
|
||||||
inner: A,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> ProxyAcceptor<A> {
|
#[derive(Debug, Error)]
|
||||||
pub const fn new(inner: A) -> Self {
|
#[error(transparent)]
|
||||||
Self { inner }
|
pub enum ProxyAcceptError {
|
||||||
}
|
Parse(#[from] super::v1::ParseError),
|
||||||
|
Read(#[from] std::io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> Accept for ProxyAcceptor<A>
|
impl ProxyAcceptor {
|
||||||
where
|
#[must_use]
|
||||||
A: Accept,
|
pub const fn new() -> Self {
|
||||||
{
|
Self { _private: () }
|
||||||
type Conn = ProxyStream<A::Conn>;
|
}
|
||||||
type Error = A::Error;
|
|
||||||
|
|
||||||
fn poll_accept(
|
/// Accept a proxy-protocol stream
|
||||||
self: std::pin::Pin<&mut Self>,
|
///
|
||||||
cx: &mut std::task::Context<'_>,
|
/// # Errors
|
||||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
///
|
||||||
let res = match ready!(self.project().inner.poll_accept(cx)) {
|
/// Returns an error on read error on the underlying stream, or when the
|
||||||
Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))),
|
/// proxy protocol preamble couldn't be parsed
|
||||||
Some(Err(e)) => Some(Err(e)),
|
pub async fn accept<T>(
|
||||||
None => None,
|
&self,
|
||||||
|
mut stream: T,
|
||||||
|
) -> Result<(ProxyProtocolV1Info, Rewind<T>), ProxyAcceptError>
|
||||||
|
where
|
||||||
|
T: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
let info = loop {
|
||||||
|
stream.read_buf(&mut buf).await?;
|
||||||
|
|
||||||
|
match ProxyProtocolV1Info::parse(&mut buf) {
|
||||||
|
Ok(info) => break info,
|
||||||
|
Err(e) if e.not_enough_bytes() => {}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::task::Poll::Ready(res)
|
let stream = Rewind::new_buffered(stream, buf.into());
|
||||||
|
|
||||||
|
Ok((info, stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,179 +12,66 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{
|
use tokio::io::AsyncRead;
|
||||||
ops::Deref,
|
|
||||||
pin::Pin,
|
|
||||||
task::{Context, Poll},
|
|
||||||
};
|
|
||||||
|
|
||||||
use futures_util::ready;
|
use super::{acceptor::ProxyAcceptError, ProxyAcceptor, ProxyProtocolV1Info};
|
||||||
use hyper::server::accept::Accept;
|
use crate::rewind::Rewind;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
|
|
||||||
use super::{stream::HandshakeNotDone, ProxyProtocolV1Info, ProxyStream};
|
#[derive(Clone)]
|
||||||
|
pub struct MaybeProxyAcceptor {
|
||||||
pin_project_lite::pin_project! {
|
acceptor: Option<ProxyAcceptor>,
|
||||||
pub struct MaybeProxyAcceptor<A> {
|
|
||||||
proxied: bool,
|
|
||||||
|
|
||||||
#[pin]
|
|
||||||
inner: A,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> MaybeProxyAcceptor<A> {
|
impl MaybeProxyAcceptor {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub const fn new(inner: A, proxied: bool) -> Self {
|
pub const fn new(proxied: bool) -> Self {
|
||||||
Self { proxied, inner }
|
let acceptor = if proxied {
|
||||||
}
|
Some(ProxyAcceptor::new())
|
||||||
|
} else {
|
||||||
pub const fn is_proxied(&self) -> bool {
|
None
|
||||||
self.proxied
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Deref for MaybeProxyAcceptor<A> {
|
|
||||||
type Target = A;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Accept for MaybeProxyAcceptor<A>
|
|
||||||
where
|
|
||||||
A: Accept,
|
|
||||||
{
|
|
||||||
type Conn = MaybeProxyStream<A::Conn>;
|
|
||||||
type Error = A::Error;
|
|
||||||
|
|
||||||
fn poll_accept(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
|
||||||
let proj = self.project();
|
|
||||||
let res = match ready!(proj.inner.poll_accept(cx)) {
|
|
||||||
Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))),
|
|
||||||
Some(Err(e)) => Some(Err(e)),
|
|
||||||
None => None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::task::Poll::Ready(res)
|
Self { acceptor }
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
pin_project_lite::pin_project! {
|
#[must_use]
|
||||||
#[project = MaybeProxyStreamProj]
|
pub const fn new_proxied(acceptor: ProxyAcceptor) -> Self {
|
||||||
pub enum MaybeProxyStream<S> {
|
Self {
|
||||||
Proxied { #[pin] stream: ProxyStream<S> },
|
acceptor: Some(acceptor),
|
||||||
NotProxied { #[pin] stream: S },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> MaybeProxyStream<S> {
|
|
||||||
pub const fn new(stream: S, proxied: bool) -> Self {
|
|
||||||
if proxied {
|
|
||||||
Self::Proxied {
|
|
||||||
stream: ProxyStream::new(stream),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Self::NotProxied { stream }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get informations from the proxied connection, if it was procied
|
#[must_use]
|
||||||
|
pub const fn new_unproxied() -> Self {
|
||||||
|
Self { acceptor: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub const fn is_proxied(&self) -> bool {
|
||||||
|
self.acceptor.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Accept a connection and do the proxy protocol handshake
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns an error if the stream did not complete the handshake yet
|
/// Returns an error if the proxy protocol handshake failed
|
||||||
pub fn proxy_info(&self) -> Result<Option<&ProxyProtocolV1Info>, HandshakeNotDone> {
|
pub async fn accept<T>(
|
||||||
match self {
|
&self,
|
||||||
Self::Proxied { stream } => Ok(Some(stream.proxy_info()?)),
|
stream: T,
|
||||||
Self::NotProxied { .. } => Ok(None),
|
) -> Result<(Option<ProxyProtocolV1Info>, Rewind<T>), ProxyAcceptError>
|
||||||
|
where
|
||||||
|
T: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
match &self.acceptor {
|
||||||
|
Some(acceptor) => {
|
||||||
|
let (info, stream) = acceptor.accept(stream).await?;
|
||||||
|
Ok((Some(info), stream))
|
||||||
}
|
}
|
||||||
|
None => {
|
||||||
|
let stream = Rewind::new(stream);
|
||||||
|
Ok((None, stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn is_proxy_handshaking(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
Self::Proxied { stream } => stream.is_handshaking(),
|
|
||||||
Self::NotProxied { .. } => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> Deref for MaybeProxyStream<S> {
|
|
||||||
type Target = S;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
match self {
|
|
||||||
Self::Proxied { stream } => &**stream,
|
|
||||||
Self::NotProxied { stream } => stream,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> AsyncRead for MaybeProxyStream<S>
|
|
||||||
where
|
|
||||||
S: AsyncRead,
|
|
||||||
{
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut tokio::io::ReadBuf<'_>,
|
|
||||||
) -> Poll<std::io::Result<()>> {
|
|
||||||
match self.project() {
|
|
||||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf),
|
|
||||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> AsyncWrite for MaybeProxyStream<S>
|
|
||||||
where
|
|
||||||
S: AsyncWrite,
|
|
||||||
{
|
|
||||||
fn poll_write(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<Result<usize, std::io::Error>> {
|
|
||||||
match self.project() {
|
|
||||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf),
|
|
||||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
match self.project() {
|
|
||||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx),
|
|
||||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
match self.project() {
|
|
||||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx),
|
|
||||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_write_vectored(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
bufs: &[std::io::IoSlice<'_>],
|
|
||||||
) -> Poll<Result<usize, std::io::Error>> {
|
|
||||||
match self.project() {
|
|
||||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs),
|
|
||||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_write_vectored(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(),
|
|
||||||
MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,12 +14,10 @@
|
|||||||
|
|
||||||
mod acceptor;
|
mod acceptor;
|
||||||
mod maybe;
|
mod maybe;
|
||||||
mod stream;
|
|
||||||
mod v1;
|
mod v1;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
acceptor::ProxyAcceptor,
|
acceptor::{ProxyAcceptError, ProxyAcceptor},
|
||||||
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
|
maybe::MaybeProxyAcceptor,
|
||||||
stream::{HandshakeNotDone as ProxyHandshakeNotDone, ProxyStream},
|
|
||||||
v1::ProxyProtocolV1Info,
|
v1::ProxyProtocolV1Info,
|
||||||
};
|
};
|
||||||
|
@ -1,169 +0,0 @@
|
|||||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
use std::ops::Deref;
|
|
||||||
|
|
||||||
use futures_util::ready;
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
||||||
|
|
||||||
use super::ProxyProtocolV1Info;
|
|
||||||
|
|
||||||
// Max theorical size we need is 108 for proxy protocol v1
|
|
||||||
const BUF_SIZE: usize = 256;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum ProxyStreamState {
|
|
||||||
Handshaking {
|
|
||||||
buffer: [u8; BUF_SIZE],
|
|
||||||
index: usize,
|
|
||||||
},
|
|
||||||
Established(ProxyProtocolV1Info),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProxyStreamState {
|
|
||||||
pub const fn is_handshaking(&self) -> bool {
|
|
||||||
matches!(self, Self::Handshaking { .. })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project_lite::pin_project! {
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ProxyStream<S> {
|
|
||||||
state: ProxyStreamState,
|
|
||||||
|
|
||||||
#[pin]
|
|
||||||
inner: S,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> ProxyStream<S> {
|
|
||||||
pub const fn new(inner: S) -> Self {
|
|
||||||
Self {
|
|
||||||
state: ProxyStreamState::Handshaking {
|
|
||||||
buffer: [0; BUF_SIZE],
|
|
||||||
index: 0,
|
|
||||||
},
|
|
||||||
inner,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error, Clone, Copy)]
|
|
||||||
#[error("Proxy protocol handshake is not complete")]
|
|
||||||
pub struct HandshakeNotDone;
|
|
||||||
|
|
||||||
impl<S> Deref for ProxyStream<S> {
|
|
||||||
type Target = S;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> ProxyStream<S> {
|
|
||||||
/// Get informations from the proxied connection
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns an error if the stream did not complete the handshake yet
|
|
||||||
pub fn proxy_info(&self) -> Result<&ProxyProtocolV1Info, HandshakeNotDone> {
|
|
||||||
match &self.state {
|
|
||||||
ProxyStreamState::Handshaking { .. } => Err(HandshakeNotDone),
|
|
||||||
ProxyStreamState::Established(info) => Ok(info),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns `true` if the proxy protocol is still handshaking
|
|
||||||
pub const fn is_handshaking(&self) -> bool {
|
|
||||||
self.state.is_handshaking()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> AsyncRead for ProxyStream<S>
|
|
||||||
where
|
|
||||||
S: AsyncRead,
|
|
||||||
{
|
|
||||||
fn poll_read(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
buf: &mut tokio::io::ReadBuf<'_>,
|
|
||||||
) -> std::task::Poll<std::io::Result<()>> {
|
|
||||||
let proj = self.project();
|
|
||||||
match proj.state {
|
|
||||||
ProxyStreamState::Handshaking { buffer, index } => {
|
|
||||||
let mut buffer = ReadBuf::new(&mut buffer[..]);
|
|
||||||
buffer.advance(*index);
|
|
||||||
ready!(proj.inner.poll_read(cx, &mut buffer))?;
|
|
||||||
let filled = buffer.filled();
|
|
||||||
*index = filled.len();
|
|
||||||
|
|
||||||
match ProxyProtocolV1Info::parse(filled) {
|
|
||||||
Ok((info, rest)) => {
|
|
||||||
if buf.remaining() < rest.len() {
|
|
||||||
// This is highly unlikely, but is better than panicking later.
|
|
||||||
// If it ever happens, we could introduce a "buffer draining" state
|
|
||||||
// which drains the inner buffer repeatedly until it's empty
|
|
||||||
return std::task::Poll::Ready(Err(std::io::Error::new(
|
|
||||||
std::io::ErrorKind::Other,
|
|
||||||
"underlying buffer is too small",
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
buf.put_slice(rest);
|
|
||||||
*proj.state = ProxyStreamState::Established(info);
|
|
||||||
std::task::Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())),
|
|
||||||
Err(e) => std::task::Poll::Ready(Err(std::io::Error::new(
|
|
||||||
std::io::ErrorKind::Other,
|
|
||||||
e,
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> AsyncWrite for ProxyStream<S>
|
|
||||||
where
|
|
||||||
S: AsyncWrite,
|
|
||||||
{
|
|
||||||
fn poll_write(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
|
||||||
let proj = self.project();
|
|
||||||
match proj.state {
|
|
||||||
// Hold off writes until the handshake is done
|
|
||||||
// XXX: is this the right way to do it?
|
|
||||||
ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending,
|
|
||||||
ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
|
||||||
self.project().inner.poll_flush(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
|
||||||
self.project().inner.poll_shutdown(cx)
|
|
||||||
}
|
|
||||||
}
|
|
@ -18,6 +18,7 @@ use std::{
|
|||||||
str::Utf8Error,
|
str::Utf8Error,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use bytes::Buf;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -35,7 +36,7 @@ pub enum ProxyProtocolV1Info {
|
|||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
#[error("Invalid proxy protocol header")]
|
#[error("Invalid proxy protocol header")]
|
||||||
pub(super) enum ParseError {
|
pub enum ParseError {
|
||||||
#[error("Not enough bytes provided")]
|
#[error("Not enough bytes provided")]
|
||||||
NotEnoughBytes,
|
NotEnoughBytes,
|
||||||
NoCrLf,
|
NoCrLf,
|
||||||
@ -60,17 +61,21 @@ impl ParseError {
|
|||||||
|
|
||||||
impl ProxyProtocolV1Info {
|
impl ProxyProtocolV1Info {
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> {
|
pub(super) fn parse<B>(mut buf: B) -> Result<Self, ParseError>
|
||||||
|
where
|
||||||
|
B: Buf + AsRef<[u8]>,
|
||||||
|
{
|
||||||
use ParseError as E;
|
use ParseError as E;
|
||||||
// First, check if we *possibly* have enough bytes.
|
// First, check if we *possibly* have enough bytes.
|
||||||
// Minimum is 15: "PROXY UNKNOWN\r\n"
|
// Minimum is 15: "PROXY UNKNOWN\r\n"
|
||||||
|
|
||||||
if bytes.len() < 15 {
|
if buf.remaining() < 15 {
|
||||||
return Err(E::NotEnoughBytes);
|
return Err(E::NotEnoughBytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's check in the first 108 bytes if we find a CRLF
|
// Let's check in the first 108 bytes if we find a CRLF
|
||||||
let crlf = if let Some(crlf) = bytes
|
let crlf = if let Some(crlf) = buf
|
||||||
|
.as_ref()
|
||||||
.windows(2)
|
.windows(2)
|
||||||
.take(108)
|
.take(108)
|
||||||
.position(|needle| needle == [0x0D, 0x0A])
|
.position(|needle| needle == [0x0D, 0x0A])
|
||||||
@ -78,7 +83,7 @@ impl ProxyProtocolV1Info {
|
|||||||
crlf
|
crlf
|
||||||
} else {
|
} else {
|
||||||
// If not, it might be because we don't have enough bytes
|
// If not, it might be because we don't have enough bytes
|
||||||
return if bytes.len() < 108 {
|
return if buf.remaining() < 108 {
|
||||||
Err(E::NotEnoughBytes)
|
Err(E::NotEnoughBytes)
|
||||||
} else {
|
} else {
|
||||||
// Else it's just invalid
|
// Else it's just invalid
|
||||||
@ -86,10 +91,8 @@ impl ProxyProtocolV1Info {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Keep the rest of the buffer to pass it to the underlying protocol
|
|
||||||
let rest = &bytes[crlf + 2..];
|
|
||||||
// Trim to everything before the CRLF
|
// Trim to everything before the CRLF
|
||||||
let bytes = &bytes[..crlf];
|
let bytes = &buf.as_ref()[..crlf];
|
||||||
|
|
||||||
let mut it = bytes.splitn(6, |c| c == &b' ');
|
let mut it = bytes.splitn(6, |c| c == &b' ');
|
||||||
// Check for the preamble
|
// Check for the preamble
|
||||||
@ -187,7 +190,9 @@ impl ProxyProtocolV1Info {
|
|||||||
None => return Err(E::NoProtocol),
|
None => return Err(E::NoProtocol),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((result, rest))
|
buf.advance(crlf + 2);
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
@ -258,39 +263,42 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse() {
|
fn test_parse() {
|
||||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
let mut buf =
|
||||||
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world",
|
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world"
|
||||||
)
|
.as_slice();
|
||||||
.unwrap();
|
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||||
assert_eq!(rest, b"hello world");
|
assert_eq!(buf, b"hello world");
|
||||||
assert!(info.is_tcp());
|
assert!(info.is_tcp());
|
||||||
assert!(!info.is_udp());
|
assert!(!info.is_udp());
|
||||||
assert!(!info.is_unknown());
|
assert!(!info.is_unknown());
|
||||||
assert!(info.is_ipv4());
|
assert!(info.is_ipv4());
|
||||||
assert!(!info.is_ipv6());
|
assert!(!info.is_ipv6());
|
||||||
|
|
||||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
let mut buf =
|
||||||
b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||||
).unwrap();
|
.as_slice();
|
||||||
assert_eq!(rest, b"hello world");
|
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||||
|
assert_eq!(buf, b"hello world");
|
||||||
assert!(info.is_tcp());
|
assert!(info.is_tcp());
|
||||||
assert!(!info.is_udp());
|
assert!(!info.is_udp());
|
||||||
assert!(!info.is_unknown());
|
assert!(!info.is_unknown());
|
||||||
assert!(!info.is_ipv4());
|
assert!(!info.is_ipv4());
|
||||||
assert!(info.is_ipv6());
|
assert!(info.is_ipv6());
|
||||||
|
|
||||||
let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap();
|
let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
|
||||||
assert_eq!(rest, b"hello world");
|
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||||
|
assert_eq!(buf, b"hello world");
|
||||||
assert!(!info.is_tcp());
|
assert!(!info.is_tcp());
|
||||||
assert!(!info.is_udp());
|
assert!(!info.is_udp());
|
||||||
assert!(info.is_unknown());
|
assert!(info.is_unknown());
|
||||||
assert!(!info.is_ipv4());
|
assert!(!info.is_ipv4());
|
||||||
assert!(!info.is_ipv6());
|
assert!(!info.is_ipv6());
|
||||||
|
|
||||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
let mut buf =
|
||||||
b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||||
).unwrap();
|
.as_slice();
|
||||||
assert_eq!(rest, b"hello world");
|
let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
|
||||||
|
assert_eq!(buf, b"hello world");
|
||||||
assert!(!info.is_tcp());
|
assert!(!info.is_tcp());
|
||||||
assert!(!info.is_udp());
|
assert!(!info.is_udp());
|
||||||
assert!(info.is_unknown());
|
assert!(info.is_unknown());
|
||||||
|
151
crates/listener/src/rewind.rs
Normal file
151
crates/listener/src/rewind.rs
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
// Taken from hyper@0.14.20, src/common/io/rewind.rs
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
cmp, io,
|
||||||
|
marker::Unpin,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
use bytes::{Buf, Bytes};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
|
||||||
|
/// Combine a buffer with an IO, rewinding reads to use the buffer.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Rewind<T> {
|
||||||
|
pre: Option<Bytes>,
|
||||||
|
inner: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Rewind<T> {
|
||||||
|
pub(crate) fn new(io: T) -> Self {
|
||||||
|
Rewind {
|
||||||
|
pre: None,
|
||||||
|
inner: io,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
|
||||||
|
Rewind {
|
||||||
|
pre: Some(buf),
|
||||||
|
inner: io,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn rewind(&mut self, bs: Bytes) {
|
||||||
|
debug_assert!(self.pre.is_none());
|
||||||
|
self.pre = Some(bs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AsyncRead for Rewind<T>
|
||||||
|
where
|
||||||
|
T: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
if let Some(mut prefix) = self.pre.take() {
|
||||||
|
// If there are no remaining bytes, let the bytes get dropped.
|
||||||
|
if !prefix.is_empty() {
|
||||||
|
let copy_len = cmp::min(prefix.len(), buf.remaining());
|
||||||
|
// TODO: There should be a way to do following two lines cleaner...
|
||||||
|
buf.put_slice(&prefix[..copy_len]);
|
||||||
|
prefix.advance(copy_len);
|
||||||
|
// Put back what's left
|
||||||
|
if !prefix.is_empty() {
|
||||||
|
self.pre = Some(prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AsyncWrite for Rewind<T>
|
||||||
|
where
|
||||||
|
T: AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
bufs: &[io::IoSlice<'_>],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.inner).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_write_vectored(&self) -> bool {
|
||||||
|
self.inner.is_write_vectored()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
// FIXME: re-implement tests with `async/await`, this import should
|
||||||
|
// trigger a warning to remind us
|
||||||
|
use bytes::Bytes;
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
|
|
||||||
|
use super::Rewind;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn partial_rewind() {
|
||||||
|
let underlying = [104, 101, 108, 108, 111];
|
||||||
|
|
||||||
|
let mock = tokio_test::io::Builder::new().read(&underlying).build();
|
||||||
|
|
||||||
|
let mut stream = Rewind::new(mock);
|
||||||
|
|
||||||
|
// Read off some bytes, ensure we filled o1
|
||||||
|
let mut buf = [0; 2];
|
||||||
|
stream.read_exact(&mut buf).await.expect("read1");
|
||||||
|
|
||||||
|
// Rewind the stream so that it is as if we never read in the first place.
|
||||||
|
stream.rewind(Bytes::copy_from_slice(&buf[..]));
|
||||||
|
|
||||||
|
let mut buf = [0; 5];
|
||||||
|
stream.read_exact(&mut buf).await.expect("read1");
|
||||||
|
|
||||||
|
// At this point we should have read everything that was in the MockStream
|
||||||
|
assert_eq!(&buf, &underlying);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn full_rewind() {
|
||||||
|
let underlying = [104, 101, 108, 108, 111];
|
||||||
|
|
||||||
|
let mock = tokio_test::io::Builder::new().read(&underlying).build();
|
||||||
|
|
||||||
|
let mut stream = Rewind::new(mock);
|
||||||
|
|
||||||
|
let mut buf = [0; 5];
|
||||||
|
stream.read_exact(&mut buf).await.expect("read1");
|
||||||
|
|
||||||
|
// Rewind the stream so that it is as if we never read in the first place.
|
||||||
|
stream.rewind(Bytes::copy_from_slice(&buf[..]));
|
||||||
|
|
||||||
|
let mut buf = [0; 5];
|
||||||
|
stream.read_exact(&mut buf).await.expect("read1");
|
||||||
|
}
|
||||||
|
}
|
301
crates/listener/src/server.rs
Normal file
301
crates/listener/src/server.rs
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures_util::{Stream, StreamExt};
|
||||||
|
use http_body::Body;
|
||||||
|
use hyper::{Request, Response};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio_rustls::rustls::ServerConfig;
|
||||||
|
use tower_http::add_extension::AddExtension;
|
||||||
|
use tower_service::Service;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
maybe_tls::{MaybeTlsAcceptor, TlsStreamInfo},
|
||||||
|
proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
|
||||||
|
unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
|
||||||
|
ConnectionInfo,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct Server<S> {
|
||||||
|
tls: Option<Arc<ServerConfig>>,
|
||||||
|
proxy: bool,
|
||||||
|
listener: UnixOrTcpListener,
|
||||||
|
service: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Server<S> {
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the listener couldn't be converted via [`TryInfo`]
|
||||||
|
pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
|
||||||
|
where
|
||||||
|
L: TryInto<UnixOrTcpListener>,
|
||||||
|
{
|
||||||
|
Ok(Self {
|
||||||
|
tls: None,
|
||||||
|
proxy: false,
|
||||||
|
listener: listener.try_into()?,
|
||||||
|
service,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
|
||||||
|
Self {
|
||||||
|
tls: None,
|
||||||
|
proxy: false,
|
||||||
|
listener: listener.into(),
|
||||||
|
service,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub const fn with_proxy(mut self) -> Self {
|
||||||
|
self.proxy = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
|
||||||
|
self.tls = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run a single server
|
||||||
|
pub async fn run<B, SD>(self, shutdown: SD)
|
||||||
|
where
|
||||||
|
S: Service<Request<hyper::Body>, Response = Response<B>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
B: Body + Send + 'static,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
SD: Stream + Unpin,
|
||||||
|
SD::Item: std::fmt::Display,
|
||||||
|
{
|
||||||
|
run_servers(std::iter::once(self), shutdown).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
enum AcceptError {
|
||||||
|
#[error("failed to accept connection from the underlying socket")]
|
||||||
|
Socket {
|
||||||
|
#[source]
|
||||||
|
source: std::io::Error,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("failed to complete the TLS handshake")]
|
||||||
|
TlsHandshake {
|
||||||
|
#[source]
|
||||||
|
source: std::io::Error,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("failed to complete the proxy protocol handshake")]
|
||||||
|
ProxyHandshake {
|
||||||
|
#[source]
|
||||||
|
source: ProxyAcceptError,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Hyper(#[from] hyper::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AcceptError {
|
||||||
|
fn socket(source: std::io::Error) -> Self {
|
||||||
|
Self::Socket { source }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tls_handshake(source: std::io::Error) -> Self {
|
||||||
|
Self::TlsHandshake { source }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn proxy_handshake(source: ProxyAcceptError) -> Self {
|
||||||
|
Self::ProxyHandshake { source }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept<S, B>(
|
||||||
|
maybe_proxy_acceptor: &MaybeProxyAcceptor,
|
||||||
|
maybe_tls_acceptor: &MaybeTlsAcceptor,
|
||||||
|
peer_addr: SocketAddr,
|
||||||
|
stream: UnixOrTcpConnection,
|
||||||
|
service: S,
|
||||||
|
) -> Result<(), AcceptError>
|
||||||
|
where
|
||||||
|
S: Service<Request<hyper::Body>, Response = Response<B>>,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
B: Body + Send + 'static,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
{
|
||||||
|
let (proxy, stream) = maybe_proxy_acceptor
|
||||||
|
.accept(stream)
|
||||||
|
.await
|
||||||
|
.map_err(AcceptError::proxy_handshake)?;
|
||||||
|
|
||||||
|
let stream = maybe_tls_acceptor
|
||||||
|
.accept(stream)
|
||||||
|
.await
|
||||||
|
.map_err(AcceptError::tls_handshake)?;
|
||||||
|
|
||||||
|
let tls = stream.tls_info();
|
||||||
|
|
||||||
|
// Figure out if it's HTTP/2 based on the negociated ALPN info
|
||||||
|
let is_h2 = tls.as_ref().map_or(false, TlsStreamInfo::is_alpn_h2);
|
||||||
|
|
||||||
|
let info = ConnectionInfo {
|
||||||
|
tls,
|
||||||
|
proxy,
|
||||||
|
net_peer_addr: peer_addr.into_net(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let service = AddExtension::new(service, info);
|
||||||
|
|
||||||
|
if is_h2 {
|
||||||
|
hyper::server::conn::Http::new()
|
||||||
|
.http2_only(true)
|
||||||
|
.serve_connection(stream, service)
|
||||||
|
.await?;
|
||||||
|
} else {
|
||||||
|
hyper::server::conn::Http::new()
|
||||||
|
.http1_only(true)
|
||||||
|
.http1_keep_alive(false)
|
||||||
|
.serve_connection(stream, service)
|
||||||
|
.await?;
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_servers<S, B, SD>(listeners: impl IntoIterator<Item = Server<S>>, mut shutdown: SD)
|
||||||
|
where
|
||||||
|
S: Service<Request<hyper::Body>, Response = Response<B>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
B: Body + Send + 'static,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
SD: Stream + Unpin,
|
||||||
|
SD::Item: std::fmt::Display,
|
||||||
|
{
|
||||||
|
let listeners: Vec<_> = listeners
|
||||||
|
.into_iter()
|
||||||
|
.map(|server| {
|
||||||
|
let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
|
||||||
|
let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
|
||||||
|
let service = server.service;
|
||||||
|
let listener = server.listener;
|
||||||
|
(maybe_proxy_acceptor, maybe_tls_acceptor, service, listener)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut set = tokio::task::JoinSet::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let mut accept_all: futures_util::stream::FuturesUnordered<_> = listeners
|
||||||
|
.iter()
|
||||||
|
.map(
|
||||||
|
|(maybe_proxy_acceptor, maybe_tls_acceptor, service, listener)| async move {
|
||||||
|
listener
|
||||||
|
.accept()
|
||||||
|
.await
|
||||||
|
.map_err(AcceptError::socket)
|
||||||
|
.map(|(addr, conn)| {
|
||||||
|
(
|
||||||
|
maybe_proxy_acceptor.clone(),
|
||||||
|
maybe_tls_acceptor.clone(),
|
||||||
|
service.clone(),
|
||||||
|
addr,
|
||||||
|
conn,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
biased;
|
||||||
|
|
||||||
|
// First look for the shutdown signal
|
||||||
|
res = shutdown.next() => {
|
||||||
|
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
|
||||||
|
tracing::info!("Received shutdown signal ({why})");
|
||||||
|
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
|
||||||
|
// Poll on the JoinSet, clearing finished task
|
||||||
|
res = set.join_next(), if !set.is_empty() => {
|
||||||
|
match res {
|
||||||
|
Some(Ok(Ok(()))) => tracing::trace!("Task was successful"),
|
||||||
|
Some(Ok(Err(e))) => tracing::error!("{e}"),
|
||||||
|
Some(Err(e)) => tracing::error!("Join error: {e}"),
|
||||||
|
None => tracing::error!("Join set was polled even though it was empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Then look for connections to accept
|
||||||
|
res = accept_all.next(), if !accept_all.is_empty() => {
|
||||||
|
// SAFETY: We shouldn't reach this branch if the unordered future set is empty
|
||||||
|
let res = if let Some(res) = res { res } else { unreachable!() };
|
||||||
|
|
||||||
|
// Spawn the connection in the set, so we don't have to wait for the handshake to
|
||||||
|
// accept the next connection. This allows us to keep track of active connections
|
||||||
|
// and waiting on them for a graceful shutdown
|
||||||
|
set.spawn(async move {
|
||||||
|
let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res?;
|
||||||
|
accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
|
||||||
|
});
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if !set.is_empty() {
|
||||||
|
tracing::info!(
|
||||||
|
"There are {active} active connections, performing a graceful shutdown. Send the shutdown signal again to force.",
|
||||||
|
active = set.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
biased;
|
||||||
|
|
||||||
|
res = set.join_next() => {
|
||||||
|
match res {
|
||||||
|
Some(Ok(Ok(()))) => tracing::trace!("Task was successful"),
|
||||||
|
Some(Ok(Err(e))) => tracing::error!("{e}"),
|
||||||
|
Some(Err(e)) => tracing::error!("Join error: {e}"),
|
||||||
|
// No more tasks, going out
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
|
||||||
|
res = shutdown.next() => {
|
||||||
|
let why = res.map_or_else(|| String::from("???"), |why| format!("{why}"));
|
||||||
|
tracing::warn!("Received shutdown signal again ({why}), forcing shutdown ({active} active connections)", active = set.len());
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
set.shutdown().await;
|
||||||
|
tracing::info!("Shutdown complete");
|
||||||
|
}
|
178
crates/listener/src/shutdown.rs
Normal file
178
crates/listener/src/shutdown.rs
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::{fmt::Display, pin::Pin, task::Poll, time::Duration};
|
||||||
|
|
||||||
|
use futures_util::{ready, Future, Stream};
|
||||||
|
use tokio::{
|
||||||
|
signal::unix::{signal, Signal, SignalKind},
|
||||||
|
time::Sleep,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum ShutdownReason {
|
||||||
|
Signal(SignalKind),
|
||||||
|
Timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn signal_to_str(kind: SignalKind) -> &'static str {
|
||||||
|
match kind.as_raw_value() {
|
||||||
|
libc::SIGALRM => "SIGALRM",
|
||||||
|
libc::SIGCHLD => "SIGCHLD",
|
||||||
|
libc::SIGHUP => "SIGHUP",
|
||||||
|
libc::SIGINT => "SIGINT",
|
||||||
|
libc::SIGIO => "SIGIO",
|
||||||
|
libc::SIGPIPE => "SIGPIPE",
|
||||||
|
libc::SIGQUIT => "SIGQUIT",
|
||||||
|
libc::SIGTERM => "SIGTERM",
|
||||||
|
libc::SIGUSR1 => "SIGUSR1",
|
||||||
|
libc::SIGUSR2 => "SIGUSR2",
|
||||||
|
_ => "SIG???",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for ShutdownReason {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Signal(s) => signal_to_str(*s).fmt(f),
|
||||||
|
Self::Timeout => "timeout".fmt(f),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub enum ShutdownStreamState {
|
||||||
|
#[default]
|
||||||
|
Waiting,
|
||||||
|
|
||||||
|
Graceful {
|
||||||
|
sleep: Option<Pin<Box<Sleep>>>,
|
||||||
|
},
|
||||||
|
|
||||||
|
Done,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShutdownStreamState {
|
||||||
|
fn is_graceful(&self) -> bool {
|
||||||
|
matches!(self, Self::Graceful { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_done(&self) -> bool {
|
||||||
|
matches!(self, Self::Done)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_sleep_mut(&mut self) -> Option<&mut Pin<Box<Sleep>>> {
|
||||||
|
match self {
|
||||||
|
Self::Graceful { sleep } => sleep.as_mut(),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A stream which is used to drive a graceful shutdown.
|
||||||
|
///
|
||||||
|
/// It will emit 2 items: one when a first signal is caught, the other when
|
||||||
|
/// either another signal is caught, or after a timeout.
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ShutdownStream {
|
||||||
|
state: ShutdownStreamState,
|
||||||
|
signals: Vec<(SignalKind, Signal)>,
|
||||||
|
timeout: Option<Duration>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShutdownStream {
|
||||||
|
/// Create a default shutdown stream, which listens on SIGINT and SIGTERM,
|
||||||
|
/// with a 60s timeout
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if signal handlers could not be installed
|
||||||
|
pub fn new() -> Result<Self, std::io::Error> {
|
||||||
|
let ret = Self::default()
|
||||||
|
.with_timeout(Duration::from_secs(60))
|
||||||
|
.with_signal(SignalKind::interrupt())?
|
||||||
|
.with_signal(SignalKind::terminate())?;
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a signal to register
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the signal handler could not be installed
|
||||||
|
pub fn with_signal(mut self, kind: SignalKind) -> Result<Self, std::io::Error> {
|
||||||
|
let signal = signal(kind)?;
|
||||||
|
self.signals.push((kind, signal));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||||
|
self.timeout = Some(timeout);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for ShutdownStream {
|
||||||
|
type Item = ShutdownReason;
|
||||||
|
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
match self.state {
|
||||||
|
ShutdownStreamState::Waiting => (2, Some(2)),
|
||||||
|
ShutdownStreamState::Graceful { .. } => (1, Some(1)),
|
||||||
|
ShutdownStreamState::Done => (0, Some(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Option<Self::Item>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
if this.state.is_done() {
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (kind, signal) in &mut this.signals {
|
||||||
|
match signal.poll_recv(cx) {
|
||||||
|
Poll::Ready(_) => {
|
||||||
|
// We got a signal
|
||||||
|
if this.state.is_graceful() {
|
||||||
|
// If we was gracefully shutting down, mark it as done
|
||||||
|
this.state = ShutdownStreamState::Done;
|
||||||
|
} else {
|
||||||
|
// Else start the timeout
|
||||||
|
let sleep = this
|
||||||
|
.timeout
|
||||||
|
.map(|duration| Box::pin(tokio::time::sleep(duration)));
|
||||||
|
this.state = ShutdownStreamState::Graceful { sleep };
|
||||||
|
}
|
||||||
|
|
||||||
|
return Poll::Ready(Some(ShutdownReason::Signal(*kind)));
|
||||||
|
}
|
||||||
|
Poll::Pending => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(timeout) = this.state.get_sleep_mut() {
|
||||||
|
ready!(timeout.as_mut().poll(cx));
|
||||||
|
this.state = ShutdownStreamState::Done;
|
||||||
|
Poll::Ready(Some(ShutdownReason::Timeout))
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! A listener which can listen on either TCP sockets or on UNIX domain sockets
|
||||||
|
|
||||||
// TODO: Unlink the UNIX socket on drop?
|
// TODO: Unlink the UNIX socket on drop?
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
@ -19,8 +21,6 @@ use std::{
|
|||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures_util::ready;
|
|
||||||
use hyper::server::accept::Accept;
|
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncRead, AsyncWrite},
|
io::{AsyncRead, AsyncWrite},
|
||||||
net::{TcpListener, TcpStream, UnixListener, UnixStream},
|
net::{TcpListener, TcpStream, UnixListener, UnixStream},
|
||||||
@ -107,6 +107,7 @@ impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
|
|||||||
type Error = std::io::Error;
|
type Error = std::io::Error;
|
||||||
|
|
||||||
fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
|
fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
|
||||||
|
listener.set_nonblocking(true)?;
|
||||||
Ok(Self::Unix(UnixListener::from_std(listener)?))
|
Ok(Self::Unix(UnixListener::from_std(listener)?))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,6 +116,7 @@ impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
|
|||||||
type Error = std::io::Error;
|
type Error = std::io::Error;
|
||||||
|
|
||||||
fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
|
fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
|
||||||
|
listener.set_nonblocking(true)?;
|
||||||
Ok(Self::Tcp(TcpListener::from_std(listener)?))
|
Ok(Self::Tcp(TcpListener::from_std(listener)?))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -140,6 +142,24 @@ impl UnixOrTcpListener {
|
|||||||
pub const fn is_tcp(&self) -> bool {
|
pub const fn is_tcp(&self) -> bool {
|
||||||
matches!(self, Self::Tcp(_))
|
matches!(self, Self::Tcp(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Accept an incoming connection
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the underlying socket couldn't accept the connection
|
||||||
|
pub async fn accept(&self) -> Result<(SocketAddr, UnixOrTcpConnection), std::io::Error> {
|
||||||
|
match self {
|
||||||
|
Self::Unix(listener) => {
|
||||||
|
let (stream, remote_addr) = listener.accept().await?;
|
||||||
|
Ok((remote_addr.into(), UnixOrTcpConnection::Unix { stream }))
|
||||||
|
}
|
||||||
|
Self::Tcp(listener) => {
|
||||||
|
let (stream, remote_addr) = listener.accept().await?;
|
||||||
|
Ok((remote_addr.into(), UnixOrTcpConnection::Tcp { stream }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pin_project_lite::pin_project! {
|
pin_project_lite::pin_project! {
|
||||||
@ -157,6 +177,12 @@ pin_project_lite::pin_project! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<TcpStream> for UnixOrTcpConnection {
|
||||||
|
fn from(stream: TcpStream) -> Self {
|
||||||
|
Self::Tcp { stream }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UnixOrTcpConnection {
|
impl UnixOrTcpConnection {
|
||||||
/// Get the local address of the stream
|
/// Get the local address of the stream
|
||||||
///
|
///
|
||||||
@ -166,8 +192,8 @@ impl UnixOrTcpConnection {
|
|||||||
/// [`UnixStream`] couldn't provide the local address
|
/// [`UnixStream`] couldn't provide the local address
|
||||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||||
match self {
|
match self {
|
||||||
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
Self::Unix { stream } => stream.local_addr().map(SocketAddr::from),
|
||||||
Self::Tcp { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
Self::Tcp { stream } => stream.local_addr().map(SocketAddr::from),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,36 +205,12 @@ impl UnixOrTcpConnection {
|
|||||||
/// [`UnixStream`] couldn't provide the remote address
|
/// [`UnixStream`] couldn't provide the remote address
|
||||||
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||||
match self {
|
match self {
|
||||||
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
Self::Unix { stream } => stream.peer_addr().map(SocketAddr::from),
|
||||||
Self::Tcp { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
Self::Tcp { stream } => stream.peer_addr().map(SocketAddr::from),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Accept for UnixOrTcpListener {
|
|
||||||
type Error = std::io::Error;
|
|
||||||
type Conn = UnixOrTcpConnection;
|
|
||||||
|
|
||||||
fn poll_accept(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
|
||||||
let conn = match &*self {
|
|
||||||
Self::Unix(listener) => {
|
|
||||||
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
|
|
||||||
UnixOrTcpConnection::Unix { stream }
|
|
||||||
}
|
|
||||||
|
|
||||||
Self::Tcp(listener) => {
|
|
||||||
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
|
|
||||||
UnixOrTcpConnection::Tcp { stream }
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Poll::Ready(Some(Ok(conn)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncRead for UnixOrTcpConnection {
|
impl AsyncRead for UnixOrTcpConnection {
|
||||||
fn poll_read(
|
fn poll_read(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
@ -234,23 +236,6 @@ impl AsyncWrite for UnixOrTcpConnection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
match self.project() {
|
|
||||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
|
|
||||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
match self.project() {
|
|
||||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
|
|
||||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_write_vectored(
|
fn poll_write_vectored(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
@ -268,4 +253,21 @@ impl AsyncWrite for UnixOrTcpConnection {
|
|||||||
UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
|
UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
match self.project() {
|
||||||
|
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
|
||||||
|
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
match self.project() {
|
||||||
|
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
|
||||||
|
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user