From 485778beb3b14fc3691ee9e58f7c003f37548404 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 6 Oct 2022 16:30:24 +0200 Subject: [PATCH] Inject connection informations in the request extension --- Cargo.lock | 2 + crates/cli/src/commands/server.rs | 10 +- crates/cli/src/server.rs | 15 +- crates/config/src/sections/http.rs | 5 + crates/listener/Cargo.toml | 10 +- crates/listener/src/info.rs | 323 +++++++++++++++++++ crates/listener/src/lib.rs | 1 + crates/listener/src/maybe_tls.rs | 63 +++- crates/listener/src/proxy_protocol/maybe.rs | 43 ++- crates/listener/src/proxy_protocol/mod.rs | 2 +- crates/listener/src/proxy_protocol/stream.rs | 27 +- crates/listener/src/proxy_protocol/v1.rs | 2 +- crates/listener/src/unix_or_tcp.rs | 42 +++ 13 files changed, 530 insertions(+), 15 deletions(-) create mode 100644 crates/listener/src/info.rs diff --git a/Cargo.lock b/Cargo.lock index 02567d82..04f88ab9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2683,6 +2683,8 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls 0.23.4", + "tower", + "tower-http", "tracing", ] diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 09703ac9..80048829 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -25,7 +25,11 @@ use mas_config::RootConfig; use mas_email::Mailer; use mas_handlers::{AppState, MatrixHomeserver}; use mas_http::ServerLayer; -use mas_listener::{maybe_tls::MaybeTlsAcceptor, proxy_protocol::MaybeProxyAcceptor}; +use mas_listener::{ + info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection}, + maybe_tls::MaybeTlsAcceptor, + proxy_protocol::MaybeProxyAcceptor, +}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_storage::MIGRATOR; @@ -276,6 +280,7 @@ impl Options { info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources); let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone())); + let make_service = IntoMakeServiceWithConnection::new(router); async move { let tls_config = if let Some(tls_config) = config.tls.as_ref() { @@ -288,9 +293,10 @@ impl Options { .try_for_each_concurrent(None, move |listener| { let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener); let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol); + let listener = ConnectionInfoAcceptor::new(listener); Server::builder(listener) - .serve(router.clone().into_make_service()) + .serve(make_service.clone()) .with_graceful_shutdown(signal.clone()) }) .await?; diff --git a/crates/cli/src/server.rs b/crates/cli/src/server.rs index f9f6f402..7ca43d0b 100644 --- a/crates/cli/src/server.rs +++ b/crates/cli/src/server.rs @@ -19,13 +19,14 @@ use std::{ }; use anyhow::Context; -use axum::{body::HttpBody, Router}; +use axum::{body::HttpBody, Extension, Router}; use listenfd::ListenFd; use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp}; use mas_handlers::AppState; -use mas_listener::unix_or_tcp::UnixOrTcpListener; +use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener}; use mas_router::Route; use rustls::ServerConfig; +use tokio::sync::OnceCell; #[allow(clippy::trait_duplication_in_bounds)] pub fn build_router(state: &Arc, resources: &[HttpResource]) -> Router @@ -60,6 +61,16 @@ where mas_config::HttpResource::Compat => { router.merge(mas_handlers::compat_router(state.clone())) } + // TODO: do a better handler here + mas_config::HttpResource::ConnectionInfo => router.route( + "/connection-info", + axum::routing::get( + |connection: Extension>>| async move { + let connection = connection.get().unwrap(); + format!("{connection:?}") + }, + ), + ), } } diff --git a/crates/config/src/sections/http.rs b/crates/config/src/sections/http.rs index 4db7ad6b..c5c971bf 100644 --- a/crates/config/src/sections/http.rs +++ b/crates/config/src/sections/http.rs @@ -248,6 +248,11 @@ pub enum Resource { #[serde(default)] web_root: Option, }, + + /// Mount a "/connection-info" handler which helps debugging informations on + /// the upstream connection + #[serde(rename = "connection-info")] + ConnectionInfo, } /// Configuration of a listener diff --git a/crates/listener/Cargo.toml b/crates/listener/Cargo.toml index ff677159..ee1bb069 100644 --- a/crates/listener/Cargo.toml +++ b/crates/listener/Cargo.toml @@ -6,10 +6,12 @@ edition = "2021" license = "Apache-2.0" [dependencies] +futures-util = "0.3.24" +hyper = { version = "0.14.20", features = ["server"] } +pin-project-lite = "0.2.9" thiserror = "1.0.37" tokio = { version = "1.21.2", features = ["net"] } -pin-project-lite = "0.2.9" -hyper = { version = "0.14.20", features = ["server"] } -futures-util = "0.3.24" -tracing = "0.1.36" tokio-rustls = "0.23.4" +tower = "0.4.13" +tower-http = { version = "0.3.4", features = ["add-extension"] } +tracing = "0.1.36" diff --git a/crates/listener/src/info.rs b/crates/listener/src/info.rs new file mode 100644 index 00000000..4757f62e --- /dev/null +++ b/crates/listener/src/info.rs @@ -0,0 +1,323 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + future::Ready, + net::SocketAddr, + ops::Deref, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use hyper::server::accept::Accept; +use thiserror::Error; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::OnceCell, +}; +use tower::Service; +use tower_http::add_extension::AddExtension; + +use crate::{ + maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo, TlsStreamInfoError}, + proxy_protocol::{ + MaybeProxyAcceptor, MaybeProxyStream, ProxyHandshakeNotDone, ProxyProtocolV1Info, + }, + unix_or_tcp::{UnixOrTcpConnection, UnixOrTcpListener}, +}; + +// TODO: this is a mess, clean that up + +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum FromStreamError { + #[error(transparent)] + Proxy(#[from] ProxyHandshakeNotDone), + + #[error(transparent)] + Tls(#[from] TlsStreamInfoError), + + #[error("Could not grab a reference to the underlying stream")] + GetRef, + + #[error("Could not get address info from underlying stream")] + IoError(#[from] std::io::Error), +} + +#[derive(Debug, Clone)] +pub struct Connection { + proxy: Option, + tls: Option, + + // We're not saving the UNIX domain socket address here because it can't be cloned, which is + // required for injecting the connection information as an extension + local_tcp_addr: Option, + peer_tcp_addr: Option, +} + +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum GrabAddressError { + #[error("Proxy protocol was initiated with an unknown protocol")] + ProxyUnknown, + + #[error("Proxy protocol was initiated with UDP")] + ProxyUdp, + + #[error("Underlying listener is a UNIX socket")] + UnixListener, +} + +impl MaybeProxyAcceptor> { + pub fn can_have_peer_address(&self) -> bool { + self.is_proxied() || self.is_tcp() + } +} + +impl MaybeProxyStream> { + /// Get informations about this connection + /// + /// # Errors + /// + /// Returns an error if the proxy protocol or the TLS handhakes are not done + /// yet + pub fn connection_info(&self) -> Result { + Connection::from_stream(self) + } +} + +impl Connection { + /// Get informations about this connection + /// + /// # Errors + /// + /// Returns an error if the proxy protocol or the TLS handhakes are not done + /// yet + pub fn from_stream( + stream: &MaybeProxyStream>, + ) -> Result { + let proxy = stream.proxy_info()?.cloned(); + let tls = stream.tls_info()?; + let original = stream.get_ref().ok_or(FromStreamError::GetRef)?; + let local_tcp_addr = original.local_addr()?.into_net(); + let peer_tcp_addr = original.peer_addr()?.into_net(); + + Ok(Self { + proxy, + tls, + local_tcp_addr, + peer_tcp_addr, + }) + } + + #[must_use] + pub const fn is_proxied(&self) -> bool { + self.proxy.is_some() + } + + #[must_use] + pub const fn is_tls(&self) -> bool { + self.tls.is_some() + } + + /// Get the outmost peer address, either from the TCP listener or from the + /// proxy protocol infos. + /// + /// # Errors + /// + /// Returns an error if the info from the proxy protocol was not for a TCP + /// connection, or if the proxy protocol is not being used, the underlying + /// listener was a UNIX domain socket + pub fn peer_addr(&self) -> Result<&SocketAddr, GrabAddressError> { + if let Some(proxy) = self.proxy.as_ref() { + if proxy.is_udp() { + return Err(GrabAddressError::ProxyUdp); + } + + proxy.source().ok_or(GrabAddressError::ProxyUnknown) + } else { + self.peer_tcp_addr + .as_ref() + .ok_or(GrabAddressError::UnixListener) + } + } + + /// Get the outmost local address, either from the TCP listener or from the + /// proxy protocol infos. + /// + /// # Errors + /// + /// Returns an error if the info from the proxy protocol was not for a TCP + /// connection, or if the proxy protocol is not being used, the underlying + /// listener was a UNIX domain socket + pub fn local_addr(&self) -> Result<&SocketAddr, GrabAddressError> { + if let Some(proxy) = self.proxy.as_ref() { + if proxy.is_udp() { + return Err(GrabAddressError::ProxyUdp); + } + + proxy.destination().ok_or(GrabAddressError::ProxyUnknown) + } else { + self.local_tcp_addr + .as_ref() + .ok_or(GrabAddressError::UnixListener) + } + } +} + +pin_project_lite::pin_project! { + pub struct ConnectionInfoAcceptor { + #[pin] + acceptor: MaybeProxyAcceptor>, + } +} + +impl ConnectionInfoAcceptor { + pub const fn new(acceptor: MaybeProxyAcceptor>) -> Self { + Self { acceptor } + } +} + +impl Accept for ConnectionInfoAcceptor { + type Conn = ConnectionInfoStream; + type Error = std::io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + let proj = self.project(); + let ret = match futures_util::ready!(proj.acceptor.poll_accept(cx)) { + Some(Ok(conn)) => Some(Ok(ConnectionInfoStream::new(conn))), + Some(Err(e)) => Some(Err(e)), + None => None, + }; + Poll::Ready(ret) + } +} + +pin_project_lite::pin_project! { + pub struct ConnectionInfoStream { + connection: Arc>, + #[pin] + stream: MaybeProxyStream>, + } +} + +impl ConnectionInfoStream { + pub fn new(stream: MaybeProxyStream>) -> Self { + Self { + connection: Arc::new(OnceCell::const_new()), + stream, + } + } +} + +impl Deref for ConnectionInfoStream { + type Target = MaybeProxyStream>; + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl AsyncRead for ConnectionInfoStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + futures_util::ready!(Pin::new(&mut this.stream).poll_read(cx, buf))?; + + if !this.stream.is_tls_handshaking() + && !this.stream.is_proxy_handshaking() + && !this.connection.initialized() + { + this.connection + .set(this.stream.connection_info().unwrap()) + .unwrap(); + } + + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for ConnectionInfoStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let proj = self.project(); + proj.stream.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let proj = self.project(); + proj.stream.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let proj = self.project(); + proj.stream.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + let proj = self.project(); + proj.stream.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } +} + +#[derive(Debug, Clone)] +pub struct IntoMakeServiceWithConnection { + svc: S, +} + +impl IntoMakeServiceWithConnection { + pub const fn new(svc: S) -> Self { + Self { svc } + } +} + +impl Service<&ConnectionInfoStream> for IntoMakeServiceWithConnection +where + S: Clone, +{ + type Response = AddExtension>>; + type Error = FromStreamError; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, target: &ConnectionInfoStream) -> Self::Future { + std::future::ready(Ok(AddExtension::new( + self.svc.clone(), + target.connection.clone(), + ))) + } +} diff --git a/crates/listener/src/lib.rs b/crates/listener/src/lib.rs index a5ae777d..f4dbe4d7 100644 --- a/crates/listener/src/lib.rs +++ b/crates/listener/src/lib.rs @@ -22,6 +22,7 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] +pub mod info; pub mod maybe_tls; pub mod proxy_protocol; pub mod unix_or_tcp; diff --git a/crates/listener/src/maybe_tls.rs b/crates/listener/src/maybe_tls.rs index ff685099..55a1cca6 100644 --- a/crates/listener/src/maybe_tls.rs +++ b/crates/listener/src/maybe_tls.rs @@ -21,8 +21,31 @@ use std::{ use futures_util::{ready, Future}; use hyper::server::accept::Accept; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_rustls::rustls::{ServerConfig, ServerConnection}; +use tokio_rustls::rustls::{ + Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite, +}; + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct TlsStreamInfo { + pub protocol_version: ProtocolVersion, + pub negotiated_cipher_suite: SupportedCipherSuite, + pub sni_hostname: Option, + pub apln_protocol: Option>, + pub peer_certificates: Option>, +} + +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum TlsStreamInfoError { + #[error("TLS handshake is not done yet")] + HandshakingNotDone, + + #[error("Some fields were not available in the TLS connection")] + FieldsNotAvailable, +} pub enum MaybeTlsStream { Handshaking(tokio_rustls::Accept), @@ -72,6 +95,44 @@ impl MaybeTlsStream { Self::Handshaking(_) | Self::Insecure(_) => None, } } + + /// Gather informations about the TLS connection. Returns `None` if the + /// stream is not a TLS stream. + /// + /// # Errors + /// + /// Returns an error if the TLS handshake is not yet done + pub fn tls_info(&self) -> Result, TlsStreamInfoError> { + let conn = match self { + Self::Streaming(stream) => stream.get_ref().1, + Self::Handshaking(_) => return Err(TlsStreamInfoError::HandshakingNotDone), + Self::Insecure(_) => return Ok(None), + }; + + // NOTE: we're getting the protocol version and cipher suite *after* the + // handshake, so this should never lead to an error + let protocol_version = conn + .protocol_version() + .ok_or(TlsStreamInfoError::FieldsNotAvailable)?; + let negotiated_cipher_suite = conn + .negotiated_cipher_suite() + .ok_or(TlsStreamInfoError::FieldsNotAvailable)?; + + let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned); + let apln_protocol = conn.alpn_protocol().map(ToOwned::to_owned); + let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned); + Ok(Some(TlsStreamInfo { + protocol_version, + negotiated_cipher_suite, + sni_hostname, + apln_protocol, + peer_certificates, + })) + } + + pub const fn is_tls_handshaking(&self) -> bool { + matches!(self, Self::Handshaking(_)) + } } impl AsyncRead for MaybeTlsStream diff --git a/crates/listener/src/proxy_protocol/maybe.rs b/crates/listener/src/proxy_protocol/maybe.rs index 7e261009..b6e74377 100644 --- a/crates/listener/src/proxy_protocol/maybe.rs +++ b/crates/listener/src/proxy_protocol/maybe.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::{ + ops::Deref, pin::Pin, task::{Context, Poll}, }; @@ -21,7 +22,7 @@ use futures_util::ready; use hyper::server::accept::Accept; use tokio::io::{AsyncRead, AsyncWrite}; -use super::ProxyStream; +use super::{stream::HandshakeNotDone, ProxyProtocolV1Info, ProxyStream}; pin_project_lite::pin_project! { pub struct MaybeProxyAcceptor { @@ -37,6 +38,17 @@ impl MaybeProxyAcceptor { pub const fn new(inner: A, proxied: bool) -> Self { Self { proxied, inner } } + + pub const fn is_proxied(&self) -> bool { + self.proxied + } +} + +impl Deref for MaybeProxyAcceptor { + type Target = A; + fn deref(&self) -> &Self::Target { + &self.inner + } } impl Accept for MaybeProxyAcceptor @@ -79,6 +91,35 @@ impl MaybeProxyStream { Self::NotProxied { stream } } } + + /// Get informations from the proxied connection, if it was procied + /// + /// # Errors + /// + /// Returns an error if the stream did not complete the handshake yet + pub fn proxy_info(&self) -> Result, HandshakeNotDone> { + match self { + Self::Proxied { stream } => Ok(Some(stream.proxy_info()?)), + Self::NotProxied { .. } => Ok(None), + } + } + + pub const fn is_proxy_handshaking(&self) -> bool { + match self { + Self::Proxied { stream } => stream.is_handshaking(), + Self::NotProxied { .. } => false, + } + } +} + +impl Deref for MaybeProxyStream { + type Target = S; + fn deref(&self) -> &Self::Target { + match self { + Self::Proxied { stream } => &**stream, + Self::NotProxied { stream } => stream, + } + } } impl AsyncRead for MaybeProxyStream diff --git a/crates/listener/src/proxy_protocol/mod.rs b/crates/listener/src/proxy_protocol/mod.rs index 39ee2d0f..7549e70d 100644 --- a/crates/listener/src/proxy_protocol/mod.rs +++ b/crates/listener/src/proxy_protocol/mod.rs @@ -20,6 +20,6 @@ mod v1; pub use self::{ acceptor::ProxyAcceptor, maybe::{MaybeProxyAcceptor, MaybeProxyStream}, - stream::ProxyStream, + stream::{HandshakeNotDone as ProxyHandshakeNotDone, ProxyStream}, v1::ProxyProtocolV1Info, }; diff --git a/crates/listener/src/proxy_protocol/stream.rs b/crates/listener/src/proxy_protocol/stream.rs index a59b1829..9126883d 100644 --- a/crates/listener/src/proxy_protocol/stream.rs +++ b/crates/listener/src/proxy_protocol/stream.rs @@ -15,6 +15,7 @@ use std::ops::Deref; use futures_util::ready; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use super::ProxyProtocolV1Info; @@ -31,6 +32,12 @@ enum ProxyStreamState { Established(ProxyProtocolV1Info), } +impl ProxyStreamState { + pub const fn is_handshaking(&self) -> bool { + matches!(self, Self::Handshaking { .. }) + } +} + pin_project_lite::pin_project! { #[derive(Debug)] pub struct ProxyStream { @@ -53,6 +60,10 @@ impl ProxyStream { } } +#[derive(Debug, Error, Clone, Copy)] +#[error("Proxy protocol handshake is not complete")] +pub struct HandshakeNotDone; + impl Deref for ProxyStream { type Target = S; fn deref(&self) -> &Self::Target { @@ -61,12 +72,22 @@ impl Deref for ProxyStream { } impl ProxyStream { - pub fn proxy_info(&self) -> Option<&ProxyProtocolV1Info> { + /// Get informations from the proxied connection + /// + /// # Errors + /// + /// Returns an error if the stream did not complete the handshake yet + pub fn proxy_info(&self) -> Result<&ProxyProtocolV1Info, HandshakeNotDone> { match &self.state { - ProxyStreamState::Handshaking { .. } => None, - ProxyStreamState::Established(info) => Some(info), + ProxyStreamState::Handshaking { .. } => Err(HandshakeNotDone), + ProxyStreamState::Established(info) => Ok(info), } } + + /// Returns `true` if the proxy protocol is still handshaking + pub const fn is_handshaking(&self) -> bool { + self.state.is_handshaking() + } } impl AsyncRead for ProxyStream diff --git a/crates/listener/src/proxy_protocol/v1.rs b/crates/listener/src/proxy_protocol/v1.rs index 3789a13c..32412c01 100644 --- a/crates/listener/src/proxy_protocol/v1.rs +++ b/crates/listener/src/proxy_protocol/v1.rs @@ -20,7 +20,7 @@ use std::{ use thiserror::Error; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum ProxyProtocolV1Info { Tcp { source: SocketAddr, diff --git a/crates/listener/src/unix_or_tcp.rs b/crates/listener/src/unix_or_tcp.rs index 1357bb7d..e66a1ad0 100644 --- a/crates/listener/src/unix_or_tcp.rs +++ b/crates/listener/src/unix_or_tcp.rs @@ -52,6 +52,40 @@ impl std::fmt::Debug for SocketAddr { } } +impl SocketAddr { + #[must_use] + pub fn into_net(self) -> Option { + match self { + Self::Net(socket) => Some(socket), + Self::Unix(_) => None, + } + } + + #[must_use] + pub fn into_unix(self) -> Option { + match self { + Self::Net(_) => None, + Self::Unix(socket) => Some(socket), + } + } + + #[must_use] + pub const fn as_net(&self) -> Option<&std::net::SocketAddr> { + match self { + Self::Net(socket) => Some(socket), + Self::Unix(_) => None, + } + } + + #[must_use] + pub const fn as_unix(&self) -> Option<&tokio::net::unix::SocketAddr> { + match self { + Self::Net(_) => None, + Self::Unix(socket) => Some(socket), + } + } +} + pub enum UnixOrTcpListener { Unix(UnixListener), Tcp(TcpListener), @@ -98,6 +132,14 @@ impl UnixOrTcpListener { Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from), } } + + pub const fn is_unix(&self) -> bool { + matches!(self, Self::Unix(_)) + } + + pub const fn is_tcp(&self) -> bool { + matches!(self, Self::Tcp(_)) + } } pin_project_lite::pin_project! {