diff --git a/Cargo.lock b/Cargo.lock index c630c2d5..02567d82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2680,6 +2680,7 @@ dependencies = [ "futures-util", "hyper", "pin-project-lite", + "thiserror", "tokio", "tokio-rustls 0.23.4", "tracing", diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index df820063..09703ac9 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -25,7 +25,7 @@ use mas_config::RootConfig; use mas_email::Mailer; use mas_handlers::{AppState, MatrixHomeserver}; use mas_http::ServerLayer; -use mas_listener::maybe_tls::MaybeTlsAcceptor; +use mas_listener::{maybe_tls::MaybeTlsAcceptor, proxy_protocol::MaybeProxyAcceptor}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_storage::MIGRATOR; @@ -271,7 +271,9 @@ impl Options { } }).collect(); - info!("Listening on {addresses:?} with resources {resources:?}", resources = &config.resources); + let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" }; + + info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources); let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone())); @@ -285,6 +287,7 @@ impl Options { .map(Ok) .try_for_each_concurrent(None, move |listener| { let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener); + let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol); Server::builder(listener) .serve(router.clone().into_make_service()) diff --git a/crates/config/src/sections/http.rs b/crates/config/src/sections/http.rs index 80a8d772..4db7ad6b 100644 --- a/crates/config/src/sections/http.rs +++ b/crates/config/src/sections/http.rs @@ -19,6 +19,7 @@ use async_trait::async_trait; use mas_keystore::PrivateKey; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; use url::Url; use super::{secrets::PasswordOrFile, ConfigurationSection}; @@ -66,6 +67,7 @@ impl UnixOrTcp { } /// Configuration of a single listener +#[skip_serializing_none] #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] #[serde(untagged)] pub enum BindConfig { @@ -74,6 +76,7 @@ pub enum BindConfig { /// Host on which to listen. /// /// Defaults to listening on all addresses + #[serde(default)] host: Option, /// Port on which to listen. @@ -107,6 +110,7 @@ pub enum BindConfig { /// Index of the file descriptor. Note that this is offseted by 3 /// because of the standard input/output sockets, so setting /// here a value of `0` will grab the file descriptor `3` + #[serde(default)] fd: usize, /// Whether the socket is a TCP socket or a UNIX domain socket. Defaults @@ -131,6 +135,7 @@ pub enum CertificateOrFile { } /// Configuration related to TLS on a listener +#[skip_serializing_none] #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] pub struct TlsConfig { /// PEM-encoded X509 certificate chain @@ -214,6 +219,7 @@ impl TlsConfig { } /// HTTP resources to mount +#[skip_serializing_none] #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] #[serde(tag = "name", rename_all = "lowercase")] pub enum Resource { @@ -245,10 +251,12 @@ pub enum Resource { } /// Configuration of a listener +#[skip_serializing_none] #[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)] pub struct ListenerConfig { /// A unique name for this listener which will be shown in traces and in /// metrics labels + #[serde(default)] pub name: Option, /// List of resources to mount @@ -257,7 +265,12 @@ pub struct ListenerConfig { /// List of sockets to bind pub binds: Vec, + /// Accept HAProxy's Proxy Protocol V1 + #[serde(default)] + pub proxy_protocol: bool, + /// If set, makes the listener use TLS with the provided certificate and key + #[serde(default)] pub tls: Option, } @@ -286,6 +299,7 @@ impl Default for HttpConfig { Resource::Static { web_root: None }, ], tls: None, + proxy_protocol: false, binds: vec![BindConfig::Address { address: "[::]:8080".into(), }], @@ -294,6 +308,7 @@ impl Default for HttpConfig { name: Some("internal".to_owned()), resources: vec![Resource::Health], tls: None, + proxy_protocol: false, binds: vec![BindConfig::Address { address: "localhost:8081".into(), }], diff --git a/crates/listener/Cargo.toml b/crates/listener/Cargo.toml index 17092d84..ff677159 100644 --- a/crates/listener/Cargo.toml +++ b/crates/listener/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] +thiserror = "1.0.37" tokio = { version = "1.21.2", features = ["net"] } pin-project-lite = "0.2.9" hyper = { version = "0.14.20", features = ["server"] } diff --git a/crates/listener/src/lib.rs b/crates/listener/src/lib.rs index ba110718..a5ae777d 100644 --- a/crates/listener/src/lib.rs +++ b/crates/listener/src/lib.rs @@ -12,5 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![forbid(unsafe_code)] +#![deny( + clippy::all, + clippy::str_to_string, + rustdoc::missing_crate_level_docs, + rustdoc::broken_intra_doc_links +)] +#![warn(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + pub mod maybe_tls; +pub mod proxy_protocol; pub mod unix_or_tcp; diff --git a/crates/listener/src/proxy_protocol/acceptor.rs b/crates/listener/src/proxy_protocol/acceptor.rs new file mode 100644 index 00000000..d10bd392 --- /dev/null +++ b/crates/listener/src/proxy_protocol/acceptor.rs @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures_util::ready; +use hyper::server::accept::Accept; + +use super::ProxyStream; + +pin_project_lite::pin_project! { + pub struct ProxyAcceptor { + #[pin] + inner: A, + } +} + +impl ProxyAcceptor { + pub const fn new(inner: A) -> Self { + Self { inner } + } +} + +impl Accept for ProxyAcceptor +where + A: Accept, +{ + type Conn = ProxyStream; + type Error = A::Error; + + fn poll_accept( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + let res = match ready!(self.project().inner.poll_accept(cx)) { + Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))), + Some(Err(e)) => Some(Err(e)), + None => None, + }; + + std::task::Poll::Ready(res) + } +} diff --git a/crates/listener/src/proxy_protocol/maybe.rs b/crates/listener/src/proxy_protocol/maybe.rs new file mode 100644 index 00000000..7e261009 --- /dev/null +++ b/crates/listener/src/proxy_protocol/maybe.rs @@ -0,0 +1,149 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except proxied: streamliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::ready; +use hyper::server::accept::Accept; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::ProxyStream; + +pin_project_lite::pin_project! { + pub struct MaybeProxyAcceptor { + proxied: bool, + + #[pin] + inner: A, + } +} + +impl MaybeProxyAcceptor { + #[must_use] + pub const fn new(inner: A, proxied: bool) -> Self { + Self { proxied, inner } + } +} + +impl Accept for MaybeProxyAcceptor +where + A: Accept, +{ + type Conn = MaybeProxyStream; + type Error = A::Error; + + fn poll_accept( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + let proj = self.project(); + let res = match ready!(proj.inner.poll_accept(cx)) { + Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))), + Some(Err(e)) => Some(Err(e)), + None => None, + }; + + std::task::Poll::Ready(res) + } +} + +pin_project_lite::pin_project! { + #[project = MaybeProxyStreamProj] + pub enum MaybeProxyStream { + Proxied { #[pin] stream: ProxyStream }, + NotProxied { #[pin] stream: S }, + } +} + +impl MaybeProxyStream { + pub const fn new(stream: S, proxied: bool) -> Self { + if proxied { + Self::Proxied { + stream: ProxyStream::new(stream), + } + } else { + Self::NotProxied { stream } + } + } +} + +impl AsyncRead for MaybeProxyStream +where + S: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.project() { + MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf), + MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for MaybeProxyStream +where + S: AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf), + MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx), + MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx), + MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs), + MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(), + MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(), + } + } +} diff --git a/crates/listener/src/proxy_protocol/mod.rs b/crates/listener/src/proxy_protocol/mod.rs new file mode 100644 index 00000000..39ee2d0f --- /dev/null +++ b/crates/listener/src/proxy_protocol/mod.rs @@ -0,0 +1,25 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod acceptor; +mod maybe; +mod stream; +mod v1; + +pub use self::{ + acceptor::ProxyAcceptor, + maybe::{MaybeProxyAcceptor, MaybeProxyStream}, + stream::ProxyStream, + v1::ProxyProtocolV1Info, +}; diff --git a/crates/listener/src/proxy_protocol/stream.rs b/crates/listener/src/proxy_protocol/stream.rs new file mode 100644 index 00000000..a59b1829 --- /dev/null +++ b/crates/listener/src/proxy_protocol/stream.rs @@ -0,0 +1,148 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Deref; + +use futures_util::ready; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use super::ProxyProtocolV1Info; + +// Max theorical size we need is 108 for proxy protocol v1 +const BUF_SIZE: usize = 256; + +#[derive(Debug)] +enum ProxyStreamState { + Handshaking { + buffer: [u8; BUF_SIZE], + index: usize, + }, + Established(ProxyProtocolV1Info), +} + +pin_project_lite::pin_project! { + #[derive(Debug)] + pub struct ProxyStream { + state: ProxyStreamState, + + #[pin] + inner: S, + } +} + +impl ProxyStream { + pub const fn new(inner: S) -> Self { + Self { + state: ProxyStreamState::Handshaking { + buffer: [0; BUF_SIZE], + index: 0, + }, + inner, + } + } +} + +impl Deref for ProxyStream { + type Target = S; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl ProxyStream { + pub fn proxy_info(&self) -> Option<&ProxyProtocolV1Info> { + match &self.state { + ProxyStreamState::Handshaking { .. } => None, + ProxyStreamState::Established(info) => Some(info), + } + } +} + +impl AsyncRead for ProxyStream +where + S: AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let proj = self.project(); + match proj.state { + ProxyStreamState::Handshaking { buffer, index } => { + let mut buffer = ReadBuf::new(&mut buffer[..]); + buffer.advance(*index); + ready!(proj.inner.poll_read(cx, &mut buffer))?; + let filled = buffer.filled(); + *index = filled.len(); + + match ProxyProtocolV1Info::parse(filled) { + Ok((info, rest)) => { + if buf.remaining() < rest.len() { + // This is highly unlikely, but is better than panicking later. + // If it ever happens, we could introduce a "buffer draining" state + // which drains the inner buffer repeatedly until it's empty + return std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "underlying buffer is too small", + ))); + } + buf.put_slice(rest); + *proj.state = ProxyStreamState::Established(info); + std::task::Poll::Ready(Ok(())) + } + Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())), + Err(e) => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + e, + ))), + } + } + ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ProxyStream +where + S: AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let proj = self.project(); + match proj.state { + // Hold off writes until the handshake is done + // XXX: is this the right way to do it? + ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending, + ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_shutdown(cx) + } +} diff --git a/crates/listener/src/proxy_protocol/v1.rs b/crates/listener/src/proxy_protocol/v1.rs new file mode 100644 index 00000000..3789a13c --- /dev/null +++ b/crates/listener/src/proxy_protocol/v1.rs @@ -0,0 +1,300 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr}, + num::ParseIntError, + str::Utf8Error, +}; + +use thiserror::Error; + +#[derive(Debug)] +pub enum ProxyProtocolV1Info { + Tcp { + source: SocketAddr, + destination: SocketAddr, + }, + Udp { + source: SocketAddr, + destination: SocketAddr, + }, + Unknown, +} + +#[derive(Error, Debug)] +#[error("Invalid proxy protocol header")] +pub(super) enum ParseError { + #[error("Not enough bytes provided")] + NotEnoughBytes, + NoCrLf, + NoProxyPreamble, + NoProtocol, + InvalidProtocol, + NoSourceAddress, + NoDestinationAddress, + NoSourcePort, + NoDestinationPort, + TooManyFields, + InvalidUtf8(#[from] Utf8Error), + InvalidAddress(#[from] AddrParseError), + InvalidPort(#[from] ParseIntError), +} + +impl ParseError { + pub const fn not_enough_bytes(&self) -> bool { + matches!(self, &Self::NotEnoughBytes) + } +} + +impl ProxyProtocolV1Info { + #[allow(clippy::too_many_lines)] + pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> { + use ParseError as E; + // First, check if we *possibly* have enough bytes. + // Minimum is 15: "PROXY UNKNOWN\r\n" + + if bytes.len() < 15 { + return Err(E::NotEnoughBytes); + } + + // Let's check in the first 108 bytes if we find a CRLF + let crlf = if let Some(crlf) = bytes + .windows(2) + .take(108) + .position(|needle| needle == [0x0D, 0x0A]) + { + crlf + } else { + // If not, it might be because we don't have enough bytes + return if bytes.len() < 108 { + Err(E::NotEnoughBytes) + } else { + // Else it's just invalid + Err(E::NoCrLf) + }; + }; + + // Keep the rest of the buffer to pass it to the underlying protocol + let rest = &bytes[crlf + 2..]; + // Trim to everything before the CRLF + let bytes = &bytes[..crlf]; + + let mut it = bytes.splitn(6, |c| c == &b' '); + // Check for the preamble + if it.next() != Some(b"PROXY") { + return Err(E::NoProxyPreamble); + } + + let result = match it.next() { + Some(b"TCP4") => { + let source_address: Ipv4Addr = + std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?; + let destination_address: Ipv4Addr = + std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?; + let source_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?; + let destination_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?; + if it.next().is_some() { + return Err(E::TooManyFields); + } + + let source = (source_address, source_port).into(); + let destination = (destination_address, destination_port).into(); + + Self::Tcp { + source, + destination, + } + } + Some(b"TCP6") => { + let source_address: Ipv6Addr = + std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?; + let destination_address: Ipv6Addr = + std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?; + let source_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?; + let destination_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?; + if it.next().is_some() { + return Err(E::TooManyFields); + } + + let source = (source_address, source_port).into(); + let destination = (destination_address, destination_port).into(); + + Self::Tcp { + source, + destination, + } + } + Some(b"UDP4") => { + let source_address: Ipv4Addr = + std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?; + let destination_address: Ipv4Addr = + std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?; + let source_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?; + let destination_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?; + if it.next().is_some() { + return Err(E::TooManyFields); + } + + let source = (source_address, source_port).into(); + let destination = (destination_address, destination_port).into(); + + Self::Udp { + source, + destination, + } + } + Some(b"UDP6") => { + let source_address: Ipv6Addr = + std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?; + let destination_address: Ipv6Addr = + std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?; + let source_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?; + let destination_port: u16 = + std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?; + if it.next().is_some() { + return Err(E::TooManyFields); + } + + let source = (source_address, source_port).into(); + let destination = (destination_address, destination_port).into(); + + Self::Udp { + source, + destination, + } + } + Some(b"UNKNOWN") => Self::Unknown, + Some(_) => return Err(E::InvalidProtocol), + None => return Err(E::NoProtocol), + }; + + Ok((result, rest)) + } + + #[must_use] + pub fn is_ipv4(&self) -> bool { + match self { + Self::Udp { + source, + destination, + } + | Self::Tcp { + source, + destination, + } => source.is_ipv4() && destination.is_ipv4(), + Self::Unknown => false, + } + } + + #[must_use] + pub fn is_ipv6(&self) -> bool { + match self { + Self::Udp { + source, + destination, + } + | Self::Tcp { + source, + destination, + } => source.is_ipv6() && destination.is_ipv6(), + Self::Unknown => false, + } + } + + #[must_use] + pub const fn is_tcp(&self) -> bool { + matches!(self, Self::Tcp { .. }) + } + + #[must_use] + pub const fn is_udp(&self) -> bool { + matches!(self, Self::Udp { .. }) + } + + #[must_use] + pub const fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } + + #[must_use] + pub const fn source(&self) -> Option<&SocketAddr> { + match self { + Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source), + Self::Unknown => None, + } + } + + #[must_use] + pub const fn destination(&self) -> Option<&SocketAddr> { + match self { + Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination), + Self::Unknown => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse() { + let (info, rest) = ProxyProtocolV1Info::parse( + b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world", + ) + .unwrap(); + assert_eq!(rest, b"hello world"); + assert!(info.is_tcp()); + assert!(!info.is_udp()); + assert!(!info.is_unknown()); + assert!(info.is_ipv4()); + assert!(!info.is_ipv6()); + + let (info, rest) = ProxyProtocolV1Info::parse( + b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world" + ).unwrap(); + assert_eq!(rest, b"hello world"); + assert!(info.is_tcp()); + assert!(!info.is_udp()); + assert!(!info.is_unknown()); + assert!(!info.is_ipv4()); + assert!(info.is_ipv6()); + + let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap(); + assert_eq!(rest, b"hello world"); + assert!(!info.is_tcp()); + assert!(!info.is_udp()); + assert!(info.is_unknown()); + assert!(!info.is_ipv4()); + assert!(!info.is_ipv6()); + + let (info, rest) = ProxyProtocolV1Info::parse( + b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world" + ).unwrap(); + assert_eq!(rest, b"hello world"); + assert!(!info.is_tcp()); + assert!(!info.is_udp()); + assert!(info.is_unknown()); + assert!(!info.is_ipv4()); + assert!(!info.is_ipv6()); + } +} diff --git a/crates/listener/src/unix_or_tcp.rs b/crates/listener/src/unix_or_tcp.rs index 86f8e73c..1357bb7d 100644 --- a/crates/listener/src/unix_or_tcp.rs +++ b/crates/listener/src/unix_or_tcp.rs @@ -13,7 +13,6 @@ // limitations under the License. // TODO: Unlink the UNIX socket on drop? -// TODO: Proxy protocol use std::{ pin::Pin, @@ -87,6 +86,12 @@ impl TryFrom for UnixOrTcpListener { } impl UnixOrTcpListener { + /// Get the local address of the listener + /// + /// # Errors + /// + /// Returns an error on rare cases where the underlying [`TcpListener`] or + /// [`UnixListener`] couldn't provide the local address pub fn local_addr(&self) -> Result { match self { Self::Unix(listener) => listener.local_addr().map(SocketAddr::from), @@ -111,6 +116,12 @@ pin_project_lite::pin_project! { } impl UnixOrTcpConnection { + /// Get the local address of the stream + /// + /// # Errors + /// + /// Returns an error on rare cases where the underlying [`TcpStream`] or + /// [`UnixStream`] couldn't provide the local address pub fn local_addr(&self) -> Result { match self { Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from), @@ -118,6 +129,12 @@ impl UnixOrTcpConnection { } } + /// Get the remote address of the stream + /// + /// # Errors + /// + /// Returns an error on rare cases where the underlying [`TcpStream`] or + /// [`UnixStream`] couldn't provide the remote address pub fn peer_addr(&self) -> Result { match self { Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),