You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
WIP: better listeners
- listen on UNIX domain sockets - handle TLS stuff - allow mounting only some resources
This commit is contained in:
16
crates/listener/src/lib.rs
Normal file
16
crates/listener/src/lib.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod maybe_tls;
|
||||
pub mod unix_or_tcp;
|
||||
213
crates/listener/src/maybe_tls.rs
Normal file
213
crates/listener/src/maybe_tls.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::Future;
|
||||
use hyper::server::accept::Accept;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::rustls::{ServerConfig, ServerConnection};
|
||||
|
||||
pub enum MaybeTlsStream<T> {
|
||||
Handshaking(tokio_rustls::Accept<T>),
|
||||
Streaming(tokio_rustls::server::TlsStream<T>),
|
||||
Insecure(T),
|
||||
}
|
||||
|
||||
impl<T> MaybeTlsStream<T> {
|
||||
pub fn new(stream: T, config: Option<Arc<ServerConfig>>) -> Self
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
if let Some(config) = config {
|
||||
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
|
||||
MaybeTlsStream::Handshaking(accept)
|
||||
} else {
|
||||
MaybeTlsStream::Insecure(stream)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying IO stream
|
||||
///
|
||||
/// Returns [`None`] if the stream closed before the TLS handshake finished.
|
||||
/// It is guaranteed to return [`Some`] value after the handshake finished,
|
||||
/// or if it is a non-TLS connection.
|
||||
pub fn get_ref(&self) -> Option<&T> {
|
||||
match self {
|
||||
Self::Handshaking(accept) => accept.get_ref(),
|
||||
Self::Streaming(stream) => {
|
||||
let (inner, _) = stream.get_ref();
|
||||
Some(inner)
|
||||
}
|
||||
Self::Insecure(inner) => Some(inner),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a ref to the [`ServerConnection`] of the establish TLS stream.
|
||||
///
|
||||
/// Returns [`None`] if the connection is still handshaking and for non-TLS
|
||||
/// connections.
|
||||
pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
|
||||
match self {
|
||||
Self::Streaming(stream) => {
|
||||
let (_, conn) = stream.get_ref();
|
||||
Some(conn)
|
||||
}
|
||||
Self::Handshaking(_) | Self::Insecure(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncRead for MaybeTlsStream<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &mut ReadBuf,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let pin = self.get_mut();
|
||||
match pin {
|
||||
MaybeTlsStream::Handshaking(ref mut accept) => {
|
||||
match ready!(Pin::new(accept).poll(cx)) {
|
||||
Ok(mut stream) => {
|
||||
let result = Pin::new(&mut stream).poll_read(cx, buf);
|
||||
*pin = MaybeTlsStream::Streaming(stream);
|
||||
result
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncWrite for MaybeTlsStream<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
let pin = self.get_mut();
|
||||
match pin {
|
||||
MaybeTlsStream::Handshaking(ref mut accept) => {
|
||||
match ready!(Pin::new(accept).poll(cx)) {
|
||||
Ok(mut stream) => {
|
||||
let result = Pin::new(&mut stream).poll_write(cx, buf);
|
||||
*pin = MaybeTlsStream::Streaming(stream);
|
||||
result
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
MaybeTlsStream::Insecure(ref mut fallback) => Pin::new(fallback).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Handshaking { .. } => Poll::Ready(Ok(())),
|
||||
MaybeTlsStream::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
MaybeTlsStream::Insecure(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MaybeTlsAcceptor<T> {
|
||||
tls_config: Option<Arc<ServerConfig>>,
|
||||
incoming: T,
|
||||
}
|
||||
|
||||
impl<T> MaybeTlsAcceptor<T> {
|
||||
pub fn new(tls_config: Option<Arc<ServerConfig>>, incoming: T) -> Self {
|
||||
Self {
|
||||
tls_config,
|
||||
incoming,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_secure(tls_config: Arc<ServerConfig>, incoming: T) -> Self {
|
||||
Self {
|
||||
tls_config: Some(tls_config),
|
||||
incoming,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_insecure(incoming: T) -> Self {
|
||||
Self {
|
||||
tls_config: None,
|
||||
incoming,
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_secure(&self) -> bool {
|
||||
self.tls_config.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Accept for MaybeTlsAcceptor<T>
|
||||
where
|
||||
T: Accept + Unpin,
|
||||
T::Conn: AsyncRead + AsyncWrite + Unpin,
|
||||
T::Error: Into<std::io::Error>,
|
||||
{
|
||||
type Conn = MaybeTlsStream<T::Conn>;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let pin = self.get_mut();
|
||||
|
||||
let ret = match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
|
||||
Some(Ok(sock)) => {
|
||||
let config = pin.tls_config.clone();
|
||||
Some(Ok(MaybeTlsStream::new(sock, config)))
|
||||
}
|
||||
|
||||
Some(Err(e)) => Some(Err(e.into())),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Poll::Ready(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for MaybeTlsAcceptor<T> {
|
||||
type Target = T;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.incoming
|
||||
}
|
||||
}
|
||||
212
crates/listener/src/unix_or_tcp.rs
Normal file
212
crates/listener/src/unix_or_tcp.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// TODO: Unlink the UNIX socket on drop?
|
||||
// TODO: Proxy protocol
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::ready;
|
||||
use hyper::server::accept::Accept;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::{TcpListener, TcpStream, UnixListener, UnixStream},
|
||||
};
|
||||
|
||||
pub enum SocketAddr {
|
||||
Unix(tokio::net::unix::SocketAddr),
|
||||
Net(std::net::SocketAddr),
|
||||
}
|
||||
|
||||
impl From<tokio::net::unix::SocketAddr> for SocketAddr {
|
||||
fn from(value: tokio::net::unix::SocketAddr) -> Self {
|
||||
Self::Unix(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::net::SocketAddr> for SocketAddr {
|
||||
fn from(value: std::net::SocketAddr) -> Self {
|
||||
Self::Net(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SocketAddr {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Unix(l) => std::fmt::Debug::fmt(l, f),
|
||||
Self::Net(l) => std::fmt::Debug::fmt(l, f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum UnixOrTcpListener {
|
||||
Unix(UnixListener),
|
||||
Tcp(TcpListener),
|
||||
}
|
||||
|
||||
impl From<UnixListener> for UnixOrTcpListener {
|
||||
fn from(listener: UnixListener) -> Self {
|
||||
Self::Unix(listener)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TcpListener> for UnixOrTcpListener {
|
||||
fn from(listener: TcpListener) -> Self {
|
||||
Self::Tcp(listener)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
|
||||
Ok(Self::Unix(UnixListener::from_std(listener)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
|
||||
Ok(Self::Tcp(TcpListener::from_std(listener)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl UnixOrTcpListener {
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
|
||||
Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[project = UnixOrTcpConnectionProj]
|
||||
pub enum UnixOrTcpConnection {
|
||||
Unix {
|
||||
#[pin]
|
||||
stream: UnixStream,
|
||||
},
|
||||
|
||||
Tcp {
|
||||
#[pin]
|
||||
stream: TcpStream,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
impl UnixOrTcpConnection {
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
||||
Self::Tcp { stream, .. } => stream.local_addr().map(SocketAddr::from),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
match self {
|
||||
Self::Unix { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
||||
Self::Tcp { stream, .. } => stream.peer_addr().map(SocketAddr::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Accept for UnixOrTcpListener {
|
||||
type Error = std::io::Error;
|
||||
type Conn = UnixOrTcpConnection;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let conn = match &*self {
|
||||
Self::Unix(listener) => {
|
||||
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
|
||||
UnixOrTcpConnection::Unix { stream }
|
||||
}
|
||||
|
||||
Self::Tcp(listener) => {
|
||||
let (stream, _remote_addr) = ready!(listener.poll_accept(cx))?;
|
||||
UnixOrTcpConnection::Tcp { stream }
|
||||
}
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(conn)))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for UnixOrTcpConnection {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_read(cx, buf),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for UnixOrTcpConnection {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write(cx, buf),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.project() {
|
||||
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match self {
|
||||
UnixOrTcpConnection::Unix { stream } => stream.is_write_vectored(),
|
||||
UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user