You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
HAProxy's Proxy Protocol acceptor
This commit is contained in:
@@ -12,5 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(
|
||||
clippy::all,
|
||||
clippy::str_to_string,
|
||||
rustdoc::missing_crate_level_docs,
|
||||
rustdoc::broken_intra_doc_links
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
pub mod maybe_tls;
|
||||
pub mod proxy_protocol;
|
||||
pub mod unix_or_tcp;
|
||||
|
||||
52
crates/listener/src/proxy_protocol/acceptor.rs
Normal file
52
crates/listener/src/proxy_protocol/acceptor.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
|
||||
use super::ProxyStream;
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct ProxyAcceptor<A> {
|
||||
#[pin]
|
||||
inner: A,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> ProxyAcceptor<A> {
|
||||
pub const fn new(inner: A) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Accept for ProxyAcceptor<A>
|
||||
where
|
||||
A: Accept,
|
||||
{
|
||||
type Conn = ProxyStream<A::Conn>;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let res = match ready!(self.project().inner.poll_accept(cx)) {
|
||||
Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
std::task::Poll::Ready(res)
|
||||
}
|
||||
}
|
||||
149
crates/listener/src/proxy_protocol/maybe.rs
Normal file
149
crates/listener/src/proxy_protocol/maybe.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except proxied: streamliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::ProxyStream;
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct MaybeProxyAcceptor<A> {
|
||||
proxied: bool,
|
||||
|
||||
#[pin]
|
||||
inner: A,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> MaybeProxyAcceptor<A> {
|
||||
#[must_use]
|
||||
pub const fn new(inner: A, proxied: bool) -> Self {
|
||||
Self { proxied, inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Accept for MaybeProxyAcceptor<A>
|
||||
where
|
||||
A: Accept,
|
||||
{
|
||||
type Conn = MaybeProxyStream<A::Conn>;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let proj = self.project();
|
||||
let res = match ready!(proj.inner.poll_accept(cx)) {
|
||||
Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
std::task::Poll::Ready(res)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[project = MaybeProxyStreamProj]
|
||||
pub enum MaybeProxyStream<S> {
|
||||
Proxied { #[pin] stream: ProxyStream<S> },
|
||||
NotProxied { #[pin] stream: S },
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MaybeProxyStream<S> {
|
||||
pub const fn new(stream: S, proxied: bool) -> Self {
|
||||
if proxied {
|
||||
Self::Proxied {
|
||||
stream: ProxyStream::new(stream),
|
||||
}
|
||||
} else {
|
||||
Self::NotProxied { stream }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for MaybeProxyStream<S>
|
||||
where
|
||||
S: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for MaybeProxyStream<S>
|
||||
where
|
||||
S: AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match self {
|
||||
MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(),
|
||||
MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
}
|
||||
25
crates/listener/src/proxy_protocol/mod.rs
Normal file
25
crates/listener/src/proxy_protocol/mod.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod acceptor;
|
||||
mod maybe;
|
||||
mod stream;
|
||||
mod v1;
|
||||
|
||||
pub use self::{
|
||||
acceptor::ProxyAcceptor,
|
||||
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
|
||||
stream::ProxyStream,
|
||||
v1::ProxyProtocolV1Info,
|
||||
};
|
||||
148
crates/listener/src/proxy_protocol/stream.rs
Normal file
148
crates/listener/src/proxy_protocol/stream.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use futures_util::ready;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use super::ProxyProtocolV1Info;
|
||||
|
||||
// Max theorical size we need is 108 for proxy protocol v1
|
||||
const BUF_SIZE: usize = 256;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ProxyStreamState {
|
||||
Handshaking {
|
||||
buffer: [u8; BUF_SIZE],
|
||||
index: usize,
|
||||
},
|
||||
Established(ProxyProtocolV1Info),
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyStream<S> {
|
||||
state: ProxyStreamState,
|
||||
|
||||
#[pin]
|
||||
inner: S,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyStream<S> {
|
||||
pub const fn new(inner: S) -> Self {
|
||||
Self {
|
||||
state: ProxyStreamState::Handshaking {
|
||||
buffer: [0; BUF_SIZE],
|
||||
index: 0,
|
||||
},
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Deref for ProxyStream<S> {
|
||||
type Target = S;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyStream<S> {
|
||||
pub fn proxy_info(&self) -> Option<&ProxyProtocolV1Info> {
|
||||
match &self.state {
|
||||
ProxyStreamState::Handshaking { .. } => None,
|
||||
ProxyStreamState::Established(info) => Some(info),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for ProxyStream<S>
|
||||
where
|
||||
S: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let proj = self.project();
|
||||
match proj.state {
|
||||
ProxyStreamState::Handshaking { buffer, index } => {
|
||||
let mut buffer = ReadBuf::new(&mut buffer[..]);
|
||||
buffer.advance(*index);
|
||||
ready!(proj.inner.poll_read(cx, &mut buffer))?;
|
||||
let filled = buffer.filled();
|
||||
*index = filled.len();
|
||||
|
||||
match ProxyProtocolV1Info::parse(filled) {
|
||||
Ok((info, rest)) => {
|
||||
if buf.remaining() < rest.len() {
|
||||
// This is highly unlikely, but is better than panicking later.
|
||||
// If it ever happens, we could introduce a "buffer draining" state
|
||||
// which drains the inner buffer repeatedly until it's empty
|
||||
return std::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"underlying buffer is too small",
|
||||
)));
|
||||
}
|
||||
buf.put_slice(rest);
|
||||
*proj.state = ProxyStreamState::Established(info);
|
||||
std::task::Poll::Ready(Ok(()))
|
||||
}
|
||||
Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())),
|
||||
Err(e) => std::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e,
|
||||
))),
|
||||
}
|
||||
}
|
||||
ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for ProxyStream<S>
|
||||
where
|
||||
S: AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let proj = self.project();
|
||||
match proj.state {
|
||||
// Hold off writes until the handshake is done
|
||||
// XXX: is this the right way to do it?
|
||||
ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending,
|
||||
ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
300
crates/listener/src/proxy_protocol/v1.rs
Normal file
300
crates/listener/src/proxy_protocol/v1.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
num::ParseIntError,
|
||||
str::Utf8Error,
|
||||
};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProxyProtocolV1Info {
|
||||
Tcp {
|
||||
source: SocketAddr,
|
||||
destination: SocketAddr,
|
||||
},
|
||||
Udp {
|
||||
source: SocketAddr,
|
||||
destination: SocketAddr,
|
||||
},
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Invalid proxy protocol header")]
|
||||
pub(super) enum ParseError {
|
||||
#[error("Not enough bytes provided")]
|
||||
NotEnoughBytes,
|
||||
NoCrLf,
|
||||
NoProxyPreamble,
|
||||
NoProtocol,
|
||||
InvalidProtocol,
|
||||
NoSourceAddress,
|
||||
NoDestinationAddress,
|
||||
NoSourcePort,
|
||||
NoDestinationPort,
|
||||
TooManyFields,
|
||||
InvalidUtf8(#[from] Utf8Error),
|
||||
InvalidAddress(#[from] AddrParseError),
|
||||
InvalidPort(#[from] ParseIntError),
|
||||
}
|
||||
|
||||
impl ParseError {
|
||||
pub const fn not_enough_bytes(&self) -> bool {
|
||||
matches!(self, &Self::NotEnoughBytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProxyProtocolV1Info {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> {
|
||||
use ParseError as E;
|
||||
// First, check if we *possibly* have enough bytes.
|
||||
// Minimum is 15: "PROXY UNKNOWN\r\n"
|
||||
|
||||
if bytes.len() < 15 {
|
||||
return Err(E::NotEnoughBytes);
|
||||
}
|
||||
|
||||
// Let's check in the first 108 bytes if we find a CRLF
|
||||
let crlf = if let Some(crlf) = bytes
|
||||
.windows(2)
|
||||
.take(108)
|
||||
.position(|needle| needle == [0x0D, 0x0A])
|
||||
{
|
||||
crlf
|
||||
} else {
|
||||
// If not, it might be because we don't have enough bytes
|
||||
return if bytes.len() < 108 {
|
||||
Err(E::NotEnoughBytes)
|
||||
} else {
|
||||
// Else it's just invalid
|
||||
Err(E::NoCrLf)
|
||||
};
|
||||
};
|
||||
|
||||
// Keep the rest of the buffer to pass it to the underlying protocol
|
||||
let rest = &bytes[crlf + 2..];
|
||||
// Trim to everything before the CRLF
|
||||
let bytes = &bytes[..crlf];
|
||||
|
||||
let mut it = bytes.splitn(6, |c| c == &b' ');
|
||||
// Check for the preamble
|
||||
if it.next() != Some(b"PROXY") {
|
||||
return Err(E::NoProxyPreamble);
|
||||
}
|
||||
|
||||
let result = match it.next() {
|
||||
Some(b"TCP4") => {
|
||||
let source_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"TCP6") => {
|
||||
let source_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"UDP4") => {
|
||||
let source_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"UDP6") => {
|
||||
let source_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"UNKNOWN") => Self::Unknown,
|
||||
Some(_) => return Err(E::InvalidProtocol),
|
||||
None => return Err(E::NoProtocol),
|
||||
};
|
||||
|
||||
Ok((result, rest))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_ipv4(&self) -> bool {
|
||||
match self {
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
| Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
} => source.is_ipv4() && destination.is_ipv4(),
|
||||
Self::Unknown => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_ipv6(&self) -> bool {
|
||||
match self {
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
| Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
} => source.is_ipv6() && destination.is_ipv6(),
|
||||
Self::Unknown => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_tcp(&self) -> bool {
|
||||
matches!(self, Self::Tcp { .. })
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_udp(&self) -> bool {
|
||||
matches!(self, Self::Udp { .. })
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_unknown(&self) -> bool {
|
||||
matches!(self, Self::Unknown)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn source(&self) -> Option<&SocketAddr> {
|
||||
match self {
|
||||
Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
|
||||
Self::Unknown => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn destination(&self) -> Option<&SocketAddr> {
|
||||
match self {
|
||||
Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
|
||||
Self::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse() {
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world",
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(!info.is_unknown());
|
||||
assert!(info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||
).unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(!info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(!info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||
).unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(!info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,6 @@
|
||||
// limitations under the License.
|
||||
|
||||
// TODO: Unlink the UNIX socket on drop?
|
||||
// TODO: Proxy protocol
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
@@ -87,6 +86,12 @@ impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
|
||||
}
|
||||
|
||||
impl UnixOrTcpListener {
|
||||
/// Get the local address of the listener
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error on rare cases where the underlying [`TcpListener`] or
|
||||
/// [`UnixListener`] couldn't provide the local address
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
|
||||
@@ -111,6 +116,12 @@ pin_project_lite::pin_project! {
|
||||
}
|
||||
|
||||
impl UnixOrTcpConnection {
|
||||
/// Get the local address of the stream
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error on rare cases where the underlying [`TcpStream`] or
|
||||
/// [`UnixStream`] couldn't provide the local address
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
||||
@@ -118,6 +129,12 @@ impl UnixOrTcpConnection {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the remote address of the stream
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error on rare cases where the underlying [`TcpStream`] or
|
||||
/// [`UnixStream`] couldn't provide the remote address
|
||||
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
||||
|
||||
Reference in New Issue
Block a user