You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Inject connection informations in the request extension
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2683,6 +2683,8 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.4",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
|
@@ -25,7 +25,11 @@ use mas_config::RootConfig;
|
||||
use mas_email::Mailer;
|
||||
use mas_handlers::{AppState, MatrixHomeserver};
|
||||
use mas_http::ServerLayer;
|
||||
use mas_listener::{maybe_tls::MaybeTlsAcceptor, proxy_protocol::MaybeProxyAcceptor};
|
||||
use mas_listener::{
|
||||
info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection},
|
||||
maybe_tls::MaybeTlsAcceptor,
|
||||
proxy_protocol::MaybeProxyAcceptor,
|
||||
};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::MIGRATOR;
|
||||
@@ -276,6 +280,7 @@ impl Options {
|
||||
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
|
||||
|
||||
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
|
||||
let make_service = IntoMakeServiceWithConnection::new(router);
|
||||
|
||||
async move {
|
||||
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
|
||||
@@ -288,9 +293,10 @@ impl Options {
|
||||
.try_for_each_concurrent(None, move |listener| {
|
||||
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
|
||||
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
|
||||
let listener = ConnectionInfoAcceptor::new(listener);
|
||||
|
||||
Server::builder(listener)
|
||||
.serve(router.clone().into_make_service())
|
||||
.serve(make_service.clone())
|
||||
.with_graceful_shutdown(signal.clone())
|
||||
})
|
||||
.await?;
|
||||
|
@@ -19,13 +19,14 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{body::HttpBody, Router};
|
||||
use axum::{body::HttpBody, Extension, Router};
|
||||
use listenfd::ListenFd;
|
||||
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
|
||||
use mas_handlers::AppState;
|
||||
use mas_listener::unix_or_tcp::UnixOrTcpListener;
|
||||
use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener};
|
||||
use mas_router::Route;
|
||||
use rustls::ServerConfig;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
|
||||
@@ -60,6 +61,16 @@ where
|
||||
mas_config::HttpResource::Compat => {
|
||||
router.merge(mas_handlers::compat_router(state.clone()))
|
||||
}
|
||||
// TODO: do a better handler here
|
||||
mas_config::HttpResource::ConnectionInfo => router.route(
|
||||
"/connection-info",
|
||||
axum::routing::get(
|
||||
|connection: Extension<Arc<OnceCell<Connection>>>| async move {
|
||||
let connection = connection.get().unwrap();
|
||||
format!("{connection:?}")
|
||||
},
|
||||
),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -248,6 +248,11 @@ pub enum Resource {
|
||||
#[serde(default)]
|
||||
web_root: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Mount a "/connection-info" handler which helps debugging informations on
|
||||
/// the upstream connection
|
||||
#[serde(rename = "connection-info")]
|
||||
ConnectionInfo,
|
||||
}
|
||||
|
||||
/// Configuration of a listener
|
||||
|
@@ -6,10 +6,12 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
futures-util = "0.3.24"
|
||||
hyper = { version = "0.14.20", features = ["server"] }
|
||||
pin-project-lite = "0.2.9"
|
||||
thiserror = "1.0.37"
|
||||
tokio = { version = "1.21.2", features = ["net"] }
|
||||
pin-project-lite = "0.2.9"
|
||||
hyper = { version = "0.14.20", features = ["server"] }
|
||||
futures-util = "0.3.24"
|
||||
tracing = "0.1.36"
|
||||
tokio-rustls = "0.23.4"
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.3.4", features = ["add-extension"] }
|
||||
tracing = "0.1.36"
|
||||
|
323
crates/listener/src/info.rs
Normal file
323
crates/listener/src/info.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
future::Ready,
|
||||
net::SocketAddr,
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use hyper::server::accept::Accept;
|
||||
use thiserror::Error;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
sync::OnceCell,
|
||||
};
|
||||
use tower::Service;
|
||||
use tower_http::add_extension::AddExtension;
|
||||
|
||||
use crate::{
|
||||
maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo, TlsStreamInfoError},
|
||||
proxy_protocol::{
|
||||
MaybeProxyAcceptor, MaybeProxyStream, ProxyHandshakeNotDone, ProxyProtocolV1Info,
|
||||
},
|
||||
unix_or_tcp::{UnixOrTcpConnection, UnixOrTcpListener},
|
||||
};
|
||||
|
||||
// TODO: this is a mess, clean that up
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum FromStreamError {
|
||||
#[error(transparent)]
|
||||
Proxy(#[from] ProxyHandshakeNotDone),
|
||||
|
||||
#[error(transparent)]
|
||||
Tls(#[from] TlsStreamInfoError),
|
||||
|
||||
#[error("Could not grab a reference to the underlying stream")]
|
||||
GetRef,
|
||||
|
||||
#[error("Could not get address info from underlying stream")]
|
||||
IoError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Connection {
|
||||
proxy: Option<ProxyProtocolV1Info>,
|
||||
tls: Option<TlsStreamInfo>,
|
||||
|
||||
// We're not saving the UNIX domain socket address here because it can't be cloned, which is
|
||||
// required for injecting the connection information as an extension
|
||||
local_tcp_addr: Option<SocketAddr>,
|
||||
peer_tcp_addr: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum GrabAddressError {
|
||||
#[error("Proxy protocol was initiated with an unknown protocol")]
|
||||
ProxyUnknown,
|
||||
|
||||
#[error("Proxy protocol was initiated with UDP")]
|
||||
ProxyUdp,
|
||||
|
||||
#[error("Underlying listener is a UNIX socket")]
|
||||
UnixListener,
|
||||
}
|
||||
|
||||
impl MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>> {
|
||||
pub fn can_have_peer_address(&self) -> bool {
|
||||
self.is_proxied() || self.is_tcp()
|
||||
}
|
||||
}
|
||||
|
||||
impl MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>> {
|
||||
/// Get informations about this connection
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the proxy protocol or the TLS handhakes are not done
|
||||
/// yet
|
||||
pub fn connection_info(&self) -> Result<Connection, FromStreamError> {
|
||||
Connection::from_stream(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Get informations about this connection
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the proxy protocol or the TLS handhakes are not done
|
||||
/// yet
|
||||
pub fn from_stream(
|
||||
stream: &MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
|
||||
) -> Result<Self, FromStreamError> {
|
||||
let proxy = stream.proxy_info()?.cloned();
|
||||
let tls = stream.tls_info()?;
|
||||
let original = stream.get_ref().ok_or(FromStreamError::GetRef)?;
|
||||
let local_tcp_addr = original.local_addr()?.into_net();
|
||||
let peer_tcp_addr = original.peer_addr()?.into_net();
|
||||
|
||||
Ok(Self {
|
||||
proxy,
|
||||
tls,
|
||||
local_tcp_addr,
|
||||
peer_tcp_addr,
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_proxied(&self) -> bool {
|
||||
self.proxy.is_some()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_tls(&self) -> bool {
|
||||
self.tls.is_some()
|
||||
}
|
||||
|
||||
/// Get the outmost peer address, either from the TCP listener or from the
|
||||
/// proxy protocol infos.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the info from the proxy protocol was not for a TCP
|
||||
/// connection, or if the proxy protocol is not being used, the underlying
|
||||
/// listener was a UNIX domain socket
|
||||
pub fn peer_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
|
||||
if let Some(proxy) = self.proxy.as_ref() {
|
||||
if proxy.is_udp() {
|
||||
return Err(GrabAddressError::ProxyUdp);
|
||||
}
|
||||
|
||||
proxy.source().ok_or(GrabAddressError::ProxyUnknown)
|
||||
} else {
|
||||
self.peer_tcp_addr
|
||||
.as_ref()
|
||||
.ok_or(GrabAddressError::UnixListener)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the outmost local address, either from the TCP listener or from the
|
||||
/// proxy protocol infos.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the info from the proxy protocol was not for a TCP
|
||||
/// connection, or if the proxy protocol is not being used, the underlying
|
||||
/// listener was a UNIX domain socket
|
||||
pub fn local_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
|
||||
if let Some(proxy) = self.proxy.as_ref() {
|
||||
if proxy.is_udp() {
|
||||
return Err(GrabAddressError::ProxyUdp);
|
||||
}
|
||||
|
||||
proxy.destination().ok_or(GrabAddressError::ProxyUnknown)
|
||||
} else {
|
||||
self.local_tcp_addr
|
||||
.as_ref()
|
||||
.ok_or(GrabAddressError::UnixListener)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct ConnectionInfoAcceptor {
|
||||
#[pin]
|
||||
acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionInfoAcceptor {
|
||||
pub const fn new(acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>) -> Self {
|
||||
Self { acceptor }
|
||||
}
|
||||
}
|
||||
|
||||
impl Accept for ConnectionInfoAcceptor {
|
||||
type Conn = ConnectionInfoStream;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let proj = self.project();
|
||||
let ret = match futures_util::ready!(proj.acceptor.poll_accept(cx)) {
|
||||
Some(Ok(conn)) => Some(Ok(ConnectionInfoStream::new(conn))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
};
|
||||
Poll::Ready(ret)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct ConnectionInfoStream {
|
||||
connection: Arc<OnceCell<Connection>>,
|
||||
#[pin]
|
||||
stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionInfoStream {
|
||||
pub fn new(stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>) -> Self {
|
||||
Self {
|
||||
connection: Arc::new(OnceCell::const_new()),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ConnectionInfoStream {
|
||||
type Target = MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for ConnectionInfoStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
futures_util::ready!(Pin::new(&mut this.stream).poll_read(cx, buf))?;
|
||||
|
||||
if !this.stream.is_tls_handshaking()
|
||||
&& !this.stream.is_proxy_handshaking()
|
||||
&& !this.connection.initialized()
|
||||
{
|
||||
this.connection
|
||||
.set(this.stream.connection_info().unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for ConnectionInfoStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let proj = self.project();
|
||||
proj.stream.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.stream.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IntoMakeServiceWithConnection<S> {
|
||||
svc: S,
|
||||
}
|
||||
|
||||
impl<S> IntoMakeServiceWithConnection<S> {
|
||||
pub const fn new(svc: S) -> Self {
|
||||
Self { svc }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Service<&ConnectionInfoStream> for IntoMakeServiceWithConnection<S>
|
||||
where
|
||||
S: Clone,
|
||||
{
|
||||
type Response = AddExtension<S, Arc<OnceCell<Connection>>>;
|
||||
type Error = FromStreamError;
|
||||
type Future = Ready<Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, target: &ConnectionInfoStream) -> Self::Future {
|
||||
std::future::ready(Ok(AddExtension::new(
|
||||
self.svc.clone(),
|
||||
target.connection.clone(),
|
||||
)))
|
||||
}
|
||||
}
|
@@ -22,6 +22,7 @@
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
pub mod info;
|
||||
pub mod maybe_tls;
|
||||
pub mod proxy_protocol;
|
||||
pub mod unix_or_tcp;
|
||||
|
@@ -21,8 +21,31 @@ use std::{
|
||||
|
||||
use futures_util::{ready, Future};
|
||||
use hyper::server::accept::Accept;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::rustls::{ServerConfig, ServerConnection};
|
||||
use tokio_rustls::rustls::{
|
||||
Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub struct TlsStreamInfo {
|
||||
pub protocol_version: ProtocolVersion,
|
||||
pub negotiated_cipher_suite: SupportedCipherSuite,
|
||||
pub sni_hostname: Option<String>,
|
||||
pub apln_protocol: Option<Vec<u8>>,
|
||||
pub peer_certificates: Option<Vec<Certificate>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum TlsStreamInfoError {
|
||||
#[error("TLS handshake is not done yet")]
|
||||
HandshakingNotDone,
|
||||
|
||||
#[error("Some fields were not available in the TLS connection")]
|
||||
FieldsNotAvailable,
|
||||
}
|
||||
|
||||
pub enum MaybeTlsStream<T> {
|
||||
Handshaking(tokio_rustls::Accept<T>),
|
||||
@@ -72,6 +95,44 @@ impl<T> MaybeTlsStream<T> {
|
||||
Self::Handshaking(_) | Self::Insecure(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gather informations about the TLS connection. Returns `None` if the
|
||||
/// stream is not a TLS stream.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the TLS handshake is not yet done
|
||||
pub fn tls_info(&self) -> Result<Option<TlsStreamInfo>, TlsStreamInfoError> {
|
||||
let conn = match self {
|
||||
Self::Streaming(stream) => stream.get_ref().1,
|
||||
Self::Handshaking(_) => return Err(TlsStreamInfoError::HandshakingNotDone),
|
||||
Self::Insecure(_) => return Ok(None),
|
||||
};
|
||||
|
||||
// NOTE: we're getting the protocol version and cipher suite *after* the
|
||||
// handshake, so this should never lead to an error
|
||||
let protocol_version = conn
|
||||
.protocol_version()
|
||||
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
|
||||
let negotiated_cipher_suite = conn
|
||||
.negotiated_cipher_suite()
|
||||
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
|
||||
|
||||
let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned);
|
||||
let apln_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
|
||||
let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned);
|
||||
Ok(Some(TlsStreamInfo {
|
||||
protocol_version,
|
||||
negotiated_cipher_suite,
|
||||
sni_hostname,
|
||||
apln_protocol,
|
||||
peer_certificates,
|
||||
}))
|
||||
}
|
||||
|
||||
pub const fn is_tls_handshaking(&self) -> bool {
|
||||
matches!(self, Self::Handshaking(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncRead for MaybeTlsStream<T>
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
@@ -21,7 +22,7 @@ use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::ProxyStream;
|
||||
use super::{stream::HandshakeNotDone, ProxyProtocolV1Info, ProxyStream};
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct MaybeProxyAcceptor<A> {
|
||||
@@ -37,6 +38,17 @@ impl<A> MaybeProxyAcceptor<A> {
|
||||
pub const fn new(inner: A, proxied: bool) -> Self {
|
||||
Self { proxied, inner }
|
||||
}
|
||||
|
||||
pub const fn is_proxied(&self) -> bool {
|
||||
self.proxied
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Deref for MaybeProxyAcceptor<A> {
|
||||
type Target = A;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Accept for MaybeProxyAcceptor<A>
|
||||
@@ -79,6 +91,35 @@ impl<S> MaybeProxyStream<S> {
|
||||
Self::NotProxied { stream }
|
||||
}
|
||||
}
|
||||
|
||||
/// Get informations from the proxied connection, if it was procied
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the stream did not complete the handshake yet
|
||||
pub fn proxy_info(&self) -> Result<Option<&ProxyProtocolV1Info>, HandshakeNotDone> {
|
||||
match self {
|
||||
Self::Proxied { stream } => Ok(Some(stream.proxy_info()?)),
|
||||
Self::NotProxied { .. } => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_proxy_handshaking(&self) -> bool {
|
||||
match self {
|
||||
Self::Proxied { stream } => stream.is_handshaking(),
|
||||
Self::NotProxied { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Deref for MaybeProxyStream<S> {
|
||||
type Target = S;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::Proxied { stream } => &**stream,
|
||||
Self::NotProxied { stream } => stream,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for MaybeProxyStream<S>
|
||||
|
@@ -20,6 +20,6 @@ mod v1;
|
||||
pub use self::{
|
||||
acceptor::ProxyAcceptor,
|
||||
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
|
||||
stream::ProxyStream,
|
||||
stream::{HandshakeNotDone as ProxyHandshakeNotDone, ProxyStream},
|
||||
v1::ProxyProtocolV1Info,
|
||||
};
|
||||
|
@@ -15,6 +15,7 @@
|
||||
use std::ops::Deref;
|
||||
|
||||
use futures_util::ready;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use super::ProxyProtocolV1Info;
|
||||
@@ -31,6 +32,12 @@ enum ProxyStreamState {
|
||||
Established(ProxyProtocolV1Info),
|
||||
}
|
||||
|
||||
impl ProxyStreamState {
|
||||
pub const fn is_handshaking(&self) -> bool {
|
||||
matches!(self, Self::Handshaking { .. })
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyStream<S> {
|
||||
@@ -53,6 +60,10 @@ impl<S> ProxyStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, Clone, Copy)]
|
||||
#[error("Proxy protocol handshake is not complete")]
|
||||
pub struct HandshakeNotDone;
|
||||
|
||||
impl<S> Deref for ProxyStream<S> {
|
||||
type Target = S;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
@@ -61,12 +72,22 @@ impl<S> Deref for ProxyStream<S> {
|
||||
}
|
||||
|
||||
impl<S> ProxyStream<S> {
|
||||
pub fn proxy_info(&self) -> Option<&ProxyProtocolV1Info> {
|
||||
/// Get informations from the proxied connection
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the stream did not complete the handshake yet
|
||||
pub fn proxy_info(&self) -> Result<&ProxyProtocolV1Info, HandshakeNotDone> {
|
||||
match &self.state {
|
||||
ProxyStreamState::Handshaking { .. } => None,
|
||||
ProxyStreamState::Established(info) => Some(info),
|
||||
ProxyStreamState::Handshaking { .. } => Err(HandshakeNotDone),
|
||||
ProxyStreamState::Established(info) => Ok(info),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the proxy protocol is still handshaking
|
||||
pub const fn is_handshaking(&self) -> bool {
|
||||
self.state.is_handshaking()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for ProxyStream<S>
|
||||
|
@@ -20,7 +20,7 @@ use std::{
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProxyProtocolV1Info {
|
||||
Tcp {
|
||||
source: SocketAddr,
|
||||
|
@@ -52,6 +52,40 @@ impl std::fmt::Debug for SocketAddr {
|
||||
}
|
||||
}
|
||||
|
||||
impl SocketAddr {
|
||||
#[must_use]
|
||||
pub fn into_net(self) -> Option<std::net::SocketAddr> {
|
||||
match self {
|
||||
Self::Net(socket) => Some(socket),
|
||||
Self::Unix(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn into_unix(self) -> Option<tokio::net::unix::SocketAddr> {
|
||||
match self {
|
||||
Self::Net(_) => None,
|
||||
Self::Unix(socket) => Some(socket),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn as_net(&self) -> Option<&std::net::SocketAddr> {
|
||||
match self {
|
||||
Self::Net(socket) => Some(socket),
|
||||
Self::Unix(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn as_unix(&self) -> Option<&tokio::net::unix::SocketAddr> {
|
||||
match self {
|
||||
Self::Net(_) => None,
|
||||
Self::Unix(socket) => Some(socket),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum UnixOrTcpListener {
|
||||
Unix(UnixListener),
|
||||
Tcp(TcpListener),
|
||||
@@ -98,6 +132,14 @@ impl UnixOrTcpListener {
|
||||
Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_unix(&self) -> bool {
|
||||
matches!(self, Self::Unix(_))
|
||||
}
|
||||
|
||||
pub const fn is_tcp(&self) -> bool {
|
||||
matches!(self, Self::Tcp(_))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
|
Reference in New Issue
Block a user