1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Inject connection informations in the request extension

This commit is contained in:
Quentin Gliech
2022-10-06 16:30:24 +02:00
parent fc5c8314b5
commit 485778beb3
13 changed files with 530 additions and 15 deletions

2
Cargo.lock generated
View File

@@ -2683,6 +2683,8 @@ dependencies = [
"thiserror",
"tokio",
"tokio-rustls 0.23.4",
"tower",
"tower-http",
"tracing",
]

View File

@@ -25,7 +25,11 @@ use mas_config::RootConfig;
use mas_email::Mailer;
use mas_handlers::{AppState, MatrixHomeserver};
use mas_http::ServerLayer;
use mas_listener::{maybe_tls::MaybeTlsAcceptor, proxy_protocol::MaybeProxyAcceptor};
use mas_listener::{
info::{ConnectionInfoAcceptor, IntoMakeServiceWithConnection},
maybe_tls::MaybeTlsAcceptor,
proxy_protocol::MaybeProxyAcceptor,
};
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR;
@@ -276,6 +280,7 @@ impl Options {
info!("Listening on {addresses:?} with resources {resources:?} {additional}", resources = &config.resources);
let router = crate::server::build_router(&state, &config.resources).layer(ServerLayer::new(config.name.clone()));
let make_service = IntoMakeServiceWithConnection::new(router);
async move {
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
@@ -288,9 +293,10 @@ impl Options {
.try_for_each_concurrent(None, move |listener| {
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
let listener = MaybeProxyAcceptor::new(listener, config.proxy_protocol);
let listener = ConnectionInfoAcceptor::new(listener);
Server::builder(listener)
.serve(router.clone().into_make_service())
.serve(make_service.clone())
.with_graceful_shutdown(signal.clone())
})
.await?;

View File

@@ -19,13 +19,14 @@ use std::{
};
use anyhow::Context;
use axum::{body::HttpBody, Router};
use axum::{body::HttpBody, Extension, Router};
use listenfd::ListenFd;
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
use mas_handlers::AppState;
use mas_listener::unix_or_tcp::UnixOrTcpListener;
use mas_listener::{info::Connection, unix_or_tcp::UnixOrTcpListener};
use mas_router::Route;
use rustls::ServerConfig;
use tokio::sync::OnceCell;
#[allow(clippy::trait_duplication_in_bounds)]
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
@@ -60,6 +61,16 @@ where
mas_config::HttpResource::Compat => {
router.merge(mas_handlers::compat_router(state.clone()))
}
// TODO: do a better handler here
mas_config::HttpResource::ConnectionInfo => router.route(
"/connection-info",
axum::routing::get(
|connection: Extension<Arc<OnceCell<Connection>>>| async move {
let connection = connection.get().unwrap();
format!("{connection:?}")
},
),
),
}
}

View File

@@ -248,6 +248,11 @@ pub enum Resource {
#[serde(default)]
web_root: Option<PathBuf>,
},
/// Mount a "/connection-info" handler which helps debugging informations on
/// the upstream connection
#[serde(rename = "connection-info")]
ConnectionInfo,
}
/// Configuration of a listener

View File

@@ -6,10 +6,12 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
futures-util = "0.3.24"
hyper = { version = "0.14.20", features = ["server"] }
pin-project-lite = "0.2.9"
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["net"] }
pin-project-lite = "0.2.9"
hyper = { version = "0.14.20", features = ["server"] }
futures-util = "0.3.24"
tracing = "0.1.36"
tokio-rustls = "0.23.4"
tower = "0.4.13"
tower-http = { version = "0.3.4", features = ["add-extension"] }
tracing = "0.1.36"

323
crates/listener/src/info.rs Normal file
View File

@@ -0,0 +1,323 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
future::Ready,
net::SocketAddr,
ops::Deref,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use hyper::server::accept::Accept;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::OnceCell,
};
use tower::Service;
use tower_http::add_extension::AddExtension;
use crate::{
maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo, TlsStreamInfoError},
proxy_protocol::{
MaybeProxyAcceptor, MaybeProxyStream, ProxyHandshakeNotDone, ProxyProtocolV1Info,
},
unix_or_tcp::{UnixOrTcpConnection, UnixOrTcpListener},
};
// TODO: this is a mess, clean that up
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum FromStreamError {
#[error(transparent)]
Proxy(#[from] ProxyHandshakeNotDone),
#[error(transparent)]
Tls(#[from] TlsStreamInfoError),
#[error("Could not grab a reference to the underlying stream")]
GetRef,
#[error("Could not get address info from underlying stream")]
IoError(#[from] std::io::Error),
}
#[derive(Debug, Clone)]
pub struct Connection {
proxy: Option<ProxyProtocolV1Info>,
tls: Option<TlsStreamInfo>,
// We're not saving the UNIX domain socket address here because it can't be cloned, which is
// required for injecting the connection information as an extension
local_tcp_addr: Option<SocketAddr>,
peer_tcp_addr: Option<SocketAddr>,
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum GrabAddressError {
#[error("Proxy protocol was initiated with an unknown protocol")]
ProxyUnknown,
#[error("Proxy protocol was initiated with UDP")]
ProxyUdp,
#[error("Underlying listener is a UNIX socket")]
UnixListener,
}
impl MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>> {
pub fn can_have_peer_address(&self) -> bool {
self.is_proxied() || self.is_tcp()
}
}
impl MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>> {
/// Get informations about this connection
///
/// # Errors
///
/// Returns an error if the proxy protocol or the TLS handhakes are not done
/// yet
pub fn connection_info(&self) -> Result<Connection, FromStreamError> {
Connection::from_stream(self)
}
}
impl Connection {
/// Get informations about this connection
///
/// # Errors
///
/// Returns an error if the proxy protocol or the TLS handhakes are not done
/// yet
pub fn from_stream(
stream: &MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
) -> Result<Self, FromStreamError> {
let proxy = stream.proxy_info()?.cloned();
let tls = stream.tls_info()?;
let original = stream.get_ref().ok_or(FromStreamError::GetRef)?;
let local_tcp_addr = original.local_addr()?.into_net();
let peer_tcp_addr = original.peer_addr()?.into_net();
Ok(Self {
proxy,
tls,
local_tcp_addr,
peer_tcp_addr,
})
}
#[must_use]
pub const fn is_proxied(&self) -> bool {
self.proxy.is_some()
}
#[must_use]
pub const fn is_tls(&self) -> bool {
self.tls.is_some()
}
/// Get the outmost peer address, either from the TCP listener or from the
/// proxy protocol infos.
///
/// # Errors
///
/// Returns an error if the info from the proxy protocol was not for a TCP
/// connection, or if the proxy protocol is not being used, the underlying
/// listener was a UNIX domain socket
pub fn peer_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
if let Some(proxy) = self.proxy.as_ref() {
if proxy.is_udp() {
return Err(GrabAddressError::ProxyUdp);
}
proxy.source().ok_or(GrabAddressError::ProxyUnknown)
} else {
self.peer_tcp_addr
.as_ref()
.ok_or(GrabAddressError::UnixListener)
}
}
/// Get the outmost local address, either from the TCP listener or from the
/// proxy protocol infos.
///
/// # Errors
///
/// Returns an error if the info from the proxy protocol was not for a TCP
/// connection, or if the proxy protocol is not being used, the underlying
/// listener was a UNIX domain socket
pub fn local_addr(&self) -> Result<&SocketAddr, GrabAddressError> {
if let Some(proxy) = self.proxy.as_ref() {
if proxy.is_udp() {
return Err(GrabAddressError::ProxyUdp);
}
proxy.destination().ok_or(GrabAddressError::ProxyUnknown)
} else {
self.local_tcp_addr
.as_ref()
.ok_or(GrabAddressError::UnixListener)
}
}
}
pin_project_lite::pin_project! {
pub struct ConnectionInfoAcceptor {
#[pin]
acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>,
}
}
impl ConnectionInfoAcceptor {
pub const fn new(acceptor: MaybeProxyAcceptor<MaybeTlsAcceptor<UnixOrTcpListener>>) -> Self {
Self { acceptor }
}
}
impl Accept for ConnectionInfoAcceptor {
type Conn = ConnectionInfoStream;
type Error = std::io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let proj = self.project();
let ret = match futures_util::ready!(proj.acceptor.poll_accept(cx)) {
Some(Ok(conn)) => Some(Ok(ConnectionInfoStream::new(conn))),
Some(Err(e)) => Some(Err(e)),
None => None,
};
Poll::Ready(ret)
}
}
pin_project_lite::pin_project! {
pub struct ConnectionInfoStream {
connection: Arc<OnceCell<Connection>>,
#[pin]
stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>,
}
}
impl ConnectionInfoStream {
pub fn new(stream: MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>) -> Self {
Self {
connection: Arc::new(OnceCell::const_new()),
stream,
}
}
}
impl Deref for ConnectionInfoStream {
type Target = MaybeProxyStream<MaybeTlsStream<UnixOrTcpConnection>>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl AsyncRead for ConnectionInfoStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
futures_util::ready!(Pin::new(&mut this.stream).poll_read(cx, buf))?;
if !this.stream.is_tls_handshaking()
&& !this.stream.is_proxy_handshaking()
&& !this.connection.initialized()
{
this.connection
.set(this.stream.connection_info().unwrap())
.unwrap();
}
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for ConnectionInfoStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let proj = self.project();
proj.stream.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
let proj = self.project();
proj.stream.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let proj = self.project();
proj.stream.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
let proj = self.project();
proj.stream.poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
}
#[derive(Debug, Clone)]
pub struct IntoMakeServiceWithConnection<S> {
svc: S,
}
impl<S> IntoMakeServiceWithConnection<S> {
pub const fn new(svc: S) -> Self {
Self { svc }
}
}
impl<S> Service<&ConnectionInfoStream> for IntoMakeServiceWithConnection<S>
where
S: Clone,
{
type Response = AddExtension<S, Arc<OnceCell<Connection>>>;
type Error = FromStreamError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, target: &ConnectionInfoStream) -> Self::Future {
std::future::ready(Ok(AddExtension::new(
self.svc.clone(),
target.connection.clone(),
)))
}
}

View File

@@ -22,6 +22,7 @@
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
pub mod info;
pub mod maybe_tls;
pub mod proxy_protocol;
pub mod unix_or_tcp;

View File

@@ -21,8 +21,31 @@ use std::{
use futures_util::{ready, Future};
use hyper::server::accept::Accept;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::{ServerConfig, ServerConnection};
use tokio_rustls::rustls::{
Certificate, ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TlsStreamInfo {
pub protocol_version: ProtocolVersion,
pub negotiated_cipher_suite: SupportedCipherSuite,
pub sni_hostname: Option<String>,
pub apln_protocol: Option<Vec<u8>>,
pub peer_certificates: Option<Vec<Certificate>>,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TlsStreamInfoError {
#[error("TLS handshake is not done yet")]
HandshakingNotDone,
#[error("Some fields were not available in the TLS connection")]
FieldsNotAvailable,
}
pub enum MaybeTlsStream<T> {
Handshaking(tokio_rustls::Accept<T>),
@@ -72,6 +95,44 @@ impl<T> MaybeTlsStream<T> {
Self::Handshaking(_) | Self::Insecure(_) => None,
}
}
/// Gather informations about the TLS connection. Returns `None` if the
/// stream is not a TLS stream.
///
/// # Errors
///
/// Returns an error if the TLS handshake is not yet done
pub fn tls_info(&self) -> Result<Option<TlsStreamInfo>, TlsStreamInfoError> {
let conn = match self {
Self::Streaming(stream) => stream.get_ref().1,
Self::Handshaking(_) => return Err(TlsStreamInfoError::HandshakingNotDone),
Self::Insecure(_) => return Ok(None),
};
// NOTE: we're getting the protocol version and cipher suite *after* the
// handshake, so this should never lead to an error
let protocol_version = conn
.protocol_version()
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
let negotiated_cipher_suite = conn
.negotiated_cipher_suite()
.ok_or(TlsStreamInfoError::FieldsNotAvailable)?;
let sni_hostname = conn.sni_hostname().map(ToOwned::to_owned);
let apln_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
let peer_certificates = conn.peer_certificates().map(ToOwned::to_owned);
Ok(Some(TlsStreamInfo {
protocol_version,
negotiated_cipher_suite,
sni_hostname,
apln_protocol,
peer_certificates,
}))
}
pub const fn is_tls_handshaking(&self) -> bool {
matches!(self, Self::Handshaking(_))
}
}
impl<T> AsyncRead for MaybeTlsStream<T>

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::{
ops::Deref,
pin::Pin,
task::{Context, Poll},
};
@@ -21,7 +22,7 @@ use futures_util::ready;
use hyper::server::accept::Accept;
use tokio::io::{AsyncRead, AsyncWrite};
use super::ProxyStream;
use super::{stream::HandshakeNotDone, ProxyProtocolV1Info, ProxyStream};
pin_project_lite::pin_project! {
pub struct MaybeProxyAcceptor<A> {
@@ -37,6 +38,17 @@ impl<A> MaybeProxyAcceptor<A> {
pub const fn new(inner: A, proxied: bool) -> Self {
Self { proxied, inner }
}
pub const fn is_proxied(&self) -> bool {
self.proxied
}
}
impl<A> Deref for MaybeProxyAcceptor<A> {
type Target = A;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<A> Accept for MaybeProxyAcceptor<A>
@@ -79,6 +91,35 @@ impl<S> MaybeProxyStream<S> {
Self::NotProxied { stream }
}
}
/// Get informations from the proxied connection, if it was procied
///
/// # Errors
///
/// Returns an error if the stream did not complete the handshake yet
pub fn proxy_info(&self) -> Result<Option<&ProxyProtocolV1Info>, HandshakeNotDone> {
match self {
Self::Proxied { stream } => Ok(Some(stream.proxy_info()?)),
Self::NotProxied { .. } => Ok(None),
}
}
pub const fn is_proxy_handshaking(&self) -> bool {
match self {
Self::Proxied { stream } => stream.is_handshaking(),
Self::NotProxied { .. } => false,
}
}
}
impl<S> Deref for MaybeProxyStream<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
match self {
Self::Proxied { stream } => &**stream,
Self::NotProxied { stream } => stream,
}
}
}
impl<S> AsyncRead for MaybeProxyStream<S>

View File

@@ -20,6 +20,6 @@ mod v1;
pub use self::{
acceptor::ProxyAcceptor,
maybe::{MaybeProxyAcceptor, MaybeProxyStream},
stream::ProxyStream,
stream::{HandshakeNotDone as ProxyHandshakeNotDone, ProxyStream},
v1::ProxyProtocolV1Info,
};

View File

@@ -15,6 +15,7 @@
use std::ops::Deref;
use futures_util::ready;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::ProxyProtocolV1Info;
@@ -31,6 +32,12 @@ enum ProxyStreamState {
Established(ProxyProtocolV1Info),
}
impl ProxyStreamState {
pub const fn is_handshaking(&self) -> bool {
matches!(self, Self::Handshaking { .. })
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct ProxyStream<S> {
@@ -53,6 +60,10 @@ impl<S> ProxyStream<S> {
}
}
#[derive(Debug, Error, Clone, Copy)]
#[error("Proxy protocol handshake is not complete")]
pub struct HandshakeNotDone;
impl<S> Deref for ProxyStream<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
@@ -61,12 +72,22 @@ impl<S> Deref for ProxyStream<S> {
}
impl<S> ProxyStream<S> {
pub fn proxy_info(&self) -> Option<&ProxyProtocolV1Info> {
/// Get informations from the proxied connection
///
/// # Errors
///
/// Returns an error if the stream did not complete the handshake yet
pub fn proxy_info(&self) -> Result<&ProxyProtocolV1Info, HandshakeNotDone> {
match &self.state {
ProxyStreamState::Handshaking { .. } => None,
ProxyStreamState::Established(info) => Some(info),
ProxyStreamState::Handshaking { .. } => Err(HandshakeNotDone),
ProxyStreamState::Established(info) => Ok(info),
}
}
/// Returns `true` if the proxy protocol is still handshaking
pub const fn is_handshaking(&self) -> bool {
self.state.is_handshaking()
}
}
impl<S> AsyncRead for ProxyStream<S>

View File

@@ -20,7 +20,7 @@ use std::{
use thiserror::Error;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum ProxyProtocolV1Info {
Tcp {
source: SocketAddr,

View File

@@ -52,6 +52,40 @@ impl std::fmt::Debug for SocketAddr {
}
}
impl SocketAddr {
#[must_use]
pub fn into_net(self) -> Option<std::net::SocketAddr> {
match self {
Self::Net(socket) => Some(socket),
Self::Unix(_) => None,
}
}
#[must_use]
pub fn into_unix(self) -> Option<tokio::net::unix::SocketAddr> {
match self {
Self::Net(_) => None,
Self::Unix(socket) => Some(socket),
}
}
#[must_use]
pub const fn as_net(&self) -> Option<&std::net::SocketAddr> {
match self {
Self::Net(socket) => Some(socket),
Self::Unix(_) => None,
}
}
#[must_use]
pub const fn as_unix(&self) -> Option<&tokio::net::unix::SocketAddr> {
match self {
Self::Net(_) => None,
Self::Unix(socket) => Some(socket),
}
}
}
pub enum UnixOrTcpListener {
Unix(UnixListener),
Tcp(TcpListener),
@@ -98,6 +132,14 @@ impl UnixOrTcpListener {
Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from),
}
}
pub const fn is_unix(&self) -> bool {
matches!(self, Self::Unix(_))
}
pub const fn is_tcp(&self) -> bool {
matches!(self, Self::Tcp(_))
}
}
pin_project_lite::pin_project! {