You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Split the service in multiple crates
This commit is contained in:
49
crates/config/src/cookies.rs
Normal file
49
crates/config/src/cookies.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn secret_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
String::json_schema(gen)
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CookiesConfig {
|
||||
#[schemars(schema_with = "secret_schema")]
|
||||
#[serde_as(as = "serde_with::hex::Hex")]
|
||||
pub secret: [u8; 32],
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for CookiesConfig {
|
||||
fn path() -> &'static str {
|
||||
"cookies"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
secret: rand::random(),
|
||||
})
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self { secret: [0xEA; 32] }
|
||||
}
|
||||
}
|
||||
85
crates/config/src/csrf.rs
Normal file
85
crates/config/src/csrf.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::Duration;
|
||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_ttl() -> Duration {
|
||||
Duration::hours(1)
|
||||
}
|
||||
|
||||
fn ttl_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
u64::json_schema(gen)
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CsrfConfig {
|
||||
#[schemars(schema_with = "ttl_schema")]
|
||||
#[serde(default = "default_ttl")]
|
||||
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
impl Default for CsrfConfig {
|
||||
fn default() -> Self {
|
||||
Self { ttl: default_ttl() }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for CsrfConfig {
|
||||
fn path() -> &'static str {
|
||||
"csrf"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use figment::Jail;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_config() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"config.yaml",
|
||||
r#"
|
||||
csrf:
|
||||
ttl: 1800
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let config = CsrfConfig::load_from_file("config.yaml")?;
|
||||
|
||||
assert_eq!(config.ttl, Duration::minutes(30));
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
166
crates/config/src/database.rs
Normal file
166
crates/config/src/database.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, skip_serializing_none};
|
||||
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
|
||||
|
||||
// FIXME
|
||||
// use sqlx::ConnectOptions
|
||||
// use tracing::log::LevelFilter;
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_uri() -> String {
|
||||
"postgresql://".to_string()
|
||||
}
|
||||
|
||||
fn default_max_connections() -> u32 {
|
||||
10
|
||||
}
|
||||
|
||||
fn default_connect_timeout() -> Duration {
|
||||
Duration::from_secs(30)
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn default_idle_timeout() -> Option<Duration> {
|
||||
Some(Duration::from_secs(10 * 60))
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn default_max_lifetime() -> Option<Duration> {
|
||||
Some(Duration::from_secs(30 * 60))
|
||||
}
|
||||
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
uri: default_uri(),
|
||||
max_connections: default_max_connections(),
|
||||
min_connections: Default::default(),
|
||||
connect_timeout: default_connect_timeout(),
|
||||
idle_timeout: default_idle_timeout(),
|
||||
max_lifetime: default_max_lifetime(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn duration_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
Option::<u64>::json_schema(gen)
|
||||
}
|
||||
|
||||
fn optional_duration_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
u64::json_schema(gen)
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DatabaseConfig {
|
||||
#[serde(default = "default_uri")]
|
||||
uri: String,
|
||||
|
||||
#[serde(default = "default_max_connections")]
|
||||
max_connections: u32,
|
||||
|
||||
#[serde(default)]
|
||||
min_connections: u32,
|
||||
|
||||
#[schemars(schema_with = "duration_schema")]
|
||||
#[serde(default = "default_connect_timeout")]
|
||||
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
|
||||
connect_timeout: Duration,
|
||||
|
||||
#[schemars(schema_with = "optional_duration_schema")]
|
||||
#[serde(default = "default_idle_timeout")]
|
||||
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
|
||||
idle_timeout: Option<Duration>,
|
||||
|
||||
#[schemars(schema_with = "optional_duration_schema")]
|
||||
#[serde(default = "default_max_lifetime")]
|
||||
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
|
||||
max_lifetime: Option<Duration>,
|
||||
}
|
||||
|
||||
impl DatabaseConfig {
|
||||
#[tracing::instrument(err)]
|
||||
pub async fn connect(&self) -> anyhow::Result<PgPool> {
|
||||
let options = self
|
||||
.uri
|
||||
.parse::<PgConnectOptions>()
|
||||
.context("invalid database URL")?
|
||||
.application_name("matrix-authentication-service");
|
||||
|
||||
// FIXME
|
||||
// options
|
||||
// .log_statements(LevelFilter::Debug)
|
||||
// .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
|
||||
|
||||
PgPoolOptions::new()
|
||||
.max_connections(self.max_connections)
|
||||
.min_connections(self.min_connections)
|
||||
.connect_timeout(self.connect_timeout)
|
||||
.idle_timeout(self.idle_timeout)
|
||||
.max_lifetime(self.max_lifetime)
|
||||
.connect_with(options)
|
||||
.await
|
||||
.context("could not connect to the database")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for DatabaseConfig {
|
||||
fn path() -> &'static str {
|
||||
"database"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use figment::Jail;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_config() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"config.yaml",
|
||||
r#"
|
||||
database:
|
||||
uri: postgresql://user:password@host/database
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let config = DatabaseConfig::load_from_file("config.yaml")?;
|
||||
|
||||
assert_eq!(config.uri, "postgresql://user:password@host/database");
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
52
crates/config/src/http.rs
Normal file
52
crates/config/src/http.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_http_address() -> String {
|
||||
"[::]:8080".into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct HttpConfig {
|
||||
#[serde(default = "default_http_address")]
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
impl Default for HttpConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
address: default_http_address(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for HttpConfig {
|
||||
fn path() -> &'static str {
|
||||
"http"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
76
crates/config/src/lib.rs
Normal file
76
crates/config/src/lib.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod cookies;
|
||||
mod csrf;
|
||||
mod database;
|
||||
mod http;
|
||||
mod oauth2;
|
||||
mod util;
|
||||
|
||||
pub use self::{
|
||||
cookies::CookiesConfig,
|
||||
csrf::CsrfConfig,
|
||||
database::DatabaseConfig,
|
||||
http::HttpConfig,
|
||||
oauth2::{Algorithm, KeySet, OAuth2ClientConfig, OAuth2Config},
|
||||
util::ConfigurationSection,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct RootConfig {
|
||||
pub oauth2: OAuth2Config,
|
||||
|
||||
#[serde(default)]
|
||||
pub http: HttpConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub database: DatabaseConfig,
|
||||
|
||||
pub cookies: CookiesConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub csrf: CsrfConfig,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for RootConfig {
|
||||
fn path() -> &'static str {
|
||||
""
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
oauth2: OAuth2Config::generate().await?,
|
||||
http: HttpConfig::generate().await?,
|
||||
database: DatabaseConfig::generate().await?,
|
||||
cookies: CookiesConfig::generate().await?,
|
||||
csrf: CsrfConfig::generate().await?,
|
||||
})
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self {
|
||||
oauth2: OAuth2Config::test(),
|
||||
http: HttpConfig::test(),
|
||||
database: DatabaseConfig::test(),
|
||||
cookies: CookiesConfig::test(),
|
||||
csrf: CsrfConfig::test(),
|
||||
}
|
||||
}
|
||||
}
|
||||
476
crates/config/src/oauth2.rs
Normal file
476
crates/config/src/oauth2.rs
Normal file
@@ -0,0 +1,476 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use jwt_compact::{
|
||||
alg::{self, StrongAlg, StrongKey},
|
||||
jwk::JsonWebKey,
|
||||
AlgorithmExt, Claims, Header,
|
||||
};
|
||||
use pkcs8::{FromPrivateKey, ToPrivateKey};
|
||||
use rsa::RsaPrivateKey;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{
|
||||
de::{MapAccess, Visitor},
|
||||
ser::SerializeStruct,
|
||||
Deserialize, Serialize,
|
||||
};
|
||||
use serde_with::skip_serializing_none;
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
// TODO: a lot of the signing logic should go out somewhere else
|
||||
|
||||
const RS256: StrongAlg<alg::Rsa> = StrongAlg(alg::Rsa::rs256());
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum Algorithm {
|
||||
Rs256,
|
||||
Es256k,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Jwk {
|
||||
kid: String,
|
||||
alg: Algorithm,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Jwks {
|
||||
keys: Vec<Jwk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(transparent)]
|
||||
pub struct KeySet(Vec<Key>);
|
||||
|
||||
impl KeySet {
|
||||
pub fn to_public_jwks(&self) -> Jwks {
|
||||
let keys = self.0.iter().map(Key::to_public_jwk).collect();
|
||||
Jwks { keys }
|
||||
}
|
||||
|
||||
#[tracing::instrument(err)]
|
||||
pub async fn token<T>(
|
||||
&self,
|
||||
alg: Algorithm,
|
||||
header: Header,
|
||||
claims: Claims<T>,
|
||||
) -> anyhow::Result<String>
|
||||
where
|
||||
T: std::fmt::Debug + Serialize + Send + Sync + 'static,
|
||||
{
|
||||
match alg {
|
||||
Algorithm::Rs256 => {
|
||||
let (kid, key) = self
|
||||
.0
|
||||
.iter()
|
||||
.find_map(Key::rsa)
|
||||
.context("could not find RSA key")?;
|
||||
let header = header.with_key_id(kid);
|
||||
|
||||
// TODO: store them as strong keys
|
||||
let key = StrongKey::try_from(key.clone())?;
|
||||
task::spawn_blocking(move || {
|
||||
RS256
|
||||
.token(header, &claims, &key)
|
||||
.context("failed to sign token")
|
||||
})
|
||||
.await?
|
||||
}
|
||||
Algorithm::Es256k => {
|
||||
// TODO: make this const with lazy_static?
|
||||
let es256k: alg::Es256k = alg::Es256k::default();
|
||||
let (kid, key) = self
|
||||
.0
|
||||
.iter()
|
||||
.find_map(Key::ecdsa)
|
||||
.context("could not find ECDSA key")?;
|
||||
let key = k256::ecdsa::SigningKey::from(key);
|
||||
let header = header.with_key_id(kid);
|
||||
// TODO: use StrongAlg
|
||||
|
||||
task::spawn_blocking(move || {
|
||||
es256k
|
||||
.token(header, &claims, &key)
|
||||
.context("failed to sign token")
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum Key {
|
||||
Rsa { key: RsaPrivateKey, kid: String },
|
||||
Ecdsa { key: k256::SecretKey, kid: String },
|
||||
}
|
||||
|
||||
impl Key {
|
||||
fn from_ecdsa(key: k256::SecretKey) -> Self {
|
||||
// TODO: hash the key and use as KID
|
||||
let kid = String::from("ecdsa-kid");
|
||||
Self::Ecdsa { kid, key }
|
||||
}
|
||||
|
||||
fn from_ecdsa_pem(key: &str) -> anyhow::Result<Self> {
|
||||
let key = k256::SecretKey::from_pkcs8_pem(key)?;
|
||||
Ok(Self::from_ecdsa(key))
|
||||
}
|
||||
|
||||
fn from_rsa(key: RsaPrivateKey) -> Self {
|
||||
// TODO: hash the key and use as KID
|
||||
let kid = String::from("rsa-kid");
|
||||
Self::Rsa { kid, key }
|
||||
}
|
||||
|
||||
fn from_rsa_pem(key: &str) -> anyhow::Result<Self> {
|
||||
let key = RsaPrivateKey::from_pkcs8_pem(key)?;
|
||||
Ok(Self::from_rsa(key))
|
||||
}
|
||||
|
||||
fn to_public_jwk(&self) -> Jwk {
|
||||
match self {
|
||||
Key::Rsa { key, kid } => {
|
||||
let pubkey = key.to_public_key();
|
||||
let inner = JsonWebKey::from(&pubkey);
|
||||
let inner = serde_json::to_value(&inner).unwrap();
|
||||
let kid = kid.to_string();
|
||||
let alg = Algorithm::Rs256;
|
||||
Jwk { kid, alg, inner }
|
||||
}
|
||||
Key::Ecdsa { key, kid } => {
|
||||
let pubkey = k256::ecdsa::VerifyingKey::from(key.public_key());
|
||||
let inner = JsonWebKey::from(&pubkey);
|
||||
let inner = serde_json::to_value(&inner).unwrap();
|
||||
let kid = kid.to_string();
|
||||
let alg = Algorithm::Es256k;
|
||||
Jwk { kid, alg, inner }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rsa(&self) -> Option<(&str, &RsaPrivateKey)> {
|
||||
match self {
|
||||
Key::Rsa { key, kid } => Some((kid, key)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ecdsa(&self) -> Option<(&str, &k256::SecretKey)> {
|
||||
match self {
|
||||
Key::Ecdsa { key, kid } => Some((kid, key)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Key {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_struct("Key", 2)?;
|
||||
match self {
|
||||
Key::Rsa { key, kid: _ } => {
|
||||
map.serialize_field("type", "rsa")?;
|
||||
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
|
||||
map.serialize_field("key", pem.as_str())?;
|
||||
}
|
||||
Key::Ecdsa { key, kid: _ } => {
|
||||
map.serialize_field("type", "ecdsa")?;
|
||||
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
|
||||
map.serialize_field("key", pem.as_str())?;
|
||||
}
|
||||
}
|
||||
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Key {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(field_identifier, rename_all = "lowercase")]
|
||||
enum Field {
|
||||
Type,
|
||||
Key,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum KeyType {
|
||||
Rsa,
|
||||
Ecdsa,
|
||||
}
|
||||
|
||||
struct KeyVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for KeyVisitor {
|
||||
type Value = Key;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("struct Key")
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<Key, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
let mut key_type = None;
|
||||
let mut key_key = None;
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
Field::Type => {
|
||||
if key_type.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("type"));
|
||||
}
|
||||
key_type = Some(map.next_value()?);
|
||||
}
|
||||
Field::Key => {
|
||||
if key_key.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("key"));
|
||||
}
|
||||
key_key = Some(map.next_value()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
let key_type: KeyType =
|
||||
key_type.ok_or_else(|| serde::de::Error::missing_field("type"))?;
|
||||
let key_key: String =
|
||||
key_key.ok_or_else(|| serde::de::Error::missing_field("key"))?;
|
||||
|
||||
match key_type {
|
||||
KeyType::Rsa => Key::from_rsa_pem(&key_key).map_err(serde::de::Error::custom),
|
||||
KeyType::Ecdsa => {
|
||||
Key::from_ecdsa_pem(&key_key).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_struct("Key", &["type", "key"], KeyVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OAuth2ClientConfig {
|
||||
pub client_id: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub redirect_uris: Vec<Url>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Invalid redirect URI")]
|
||||
pub struct InvalidRedirectUriError;
|
||||
|
||||
impl OAuth2ClientConfig {
|
||||
pub fn resolve_redirect_uri<'a>(
|
||||
&'a self,
|
||||
suggested_uri: &'a Option<Url>,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
suggested_uri.as_ref().map_or_else(
|
||||
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError),
|
||||
|suggested_uri| self.check_redirect_uri(suggested_uri),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn check_redirect_uri<'a>(
|
||||
&self,
|
||||
redirect_uri: &'a Url,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
if self.redirect_uris.contains(redirect_uri) {
|
||||
Ok(redirect_uri)
|
||||
} else {
|
||||
Err(InvalidRedirectUriError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_oauth2_issuer() -> Url {
|
||||
"http://[::]:8080".parse().unwrap()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OAuth2Config {
|
||||
#[serde(default = "default_oauth2_issuer")]
|
||||
pub issuer: Url,
|
||||
|
||||
#[serde(default)]
|
||||
pub clients: Vec<OAuth2ClientConfig>,
|
||||
|
||||
#[schemars(with = "Vec<String>")] // TODO: this is a lie
|
||||
pub keys: KeySet,
|
||||
}
|
||||
|
||||
impl OAuth2Config {
|
||||
pub fn discovery_url(&self) -> Url {
|
||||
self.issuer
|
||||
.join(".well-known/openid-configuration")
|
||||
.expect("could not build discovery url")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for OAuth2Config {
|
||||
fn path() -> &'static str {
|
||||
"oauth2"
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
info!("Generating keys...");
|
||||
|
||||
let span = tracing::info_span!("rsa");
|
||||
let rsa_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let mut rng = rand::thread_rng();
|
||||
let ret =
|
||||
RsaPrivateKey::new(&mut rng, 2048).context("could not generate RSA private key");
|
||||
info!("Done generating RSA key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")??;
|
||||
|
||||
let span = tracing::info_span!("ecdsa");
|
||||
let ecdsa_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let rng = rand::thread_rng();
|
||||
let ret = k256::SecretKey::random(rng);
|
||||
info!("Done generating ECDSA key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")?;
|
||||
|
||||
Ok(Self {
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Vec::new(),
|
||||
keys: KeySet(vec![Key::from_rsa(rsa_key), Key::from_ecdsa(ecdsa_key)]),
|
||||
})
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
let rsa_key = Key::from_rsa_pem(indoc::indoc! {r#"
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
|
||||
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
|
||||
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
|
||||
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
|
||||
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
|
||||
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
|
||||
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
|
||||
Gh7BNzCeN+D6
|
||||
-----END PRIVATE KEY-----
|
||||
"#})
|
||||
.unwrap();
|
||||
let ecdsa_key = Key::from_ecdsa_pem(indoc::indoc! {r#"
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
|
||||
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
|
||||
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
|
||||
-----END PRIVATE KEY-----
|
||||
"#})
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Vec::new(),
|
||||
keys: KeySet(vec![rsa_key, ecdsa_key]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use figment::Jail;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_config() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"config.yaml",
|
||||
r#"
|
||||
oauth2:
|
||||
keys:
|
||||
- type: rsa
|
||||
key: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
|
||||
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
|
||||
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
|
||||
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
|
||||
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
|
||||
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
|
||||
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
|
||||
Gh7BNzCeN+D6
|
||||
-----END PRIVATE KEY-----
|
||||
- type: ecdsa
|
||||
key: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
|
||||
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
|
||||
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
|
||||
-----END PRIVATE KEY-----
|
||||
issuer: https://example.com
|
||||
clients:
|
||||
- client_id: hello
|
||||
redirect_uris:
|
||||
- https://exemple.fr/callback
|
||||
- client_id: world
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let config = OAuth2Config::load_from_file("config.yaml")?;
|
||||
|
||||
assert_eq!(config.issuer, "https://example.com".parse().unwrap());
|
||||
assert_eq!(config.clients.len(), 2);
|
||||
|
||||
assert_eq!(config.clients[0].client_id, "hello");
|
||||
assert_eq!(
|
||||
config.clients[0].redirect_uris,
|
||||
vec!["https://exemple.fr/callback".parse().unwrap()]
|
||||
);
|
||||
|
||||
assert_eq!(config.clients[1].client_id, "world");
|
||||
assert_eq!(config.clients[1].redirect_uris, Vec::new());
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
72
crates/config/src/util.rs
Normal file
72
crates/config/src/util.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use figment::{
|
||||
error::Error as FigmentError,
|
||||
providers::{Env, Format, Serialized, Yaml},
|
||||
Figment, Profile,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[async_trait]
|
||||
/// Trait implemented by all configuration section to help loading specific part
|
||||
/// of the config and generate the sample config.
|
||||
pub trait ConfigurationSection<'a>: Sized + Deserialize<'a> + Serialize {
|
||||
/// Specify where this section should live relative to the root.
|
||||
fn path() -> &'static str;
|
||||
|
||||
/// Generate a sample configuration for this section.
|
||||
async fn generate() -> anyhow::Result<Self>;
|
||||
|
||||
/// Generate a sample configuration and override it with environment
|
||||
/// variables.
|
||||
///
|
||||
/// This is what backs the `config generate` subcommand, allowing to
|
||||
/// programatically generate a configuration file, e.g.
|
||||
///
|
||||
/// ```sh
|
||||
/// export MAS_OAUTH2_ISSUER=https://example.com/
|
||||
/// export MAS_HTTP_ADDRESS=127.0.0.1:1234
|
||||
/// matrix-authentication-service config generate
|
||||
/// ```
|
||||
async fn load_and_generate() -> anyhow::Result<Self> {
|
||||
let base = Self::generate()
|
||||
.await
|
||||
.context("could not generate configuration")?;
|
||||
|
||||
Figment::new()
|
||||
.merge(Serialized::from(&base, Profile::Default))
|
||||
.merge(Env::prefixed("MAS_").split("_"))
|
||||
.extract_inner(Self::path())
|
||||
.context("could not load configuration")
|
||||
}
|
||||
|
||||
/// Load configuration from a file and environment variables.
|
||||
fn load_from_file<P>(path: P) -> Result<Self, FigmentError>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
Figment::new()
|
||||
.merge(Env::prefixed("MAS_").split("_"))
|
||||
.merge(Yaml::file(path))
|
||||
.extract_inner(Self::path())
|
||||
}
|
||||
|
||||
/// Generate config used in unit tests
|
||||
fn test() -> Self;
|
||||
}
|
||||
Reference in New Issue
Block a user