1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

HAProxy's Proxy Protocol acceptor

This commit is contained in:
Quentin Gliech
2022-10-06 11:00:55 +02:00
parent 9309f04880
commit f687ae4ac4
11 changed files with 725 additions and 3 deletions

View File

@ -25,7 +25,7 @@ use mas_config::RootConfig;
use mas_email::Mailer;
use mas_handlers::{AppState, MatrixHomeserver};
use mas_http::ServerLayer;
use mas_listener::maybe_tls::MaybeTlsAcceptor;
use mas_listener::{maybe_tls::MaybeTlsAcceptor, proxy_protocol::MaybeProxyAcceptor};
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR;
@ -271,7 +271,9 @@ impl Options {
}
}).collect();
info!("Listening on {addresses:?} with resources {resources:?}", resources = &config.resources);
let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" };
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
@ -285,6 +287,7 @@ impl Options {
.map(Ok)
.try_for_each_concurrent(None, move |listener| {
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
Server::builder(listener)
.serve(router.clone().into_make_service())

View File

@ -19,6 +19,7 @@ use async_trait::async_trait;
use mas_keystore::PrivateKey;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use url::Url;
use super::{secrets::PasswordOrFile, ConfigurationSection};
@ -66,6 +67,7 @@ impl UnixOrTcp {
}
/// Configuration of a single listener
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(untagged)]
pub enum BindConfig {
@ -74,6 +76,7 @@ pub enum BindConfig {
/// Host on which to listen.
///
/// Defaults to listening on all addresses
#[serde(default)]
host: Option<String>,
/// Port on which to listen.
@ -107,6 +110,7 @@ pub enum BindConfig {
/// Index of the file descriptor. Note that this is offseted by 3
/// because of the standard input/output sockets, so setting
/// here a value of `0` will grab the file descriptor `3`
#[serde(default)]
fd: usize,
/// Whether the socket is a TCP socket or a UNIX domain socket. Defaults
@ -131,6 +135,7 @@ pub enum CertificateOrFile {
}
/// Configuration related to TLS on a listener
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct TlsConfig {
/// PEM-encoded X509 certificate chain
@ -214,6 +219,7 @@ impl TlsConfig {
}
/// HTTP resources to mount
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(tag = "name", rename_all = "lowercase")]
pub enum Resource {
@ -245,10 +251,12 @@ pub enum Resource {
}
/// Configuration of a listener
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct ListenerConfig {
/// A unique name for this listener which will be shown in traces and in
/// metrics labels
#[serde(default)]
pub name: Option<String>,
/// List of resources to mount
@ -257,7 +265,12 @@ pub struct ListenerConfig {
/// List of sockets to bind
pub binds: Vec<BindConfig>,
/// Accept HAProxy's Proxy Protocol V1
#[serde(default)]
pub proxy_protocol: bool,
/// If set, makes the listener use TLS with the provided certificate and key
#[serde(default)]
pub tls: Option<TlsConfig>,
}
@ -286,6 +299,7 @@ impl Default for HttpConfig {
Resource::Static { web_root: None },
],
tls: None,
proxy_protocol: false,
binds: vec![BindConfig::Address {
address: "[::]:8080".into(),
}],
@ -294,6 +308,7 @@ impl Default for HttpConfig {
name: Some("internal".to_owned()),
resources: vec![Resource::Health],
tls: None,
proxy_protocol: false,
binds: vec![BindConfig::Address {
address: "localhost:8081".into(),
}],

View File

@ -6,6 +6,7 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["net"] }
pin-project-lite = "0.2.9"
hyper = { version = "0.14.20", features = ["server"] }

View File

@ -12,5 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(
clippy::all,
clippy::str_to_string,
rustdoc::missing_crate_level_docs,
rustdoc::broken_intra_doc_links
)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
pub mod maybe_tls;
pub mod proxy_protocol;
pub mod unix_or_tcp;

View File

@ -0,0 +1,52 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures_util::ready;
use hyper::server::accept::Accept;
use super::ProxyStream;
pin_project_lite::pin_project! {
pub struct ProxyAcceptor<A> {
#[pin]
inner: A,
}
}
impl<A> ProxyAcceptor<A> {
pub const fn new(inner: A) -> Self {
Self { inner }
}
}
impl<A> Accept for ProxyAcceptor<A>
where
A: Accept,
{
type Conn = ProxyStream<A::Conn>;
type Error = A::Error;
fn poll_accept(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
let res = match ready!(self.project().inner.poll_accept(cx)) {
Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))),
Some(Err(e)) => Some(Err(e)),
None => None,
};
std::task::Poll::Ready(res)
}
}

View File

@ -0,0 +1,149 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except proxied: streamliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures_util::ready;
use hyper::server::accept::Accept;
use tokio::io::{AsyncRead, AsyncWrite};
use super::ProxyStream;
pin_project_lite::pin_project! {
pub struct MaybeProxyAcceptor<A> {
proxied: bool,
#[pin]
inner: A,
}
}
impl<A> MaybeProxyAcceptor<A> {
#[must_use]
pub const fn new(inner: A, proxied: bool) -> Self {
Self { proxied, inner }
}
}
impl<A> Accept for MaybeProxyAcceptor<A>
where
A: Accept,
{
type Conn = MaybeProxyStream<A::Conn>;
type Error = A::Error;
fn poll_accept(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
let proj = self.project();
let res = match ready!(proj.inner.poll_accept(cx)) {
Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))),
Some(Err(e)) => Some(Err(e)),
None => None,
};
std::task::Poll::Ready(res)
}
}
pin_project_lite::pin_project! {
#[project = MaybeProxyStreamProj]
pub enum MaybeProxyStream<S> {
Proxied { #[pin] stream: ProxyStream<S> },
NotProxied { #[pin] stream: S },
}
}
impl<S> MaybeProxyStream<S> {
pub const fn new(stream: S, proxied: bool) -> Self {
if proxied {
Self::Proxied {
stream: ProxyStream::new(stream),
}
} else {
Self::NotProxied { stream }
}
}
}
impl<S> AsyncRead for MaybeProxyStream<S>
where
S: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf),
}
}
}
impl<S> AsyncWrite for MaybeProxyStream<S>
where
S: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs),
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(),
MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(),
}
}
}

View File

@ -0,0 +1,25 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod acceptor;
mod maybe;
mod stream;
mod v1;
pub use self::{
acceptor::ProxyAcceptor,
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
stream::ProxyStream,
v1::ProxyProtocolV1Info,
};

View File

@ -0,0 +1,148 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use futures_util::ready;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::ProxyProtocolV1Info;
// Max theorical size we need is 108 for proxy protocol v1
const BUF_SIZE: usize = 256;
#[derive(Debug)]
enum ProxyStreamState {
Handshaking {
buffer: [u8; BUF_SIZE],
index: usize,
},
Established(ProxyProtocolV1Info),
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct ProxyStream<S> {
state: ProxyStreamState,
#[pin]
inner: S,
}
}
impl<S> ProxyStream<S> {
pub const fn new(inner: S) -> Self {
Self {
state: ProxyStreamState::Handshaking {
buffer: [0; BUF_SIZE],
index: 0,
},
inner,
}
}
}
impl<S> Deref for ProxyStream<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<S> ProxyStream<S> {
pub fn proxy_info(&self) -> Option<&ProxyProtocolV1Info> {
match &self.state {
ProxyStreamState::Handshaking { .. } => None,
ProxyStreamState::Established(info) => Some(info),
}
}
}
impl<S> AsyncRead for ProxyStream<S>
where
S: AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let proj = self.project();
match proj.state {
ProxyStreamState::Handshaking { buffer, index } => {
let mut buffer = ReadBuf::new(&mut buffer[..]);
buffer.advance(*index);
ready!(proj.inner.poll_read(cx, &mut buffer))?;
let filled = buffer.filled();
*index = filled.len();
match ProxyProtocolV1Info::parse(filled) {
Ok((info, rest)) => {
if buf.remaining() < rest.len() {
// This is highly unlikely, but is better than panicking later.
// If it ever happens, we could introduce a "buffer draining" state
// which drains the inner buffer repeatedly until it's empty
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"underlying buffer is too small",
)));
}
buf.put_slice(rest);
*proj.state = ProxyStreamState::Established(info);
std::task::Poll::Ready(Ok(()))
}
Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())),
Err(e) => std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
e,
))),
}
}
ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf),
}
}
}
impl<S> AsyncWrite for ProxyStream<S>
where
S: AsyncWrite,
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let proj = self.project();
match proj.state {
// Hold off writes until the handshake is done
// XXX: is this the right way to do it?
ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending,
ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.project().inner.poll_shutdown(cx)
}
}

View File

@ -0,0 +1,300 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
num::ParseIntError,
str::Utf8Error,
};
use thiserror::Error;
#[derive(Debug)]
pub enum ProxyProtocolV1Info {
Tcp {
source: SocketAddr,
destination: SocketAddr,
},
Udp {
source: SocketAddr,
destination: SocketAddr,
},
Unknown,
}
#[derive(Error, Debug)]
#[error("Invalid proxy protocol header")]
pub(super) enum ParseError {
#[error("Not enough bytes provided")]
NotEnoughBytes,
NoCrLf,
NoProxyPreamble,
NoProtocol,
InvalidProtocol,
NoSourceAddress,
NoDestinationAddress,
NoSourcePort,
NoDestinationPort,
TooManyFields,
InvalidUtf8(#[from] Utf8Error),
InvalidAddress(#[from] AddrParseError),
InvalidPort(#[from] ParseIntError),
}
impl ParseError {
pub const fn not_enough_bytes(&self) -> bool {
matches!(self, &Self::NotEnoughBytes)
}
}
impl ProxyProtocolV1Info {
#[allow(clippy::too_many_lines)]
pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> {
use ParseError as E;
// First, check if we *possibly* have enough bytes.
// Minimum is 15: "PROXY UNKNOWN\r\n"
if bytes.len() < 15 {
return Err(E::NotEnoughBytes);
}
// Let's check in the first 108 bytes if we find a CRLF
let crlf = if let Some(crlf) = bytes
.windows(2)
.take(108)
.position(|needle| needle == [0x0D, 0x0A])
{
crlf
} else {
// If not, it might be because we don't have enough bytes
return if bytes.len() < 108 {
Err(E::NotEnoughBytes)
} else {
// Else it's just invalid
Err(E::NoCrLf)
};
};
// Keep the rest of the buffer to pass it to the underlying protocol
let rest = &bytes[crlf + 2..];
// Trim to everything before the CRLF
let bytes = &bytes[..crlf];
let mut it = bytes.splitn(6, |c| c == &b' ');
// Check for the preamble
if it.next() != Some(b"PROXY") {
return Err(E::NoProxyPreamble);
}
let result = match it.next() {
Some(b"TCP4") => {
let source_address: Ipv4Addr =
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
let destination_address: Ipv4Addr =
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
let source_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
let destination_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
if it.next().is_some() {
return Err(E::TooManyFields);
}
let source = (source_address, source_port).into();
let destination = (destination_address, destination_port).into();
Self::Tcp {
source,
destination,
}
}
Some(b"TCP6") => {
let source_address: Ipv6Addr =
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
let destination_address: Ipv6Addr =
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
let source_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
let destination_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
if it.next().is_some() {
return Err(E::TooManyFields);
}
let source = (source_address, source_port).into();
let destination = (destination_address, destination_port).into();
Self::Tcp {
source,
destination,
}
}
Some(b"UDP4") => {
let source_address: Ipv4Addr =
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
let destination_address: Ipv4Addr =
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
let source_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
let destination_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
if it.next().is_some() {
return Err(E::TooManyFields);
}
let source = (source_address, source_port).into();
let destination = (destination_address, destination_port).into();
Self::Udp {
source,
destination,
}
}
Some(b"UDP6") => {
let source_address: Ipv6Addr =
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
let destination_address: Ipv6Addr =
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
let source_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
let destination_port: u16 =
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
if it.next().is_some() {
return Err(E::TooManyFields);
}
let source = (source_address, source_port).into();
let destination = (destination_address, destination_port).into();
Self::Udp {
source,
destination,
}
}
Some(b"UNKNOWN") => Self::Unknown,
Some(_) => return Err(E::InvalidProtocol),
None => return Err(E::NoProtocol),
};
Ok((result, rest))
}
#[must_use]
pub fn is_ipv4(&self) -> bool {
match self {
Self::Udp {
source,
destination,
}
| Self::Tcp {
source,
destination,
} => source.is_ipv4() && destination.is_ipv4(),
Self::Unknown => false,
}
}
#[must_use]
pub fn is_ipv6(&self) -> bool {
match self {
Self::Udp {
source,
destination,
}
| Self::Tcp {
source,
destination,
} => source.is_ipv6() && destination.is_ipv6(),
Self::Unknown => false,
}
}
#[must_use]
pub const fn is_tcp(&self) -> bool {
matches!(self, Self::Tcp { .. })
}
#[must_use]
pub const fn is_udp(&self) -> bool {
matches!(self, Self::Udp { .. })
}
#[must_use]
pub const fn is_unknown(&self) -> bool {
matches!(self, Self::Unknown)
}
#[must_use]
pub const fn source(&self) -> Option<&SocketAddr> {
match self {
Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
Self::Unknown => None,
}
}
#[must_use]
pub const fn destination(&self) -> Option<&SocketAddr> {
match self {
Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
Self::Unknown => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse() {
let (info, rest) = ProxyProtocolV1Info::parse(
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world",
)
.unwrap();
assert_eq!(rest, b"hello world");
assert!(info.is_tcp());
assert!(!info.is_udp());
assert!(!info.is_unknown());
assert!(info.is_ipv4());
assert!(!info.is_ipv6());
let (info, rest) = ProxyProtocolV1Info::parse(
b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
).unwrap();
assert_eq!(rest, b"hello world");
assert!(info.is_tcp());
assert!(!info.is_udp());
assert!(!info.is_unknown());
assert!(!info.is_ipv4());
assert!(info.is_ipv6());
let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap();
assert_eq!(rest, b"hello world");
assert!(!info.is_tcp());
assert!(!info.is_udp());
assert!(info.is_unknown());
assert!(!info.is_ipv4());
assert!(!info.is_ipv6());
let (info, rest) = ProxyProtocolV1Info::parse(
b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
).unwrap();
assert_eq!(rest, b"hello world");
assert!(!info.is_tcp());
assert!(!info.is_udp());
assert!(info.is_unknown());
assert!(!info.is_ipv4());
assert!(!info.is_ipv6());
}
}

View File

@ -13,7 +13,6 @@
// limitations under the License.
// TODO: Unlink the UNIX socket on drop?
// TODO: Proxy protocol
use std::{
pin::Pin,
@ -87,6 +86,12 @@ impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
}
impl UnixOrTcpListener {
/// Get the local address of the listener
///
/// # Errors
///
/// Returns an error on rare cases where the underlying [`TcpListener`] or
/// [`UnixListener`] couldn't provide the local address
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
@ -111,6 +116,12 @@ pin_project_lite::pin_project! {
}
impl UnixOrTcpConnection {
/// Get the local address of the stream
///
/// # Errors
///
/// Returns an error on rare cases where the underlying [`TcpStream`] or
/// [`UnixStream`] couldn't provide the local address
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
@ -118,6 +129,12 @@ impl UnixOrTcpConnection {
}
}
/// Get the remote address of the stream
///
/// # Errors
///
/// Returns an error on rare cases where the underlying [`TcpStream`] or
/// [`UnixStream`] couldn't provide the remote address
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),