From a965e488e29987d93c830d45d6a867eb0e8974c1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 5 Jan 2022 21:07:18 +0100 Subject: [PATCH] Support private_key_jwt client auth Which includes having a verifying keystore out of JWKS (and soon out of a JWKS URI) --- Cargo.lock | 2 + crates/config/src/lib.rs | 2 +- crates/config/src/oauth2.rs | 74 +++++- crates/handlers/src/oauth2/keys.rs | 3 +- crates/jose/Cargo.toml | 4 +- crates/jose/src/claims.rs | 25 ++ crates/jose/src/iana.rs | 17 +- crates/jose/src/jwk.rs | 56 +++- crates/jose/src/keystore/jwks.rs | 278 ++++++++++++++++++++ crates/jose/src/keystore/mod.rs | 2 + crates/jose/src/keystore/static_keystore.rs | 26 +- crates/jose/src/keystore/traits.rs | 7 +- crates/jose/src/lib.rs | 6 +- crates/warp-utils/src/filters/client.rs | 184 +++++++------ 14 files changed, 557 insertions(+), 129 deletions(-) create mode 100644 crates/jose/src/claims.rs create mode 100644 crates/jose/src/keystore/jwks.rs diff --git a/Cargo.lock b/Cargo.lock index 47f4ae8c..7535226f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1570,6 +1570,7 @@ dependencies = [ "anyhow", "async-trait", "base64ct", + "chrono", "crypto-mac", "digest 0.10.1", "ecdsa", @@ -1580,6 +1581,7 @@ dependencies = [ "pkcs8", "rand", "rsa", + "schemars", "sec1", "serde", "serde_json", diff --git a/crates/config/src/lib.rs b/crates/config/src/lib.rs index 443552db..17ccef09 100644 --- a/crates/config/src/lib.rs +++ b/crates/config/src/lib.rs @@ -38,7 +38,7 @@ pub use self::{ csrf::CsrfConfig, database::DatabaseConfig, http::HttpConfig, - oauth2::{OAuth2ClientConfig, OAuth2Config}, + oauth2::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config}, telemetry::{ MetricsConfig, MetricsExporterConfig, Propagator, TelemetryConfig, TracingConfig, TracingExporterConfig, diff --git a/crates/config/src/oauth2.rs b/crates/config/src/oauth2.rs index fe4c9c47..3e57a290 100644 --- a/crates/config/src/oauth2.rs +++ b/crates/config/src/oauth2.rs @@ -14,7 +14,7 @@ use anyhow::Context; use async_trait::async_trait; -use mas_jose::StaticKeystore; +use mas_jose::{JsonWebKeySet, StaticJwksStore, StaticKeystore}; use pkcs8::{DecodePrivateKey, EncodePrivateKey}; use rsa::{ pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey}, @@ -43,13 +43,41 @@ pub struct KeyConfig { key: String, } +#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)] +#[serde(rename_all = "snake_case")] +pub enum JwksOrJwksUri { + Jwks(JsonWebKeySet), + JwksUri(Url), +} + +impl JwksOrJwksUri { + pub fn key_store(&self) -> StaticJwksStore { + let jwks = match self { + Self::Jwks(jwks) => jwks.clone(), + Self::JwksUri(_) => unimplemented!("jwks_uri are not implemented yet"), + }; + + StaticJwksStore::new(jwks) + } +} + +#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)] +#[serde(tag = "client_auth_method", rename_all = "snake_case")] +pub enum OAuth2ClientAuthMethodConfig { + None, + ClientSecretBasic { client_secret: String }, + ClientSecretPost { client_secret: String }, + ClientSecretJwt { client_secret: String }, + PrivateKeyJwt(JwksOrJwksUri), +} + #[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct OAuth2ClientConfig { pub client_id: String, - #[serde(default)] - pub client_secret: Option, + #[serde(flatten)] + pub client_auth_method: OAuth2ClientAuthMethodConfig, #[serde(default)] pub redirect_uris: Vec, @@ -246,25 +274,55 @@ mod tests { -----END PRIVATE KEY----- issuer: https://example.com clients: - - client_id: hello + - client_id: public + client_auth_method: none redirect_uris: - https://exemple.fr/callback - - client_id: world + + - client_id: secret-basic + client_auth_method: client_secret_basic + client_secret: hello + + - client_id: secret-post + client_auth_method: client_secret_post + client_secret: hello + + - client_id: secret-jwk + client_auth_method: client_secret_jwt + client_secret: hello + + - client_id: jwks + client_auth_method: private_key_jwt + jwks: + keys: + - kid: "03e84aed4ef4431014e8617567864c4efaaaede9" + kty: "RSA" + alg: "RS256" + use: "sig" + e: "AQAB" + n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw" + + - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567" + kty: "RSA" + alg: "RS256" + use: "sig" + e: "AQAB" + n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw" "#, )?; let config = OAuth2Config::load_from_file("config.yaml")?; assert_eq!(config.issuer, "https://example.com".parse().unwrap()); - assert_eq!(config.clients.len(), 2); + assert_eq!(config.clients.len(), 5); - assert_eq!(config.clients[0].client_id, "hello"); + assert_eq!(config.clients[0].client_id, "public"); assert_eq!( config.clients[0].redirect_uris, vec!["https://exemple.fr/callback".parse().unwrap()] ); - assert_eq!(config.clients[1].client_id, "world"); + assert_eq!(config.clients[1].client_id, "secret-basic"); assert_eq!(config.clients[1].redirect_uris, Vec::new()); Ok(()) diff --git a/crates/handlers/src/oauth2/keys.rs b/crates/handlers/src/oauth2/keys.rs index d33005ce..d426b9ab 100644 --- a/crates/handlers/src/oauth2/keys.rs +++ b/crates/handlers/src/oauth2/keys.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use mas_jose::{ExportJwks, StaticKeystore}; +use mas_warp_utils::errors::WrapError; use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; pub(super) fn filter(key_store: &Arc) -> BoxedFilter<(Box,)> { @@ -25,7 +26,7 @@ pub(super) fn filter(key_store: &Arc) -> BoxedFilter<(Box) -> Result, Rejection> { - let jwks = key_store.export_jwks().await; + let jwks = key_store.export_jwks().await.wrap_error()?; Ok(Box::new(warp::reply::json(&jwks))) } diff --git a/crates/jose/Cargo.toml b/crates/jose/Cargo.toml index 333a2e6a..299464af 100644 --- a/crates/jose/Cargo.toml +++ b/crates/jose/Cargo.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" anyhow = "1.0.52" async-trait = "0.1.52" base64ct = { version = "1.0.1", features = ["std"] } +chrono = "0.4.19" crypto-mac = { version = "0.11.1", features = ["std"] } digest = "0.10.1" ecdsa = { version = "0.13.3", features = ["sign", "verify", "pem", "pkcs8"] } @@ -19,6 +20,7 @@ pkcs1 = { version = "0.3.1", features = ["pem", "pkcs8"] } pkcs8 = { version = "0.8.0", features = ["pem"] } rand = "0.8.4" rsa = { git = "https://github.com/sandhose/RSA.git", branch = "bump-pkcs" } +schemars = "0.8.8" sec1 = "0.2.1" serde = { version = "1.0.133", features = ["derive"] } serde_json = "1.0.74" @@ -26,6 +28,6 @@ serde_with = { version = "1.11.0", features = ["base64"] } sha2 = "0.10.0" signature = "1.4.0" thiserror = "1.0.30" -tokio = { version = "1.15.0", features = ["macros", "rt"] } +tokio = { version = "1.15.0", features = ["macros", "rt", "sync"] } url = { version = "2.2.2", features = ["serde"] } zeroize = "1.4.3" diff --git a/crates/jose/src/claims.rs b/crates/jose/src/claims.rs new file mode 100644 index 00000000..ffe7a9f2 --- /dev/null +++ b/crates/jose/src/claims.rs @@ -0,0 +1,25 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +trait ClaimSet { + fn validate(&self) -> anyhow::Result<()>; +} + +struct UnvalidatedClaim(T); + +impl ClaimSet for UnvalidatedClaim { + fn validate(&self) -> anyhow::Result<()> { + Ok(()) + } +} diff --git a/crates/jose/src/iana.rs b/crates/jose/src/iana.rs index 1783cc7e..c31d3ac8 100644 --- a/crates/jose/src/iana.rs +++ b/crates/jose/src/iana.rs @@ -16,9 +16,10 @@ //! //! +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum JsonWebSignatureAlgorithm { /// HMAC using SHA-256 #[serde(rename = "HS256")] @@ -157,7 +158,7 @@ pub enum JsonWebSignatureAlgorithm { Es256K, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum JsonWebEncryptionAlgorithm { /// AES_128_CBC_HMAC_SHA_256 authenticated encryption algorithm #[serde(rename = "A128CBC-HS256")] @@ -184,14 +185,14 @@ pub enum JsonWebEncryptionAlgorithm { A256Gcm, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum JsonWebEncryptionCompressionAlgorithm { /// DEFLATE #[serde(rename = "DEF")] Def, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum JsonWebKeyType { /// Elliptic Curve #[serde(rename = "EC")] @@ -210,7 +211,7 @@ pub enum JsonWebKeyType { Okp, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum JsonWebKeyEcEllipticCurve { /// P-256 Curve #[serde(rename = "P-256")] @@ -229,7 +230,7 @@ pub enum JsonWebKeyEcEllipticCurve { Secp256K1, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum JsonWebKeyOkpEllipticCurve { /// Ed25519 signature algorithm key pairs #[serde(rename = "Ed25519")] @@ -248,7 +249,7 @@ pub enum JsonWebKeyOkpEllipticCurve { X448, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum JsonWebKeyUse { /// Digital Signature or MAC #[serde(rename = "sig")] @@ -259,7 +260,7 @@ pub enum JsonWebKeyUse { Enc, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum JsonWebKeyOperation { /// Compute digital signature or MAC #[serde(rename = "sign")] diff --git a/crates/jose/src/jwk.rs b/crates/jose/src/jwk.rs index afe6719c..754bd2ac 100644 --- a/crates/jose/src/jwk.rs +++ b/crates/jose/src/jwk.rs @@ -17,6 +17,7 @@ use anyhow::bail; use p256::NistP256; use rsa::{BigUint, PublicKeyParts}; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_with::{ base64::{Base64, Standard, UrlSafe}, @@ -25,14 +26,17 @@ use serde_with::{ }; use url::Url; -use crate::iana::{ - JsonWebKeyEcEllipticCurve, JsonWebKeyOkpEllipticCurve, JsonWebKeyOperation, JsonWebKeyUse, - JsonWebSignatureAlgorithm, +use crate::{ + iana::{ + JsonWebKeyEcEllipticCurve, JsonWebKeyOkpEllipticCurve, JsonWebKeyOperation, JsonWebKeyUse, + JsonWebSignatureAlgorithm, + }, + JsonWebKeyType, }; #[serde_as] #[skip_serializing_none] -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] pub struct JsonWebKey { #[serde(flatten)] parameters: JsonWebKeyParameters, @@ -49,17 +53,21 @@ pub struct JsonWebKey { #[serde(default)] kid: Option, + #[schemars(with = "Option")] #[serde(default)] x5u: Option, + #[schemars(with = "Vec")] #[serde(default)] #[serde_as(as = "Option>>")] x5c: Option>>, + #[schemars(with = "Option")] #[serde(default)] #[serde_as(as = "Option>")] x5t: Option>, + #[schemars(with = "Option")] #[serde(default, rename = "x5t#S256")] #[serde_as(as = "Option>")] x5t_s256: Option>, @@ -104,13 +112,40 @@ impl JsonWebKey { self.kid = Some(kid.into()); self } + + #[must_use] + pub fn kty(&self) -> JsonWebKeyType { + match self.parameters { + JsonWebKeyParameters::Ec { .. } => JsonWebKeyType::Ec, + JsonWebKeyParameters::Rsa { .. } => JsonWebKeyType::Rsa, + JsonWebKeyParameters::Okp { .. } => JsonWebKeyType::Okp, + } + } + + #[must_use] + pub fn kid(&self) -> Option<&str> { + self.kid.as_deref() + } + + #[must_use] + pub fn params(&self) -> &JsonWebKeyParameters { + &self.parameters + } } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] pub struct JsonWebKeySet { keys: Vec, } +impl std::ops::Deref for JsonWebKeySet { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.keys + } +} + impl JsonWebKeySet { #[must_use] pub fn new(keys: Vec) -> Self { @@ -119,27 +154,36 @@ impl JsonWebKeySet { } #[serde_as] -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "kty")] pub enum JsonWebKeyParameters { #[serde(rename = "RSA")] Rsa { + #[schemars(with = "String")] #[serde_as(as = "Base64")] n: Vec, + + #[schemars(with = "String")] #[serde_as(as = "Base64")] e: Vec, }, #[serde(rename = "EC")] Ec { crv: JsonWebKeyEcEllipticCurve, + + #[schemars(with = "String")] #[serde_as(as = "Base64")] x: Vec, + + #[schemars(with = "String")] #[serde_as(as = "Base64")] y: Vec, }, #[serde(rename = "OKP")] Okp { crv: JsonWebKeyOkpEllipticCurve, + + #[schemars(with = "String")] #[serde_as(as = "Base64")] x: Vec, }, diff --git a/crates/jose/src/keystore/jwks.rs b/crates/jose/src/keystore/jwks.rs new file mode 100644 index 00000000..b2ad3724 --- /dev/null +++ b/crates/jose/src/keystore/jwks.rs @@ -0,0 +1,278 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use anyhow::bail; +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use digest::Digest; +use rsa::{PublicKey, RsaPublicKey}; +use sha2::{Sha256, Sha384, Sha512}; +use signature::{Signature, Verifier}; +use tokio::sync::RwLock; + +use crate::{ + ExportJwks, JsonWebKeySet, JsonWebKeyType, JsonWebSignatureAlgorithm, JwtHeader, + VerifyingKeystore, +}; + +pub struct StaticJwksStore { + key_set: JsonWebKeySet, + index: HashMap<(JsonWebKeyType, String), usize>, +} + +impl StaticJwksStore { + #[must_use] + pub fn new(key_set: JsonWebKeySet) -> Self { + let index = key_set + .iter() + .enumerate() + .filter_map(|(index, key)| { + let kid = key.kid()?.to_string(); + let kty = key.kty(); + + Some(((kty, kid), index)) + }) + .collect(); + + Self { key_set, index } + } + + fn find_rsa_key(&self, kid: String) -> anyhow::Result { + let index = *self + .index + .get(&(JsonWebKeyType::Rsa, kid)) + .ok_or_else(|| anyhow::anyhow!("key not found"))?; + + let key = self + .key_set + .get(index) + .ok_or_else(|| anyhow::anyhow!("invalid index"))?; + + let key = key.params().clone().try_into()?; + + Ok(key) + } + + fn find_ecdsa_key(&self, kid: String) -> anyhow::Result> { + let index = *self + .index + .get(&(JsonWebKeyType::Ec, kid)) + .ok_or_else(|| anyhow::anyhow!("key not found"))?; + + let key = self + .key_set + .get(index) + .ok_or_else(|| anyhow::anyhow!("invalid index"))?; + + let key = key.params().clone().try_into()?; + + Ok(key) + } +} + +#[async_trait] +impl VerifyingKeystore for &StaticJwksStore { + async fn verify( + self, + header: &JwtHeader, + payload: &[u8], + signature: &[u8], + ) -> anyhow::Result<()> { + let kid = header + .kid() + .ok_or_else(|| anyhow::anyhow!("missing kid"))? + .to_string(); + match header.alg() { + JsonWebSignatureAlgorithm::Rs256 => { + let key = self.find_rsa_key(kid)?; + + let digest = { + let mut digest = Sha256::new(); + digest.update(&payload); + digest.finalize() + }; + + key.verify( + rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)), + &digest, + signature, + )?; + } + + JsonWebSignatureAlgorithm::Rs384 => { + let key = self.find_rsa_key(kid)?; + + let digest = { + let mut digest = Sha384::new(); + digest.update(&payload); + digest.finalize() + }; + + key.verify( + rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)), + &digest, + signature, + )?; + } + + JsonWebSignatureAlgorithm::Rs512 => { + let key = self.find_rsa_key(kid)?; + + let digest = { + let mut digest = Sha512::new(); + digest.update(&payload); + digest.finalize() + }; + + key.verify( + rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)), + &digest, + signature, + )?; + } + + JsonWebSignatureAlgorithm::Es256 => { + let key = self.find_ecdsa_key(kid)?; + + let signature = ecdsa::Signature::from_bytes(signature)?; + + key.verify(payload, &signature)?; + } + + _ => bail!("unsupported algorithm"), + }; + + Ok(()) + } +} + +enum RemoteKeySet { + Pending, + Errored { + at: DateTime, + error: anyhow::Error, + }, + Fulfilled { + at: DateTime, + store: StaticJwksStore, + }, +} + +impl Default for RemoteKeySet { + fn default() -> Self { + Self::Pending + } +} + +impl RemoteKeySet { + fn fullfill(&mut self, key_set: JsonWebKeySet) { + *self = Self::Fulfilled { + at: Utc::now(), + store: StaticJwksStore::new(key_set), + } + } + + fn error(&mut self, error: anyhow::Error) { + *self = Self::Errored { + at: Utc::now(), + error, + } + } + + fn should_refresh(&self) -> bool { + let now = Utc::now(); + match self { + Self::Pending => true, + Self::Errored { at, .. } if *at - now > Duration::minutes(5) => true, + Self::Fulfilled { at, .. } if *at - now > Duration::hours(1) => true, + _ => false, + } + } + + fn should_force_refresh(&self) -> bool { + let now = Utc::now(); + match self { + Self::Pending => true, + Self::Errored { at, .. } | Self::Fulfilled { at, .. } + if *at - now > Duration::minutes(5) => + { + true + } + _ => false, + } + } +} + +pub struct JwksStore +where + T: ExportJwks, +{ + exporter: T, + cache: RwLock, +} + +impl JwksStore { + pub fn new(exporter: T) -> Self { + Self { + exporter, + cache: RwLock::default(), + } + } + + async fn should_refresh(&self) -> bool { + let cache = self.cache.read().await; + cache.should_refresh() + } + + async fn refresh(&self) { + let mut cache = self.cache.write().await; + + if cache.should_force_refresh() { + let jwks = self.exporter.export_jwks().await; + + match jwks { + Ok(jwks) => cache.fullfill(jwks), + Err(err) => cache.error(err), + } + } + } +} + +#[async_trait] +impl VerifyingKeystore for &JwksStore { + async fn verify( + self, + header: &JwtHeader, + payload: &[u8], + signature: &[u8], + ) -> anyhow::Result<()> { + if self.should_refresh().await { + self.refresh().await; + } + + let cache = self.cache.read().await; + // TODO: we could bubble up the underlying error here + let store = match &*cache { + RemoteKeySet::Pending => bail!("inconsistent cache state"), + RemoteKeySet::Errored { error, .. } => bail!("cache in error state {}", error), + RemoteKeySet::Fulfilled { store, .. } => store, + }; + + store.verify(header, payload, signature).await?; + + Ok(()) + } +} diff --git a/crates/jose/src/keystore/mod.rs b/crates/jose/src/keystore/mod.rs index 4203db4b..350e40a8 100644 --- a/crates/jose/src/keystore/mod.rs +++ b/crates/jose/src/keystore/mod.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod jwks; mod shared_secret; mod static_keystore; mod traits; pub use self::{ + jwks::{JwksStore, StaticJwksStore}, shared_secret::SharedSecret, static_keystore::StaticKeystore, traits::{ExportJwks, SigningKeystore, VerifyingKeystore}, diff --git a/crates/jose/src/keystore/static_keystore.rs b/crates/jose/src/keystore/static_keystore.rs index 1fb2bdb6..0d50a3d3 100644 --- a/crates/jose/src/keystore/static_keystore.rs +++ b/crates/jose/src/keystore/static_keystore.rs @@ -27,10 +27,7 @@ use sha2::{Sha256, Sha384, Sha512}; use signature::{Signature, Signer, Verifier}; use super::{ExportJwks, SigningKeystore, VerifyingKeystore}; -use crate::{ - iana::{JsonWebKeyOperation, JsonWebSignatureAlgorithm}, - JsonWebKey, JsonWebKeySet, JwtHeader, -}; +use crate::{iana::JsonWebSignatureAlgorithm, JsonWebKey, JsonWebKeySet, JwtHeader}; #[derive(Default)] pub struct StaticKeystore { @@ -276,23 +273,13 @@ impl VerifyingKeystore for &StaticKeystore { } #[async_trait] -impl ExportJwks for &StaticKeystore { - async fn export_jwks(self) -> JsonWebKeySet { - let rsa = self.rsa_keys.iter().flat_map(|(kid, key)| { +impl ExportJwks for StaticKeystore { + async fn export_jwks(&self) -> anyhow::Result { + let rsa = self.rsa_keys.iter().map(|(kid, key)| { let pubkey = RsaPublicKey::from(key); - let basekey = JsonWebKey::new(pubkey.into()) + JsonWebKey::new(pubkey.into()) .with_kid(kid) .with_use(crate::JsonWebKeyUse::Sig) - .with_key_ops(vec![JsonWebKeyOperation::Sign]); - - let algs = [ - JsonWebSignatureAlgorithm::Rs256, - JsonWebSignatureAlgorithm::Rs384, - JsonWebSignatureAlgorithm::Rs512, - ]; - - algs.into_iter() - .map(move |alg| basekey.clone().with_alg(alg)) }); let es256 = self.es256_keys.iter().map(|(kid, key)| { @@ -300,12 +287,11 @@ impl ExportJwks for &StaticKeystore { JsonWebKey::new(pubkey.into()) .with_kid(kid) .with_use(crate::JsonWebKeyUse::Sig) - .with_key_ops(vec![JsonWebKeyOperation::Sign]) .with_alg(JsonWebSignatureAlgorithm::Es256) }); let keys = rsa.chain(es256).collect(); - JsonWebKeySet::new(keys) + Ok(JsonWebKeySet::new(keys)) } } diff --git a/crates/jose/src/keystore/traits.rs b/crates/jose/src/keystore/traits.rs index 2343c3df..9ebe7023 100644 --- a/crates/jose/src/keystore/traits.rs +++ b/crates/jose/src/keystore/traits.rs @@ -14,9 +14,7 @@ use async_trait::async_trait; -use crate::{ - iana::JsonWebSignatureAlgorithm, JsonWebKeySet, JwtHeader, -}; +use crate::{iana::JsonWebSignatureAlgorithm, JsonWebKeySet, JwtHeader}; #[async_trait] pub trait SigningKeystore { @@ -32,6 +30,5 @@ pub trait VerifyingKeystore { #[async_trait] pub trait ExportJwks { - async fn export_jwks(self) -> JsonWebKeySet; + async fn export_jwks(&self) -> anyhow::Result; } - diff --git a/crates/jose/src/lib.rs b/crates/jose/src/lib.rs index f55ff85e..efb80b9b 100644 --- a/crates/jose/src/lib.rs +++ b/crates/jose/src/lib.rs @@ -19,6 +19,7 @@ #![allow(clippy::missing_errors_doc)] #![allow(clippy::module_name_repetitions)] +mod claims; pub(crate) mod iana; pub(crate) mod jwk; pub(crate) mod jwt; @@ -31,5 +32,8 @@ pub use self::{ }, jwk::{JsonWebKey, JsonWebKeySet}, jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader}, - keystore::{ExportJwks, SharedSecret, SigningKeystore, StaticKeystore, VerifyingKeystore}, + keystore::{ + ExportJwks, JwksStore, SharedSecret, SigningKeystore, StaticJwksStore, StaticKeystore, + VerifyingKeystore, + }, }; diff --git a/crates/warp-utils/src/filters/client.rs b/crates/warp-utils/src/filters/client.rs index fda58a65..3982ac4b 100644 --- a/crates/warp-utils/src/filters/client.rs +++ b/crates/warp-utils/src/filters/client.rs @@ -15,7 +15,7 @@ //! Handle client authentication use headers::{authorization::Basic, Authorization}; -use mas_config::{OAuth2ClientConfig, OAuth2Config}; +use mas_config::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config}; use mas_jose::{DecodedJsonWebToken, JsonWebTokenParts, SharedSecret}; use oauth2_types::requests::ClientAuthenticationMethod; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -72,17 +72,14 @@ pub fn client_authentication( #[derive(Error, Debug)] enum ClientAuthenticationError { - #[error("no client secret found for client {client_id:?}")] - NoClientSecret { client_id: String }, - #[error("wrong client secret for client {client_id:?}")] ClientSecretMismatch { client_id: String }, #[error("could not find client {client_id:?}")] ClientNotFound { client_id: String }, - #[error("client secret required for client {client_id:?}")] - ClientSecretRequired { client_id: String }, + #[error("wrong client authentication method for client {client_id:?}")] + WrongAuthenticationMethod { client_id: String }, #[error("wrong audience in client assertion: expected {expected:?}, got {got:?}")] AudienceMismatch { expected: String, got: String }, @@ -113,12 +110,11 @@ async fn authenticate_client( credentials: ClientCredentials, body: T, ) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> { - let auth_type = credentials.authentication_type(); - let client = match credentials { + let (auth_method, client) = match credentials { ClientCredentials::Pair { client_id, client_secret, - .. + via, } => { let client = clients .iter() @@ -127,17 +123,49 @@ async fn authenticate_client( client_id: client_id.to_string(), })?; - match (client_secret, client.client_secret.as_ref()) { - (None, None) => Ok(client), - (Some(ref given), Some(expected)) if given == expected => Ok(client), - (Some(_), Some(_)) => { - Err(ClientAuthenticationError::ClientSecretMismatch { client_id }) + let auth_method = match (&client.client_auth_method, client_secret, via) { + (OAuth2ClientAuthMethodConfig::None, None, _) => ClientAuthenticationMethod::None, + + ( + OAuth2ClientAuthMethodConfig::ClientSecretBasic { + client_secret: ref expected_client_secret, + }, + Some(ref given_client_secret), + CredentialsVia::AuthorizationHeader, + ) => { + if expected_client_secret != given_client_secret { + return Err( + ClientAuthenticationError::ClientSecretMismatch { client_id }.into(), + ); + } + + ClientAuthenticationMethod::ClientSecretBasic } - (Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }), - (None, Some(_)) => { - Err(ClientAuthenticationError::ClientSecretRequired { client_id }) + + ( + OAuth2ClientAuthMethodConfig::ClientSecretPost { + client_secret: ref expected_client_secret, + }, + Some(ref given_client_secret), + CredentialsVia::FormBody, + ) => { + if expected_client_secret != given_client_secret { + return Err( + ClientAuthenticationError::ClientSecretMismatch { client_id }.into(), + ); + } + + ClientAuthenticationMethod::ClientSecretPost } - } + + _ => { + return Err( + ClientAuthenticationError::WrongAuthenticationMethod { client_id }.into(), + ) + } + }; + + (auth_method, client) } ClientCredentials::Assertion { client_id, @@ -150,43 +178,61 @@ async fn authenticate_client( // client_id might have been passed as parameter. If not, it should be inferred // from the token, as per rfc7521 sec. 4.2 - let client_id = client_id.unwrap_or_else(|| decoded.claims().subject.clone()); + let client_id = client_id + .as_ref() + .unwrap_or_else(|| &decoded.claims().subject); let client = clients .iter() - .find(|client| client.client_id == client_id) + .find(|client| &client.client_id == client_id) .ok_or_else(|| ClientAuthenticationError::ClientNotFound { client_id: client_id.to_string(), })?; - if let Some(client_secret) = &client.client_secret { - let store = SharedSecret::new(client_secret); - token.verify(&decoded, &store).await.wrap_error()?; - let claims = decoded.claims(); - // TODO: validate the times again - - // rfc7523 sec. 3.3: the audience is the URL being called - if claims.audience != audience { - Err(ClientAuthenticationError::AudienceMismatch { - expected: audience, - got: claims.audience.clone(), - }) - // rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must - // match the client_id - } else if claims.issuer != claims.subject || claims.issuer != client_id { - Err(ClientAuthenticationError::InvalidAssertion) - } else { - Ok(client) + let auth_method = match &client.client_auth_method { + OAuth2ClientAuthMethodConfig::PrivateKeyJwt(jwks) => { + let store = jwks.key_store(); + token.verify(&decoded, &store).await.wrap_error()?; + ClientAuthenticationMethod::PrivateKeyJwt } - } else { - Err(ClientAuthenticationError::ClientSecretRequired { - client_id: client_id.to_string(), - }) - } - } - }?; - Ok((auth_type, client.clone(), body)) + OAuth2ClientAuthMethodConfig::ClientSecretJwt { client_secret } => { + let store = SharedSecret::new(client_secret); + token.verify(&decoded, &store).await.wrap_error()?; + ClientAuthenticationMethod::ClientSecretJwt + } + + _ => { + return Err(ClientAuthenticationError::WrongAuthenticationMethod { + client_id: client_id.clone(), + } + .into()) + } + }; + + let claims = decoded.claims(); + // TODO: validate the times again + + // rfc7523 sec. 3.3: the audience is the URL being called + if claims.audience != audience { + return Err(ClientAuthenticationError::AudienceMismatch { + expected: audience, + got: claims.audience.clone(), + } + .into()); + } + + // rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must + // match the client_id + if claims.issuer != claims.subject || &claims.issuer != client_id { + return Err(ClientAuthenticationError::InvalidAssertion.into()); + } + + (auth_method, client) + } + }; + + Ok((auth_method, client.clone(), body)) } #[derive(Deserialize)] @@ -225,28 +271,6 @@ enum ClientCredentials { }, } -impl ClientCredentials { - fn authentication_type(&self) -> ClientAuthenticationMethod { - match self { - ClientCredentials::Pair { - via: CredentialsVia::FormBody, - client_secret: None, - .. - } => ClientAuthenticationMethod::None, - ClientCredentials::Pair { - via: CredentialsVia::FormBody, - client_secret: Some(_), - .. - } => ClientAuthenticationMethod::ClientSecretPost, - ClientCredentials::Pair { - via: CredentialsVia::AuthorizationHeader, - .. - } => ClientAuthenticationMethod::ClientSecretBasic, - ClientCredentials::Assertion { .. } => ClientAuthenticationMethod::ClientSecretJwt, - } - } -} - #[derive(Deserialize)] struct ClientAuthForm { #[serde(flatten)] @@ -259,7 +283,7 @@ struct ClientAuthForm { #[cfg(test)] mod tests { use headers::authorization::Credentials; - use mas_config::ConfigurationSection; + use mas_config::{ConfigurationSection, OAuth2ClientAuthMethodConfig}; use mas_jose::{JsonWebSignatureAlgorithm, SigningKeystore}; use serde_json::json; @@ -272,17 +296,21 @@ mod tests { let mut config = OAuth2Config::test(); config.clients.push(OAuth2ClientConfig { client_id: "public".to_string(), - client_secret: None, + client_auth_method: OAuth2ClientAuthMethodConfig::None, redirect_uris: Vec::new(), }); config.clients.push(OAuth2ClientConfig { - client_id: "confidential".to_string(), - client_secret: Some(CLIENT_SECRET.to_string()), + client_id: "secret-basic".to_string(), + client_auth_method: OAuth2ClientAuthMethodConfig::ClientSecretBasic { + client_secret: CLIENT_SECRET.to_string(), + }, redirect_uris: Vec::new(), }); config.clients.push(OAuth2ClientConfig { - client_id: "confidential-2".to_string(), - client_secret: Some(CLIENT_SECRET.to_string()), + client_id: "secret-post".to_string(), + client_auth_method: OAuth2ClientAuthMethodConfig::ClientSecretPost { + client_secret: CLIENT_SECRET.to_string(), + }, redirect_uris: Vec::new(), }); config @@ -395,7 +423,7 @@ mod tests { ) .body( serde_urlencoded::to_string(json!({ - "client_id": "confidential", + "client_id": "secret-post", "client_secret": CLIENT_SECRET, "foo": "baz", "bar": "foobar", @@ -407,7 +435,7 @@ mod tests { .unwrap(); assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost); - assert_eq!(client.client_id, "confidential"); + assert_eq!(client.client_id, "secret-post"); assert_eq!(body.foo, "baz"); assert_eq!(body.bar, "foobar"); } @@ -419,7 +447,7 @@ mod tests { "https://example.com/token".to_string(), ); - let auth = Authorization::basic("confidential", CLIENT_SECRET); + let auth = Authorization::basic("secret-basic", CLIENT_SECRET); let (auth, client, body) = warp::test::request() .method("POST") .header( @@ -439,7 +467,7 @@ mod tests { .unwrap(); assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic); - assert_eq!(client.client_id, "confidential"); + assert_eq!(client.client_id, "secret-basic"); assert_eq!(body.foo, "baz"); assert_eq!(body.bar, "foobar"); }