You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Flatten the clients config
This commit is contained in:
@ -12,16 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::ops::Deref;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use figment::Figment;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::jwk::PublicJsonWebKeySet;
|
||||
use rand::Rng;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use thiserror::Error;
|
||||
use serde::{de::Error, Deserialize, Serialize};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
@ -41,40 +40,42 @@ impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
|
||||
}
|
||||
|
||||
/// Authentication method used by clients
|
||||
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(tag = "client_auth_method", rename_all = "snake_case")]
|
||||
#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ClientAuthMethodConfig {
|
||||
/// `none`: No authentication
|
||||
None,
|
||||
|
||||
/// `client_secret_basic`: `client_id` and `client_secret` used as basic
|
||||
/// authorization credentials
|
||||
ClientSecretBasic {
|
||||
/// The client secret
|
||||
client_secret: String,
|
||||
},
|
||||
ClientSecretBasic,
|
||||
|
||||
/// `client_secret_post`: `client_id` and `client_secret` sent in the
|
||||
/// request body
|
||||
ClientSecretPost {
|
||||
/// The client secret
|
||||
client_secret: String,
|
||||
},
|
||||
ClientSecretPost,
|
||||
|
||||
/// `client_secret_basic`: a `client_assertion` sent in the request body and
|
||||
/// signed using the `client_secret`
|
||||
ClientSecretJwt {
|
||||
/// The client secret
|
||||
client_secret: String,
|
||||
},
|
||||
ClientSecretJwt,
|
||||
|
||||
/// `client_secret_basic`: a `client_assertion` sent in the request body and
|
||||
/// signed by an asymmetric key
|
||||
PrivateKeyJwt(JwksOrJwksUri),
|
||||
PrivateKeyJwt,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ClientAuthMethodConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ClientAuthMethodConfig::None => write!(f, "none"),
|
||||
ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
|
||||
ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
|
||||
ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
|
||||
ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An OAuth 2.0 client configuration
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ClientConfig {
|
||||
/// The client ID
|
||||
@ -86,67 +87,121 @@ pub struct ClientConfig {
|
||||
pub client_id: Ulid,
|
||||
|
||||
/// Authentication method used for this client
|
||||
#[serde(flatten)]
|
||||
pub client_auth_method: ClientAuthMethodConfig,
|
||||
client_auth_method: ClientAuthMethodConfig,
|
||||
|
||||
/// The client secret, used by the `client_secret_basic`,
|
||||
/// `client_secret_post` and `client_secret_jwt` authentication methods
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
/// The JSON Web Key Set (JWKS) used by the `private_key_jwt` authentication
|
||||
/// method. Mutually exclusive with `jwks_uri`
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwks: Option<PublicJsonWebKeySet>,
|
||||
|
||||
/// The URL of the JSON Web Key Set (JWKS) used by the `private_key_jwt`
|
||||
/// authentication method. Mutually exclusive with `jwks`
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwks_uri: Option<Url>,
|
||||
|
||||
/// List of allowed redirect URIs
|
||||
#[serde(default)]
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub redirect_uris: Vec<Url>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Invalid redirect URI")]
|
||||
pub struct InvalidRedirectUriError;
|
||||
|
||||
impl ClientConfig {
|
||||
#[doc(hidden)]
|
||||
#[must_use]
|
||||
pub fn client_secret(&self) -> Option<&str> {
|
||||
match &self.client_auth_method {
|
||||
ClientAuthMethodConfig::ClientSecretPost { client_secret }
|
||||
| ClientAuthMethodConfig::ClientSecretBasic { client_secret }
|
||||
| ClientAuthMethodConfig::ClientSecretJwt { client_secret } => Some(client_secret),
|
||||
_ => None,
|
||||
fn validate(&self) -> Result<(), figment::error::Error> {
|
||||
let auth_method = self.client_auth_method;
|
||||
match self.client_auth_method {
|
||||
ClientAuthMethodConfig::PrivateKeyJwt => {
|
||||
if self.jwks.is_none() && self.jwks_uri.is_none() {
|
||||
let error = figment::error::Error::custom(
|
||||
"jwks or jwks_uri is required for private_key_jwt",
|
||||
);
|
||||
return Err(error.with_path("client_auth_method"));
|
||||
}
|
||||
|
||||
if self.jwks.is_some() && self.jwks_uri.is_some() {
|
||||
let error =
|
||||
figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
|
||||
return Err(error.with_path("jwks"));
|
||||
}
|
||||
|
||||
if self.client_secret.is_some() {
|
||||
let error = figment::error::Error::custom(
|
||||
"client_secret is not allowed with private_key_jwt",
|
||||
);
|
||||
return Err(error.with_path("client_secret"));
|
||||
}
|
||||
}
|
||||
|
||||
ClientAuthMethodConfig::ClientSecretPost
|
||||
| ClientAuthMethodConfig::ClientSecretBasic
|
||||
| ClientAuthMethodConfig::ClientSecretJwt => {
|
||||
if self.client_secret.is_none() {
|
||||
let error = figment::error::Error::custom(format!(
|
||||
"client_secret is required for {auth_method}"
|
||||
));
|
||||
return Err(error.with_path("client_auth_method"));
|
||||
}
|
||||
|
||||
if self.jwks.is_some() {
|
||||
let error = figment::error::Error::custom(format!(
|
||||
"jwks is not allowed with {auth_method}"
|
||||
));
|
||||
return Err(error.with_path("jwks"));
|
||||
}
|
||||
|
||||
if self.jwks_uri.is_some() {
|
||||
let error = figment::error::Error::custom(format!(
|
||||
"jwks_uri is not allowed with {auth_method}"
|
||||
));
|
||||
return Err(error.with_path("jwks_uri"));
|
||||
}
|
||||
}
|
||||
|
||||
ClientAuthMethodConfig::None => {
|
||||
if self.client_secret.is_some() {
|
||||
let error = figment::error::Error::custom(
|
||||
"client_secret is not allowed with none authentication method",
|
||||
);
|
||||
return Err(error.with_path("client_secret"));
|
||||
}
|
||||
|
||||
if self.jwks.is_some() {
|
||||
let error = figment::error::Error::custom(
|
||||
"jwks is not allowed with none authentication method",
|
||||
);
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
if self.jwks_uri.is_some() {
|
||||
let error = figment::error::Error::custom(
|
||||
"jwks_uri is not allowed with none authentication method",
|
||||
);
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
/// Authentication method used for this client
|
||||
#[must_use]
|
||||
pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
|
||||
match &self.client_auth_method {
|
||||
match self.client_auth_method {
|
||||
ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
|
||||
ClientAuthMethodConfig::ClientSecretBasic { .. } => {
|
||||
ClientAuthMethodConfig::ClientSecretBasic => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic
|
||||
}
|
||||
ClientAuthMethodConfig::ClientSecretPost { .. } => {
|
||||
ClientAuthMethodConfig::ClientSecretPost => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost
|
||||
}
|
||||
ClientAuthMethodConfig::ClientSecretJwt { .. } => {
|
||||
ClientAuthMethodConfig::ClientSecretJwt => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt
|
||||
}
|
||||
ClientAuthMethodConfig::PrivateKeyJwt(_) => {
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[must_use]
|
||||
pub fn jwks(&self) -> Option<&PublicJsonWebKeySet> {
|
||||
match &self.client_auth_method {
|
||||
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::Jwks(jwks)) => Some(jwks),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[must_use]
|
||||
pub fn jwks_uri(&self) -> Option<&Url> {
|
||||
match &self.client_auth_method {
|
||||
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::JwksUri(jwks_uri)) => {
|
||||
Some(jwks_uri)
|
||||
}
|
||||
_ => None,
|
||||
ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -154,7 +209,7 @@ impl ClientConfig {
|
||||
/// List of OAuth 2.0/OIDC clients config
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(transparent)]
|
||||
pub struct ClientsConfig(Vec<ClientConfig>);
|
||||
pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
|
||||
|
||||
impl Deref for ClientsConfig {
|
||||
type Target = Vec<ClientConfig>;
|
||||
@ -164,12 +219,6 @@ impl Deref for ClientsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for ClientsConfig {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for ClientsConfig {
|
||||
type Item = ClientConfig;
|
||||
type IntoIter = std::vec::IntoIter<ClientConfig>;
|
||||
@ -190,6 +239,21 @@ impl ConfigurationSection for ClientsConfig {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn validate(&self, figment: &Figment) -> Result<(), figment::error::Error> {
|
||||
for (index, client) in self.0.iter().enumerate() {
|
||||
client.validate().map_err(|mut err| {
|
||||
// Save the error location information in the error
|
||||
err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
|
||||
err.profile = Some(figment::Profile::Default);
|
||||
err.path.insert(0, Self::PATH.unwrap().to_owned());
|
||||
err.path.insert(1, format!("{index}"));
|
||||
err
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
@ -29,17 +29,29 @@ pub trait ConfigurationSection: Sized + DeserializeOwned + Serialize {
|
||||
where
|
||||
R: Rng + Send;
|
||||
|
||||
/// Validate the configuration section
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the configuration is invalid
|
||||
fn validate(&self, _figment: &Figment) -> Result<(), FigmentError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract configuration from a Figment instance.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the configuration could not be loaded
|
||||
fn extract(figment: &Figment) -> Result<Self, FigmentError> {
|
||||
if let Some(path) = Self::PATH {
|
||||
figment.extract_inner(path)
|
||||
let this: Self = if let Some(path) = Self::PATH {
|
||||
figment.extract_inner(path)?
|
||||
} else {
|
||||
figment.extract()
|
||||
}
|
||||
figment.extract()?
|
||||
};
|
||||
|
||||
this.validate(figment)?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
/// Generate config used in unit tests
|
||||
|
Reference in New Issue
Block a user