1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Flatten the upstream_oauth2 config section

This commit is contained in:
Quentin Gliech
2024-03-22 10:09:44 +01:00
parent aa6178abe6
commit fc7489c5f8
4 changed files with 259 additions and 254 deletions

View File

@ -53,8 +53,7 @@ pub use self::{
upstream_oauth2::{
ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode,
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction,
ImportPreference as UpstreamOAuth2ImportPreference, PkceMethod as UpstreamOAuth2PkceMethod,
ImportAction as UpstreamOAuth2ImportAction, PkceMethod as UpstreamOAuth2PkceMethod,
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
},
};

View File

@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::BTreeMap, ops::Deref};
use std::collections::BTreeMap;
use async_trait::async_trait;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use rand::Rng;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde::{de::Error, Deserialize, Serialize};
use serde_with::skip_serializing_none;
use ulid::Ulid;
use url::Url;
@ -43,42 +43,104 @@ impl ConfigurationSection for UpstreamOAuth2Config {
Ok(Self::default())
}
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
for (index, provider) in self.providers.iter().enumerate() {
let annotate = |mut error: figment::Error| {
error.metadata = figment
.find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![
Self::PATH.unwrap().to_owned(),
"providers".to_owned(),
index.to_string(),
];
Err(error)
};
match provider.token_endpoint_auth_method {
TokenAuthMethod::None | TokenAuthMethod::PrivateKeyJwt => {
if provider.client_secret.is_some() {
return annotate(figment::Error::custom("Unexpected field `client_secret` for the selected authentication method"));
}
}
TokenAuthMethod::ClientSecretBasic
| TokenAuthMethod::ClientSecretPost
| TokenAuthMethod::ClientSecretJwt => {
if provider.client_secret.is_none() {
return annotate(figment::Error::missing_field("client_secret"));
}
}
}
match provider.token_endpoint_auth_method {
TokenAuthMethod::None
| TokenAuthMethod::ClientSecretBasic
| TokenAuthMethod::ClientSecretPost => {
if provider.token_endpoint_auth_signing_alg.is_some() {
return annotate(figment::Error::custom(
"Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
));
}
}
TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
if provider.token_endpoint_auth_signing_alg.is_none() {
return annotate(figment::Error::missing_field(
"token_endpoint_auth_signing_alg",
));
}
}
}
}
Ok(())
}
fn test() -> Self {
Self::default()
}
}
/// Authentication methods used against the OAuth 2.0 provider
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "token_endpoint_auth_method", rename_all = "snake_case")]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum TokenAuthMethod {
/// `none`: No authentication
None,
/// `client_secret_basic`: `client_id` and `client_secret` used as basic
/// authorization credentials
ClientSecretBasic { client_secret: String },
ClientSecretBasic,
/// `client_secret_post`: `client_id` and `client_secret` sent in the
/// request body
ClientSecretPost { client_secret: String },
ClientSecretPost,
/// `client_secret_basic`: a `client_assertion` sent in the request body and
/// `client_secret_jwt`: a `client_assertion` sent in the request body and
/// signed using the `client_secret`
ClientSecretJwt {
client_secret: String,
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
},
ClientSecretJwt,
/// `client_secret_basic`: a `client_assertion` sent in the request body and
/// `private_key_jwt`: a `client_assertion` sent in the request body and
/// signed by an asymmetric key
PrivateKeyJwt {
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
},
PrivateKeyJwt,
}
impl From<TokenAuthMethod> for OAuthClientAuthenticationMethod {
fn from(method: TokenAuthMethod) -> Self {
match method {
TokenAuthMethod::None => OAuthClientAuthenticationMethod::None,
TokenAuthMethod::ClientSecretBasic => {
OAuthClientAuthenticationMethod::ClientSecretBasic
}
TokenAuthMethod::ClientSecretPost => OAuthClientAuthenticationMethod::ClientSecretPost,
TokenAuthMethod::ClientSecretJwt => OAuthClientAuthenticationMethod::ClientSecretJwt,
TokenAuthMethod::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
}
}
}
/// How to handle a claim
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum ImportAction {
/// Ignore the claim
@ -95,16 +157,15 @@ pub enum ImportAction {
Require,
}
/// What should be done with a attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct ImportPreference {
/// How to handle the attribute
#[serde(default)]
pub action: ImportAction,
impl ImportAction {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, ImportAction::Ignore)
}
}
/// Should the email address be marked as verified
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum SetEmailVerification {
/// Mark the email address as verified
@ -119,85 +180,130 @@ pub enum SetEmailVerification {
Import,
}
impl SetEmailVerification {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, SetEmailVerification::Import)
}
}
/// What should be done for the subject attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct SubjectImportPreference {
/// The Jinja2 template to use for the subject attribute
///
/// If not provided, the default template is `{{ user.sub }}`
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}
impl SubjectImportPreference {
const fn is_default(&self) -> bool {
self.template.is_none()
}
}
/// What should be done for the localpart attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct LocalpartImportPreference {
/// How to handle the attribute
#[serde(default)]
#[serde(default, skip_serializing_if = "ImportAction::is_default")]
pub action: ImportAction,
/// The Jinja2 template to use for the localpart attribute
///
/// If not provided, the default template is `{{ user.preferred_username }}`
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}
impl LocalpartImportPreference {
const fn is_default(&self) -> bool {
self.action.is_default() && self.template.is_none()
}
}
/// What should be done for the displayname attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct DisplaynameImportPreference {
/// How to handle the attribute
#[serde(default)]
#[serde(default, skip_serializing_if = "ImportAction::is_default")]
pub action: ImportAction,
/// The Jinja2 template to use for the displayname attribute
///
/// If not provided, the default template is `{{ user.name }}`
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}
impl DisplaynameImportPreference {
const fn is_default(&self) -> bool {
self.action.is_default() && self.template.is_none()
}
}
/// What should be done with the email attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct EmailImportPreference {
/// How to handle the claim
#[serde(default)]
#[serde(default, skip_serializing_if = "ImportAction::is_default")]
pub action: ImportAction,
/// The Jinja2 template to use for the email address attribute
///
/// If not provided, the default template is `{{ user.email }}`
#[serde(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
/// Should the email address be marked as verified
#[serde(default)]
#[serde(default, skip_serializing_if = "SetEmailVerification::is_default")]
pub set_email_verification: SetEmailVerification,
}
impl EmailImportPreference {
const fn is_default(&self) -> bool {
self.action.is_default()
&& self.template.is_none()
&& self.set_email_verification.is_default()
}
}
/// How claims should be imported
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct ClaimsImports {
/// How to determine the subject of the user
#[serde(default)]
#[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
pub subject: SubjectImportPreference,
/// Import the localpart of the MXID
#[serde(default)]
#[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
pub localpart: LocalpartImportPreference,
/// Import the displayname of the user.
#[serde(default)]
#[serde(
default,
skip_serializing_if = "DisplaynameImportPreference::is_default"
)]
pub displayname: DisplaynameImportPreference,
/// Import the email address of the user based on the `email` and
/// `email_verified` claims
#[serde(default)]
#[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
pub email: EmailImportPreference,
}
impl ClaimsImports {
const fn is_default(&self) -> bool {
self.subject.is_default()
&& self.localpart.is_default()
&& self.displayname.is_default()
&& self.email.is_default()
}
}
/// How to discover the provider's configuration
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum DiscoveryMode {
/// Use OIDC discovery with strict metadata verification
@ -211,9 +317,16 @@ pub enum DiscoveryMode {
Disabled,
}
impl DiscoveryMode {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, DiscoveryMode::Oidc)
}
}
/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum PkceMethod {
/// Use PKCE if the provider supports it
@ -229,6 +342,13 @@ pub enum PkceMethod {
Never,
}
impl PkceMethod {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, PkceMethod::Auto)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider {
@ -244,6 +364,7 @@ pub struct Provider {
pub issuer: String,
/// A human-readable name for the provider, that will be shown to users
#[serde(skip_serializing_if = "Option::is_none")]
pub human_name: Option<String>,
/// A brand identifier used to customise the UI, e.g. `apple`, `google`,
@ -257,108 +378,72 @@ pub struct Provider {
/// - `github`
/// - `gitlab`
/// - `twitter`
#[serde(skip_serializing_if = "Option::is_none")]
pub brand_name: Option<String>,
/// The client ID to use when authenticating with the provider
pub client_id: String,
/// The client secret to use when authenticating with the provider
///
/// Used by the `client_secret_basic`, `client_secret_post`, and
/// `client_secret_jwt` methods
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
/// The method to authenticate the client with the provider
pub token_endpoint_auth_method: TokenAuthMethod,
/// The JWS algorithm to use when authenticating the client with the
/// provider
///
/// Used by the `client_secret_jwt` and `private_key_jwt` methods
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
/// The scopes to request from the provider
pub scope: String,
#[serde(flatten)]
pub token_auth_method: TokenAuthMethod,
/// How to discover the provider's configuration
///
/// Defaults to use OIDC discovery with strict metadata verification
#[serde(default)]
/// Defaults to `oidc`, which uses OIDC discovery with strict metadata
/// verification
#[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
pub discovery_mode: DiscoveryMode,
/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
///
/// Defaults to `auto`, which uses PKCE if the provider supports it.
#[serde(default)]
#[serde(default, skip_serializing_if = "PkceMethod::is_default")]
pub pkce_method: PkceMethod,
/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,
/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint: Option<Url>,
/// The URL to use for getting the provider's public keys
///
/// Defaults to the `jwks_uri` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<Url>,
/// How claims should be imported from the `id_token` provided by the
/// provider
#[serde(default)]
#[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
pub claims_imports: ClaimsImports,
/// Additional parameters to include in the authorization request
///
/// Orders of the keys are not preserved.
#[serde(default)]
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub additional_authorization_parameters: BTreeMap<String, String>,
}
impl Deref for Provider {
type Target = TokenAuthMethod;
fn deref(&self) -> &Self::Target {
&self.token_auth_method
}
}
impl TokenAuthMethod {
#[doc(hidden)]
#[must_use]
pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
match self {
TokenAuthMethod::None => OAuthClientAuthenticationMethod::None,
TokenAuthMethod::ClientSecretBasic { .. } => {
OAuthClientAuthenticationMethod::ClientSecretBasic
}
TokenAuthMethod::ClientSecretPost { .. } => {
OAuthClientAuthenticationMethod::ClientSecretPost
}
TokenAuthMethod::ClientSecretJwt { .. } => {
OAuthClientAuthenticationMethod::ClientSecretJwt
}
TokenAuthMethod::PrivateKeyJwt { .. } => OAuthClientAuthenticationMethod::PrivateKeyJwt,
}
}
#[doc(hidden)]
#[must_use]
pub fn client_secret(&self) -> Option<&str> {
match self {
TokenAuthMethod::None | TokenAuthMethod::PrivateKeyJwt { .. } => None,
TokenAuthMethod::ClientSecretBasic { client_secret }
| TokenAuthMethod::ClientSecretPost { client_secret }
| TokenAuthMethod::ClientSecretJwt { client_secret, .. } => Some(client_secret),
}
}
#[doc(hidden)]
#[must_use]
pub fn client_auth_signing_alg(&self) -> Option<JsonWebSignatureAlg> {
match self {
TokenAuthMethod::ClientSecretJwt {
token_endpoint_auth_signing_alg,
..
}
| TokenAuthMethod::PrivateKeyJwt {
token_endpoint_auth_signing_alg,
..
} => token_endpoint_auth_signing_alg.clone(),
_ => None,
}
}
}