1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-21 23:00:50 +03:00

Refactor listeners building

This commit is contained in:
Quentin Gliech
2022-10-05 13:19:02 +02:00
parent 014a8366ed
commit c548417752
10 changed files with 245 additions and 157 deletions

View File

@@ -12,18 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
borrow::Cow,
io::Cursor,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
ops::Deref,
os::unix::net::UnixListener,
path::PathBuf,
};
use std::{borrow::Cow, io::Cursor, ops::Deref, path::PathBuf};
use anyhow::{bail, Context};
use anyhow::bail;
use async_trait::async_trait;
use listenfd::ListenFd;
use mas_keystore::PrivateKey;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -48,32 +40,49 @@ fn http_address_example_4() -> &'static str {
"0.0.0.0:8080"
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
/// Kind of socket
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum UnixOrTcp {
/// UNIX domain socket
Unix,
/// TCP socket
Tcp,
}
impl UnixOrTcp {
/// UNIX domain socket
#[must_use]
pub const fn unix() -> Self {
Self::Unix
}
/// TCP socket
#[must_use]
pub const fn tcp() -> Self {
Self::Tcp
}
}
/// Configuration of a single listener
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(untagged)]
pub enum BindConfig {
/// Listen on the specified host and port
Listen {
/// Host on which to listen.
///
/// Defaults to listening on all addresses
host: Option<String>,
/// Port on which to listen.
port: u16,
},
/// Listen on the specified address
Address {
/// Host and port on which to listen
#[schemars(
example = "http_address_example_1",
example = "http_address_example_2",
@@ -83,85 +92,30 @@ pub enum BindConfig {
address: String,
},
/// Listen on a UNIX domain socket
Unix {
/// Path to the socket
socket: PathBuf,
},
/// Accept connections on file descriptors passed by the parent process.
///
/// This is useful for grabbing sockets passed by systemd.
///
/// See <https://www.freedesktop.org/software/systemd/man/sd_listen_fds.html>
FileDescriptor {
/// Index of the file descriptor. Note that this is offseted by 3
/// because of the standard input/output sockets, so setting
/// here a value of `0` will grab the file descriptor `3`
fd: usize,
/// Whether the socket is a TCP socket or a UNIX domain socket. Defaults
/// to TCP.
#[serde(default = "UnixOrTcp::tcp")]
kind: UnixOrTcp,
},
}
impl BindConfig {
// TODO: move this somewhere else
pub fn listener<T>(&self, fd_manager: &mut ListenFd) -> Result<T, anyhow::Error>
where
T: TryFrom<TcpListener> + TryFrom<UnixListener>,
<T as TryFrom<TcpListener>>::Error: std::error::Error + Sync + Send + 'static,
<T as TryFrom<UnixListener>>::Error: std::error::Error + Sync + Send + 'static,
{
match self {
BindConfig::Listen { host, port } => {
let addrs = match host.as_deref() {
Some(host) => (host, *port)
.to_socket_addrs()
.context("could not parse listener host")?
.collect(),
None => vec![
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port),
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port),
],
};
let listener = TcpListener::bind(&addrs[..]).context("could not bind address")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
BindConfig::Address { address } => {
let addr: SocketAddr = address
.parse()
.context("could not parse listener address")?;
let listener = TcpListener::bind(addr).context("could not bind address")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
BindConfig::Unix { socket } => {
let listener = UnixListener::bind(socket).context("could not bind socket")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
BindConfig::FileDescriptor {
fd,
kind: UnixOrTcp::Tcp,
} => {
let listener = fd_manager
.take_tcp_listener(*fd)?
.context("no listener found on file descriptor")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
BindConfig::FileDescriptor {
fd,
kind: UnixOrTcp::Unix,
} => {
let listener = fd_manager
.take_unix_listener(*fd)?
.context("no unix socket found on file descriptor")?;
listener.set_nonblocking(true)?;
Ok(listener.try_into()?)
}
}
}
}
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum KeyOrFile {
@@ -176,19 +130,34 @@ pub enum CertificateOrFile {
CertificateFile(PathBuf),
}
/// Configuration related to TLS on a listener
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct TlsConfig {
/// PEM-encoded X509 certificate chain
#[serde(flatten)]
pub certificate: CertificateOrFile,
/// Private key
#[serde(flatten)]
pub key: KeyOrFile,
/// Password used to decode the private key
#[serde(flatten)]
pub password: Option<PasswordOrFile>,
}
impl TlsConfig {
/// Load the TLS certificate chain and key file from disk
///
/// # Errors
///
/// Returns an error if an error was encountered either while:
/// - reading the certificate, key or password files
/// - decoding the key as PEM or DER
/// - decrypting the key if encrypted
/// - a password was provided but the key was not encrypted
/// - decoding the certificate chain as PEM
/// - the certificate chain is empty
pub async fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
let password = match &self.password {
Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())),
@@ -267,12 +236,21 @@ pub enum Resource {
Compat,
/// Static files
Static,
Static {
/// Path from which to serve static files. If not specified, it will
/// serve the static files embedded in the server binary
#[serde(default)]
web_root: Option<PathBuf>,
},
}
/// Configuration of a listener
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct ListenerConfig {
/// A unique name for this listener which will be shown in traces and in
/// metrics labels
pub name: Option<String>,
/// List of resources to mount
pub resources: Vec<Resource>,
@@ -290,11 +268,6 @@ pub struct HttpConfig {
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
/// Path from which to serve static files. If not specified, it will serve
/// the static files embedded in the server binary
#[serde(default)]
pub web_root: Option<PathBuf>,
/// Public URL base from where the authentication service is reachable
pub public_base: Url,
}
@@ -302,15 +275,15 @@ pub struct HttpConfig {
impl Default for HttpConfig {
fn default() -> Self {
Self {
web_root: None,
listeners: vec![
ListenerConfig {
name: Some("web".to_owned()),
resources: vec![
Resource::Discovery,
Resource::Human,
Resource::OAuth,
Resource::Compat,
Resource::Static,
Resource::Static { web_root: None },
],
tls: None,
binds: vec![BindConfig::Address {
@@ -318,6 +291,7 @@ impl Default for HttpConfig {
}],
},
ListenerConfig {
name: Some("internal".to_owned()),
resources: vec![Resource::Health],
tls: None,
binds: vec![BindConfig::Address {