You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
WIP: better listeners
- listen on UNIX domain sockets - handle TLS stuff - allow mounting only some resources
This commit is contained in:
@@ -13,22 +13,23 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
borrow::Cow,
|
||||
io::Cursor,
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
|
||||
ops::Deref,
|
||||
os::unix::net::UnixListener,
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use listenfd::ListenFd;
|
||||
use mas_keystore::PrivateKey;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_http_address() -> String {
|
||||
"[::]:8080".into()
|
||||
}
|
||||
use super::{secrets::PasswordOrFile, ConfigurationSection};
|
||||
|
||||
fn default_public_base() -> Url {
|
||||
"http://[::]:8080".parse().unwrap()
|
||||
@@ -47,9 +48,31 @@ fn http_address_example_4() -> &'static str {
|
||||
"0.0.0.0:8080"
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum UnixOrTcp {
|
||||
Unix,
|
||||
Tcp,
|
||||
}
|
||||
|
||||
impl UnixOrTcp {
|
||||
pub const fn unix() -> Self {
|
||||
Self::Unix
|
||||
}
|
||||
|
||||
pub const fn tcp() -> Self {
|
||||
Self::Tcp
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum BindConfig {
|
||||
Listen {
|
||||
host: Option<String>,
|
||||
port: u16,
|
||||
},
|
||||
|
||||
Address {
|
||||
#[schemars(
|
||||
example = "http_address_example_1",
|
||||
@@ -59,39 +82,202 @@ pub enum BindConfig {
|
||||
)]
|
||||
address: String,
|
||||
},
|
||||
|
||||
Unix {
|
||||
socket: PathBuf,
|
||||
},
|
||||
|
||||
FileDescriptor {
|
||||
fd: usize,
|
||||
|
||||
#[serde(default = "UnixOrTcp::tcp")]
|
||||
kind: UnixOrTcp,
|
||||
},
|
||||
}
|
||||
|
||||
impl BindConfig {
|
||||
pub fn listener(self, fd_manager: &mut ListenFd) -> Result<TcpListener, anyhow::Error> {
|
||||
// TODO: move this somewhere else
|
||||
pub fn listener<T>(&self, fd_manager: &mut ListenFd) -> Result<T, anyhow::Error>
|
||||
where
|
||||
T: TryFrom<TcpListener> + TryFrom<UnixListener>,
|
||||
<T as TryFrom<TcpListener>>::Error: std::error::Error + Sync + Send + 'static,
|
||||
<T as TryFrom<UnixListener>>::Error: std::error::Error + Sync + Send + 'static,
|
||||
{
|
||||
match self {
|
||||
BindConfig::Listen { host, port } => {
|
||||
let addrs = match host.as_deref() {
|
||||
Some(host) => (host, *port)
|
||||
.to_socket_addrs()
|
||||
.context("could not parse listener host")?
|
||||
.collect(),
|
||||
|
||||
None => vec![
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port),
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port),
|
||||
],
|
||||
};
|
||||
|
||||
let listener = TcpListener::bind(&addrs[..]).context("could not bind address")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
Ok(listener.try_into()?)
|
||||
}
|
||||
|
||||
BindConfig::Address { address } => {
|
||||
let addr: SocketAddr = address
|
||||
.parse()
|
||||
.context("could not parse listener address")?;
|
||||
let listener = TcpListener::bind(addr).context("could not bind address")?;
|
||||
Ok(listener)
|
||||
listener.set_nonblocking(true)?;
|
||||
Ok(listener.try_into()?)
|
||||
}
|
||||
|
||||
BindConfig::FileDescriptor { fd } => {
|
||||
let listener = fd_manager
|
||||
.take_tcp_listener(fd)?
|
||||
.context("no listener found on file descriptor")?;
|
||||
// XXX: Do I need that?
|
||||
BindConfig::Unix { socket } => {
|
||||
let listener = UnixListener::bind(socket).context("could not bind socket")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
Ok(listener)
|
||||
Ok(listener.try_into()?)
|
||||
}
|
||||
|
||||
BindConfig::FileDescriptor {
|
||||
fd,
|
||||
kind: UnixOrTcp::Tcp,
|
||||
} => {
|
||||
let listener = fd_manager
|
||||
.take_tcp_listener(*fd)?
|
||||
.context("no listener found on file descriptor")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
Ok(listener.try_into()?)
|
||||
}
|
||||
|
||||
BindConfig::FileDescriptor {
|
||||
fd,
|
||||
kind: UnixOrTcp::Unix,
|
||||
} => {
|
||||
let listener = fd_manager
|
||||
.take_unix_listener(*fd)?
|
||||
.context("no unix socket found on file descriptor")?;
|
||||
listener.set_nonblocking(true)?;
|
||||
Ok(listener.try_into()?)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum KeyOrFile {
|
||||
Key(String),
|
||||
KeyFile(PathBuf),
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CertificateOrFile {
|
||||
Certificate(String),
|
||||
CertificateFile(PathBuf),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
pub struct TlsConfig {
|
||||
#[serde(flatten)]
|
||||
pub certificate: CertificateOrFile,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub key: KeyOrFile,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub password: Option<PasswordOrFile>,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub async fn load(&self) -> Result<(Vec<u8>, Vec<Vec<u8>>), anyhow::Error> {
|
||||
let password = match &self.password {
|
||||
Some(PasswordOrFile::Password(password)) => Some(Cow::Borrowed(password.as_str())),
|
||||
Some(PasswordOrFile::PasswordFile(path)) => {
|
||||
Some(Cow::Owned(tokio::fs::read_to_string(path).await?))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Read the key either embedded in the config file or on disk
|
||||
let key = match &self.key {
|
||||
KeyOrFile::Key(key) => {
|
||||
// If the key was embedded in the config file, assume it is formatted as PEM
|
||||
if let Some(password) = password {
|
||||
PrivateKey::load_encrypted_pem(key, password.as_bytes())?
|
||||
} else {
|
||||
PrivateKey::load_pem(key)?
|
||||
}
|
||||
}
|
||||
KeyOrFile::KeyFile(path) => {
|
||||
// When reading from disk, it might be either PEM or DER. `PrivateKey::load*`
|
||||
// will try both.
|
||||
let key = tokio::fs::read(path).await?;
|
||||
if let Some(password) = password {
|
||||
PrivateKey::load_encrypted(&key, password.as_bytes())?
|
||||
} else {
|
||||
PrivateKey::load(&key)?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Re-serialize the key to PKCS#8 DER, so rustls can consume it
|
||||
let key = key.to_pkcs8_der()?;
|
||||
// This extracts the Vec out of the Zeroizing by copying it
|
||||
// XXX: maybe we should keep that zeroizing?
|
||||
let key = key.deref().clone();
|
||||
|
||||
let certificate_chain_pem = match &self.certificate {
|
||||
CertificateOrFile::Certificate(pem) => Cow::Borrowed(pem.as_str()),
|
||||
CertificateOrFile::CertificateFile(path) => {
|
||||
Cow::Owned(tokio::fs::read_to_string(path).await?)
|
||||
}
|
||||
};
|
||||
|
||||
let mut certificate_chain_reader = Cursor::new(certificate_chain_pem.as_bytes());
|
||||
let certificate_chain = rustls_pemfile::certs(&mut certificate_chain_reader)?;
|
||||
|
||||
if certificate_chain.is_empty() {
|
||||
bail!("TLS certificate chain is empty (or invalid)")
|
||||
}
|
||||
|
||||
Ok((key, certificate_chain))
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP resources to mount
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
#[serde(tag = "name", rename_all = "lowercase")]
|
||||
pub enum Resource {
|
||||
/// Healthcheck endpoint (/health)
|
||||
Health,
|
||||
|
||||
/// OIDC discovery endpoints
|
||||
Discovery,
|
||||
|
||||
/// Pages destined to be viewed by humans
|
||||
Human,
|
||||
|
||||
/// OAuth-related APIs
|
||||
OAuth,
|
||||
|
||||
/// Matrix compatibility API
|
||||
Compat,
|
||||
|
||||
/// Static files
|
||||
Static,
|
||||
}
|
||||
|
||||
/// Configuration of a listener
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
|
||||
pub struct ListenerConfig {
|
||||
pub name: Option<String>,
|
||||
/// List of resources to mount
|
||||
pub resources: Vec<Resource>,
|
||||
|
||||
/// List of sockets to bind
|
||||
pub binds: Vec<BindConfig>,
|
||||
|
||||
/// If set, makes the listener use TLS with the provided certificate and key
|
||||
pub tls: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
/// Configuration related to the web server
|
||||
@@ -114,12 +300,28 @@ impl Default for HttpConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
web_root: None,
|
||||
listeners: vec![ListenerConfig {
|
||||
name: None,
|
||||
binds: vec![BindConfig::Address {
|
||||
address: default_http_address(),
|
||||
}],
|
||||
}],
|
||||
listeners: vec![
|
||||
ListenerConfig {
|
||||
resources: vec![
|
||||
Resource::Discovery,
|
||||
Resource::Human,
|
||||
Resource::OAuth,
|
||||
Resource::Compat,
|
||||
Resource::Static,
|
||||
],
|
||||
tls: None,
|
||||
binds: vec![BindConfig::Address {
|
||||
address: "[::]:8080".into(),
|
||||
}],
|
||||
},
|
||||
ListenerConfig {
|
||||
resources: vec![Resource::Health],
|
||||
tls: None,
|
||||
binds: vec![BindConfig::Address {
|
||||
address: "localhost:8081".into(),
|
||||
}],
|
||||
},
|
||||
],
|
||||
public_base: default_public_base(),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user