1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Flatten the clients config

This commit is contained in:
Quentin Gliech
2024-03-20 16:59:26 +01:00
parent 48b6013c4f
commit cba431d20e
4 changed files with 219 additions and 197 deletions

View File

@ -12,16 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::{Deref, DerefMut};
use std::ops::Deref;
use async_trait::async_trait;
use figment::Figment;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::jwk::PublicJsonWebKeySet;
use rand::Rng;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use thiserror::Error;
use serde::{de::Error, Deserialize, Serialize};
use ulid::Ulid;
use url::Url;
@ -41,40 +40,42 @@ impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
}
/// Authentication method used by clients
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "client_auth_method", rename_all = "snake_case")]
#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthMethodConfig {
/// `none`: No authentication
None,
/// `client_secret_basic`: `client_id` and `client_secret` used as basic
/// authorization credentials
ClientSecretBasic {
/// The client secret
client_secret: String,
},
ClientSecretBasic,
/// `client_secret_post`: `client_id` and `client_secret` sent in the
/// request body
ClientSecretPost {
/// The client secret
client_secret: String,
},
ClientSecretPost,
/// `client_secret_basic`: a `client_assertion` sent in the request body and
/// signed using the `client_secret`
ClientSecretJwt {
/// The client secret
client_secret: String,
},
ClientSecretJwt,
/// `client_secret_basic`: a `client_assertion` sent in the request body and
/// signed by an asymmetric key
PrivateKeyJwt(JwksOrJwksUri),
PrivateKeyJwt,
}
impl std::fmt::Display for ClientAuthMethodConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClientAuthMethodConfig::None => write!(f, "none"),
ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
}
}
}
/// An OAuth 2.0 client configuration
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ClientConfig {
/// The client ID
@ -86,67 +87,121 @@ pub struct ClientConfig {
pub client_id: Ulid,
/// Authentication method used for this client
#[serde(flatten)]
pub client_auth_method: ClientAuthMethodConfig,
client_auth_method: ClientAuthMethodConfig,
/// The client secret, used by the `client_secret_basic`,
/// `client_secret_post` and `client_secret_jwt` authentication methods
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
/// The JSON Web Key Set (JWKS) used by the `private_key_jwt` authentication
/// method. Mutually exclusive with `jwks_uri`
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks: Option<PublicJsonWebKeySet>,
/// The URL of the JSON Web Key Set (JWKS) used by the `private_key_jwt`
/// authentication method. Mutually exclusive with `jwks`
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<Url>,
/// List of allowed redirect URIs
#[serde(default)]
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub redirect_uris: Vec<Url>,
}
#[derive(Debug, Error)]
#[error("Invalid redirect URI")]
pub struct InvalidRedirectUriError;
impl ClientConfig {
#[doc(hidden)]
#[must_use]
pub fn client_secret(&self) -> Option<&str> {
match &self.client_auth_method {
ClientAuthMethodConfig::ClientSecretPost { client_secret }
| ClientAuthMethodConfig::ClientSecretBasic { client_secret }
| ClientAuthMethodConfig::ClientSecretJwt { client_secret } => Some(client_secret),
_ => None,
fn validate(&self) -> Result<(), figment::error::Error> {
let auth_method = self.client_auth_method;
match self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt => {
if self.jwks.is_none() && self.jwks_uri.is_none() {
let error = figment::error::Error::custom(
"jwks or jwks_uri is required for private_key_jwt",
);
return Err(error.with_path("client_auth_method"));
}
if self.jwks.is_some() && self.jwks_uri.is_some() {
let error =
figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
return Err(error.with_path("jwks"));
}
if self.client_secret.is_some() {
let error = figment::error::Error::custom(
"client_secret is not allowed with private_key_jwt",
);
return Err(error.with_path("client_secret"));
}
}
ClientAuthMethodConfig::ClientSecretPost
| ClientAuthMethodConfig::ClientSecretBasic
| ClientAuthMethodConfig::ClientSecretJwt => {
if self.client_secret.is_none() {
let error = figment::error::Error::custom(format!(
"client_secret is required for {auth_method}"
));
return Err(error.with_path("client_auth_method"));
}
if self.jwks.is_some() {
let error = figment::error::Error::custom(format!(
"jwks is not allowed with {auth_method}"
));
return Err(error.with_path("jwks"));
}
if self.jwks_uri.is_some() {
let error = figment::error::Error::custom(format!(
"jwks_uri is not allowed with {auth_method}"
));
return Err(error.with_path("jwks_uri"));
}
}
ClientAuthMethodConfig::None => {
if self.client_secret.is_some() {
let error = figment::error::Error::custom(
"client_secret is not allowed with none authentication method",
);
return Err(error.with_path("client_secret"));
}
if self.jwks.is_some() {
let error = figment::error::Error::custom(
"jwks is not allowed with none authentication method",
);
return Err(error);
}
if self.jwks_uri.is_some() {
let error = figment::error::Error::custom(
"jwks_uri is not allowed with none authentication method",
);
return Err(error);
}
}
}
Ok(())
}
#[doc(hidden)]
/// Authentication method used for this client
#[must_use]
pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
match &self.client_auth_method {
match self.client_auth_method {
ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
ClientAuthMethodConfig::ClientSecretBasic { .. } => {
ClientAuthMethodConfig::ClientSecretBasic => {
OAuthClientAuthenticationMethod::ClientSecretBasic
}
ClientAuthMethodConfig::ClientSecretPost { .. } => {
ClientAuthMethodConfig::ClientSecretPost => {
OAuthClientAuthenticationMethod::ClientSecretPost
}
ClientAuthMethodConfig::ClientSecretJwt { .. } => {
ClientAuthMethodConfig::ClientSecretJwt => {
OAuthClientAuthenticationMethod::ClientSecretJwt
}
ClientAuthMethodConfig::PrivateKeyJwt(_) => {
OAuthClientAuthenticationMethod::PrivateKeyJwt
}
}
}
#[doc(hidden)]
#[must_use]
pub fn jwks(&self) -> Option<&PublicJsonWebKeySet> {
match &self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::Jwks(jwks)) => Some(jwks),
_ => None,
}
}
#[doc(hidden)]
#[must_use]
pub fn jwks_uri(&self) -> Option<&Url> {
match &self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::JwksUri(jwks_uri)) => {
Some(jwks_uri)
}
_ => None,
ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
}
}
}
@ -154,7 +209,7 @@ impl ClientConfig {
/// List of OAuth 2.0/OIDC clients config
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
#[serde(transparent)]
pub struct ClientsConfig(Vec<ClientConfig>);
pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
impl Deref for ClientsConfig {
type Target = Vec<ClientConfig>;
@ -164,12 +219,6 @@ impl Deref for ClientsConfig {
}
}
impl DerefMut for ClientsConfig {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl IntoIterator for ClientsConfig {
type Item = ClientConfig;
type IntoIter = std::vec::IntoIter<ClientConfig>;
@ -190,6 +239,21 @@ impl ConfigurationSection for ClientsConfig {
Ok(Self::default())
}
fn validate(&self, figment: &Figment) -> Result<(), figment::error::Error> {
for (index, client) in self.0.iter().enumerate() {
client.validate().map_err(|mut err| {
// Save the error location information in the error
err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
err.profile = Some(figment::Profile::Default);
err.path.insert(0, Self::PATH.unwrap().to_owned());
err.path.insert(1, format!("{index}"));
err
})?;
}
Ok(())
}
fn test() -> Self {
Self::default()
}

View File

@ -29,17 +29,29 @@ pub trait ConfigurationSection: Sized + DeserializeOwned + Serialize {
where
R: Rng + Send;
/// Validate the configuration section
///
/// # Errors
///
/// Returns an error if the configuration is invalid
fn validate(&self, _figment: &Figment) -> Result<(), FigmentError> {
Ok(())
}
/// Extract configuration from a Figment instance.
///
/// # Errors
///
/// Returns an error if the configuration could not be loaded
fn extract(figment: &Figment) -> Result<Self, FigmentError> {
if let Some(path) = Self::PATH {
figment.extract_inner(path)
let this: Self = if let Some(path) = Self::PATH {
figment.extract_inner(path)?
} else {
figment.extract()
}
figment.extract()?
};
this.validate(figment)?;
Ok(this)
}
/// Generate config used in unit tests