1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

WIP: better listeners

- listen on UNIX domain sockets
- handle TLS stuff
- allow mounting only some resources
This commit is contained in:
Quentin Gliech
2022-10-03 22:19:08 +02:00
parent 7fbfb74a5e
commit 84ac87f551
12 changed files with 1063 additions and 170 deletions

View File

@ -20,6 +20,7 @@ argon2 = { version = "0.4.1", features = ["password-hash"] }
watchman_client = "0.8.0"
atty = "0.2.14"
listenfd = "1.0.0"
rustls = "0.20.6"
tracing = "0.1.36"
tracing-appender = "0.2.2"
@ -44,6 +45,7 @@ mas-static-files = { path = "../static-files" }
mas-storage = { path = "../storage" }
mas-tasks = { path = "../tasks" }
mas-templates = { path = "../templates" }
mas-listener = { path = "../listener" }
[dev-dependencies]
indoc = "1.0.7"

View File

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{net::SocketAddr, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use anyhow::Context;
use clap::Parser;
use futures_util::{
future::FutureExt,
future::{FutureExt, OptionFuture},
stream::{StreamExt, TryStreamExt},
};
use hyper::Server;
@ -25,6 +25,7 @@ use mas_config::RootConfig;
use mas_email::Mailer;
use mas_handlers::{AppState, MatrixHomeserver};
use mas_http::ServerLayer;
use mas_listener::{maybe_tls::MaybeTlsAcceptor, unix_or_tcp::UnixOrTcpListener};
use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder};
use mas_storage::MIGRATOR;
@ -245,39 +246,103 @@ impl Options {
});
let signal = shutdown_signal().shared();
let shutdown_signal = signal.clone();
let mut fd_manager = listenfd::ListenFd::from_env();
let futs = listeners_config
.into_iter()
.map(|listener_config| {
let listeners = listeners_config.into_iter().map(|listener_config| {
// We have to borrow it here, not in the nested closure
let fd_manager = &mut fd_manager;
// Let's first grab all the listeners in a synchronous manner
// This helps with the fd_manager mutable borrow
let listeners: Result<Vec<UnixOrTcpListener>, _> = listener_config
.binds
.iter()
.map(move |bind_config| bind_config.listener(fd_manager))
.collect();
Ok((listener_config, listeners?))
});
// Now that we have the listeners ready, we can do the rest concurrently
futures_util::stream::iter(listeners)
.try_for_each_concurrent(None, move |(config, listeners)| {
let signal = signal.clone();
let router = mas_handlers::router(state.clone())
.nest(mas_router::StaticAsset::route(), static_files.clone())
.layer(ServerLayer::default());
let mut futs: Vec<_> = Vec::with_capacity(listener_config.binds.len());
for bind in listener_config.binds {
let listener = bind.listener(&mut fd_manager)?;
let router = router.clone();
let mut router = mas_handlers::empty_router(state.clone());
let addr = listener.local_addr()?;
info!("Listening on http://{addr}");
let fut = Server::from_tcp(listener)?
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(signal.clone());
futs.push(fut);
for resource in config.resources {
router = match resource {
mas_config::HttpResource::Health => {
router.merge(mas_handlers::healthcheck_router(state.clone()))
}
mas_config::HttpResource::Discovery => {
router.merge(mas_handlers::discovery_router(state.clone()))
}
mas_config::HttpResource::Human => {
router.merge(mas_handlers::human_router(state.clone()))
}
mas_config::HttpResource::Static => {
router.nest(mas_router::StaticAsset::route(), static_files.clone())
}
mas_config::HttpResource::OAuth => {
router.merge(mas_handlers::api_router(state.clone()))
}
mas_config::HttpResource::Compat => {
router.merge(mas_handlers::compat_router(state.clone()))
}
}
}
anyhow::Ok(futures_util::future::try_join_all(futs))
})
.collect::<Result<Vec<_>, _>>()?;
let router = router.layer(ServerLayer::default());
futures_util::future::try_join_all(futs).await?;
async move {
let tls_config: OptionFuture<_> = config
.tls
.map(|tls_config| async move {
let (key, chain) = tls_config.load().await?;
let key = rustls::PrivateKey(key);
let chain = chain.into_iter().map(rustls::Certificate).collect();
let mut config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(chain, key)
.context("failed to build TLS server config")?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
anyhow::Ok(Arc::new(config))
})
.into();
let tls_config = tls_config.await.transpose()?;
futures_util::stream::iter(listeners)
.map(Ok)
.try_for_each_concurrent(None, move |listener| {
let listener = MaybeTlsAcceptor::new(tls_config.clone(), listener);
// Unless there is something really bad happening, we should be able to
// grab the local_addr here. Panicking here if it is not the case is
// probably fine.
let addr = listener.local_addr().unwrap();
if listener.is_secure() {
info!("Listening on https://{addr:?}");
} else {
info!("Listening on http://{addr:?}");
}
Server::builder(listener)
.serve(router.clone().into_make_service())
.with_graceful_shutdown(signal.clone())
})
.await?;
anyhow::Ok(())
}
})
.await?;
// This ensures we're running, even if no listener are setup
// This is useful for only running the task runner
signal.await;
shutdown_signal.await;
Ok(())
}

View File

@ -27,6 +27,7 @@ lettre = { version = "0.10.1", default-features = false, features = ["serde", "b
listenfd = "1.0.0"
pem-rfc7468 = "0.6.0"
rustls-pemfile = "1.0.1"
rand = "0.8.5"
indoc = "1.0.7"

View File

@ -13,22 +13,23 @@
// limitations under the License.
use std::{
net::{SocketAddr, TcpListener},
borrow::Cow,
io::Cursor,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
ops::Deref,
os::unix::net::UnixListener,
path::PathBuf,
};
use anyhow::Context;
use anyhow::{bail, Context};
use async_trait::async_trait;
use listenfd::ListenFd;
use mas_keystore::PrivateKey;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use url::Url;
use super::ConfigurationSection;
fn default_http_address() -> String {
"[::]:8080".into()
}
use super::{secrets::PasswordOrFile, ConfigurationSection};
fn default_public_base() -> Url {
"http://[::]:8080".parse().unwrap()
@ -47,9 +48,31 @@ fn http_address_example_4() -> &'static str {
"0.0.0.0:8080"
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(rename_all = "lowercase")]
pub enum UnixOrTcp {
Unix,
Tcp,
}
impl UnixOrTcp {
pub const fn unix() -> Self {
Self::Unix
}
pub const fn tcp() -> Self {
Self::Tcp
}
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(untagged)]
pub enum BindConfig {
Listen {
host: Option<String>,
port: u16,
},
Address {
#[schemars(
example = "http_address_example_1",
@ -59,39 +82,202 @@ pub enum BindConfig {
)]
address: String,
},
Unix {
socket: PathBuf,
},
FileDescriptor {
fd: usize,
#[serde(default = "UnixOrTcp::tcp")]
kind: UnixOrTcp,
},
}
impl BindConfig {
pub fn listener(self, fd_manager: &mut ListenFd) -> Result<TcpListener, anyhow::Error> {
// TODO: move this somewhere else
pub fn listener<T>(&self, fd_manager: &mut ListenFd) -> Result<T, anyhow::Error>
where
T: TryFrom<TcpListener> + TryFrom<UnixListener>,
<T as TryFrom<TcpListener>>::Error: std::error::Error + Sync + Send + 'static,
<T as TryFrom<UnixListener>>::Error: std::error::Error + Sync + Send + 'static,
{
match self {
BindConfig::Listen { host, port } => {
let addrs = match host.as_deref() {
Some(host) => (host, *port)
.to_socket_addrs()
.context("could not parse listener host")?
.collect(),
None => vec![
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port),
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port),
],
};
let listener = TcpListener::bind(&addrs[..]).context("could not bind address")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
BindConfig::Address { address } => {
let addr: SocketAddr = address
.parse()
.context("could not parse listener address")?;
let listener = TcpListener::bind(addr).context("could not bind address")?;
Ok(listener)
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
BindConfig::FileDescriptor { fd } => {
let listener = fd_manager
.take_tcp_listener(fd)?
.context("no listener found on file descriptor")?;
// XXX: Do I need that?
BindConfig::Unix { socket } => {
let listener = UnixListener::bind(socket).context("could not bind socket")?;
listener.set_nonblocking(true)?;
Ok(listener)
Ok(listener.try_into()?)
}
BindConfig::FileDescriptor {
fd,
kind: UnixOrTcp::Tcp,
} => {
let listener = fd_manager
.take_tcp_listener(*fd)?
.context("no listener found on file descriptor")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
BindConfig::FileDescriptor {
fd,
kind: UnixOrTcp::Unix,
} => {
let listener = fd_manager
.take_unix_listener(*fd)?
.context("no unix socket found on file descriptor")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
}
}
}
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum KeyOrFile {
Key(String),
KeyFile(PathBuf),
}
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum CertificateOrFile {
Certificate(String),
CertificateFile(PathBuf),
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct TlsConfig {
#[serde(flatten)]
pub certificate: CertificateOrFile,
#[serde(flatten)]
pub key: KeyOrFile,
#[serde(flatten)]
pub password: Option<PasswordOrFile>,
}
impl TlsConfig {
pub async fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
let password = match &self.password {
Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())),
Some(PasswordOrFile::PasswordFile(path)) => {
Some(Cow::Owned(tokio::fs::read_to_string(path).await?))
}
None => None,
};
// Read the key either embedded in the config file or on disk
let key = match &self.key {
KeyOrFile::Key(key) => {
// If the key was embedded in the config file, assume it is formatted as PEM
if let Some(password) = password {
PrivateKey::load_encrypted_pem(key, password.as_bytes())?
} else {
PrivateKey::load_pem(key)?
}
}
KeyOrFile::KeyFile(path) => {
// When reading from disk, it might be either PEM or DER. `PrivateKey::load*`
// will try both.
let key = tokio::fs::read(path).await?;
if let Some(password) = password {
PrivateKey::load_encrypted(&key, password.as_bytes())?
} else {
PrivateKey::load(&key)?
}
}
};
// Re-serialize the key to PKCS#8 DER, so rustls can consume it
let key = key.to_pkcs8_der()?;
// This extracts the Vec out of the Zeroizing by copying it
// XXX: maybe we should keep that zeroizing?
let key = key.deref().clone();
let certificate_chain_pem = match &self.certificate {
CertificateOrFile::Certificate(pem) => Cow::Borrowed(pem.as_str()),
CertificateOrFile::CertificateFile(path) => {
Cow::Owned(tokio::fs::read_to_string(path).await?)
}
};
let mut certificate_chain_reader = Cursor::new(certificate_chain_pem.as_bytes());
let certificate_chain = rustls_pemfile::certs(&mut certificate_chain_reader)?;
if certificate_chain.is_empty() {
bail!("TLS certificate chain is empty (or invalid)")
}
Ok((key, certificate_chain))
}
}
/// HTTP resources to mount
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(tag = "name", rename_all = "lowercase")]
pub enum Resource {
/// Healthcheck endpoint (/health)
Health,
/// OIDC discovery endpoints
Discovery,
/// Pages destined to be viewed by humans
Human,
/// OAuth-related APIs
OAuth,
/// Matrix compatibility API
Compat,
/// Static files
Static,
}
/// Configuration of a listener
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct ListenerConfig {
pub name: Option<String>,
/// List of resources to mount
pub resources: Vec<Resource>,
/// List of sockets to bind
pub binds: Vec<BindConfig>,
/// If set, makes the listener use TLS with the provided certificate and key
pub tls: Option<TlsConfig>,
}
/// Configuration related to the web server
@ -114,12 +300,28 @@ impl Default for HttpConfig {
fn default() -> Self {
Self {
web_root: None,
listeners: vec![ListenerConfig {
name: None,
binds: vec![BindConfig::Address {
address: default_http_address(),
}],
}],
listeners: vec![
ListenerConfig {
resources: vec![
Resource::Discovery,
Resource::Human,
Resource::OAuth,
Resource::Compat,
Resource::Static,
],
tls: None,
binds: vec![BindConfig::Address {
address: "[::]:8080".into(),
}],
},
ListenerConfig {
resources: vec![Resource::Health],
tls: None,
binds: vec![BindConfig::Address {
address: "localhost:8081".into(),
}],
},
],
public_base: default_public_base(),
}
}

View File

@ -32,7 +32,7 @@ pub use self::{
csrf::CsrfConfig,
database::DatabaseConfig,
email::{EmailConfig, EmailSmtpMode, EmailTransportConfig},
http::HttpConfig,
http::{HttpConfig, Resource as HttpResource},
matrix::MatrixConfig,
policy::PolicyConfig,
secrets::SecretsConfig,

View File

@ -50,6 +50,57 @@ pub use compat::MatrixHomeserver;
pub use self::app_state::AppState;
#[must_use]
pub fn empty_router<S, B>(state: Arc<S>) -> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Send + Sync + 'static,
{
Router::with_state_arc(state)
}
#[must_use]
pub fn healthcheck_router<S, B>(state: Arc<S>) -> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Send + Sync + 'static,
PgPool: FromRef<S>,
{
Router::with_state_arc(state).route(mas_router::Healthcheck::route(), get(self::health::get))
}
#[must_use]
pub fn discovery_router<S, B>(state: Arc<S>) -> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
{
Router::with_state_arc(state)
.route(
mas_router::OidcConfiguration::route(),
get(self::oauth2::discovery::get),
)
.route(
mas_router::Webfinger::route(),
get(self::oauth2::webfinger::get),
)
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_otel_headers([
AUTHORIZATION,
ACCEPT,
ACCEPT_LANGUAGE,
CONTENT_LANGUAGE,
CONTENT_TYPE,
])
.max_age(Duration::from_secs(60 * 60)),
)
}
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn api_router<S, B>(state: Arc<S>) -> Router<S, B>
@ -66,19 +117,6 @@ where
{
// All those routes are API-like, with a common CORS layer
Router::with_state_arc(state)
.route(mas_router::Healthcheck::route(), get(self::health::get))
.route(
mas_router::ChangePasswordDiscovery::route(),
get(|| async { mas_router::AccountPassword.go() }),
)
.route(
mas_router::OidcConfiguration::route(),
get(self::oauth2::discovery::get),
)
.route(
mas_router::Webfinger::route(),
get(self::oauth2::webfinger::get),
)
.route(
mas_router::OAuth2Keys::route(),
get(self::oauth2::keys::get),
@ -116,6 +154,7 @@ where
.max_age(Duration::from_secs(60 * 60)),
)
}
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn compat_router<S, B>(state: Arc<S>) -> Router<S, B>
@ -174,6 +213,10 @@ where
{
let templates = Templates::from_ref(&state);
Router::with_state_arc(state)
.route(
mas_router::ChangePasswordDiscovery::route(),
get(|| async { mas_router::AccountPassword.go() }),
)
.route(mas_router::Index::route(), get(self::views::index::get))
.route(
mas_router::Login::route(),
@ -267,11 +310,18 @@ where
Mailer: FromRef<S>,
MatrixHomeserver: FromRef<S>,
{
let healthcheck_router = healthcheck_router(state.clone());
let discovery_router = discovery_router(state.clone());
let api_router = api_router(state.clone());
let compat_router = compat_router(state.clone());
let human_router = human_router(state);
let human_router = human_router(state.clone());
human_router.merge(api_router).merge(compat_router)
Router::with_state_arc(state)
.merge(healthcheck_router)
.merge(discovery_router)
.merge(human_router)
.merge(api_router)
.merge(compat_router)
}
#[cfg(test)]

View File

@ -26,6 +26,7 @@
use std::{ops::Deref, sync::Arc};
use der::{zeroize::Zeroizing, Decode};
use elliptic_curve::pkcs8::EncodePrivateKey;
use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg};
pub use mas_jose::jwk::{JsonWebKey, JsonWebKeySet};
use mas_jose::{
@ -213,6 +214,22 @@ impl PrivateKey {
Ok(der)
}
/// Serialize the key as a PKCS8 DER document
///
/// # Errors
///
/// Returns an error if the encoding failed
pub fn to_pkcs8_der(&self) -> Result<Zeroizing<Vec<u8>>, anyhow::Error> {
let der = match self {
PrivateKey::Rsa(key) => key.to_pkcs8_der()?,
PrivateKey::EcP256(key) => key.to_pkcs8_der()?,
PrivateKey::EcP384(key) => key.to_pkcs8_der()?,
PrivateKey::EcK256(key) => key.to_pkcs8_der()?,
};
Ok(der.to_bytes())
}
/// Serialize the key as a PEM document
///
/// It will use the most common format depending on the key type: PKCS1 for

View File

@ -0,0 +1,14 @@
[package]
name = "mas-listener"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
tokio = { version = "1.21.2", features = ["net"] }
pin-project-lite = "0.2.9"
hyper = { version = "0.14.20", features = ["server"] }
futures-util = "0.3.24"
tracing = "0.1.36"
tokio-rustls = "0.23.4"

View File

@ -0,0 +1,16 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod maybe_tls;
pub mod unix_or_tcp;

View File

@ -0,0 +1,213 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
ops::Deref,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use futures_util::Future;
use hyper::server::accept::Accept;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::{ServerConfig, ServerConnection};
pub enum MaybeTlsStream<T> {
Handshaking(tokio_rustls::Accept<T>),
Streaming(tokio_rustls::server::TlsStream<T>),
Insecure(T),
}
impl<T> MaybeTlsStream<T> {
pub fn new(stream: T, config: Option<Arc<ServerConfig>>) -> Self
where
T: AsyncRead + AsyncWrite + Unpin,
{
if let Some(config) = config {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
MaybeTlsStream::Handshaking(accept)
} else {
MaybeTlsStream::Insecure(stream)
}
}
/// Get a reference to the underlying IO stream
///
/// Returns [`None`] if the stream closed before the TLS handshake finished.
/// It is guaranteed to return [`Some`] value after the handshake finished,
/// or if it is a non-TLS connection.
pub fn get_ref(&self) -> Option<&T> {
match self {
Self::Handshaking(accept) => accept.get_ref(),
Self::Streaming(stream) => {
let (inner, _) = stream.get_ref();
Some(inner)
}
Self::Insecure(inner) => Some(inner),
}
}
/// Get a ref to the [`ServerConnection`] of the establish TLS stream.
///
/// Returns [`None`] if the connection is still handshaking and for non-TLS
/// connections.
pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
match self {
Self::Streaming(stream) => {
let (_, conn) = stream.get_ref();
Some(conn)
}
Self::Handshaking(_) | Self::Insecure(_) => None,
}
}
}
impl<T> AsyncRead for MaybeTlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<std::io::Result<()>> {
let pin = self.get_mut();
match pin {
MaybeTlsStream::Handshaking(ref mut accept) => {
match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
*pin = MaybeTlsStream::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
}
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl<T> AsyncWrite for MaybeTlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let pin = self.get_mut();
match pin {
MaybeTlsStream::Handshaking(ref mut accept) => {
match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
*pin = MaybeTlsStream::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
}
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
MaybeTlsStream::Insecure(ref mut fallback) => Pin::new(fallback).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
pub struct MaybeTlsAcceptor<T> {
tls_config: Option<Arc<ServerConfig>>,
incoming: T,
}
impl<T> MaybeTlsAcceptor<T> {
pub fn new(tls_config: Option<Arc<ServerConfig>>, incoming: T) -> Self {
Self {
tls_config,
incoming,
}
}
pub fn new_secure(tls_config: Arc<ServerConfig>, incoming: T) -> Self {
Self {
tls_config: Some(tls_config),
incoming,
}
}
pub fn new_insecure(incoming: T) -> Self {
Self {
tls_config: None,
incoming,
}
}
pub const fn is_secure(&self) -> bool {
self.tls_config.is_some()
}
}
impl<T> Accept for MaybeTlsAcceptor<T>
where
T: Accept + Unpin,
T::Conn: AsyncRead + AsyncWrite + Unpin,
T::Error: Into<std::io::Error>,
{
type Conn = MaybeTlsStream<T::Conn>;
type Error = std::io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
let ret = match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => {
let config = pin.tls_config.clone();
Some(Ok(MaybeTlsStream::new(sock, config)))
}
Some(Err(e)) => Some(Err(e.into())),
None => None,
};
Poll::Ready(ret)
}
}
impl<T> Deref for MaybeTlsAcceptor<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.incoming
}
}

View File

@ -0,0 +1,212 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO: Unlink the UNIX socket on drop?
// TODO: Proxy protocol
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures_util::ready;
use hyper::server::accept::Accept;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream, UnixListener, UnixStream},
};
pub enum SocketAddr {
Unix(tokio::net::unix::SocketAddr),
Net(std::net::SocketAddr),
}
impl From<tokio::net::unix::SocketAddr> for SocketAddr {
fn from(value: tokio::net::unix::SocketAddr) -> Self {
Self::Unix(value)
}
}
impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self {
Self::Net(value)
}
}
impl std::fmt::Debug for SocketAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unix(l) => std::fmt::Debug::fmt(l, f),
Self::Net(l) => std::fmt::Debug::fmt(l, f),
}
}
}
pub enum UnixOrTcpListener {
Unix(UnixListener),
Tcp(TcpListener),
}
impl From<UnixListener> for UnixOrTcpListener {
fn from(listener: UnixListener) -> Self {
Self::Unix(listener)
}
}
impl From<TcpListener> for UnixOrTcpListener {
fn from(listener: TcpListener) -> Self {
Self::Tcp(listener)
}
}
impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
type Error = std::io::Error;
fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
Ok(Self::Unix(UnixListener::from_std(listener)?))
}
}
impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
type Error = std::io::Error;
fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
Ok(Self::Tcp(TcpListener::from_std(listener)?))
}
}
impl UnixOrTcpListener {
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from),
}
}
}
pin_project_lite::pin_project! {
#[project = UnixOrTcpConnectionProj]
pub enum UnixOrTcpConnection {
Unix {
#[pin]
stream: UnixStream,
},
Tcp {
#[pin]
stream: TcpStream,
},
}
}
impl UnixOrTcpConnection {
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
Self::Tcp { stream, .. } => stream.local_addr().map(SocketAddr::from),
}
}
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),
Self::Tcp { stream, .. } => stream.peer_addr().map(SocketAddr::from),
}
}
}
impl Accept for UnixOrTcpListener {
type Error = std::io::Error;
type Conn = UnixOrTcpConnection;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
let conn = match &*self {
Self::Unix(listener) => {
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
UnixOrTcpConnection::Unix { stream }
}
Self::Tcp(listener) => {
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
UnixOrTcpConnection::Tcp { stream }
}
};
Poll::Ready(Some(Ok(conn)))
}
}
impl AsyncRead for UnixOrTcpConnection {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_read(cx, buf),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_read(cx, buf),
}
}
}
impl AsyncWrite for UnixOrTcpConnection {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write(cx, buf),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write_vectored(cx, bufs),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
UnixOrTcpConnection::Unix { stream } => stream.is_write_vectored(),
UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
}
}
}