You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
HAProxy's Proxy Protocol acceptor
This commit is contained in:
@ -25,7 +25,7 @@ use mas_config::RootConfig;
|
||||
use mas_email::Mailer;
|
||||
use mas_handlers::{AppState, MatrixHomeserver};
|
||||
use mas_http::ServerLayer;
|
||||
use mas_listener::maybe_tls::MaybeTlsAcceptor;
|
||||
use mas_listener::{maybe_tls::MaybeTlsAcceptor, proxy_protocol::MaybeProxyAcceptor};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::MIGRATOR;
|
||||
@ -271,7 +271,9 @@ impl Options {
|
||||
}
|
||||
}).collect();
|
||||
|
||||
info!("Listening on {addresses:?} with resources {resources:?}", resources = &config.resources);
|
||||
let additional = if config.proxy_protocol { "(with Proxy Protocol)" } else { "" };
|
||||
|
||||
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
|
||||
|
||||
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
|
||||
|
||||
@ -285,6 +287,7 @@ impl Options {
|
||||
.map(Ok)
|
||||
.try_for_each_concurrent(None, move |listener| {
|
||||
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
|
||||
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
|
||||
|
||||
Server::builder(listener)
|
||||
.serve(router.clone().into_make_service())
|
||||
|
@ -19,6 +19,7 @@ use async_trait::async_trait;
|
||||
use mas_keystore::PrivateKey;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use url::Url;
|
||||
|
||||
use super::{secrets::PasswordOrFile, ConfigurationSection};
|
||||
@ -66,6 +67,7 @@ impl UnixOrTcp {
|
||||
}
|
||||
|
||||
/// Configuration of a single listener
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum BindConfig {
|
||||
@ -74,6 +76,7 @@ pub enum BindConfig {
|
||||
/// Host on which to listen.
|
||||
///
|
||||
/// Defaults to listening on all addresses
|
||||
#[serde(default)]
|
||||
host: Option<String>,
|
||||
|
||||
/// Port on which to listen.
|
||||
@ -107,6 +110,7 @@ pub enum BindConfig {
|
||||
/// Index of the file descriptor. Note that this is offseted by 3
|
||||
/// because of the standard input/output sockets, so setting
|
||||
/// here a value of `0` will grab the file descriptor `3`
|
||||
#[serde(default)]
|
||||
fd: usize,
|
||||
|
||||
/// Whether the socket is a TCP socket or a UNIX domain socket. Defaults
|
||||
@ -131,6 +135,7 @@ pub enum CertificateOrFile {
|
||||
}
|
||||
|
||||
/// Configuration related to TLS on a listener
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
pub struct TlsConfig {
|
||||
/// PEM-encoded X509 certificate chain
|
||||
@ -214,6 +219,7 @@ impl TlsConfig {
|
||||
}
|
||||
|
||||
/// HTTP resources to mount
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
#[serde(tag = "name", rename_all = "lowercase")]
|
||||
pub enum Resource {
|
||||
@ -245,10 +251,12 @@ pub enum Resource {
|
||||
}
|
||||
|
||||
/// Configuration of a listener
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
pub struct ListenerConfig {
|
||||
/// A unique name for this listener which will be shown in traces and in
|
||||
/// metrics labels
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// List of resources to mount
|
||||
@ -257,7 +265,12 @@ pub struct ListenerConfig {
|
||||
/// List of sockets to bind
|
||||
pub binds: Vec<BindConfig>,
|
||||
|
||||
/// Accept HAProxy's Proxy Protocol V1
|
||||
#[serde(default)]
|
||||
pub proxy_protocol: bool,
|
||||
|
||||
/// If set, makes the listener use TLS with the provided certificate and key
|
||||
#[serde(default)]
|
||||
pub tls: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
@ -286,6 +299,7 @@ impl Default for HttpConfig {
|
||||
Resource::Static { web_root: None },
|
||||
],
|
||||
tls: None,
|
||||
proxy_protocol: false,
|
||||
binds: vec![BindConfig::Address {
|
||||
address: "[::]:8080".into(),
|
||||
}],
|
||||
@ -294,6 +308,7 @@ impl Default for HttpConfig {
|
||||
name: Some("internal".to_owned()),
|
||||
resources: vec![Resource::Health],
|
||||
tls: None,
|
||||
proxy_protocol: false,
|
||||
binds: vec![BindConfig::Address {
|
||||
address: "localhost:8081".into(),
|
||||
}],
|
||||
|
@ -6,6 +6,7 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
thiserror = "1.0.37"
|
||||
tokio = { version = "1.21.2", features = ["net"] }
|
||||
pin-project-lite = "0.2.9"
|
||||
hyper = { version = "0.14.20", features = ["server"] }
|
||||
|
@ -12,5 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(
|
||||
clippy::all,
|
||||
clippy::str_to_string,
|
||||
rustdoc::missing_crate_level_docs,
|
||||
rustdoc::broken_intra_doc_links
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
pub mod maybe_tls;
|
||||
pub mod proxy_protocol;
|
||||
pub mod unix_or_tcp;
|
||||
|
52
crates/listener/src/proxy_protocol/acceptor.rs
Normal file
52
crates/listener/src/proxy_protocol/acceptor.rs
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
|
||||
use super::ProxyStream;
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct ProxyAcceptor<A> {
|
||||
#[pin]
|
||||
inner: A,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> ProxyAcceptor<A> {
|
||||
pub const fn new(inner: A) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Accept for ProxyAcceptor<A>
|
||||
where
|
||||
A: Accept,
|
||||
{
|
||||
type Conn = ProxyStream<A::Conn>;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let res = match ready!(self.project().inner.poll_accept(cx)) {
|
||||
Some(Ok(stream)) => Some(Ok(ProxyStream::new(stream))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
std::task::Poll::Ready(res)
|
||||
}
|
||||
}
|
149
crates/listener/src/proxy_protocol/maybe.rs
Normal file
149
crates/listener/src/proxy_protocol/maybe.rs
Normal file
@ -0,0 +1,149 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except proxied: streamliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::ProxyStream;
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct MaybeProxyAcceptor<A> {
|
||||
proxied: bool,
|
||||
|
||||
#[pin]
|
||||
inner: A,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> MaybeProxyAcceptor<A> {
|
||||
#[must_use]
|
||||
pub const fn new(inner: A, proxied: bool) -> Self {
|
||||
Self { proxied, inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Accept for MaybeProxyAcceptor<A>
|
||||
where
|
||||
A: Accept,
|
||||
{
|
||||
type Conn = MaybeProxyStream<A::Conn>;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let proj = self.project();
|
||||
let res = match ready!(proj.inner.poll_accept(cx)) {
|
||||
Some(Ok(stream)) => Some(Ok(MaybeProxyStream::new(stream, *proj.proxied))),
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
std::task::Poll::Ready(res)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[project = MaybeProxyStreamProj]
|
||||
pub enum MaybeProxyStream<S> {
|
||||
Proxied { #[pin] stream: ProxyStream<S> },
|
||||
NotProxied { #[pin] stream: S },
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MaybeProxyStream<S> {
|
||||
pub const fn new(stream: S, proxied: bool) -> Self {
|
||||
if proxied {
|
||||
Self::Proxied {
|
||||
stream: ProxyStream::new(stream),
|
||||
}
|
||||
} else {
|
||||
Self::NotProxied { stream }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for MaybeProxyStream<S>
|
||||
where
|
||||
S: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_read(cx, buf),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for MaybeProxyStream<S>
|
||||
where
|
||||
S: AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write(cx, buf),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_flush(cx),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_shutdown(cx),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
MaybeProxyStreamProj::Proxied { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
MaybeProxyStreamProj::NotProxied { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match self {
|
||||
MaybeProxyStream::Proxied { stream } => stream.is_write_vectored(),
|
||||
MaybeProxyStream::NotProxied { stream } => stream.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
}
|
25
crates/listener/src/proxy_protocol/mod.rs
Normal file
25
crates/listener/src/proxy_protocol/mod.rs
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod acceptor;
|
||||
mod maybe;
|
||||
mod stream;
|
||||
mod v1;
|
||||
|
||||
pub use self::{
|
||||
acceptor::ProxyAcceptor,
|
||||
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
|
||||
stream::ProxyStream,
|
||||
v1::ProxyProtocolV1Info,
|
||||
};
|
148
crates/listener/src/proxy_protocol/stream.rs
Normal file
148
crates/listener/src/proxy_protocol/stream.rs
Normal file
@ -0,0 +1,148 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use futures_util::ready;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use super::ProxyProtocolV1Info;
|
||||
|
||||
// Max theorical size we need is 108 for proxy protocol v1
|
||||
const BUF_SIZE: usize = 256;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ProxyStreamState {
|
||||
Handshaking {
|
||||
buffer: [u8; BUF_SIZE],
|
||||
index: usize,
|
||||
},
|
||||
Established(ProxyProtocolV1Info),
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyStream<S> {
|
||||
state: ProxyStreamState,
|
||||
|
||||
#[pin]
|
||||
inner: S,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyStream<S> {
|
||||
pub const fn new(inner: S) -> Self {
|
||||
Self {
|
||||
state: ProxyStreamState::Handshaking {
|
||||
buffer: [0; BUF_SIZE],
|
||||
index: 0,
|
||||
},
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Deref for ProxyStream<S> {
|
||||
type Target = S;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyStream<S> {
|
||||
pub fn proxy_info(&self) -> Option<&ProxyProtocolV1Info> {
|
||||
match &self.state {
|
||||
ProxyStreamState::Handshaking { .. } => None,
|
||||
ProxyStreamState::Established(info) => Some(info),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for ProxyStream<S>
|
||||
where
|
||||
S: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let proj = self.project();
|
||||
match proj.state {
|
||||
ProxyStreamState::Handshaking { buffer, index } => {
|
||||
let mut buffer = ReadBuf::new(&mut buffer[..]);
|
||||
buffer.advance(*index);
|
||||
ready!(proj.inner.poll_read(cx, &mut buffer))?;
|
||||
let filled = buffer.filled();
|
||||
*index = filled.len();
|
||||
|
||||
match ProxyProtocolV1Info::parse(filled) {
|
||||
Ok((info, rest)) => {
|
||||
if buf.remaining() < rest.len() {
|
||||
// This is highly unlikely, but is better than panicking later.
|
||||
// If it ever happens, we could introduce a "buffer draining" state
|
||||
// which drains the inner buffer repeatedly until it's empty
|
||||
return std::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"underlying buffer is too small",
|
||||
)));
|
||||
}
|
||||
buf.put_slice(rest);
|
||||
*proj.state = ProxyStreamState::Established(info);
|
||||
std::task::Poll::Ready(Ok(()))
|
||||
}
|
||||
Err(e) if e.not_enough_bytes() => std::task::Poll::Ready(Ok(())),
|
||||
Err(e) => std::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e,
|
||||
))),
|
||||
}
|
||||
}
|
||||
ProxyStreamState::Established(_) => proj.inner.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for ProxyStream<S>
|
||||
where
|
||||
S: AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let proj = self.project();
|
||||
match proj.state {
|
||||
// Hold off writes until the handshake is done
|
||||
// XXX: is this the right way to do it?
|
||||
ProxyStreamState::Handshaking { .. } => std::task::Poll::Pending,
|
||||
ProxyStreamState::Established(_) => proj.inner.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
}
|
300
crates/listener/src/proxy_protocol/v1.rs
Normal file
300
crates/listener/src/proxy_protocol/v1.rs
Normal file
@ -0,0 +1,300 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
num::ParseIntError,
|
||||
str::Utf8Error,
|
||||
};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProxyProtocolV1Info {
|
||||
Tcp {
|
||||
source: SocketAddr,
|
||||
destination: SocketAddr,
|
||||
},
|
||||
Udp {
|
||||
source: SocketAddr,
|
||||
destination: SocketAddr,
|
||||
},
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Invalid proxy protocol header")]
|
||||
pub(super) enum ParseError {
|
||||
#[error("Not enough bytes provided")]
|
||||
NotEnoughBytes,
|
||||
NoCrLf,
|
||||
NoProxyPreamble,
|
||||
NoProtocol,
|
||||
InvalidProtocol,
|
||||
NoSourceAddress,
|
||||
NoDestinationAddress,
|
||||
NoSourcePort,
|
||||
NoDestinationPort,
|
||||
TooManyFields,
|
||||
InvalidUtf8(#[from] Utf8Error),
|
||||
InvalidAddress(#[from] AddrParseError),
|
||||
InvalidPort(#[from] ParseIntError),
|
||||
}
|
||||
|
||||
impl ParseError {
|
||||
pub const fn not_enough_bytes(&self) -> bool {
|
||||
matches!(self, &Self::NotEnoughBytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProxyProtocolV1Info {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(super) fn parse(bytes: &[u8]) -> Result<(Self, &[u8]), ParseError> {
|
||||
use ParseError as E;
|
||||
// First, check if we *possibly* have enough bytes.
|
||||
// Minimum is 15: "PROXY UNKNOWN\r\n"
|
||||
|
||||
if bytes.len() < 15 {
|
||||
return Err(E::NotEnoughBytes);
|
||||
}
|
||||
|
||||
// Let's check in the first 108 bytes if we find a CRLF
|
||||
let crlf = if let Some(crlf) = bytes
|
||||
.windows(2)
|
||||
.take(108)
|
||||
.position(|needle| needle == [0x0D, 0x0A])
|
||||
{
|
||||
crlf
|
||||
} else {
|
||||
// If not, it might be because we don't have enough bytes
|
||||
return if bytes.len() < 108 {
|
||||
Err(E::NotEnoughBytes)
|
||||
} else {
|
||||
// Else it's just invalid
|
||||
Err(E::NoCrLf)
|
||||
};
|
||||
};
|
||||
|
||||
// Keep the rest of the buffer to pass it to the underlying protocol
|
||||
let rest = &bytes[crlf + 2..];
|
||||
// Trim to everything before the CRLF
|
||||
let bytes = &bytes[..crlf];
|
||||
|
||||
let mut it = bytes.splitn(6, |c| c == &b' ');
|
||||
// Check for the preamble
|
||||
if it.next() != Some(b"PROXY") {
|
||||
return Err(E::NoProxyPreamble);
|
||||
}
|
||||
|
||||
let result = match it.next() {
|
||||
Some(b"TCP4") => {
|
||||
let source_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"TCP6") => {
|
||||
let source_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"UDP4") => {
|
||||
let source_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv4Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"UDP6") => {
|
||||
let source_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
|
||||
let destination_address: Ipv6Addr =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
|
||||
let source_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
|
||||
let destination_port: u16 =
|
||||
std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
|
||||
if it.next().is_some() {
|
||||
return Err(E::TooManyFields);
|
||||
}
|
||||
|
||||
let source = (source_address, source_port).into();
|
||||
let destination = (destination_address, destination_port).into();
|
||||
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
}
|
||||
Some(b"UNKNOWN") => Self::Unknown,
|
||||
Some(_) => return Err(E::InvalidProtocol),
|
||||
None => return Err(E::NoProtocol),
|
||||
};
|
||||
|
||||
Ok((result, rest))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_ipv4(&self) -> bool {
|
||||
match self {
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
| Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
} => source.is_ipv4() && destination.is_ipv4(),
|
||||
Self::Unknown => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_ipv6(&self) -> bool {
|
||||
match self {
|
||||
Self::Udp {
|
||||
source,
|
||||
destination,
|
||||
}
|
||||
| Self::Tcp {
|
||||
source,
|
||||
destination,
|
||||
} => source.is_ipv6() && destination.is_ipv6(),
|
||||
Self::Unknown => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_tcp(&self) -> bool {
|
||||
matches!(self, Self::Tcp { .. })
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_udp(&self) -> bool {
|
||||
matches!(self, Self::Udp { .. })
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn is_unknown(&self) -> bool {
|
||||
matches!(self, Self::Unknown)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn source(&self) -> Option<&SocketAddr> {
|
||||
match self {
|
||||
Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
|
||||
Self::Unknown => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn destination(&self) -> Option<&SocketAddr> {
|
||||
match self {
|
||||
Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
|
||||
Self::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse() {
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world",
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(!info.is_unknown());
|
||||
assert!(info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||
).unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(!info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(b"PROXY UNKNOWN\r\nhello world").unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(!info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
|
||||
let (info, rest) = ProxyProtocolV1Info::parse(
|
||||
b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
|
||||
).unwrap();
|
||||
assert_eq!(rest, b"hello world");
|
||||
assert!(!info.is_tcp());
|
||||
assert!(!info.is_udp());
|
||||
assert!(info.is_unknown());
|
||||
assert!(!info.is_ipv4());
|
||||
assert!(!info.is_ipv6());
|
||||
}
|
||||
}
|
@ -13,7 +13,6 @@
|
||||
// limitations under the License.
|
||||
|
||||
// TODO: Unlink the UNIX socket on drop?
|
||||
// TODO: Proxy protocol
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
@ -87,6 +86,12 @@ impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
|
||||
}
|
||||
|
||||
impl UnixOrTcpListener {
|
||||
/// Get the local address of the listener
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error on rare cases where the underlying [`TcpListener`] or
|
||||
/// [`UnixListener`] couldn't provide the local address
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
|
||||
@ -111,6 +116,12 @@ pin_project_lite::pin_project! {
|
||||
}
|
||||
|
||||
impl UnixOrTcpConnection {
|
||||
/// Get the local address of the stream
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error on rare cases where the underlying [`TcpStream`] or
|
||||
/// [`UnixStream`] couldn't provide the local address
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
||||
@ -118,6 +129,12 @@ impl UnixOrTcpConnection {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the remote address of the stream
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error on rare cases where the underlying [`TcpStream`] or
|
||||
/// [`UnixStream`] couldn't provide the remote address
|
||||
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
||||
|
Reference in New Issue
Block a user