1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Support private_key_jwt client auth

Which includes having a verifying keystore out of JWKS (and soon out of
a JWKS URI)
This commit is contained in:
Quentin Gliech
2022-01-05 21:07:18 +01:00
parent f7706f2351
commit a965e488e2
14 changed files with 557 additions and 129 deletions

View File

@ -38,7 +38,7 @@ pub use self::{
csrf::CsrfConfig,
database::DatabaseConfig,
http::HttpConfig,
oauth2::{OAuth2ClientConfig, OAuth2Config},
oauth2::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config},
telemetry::{
MetricsConfig, MetricsExporterConfig, Propagator, TelemetryConfig, TracingConfig,
TracingExporterConfig,

View File

@ -14,7 +14,7 @@
use anyhow::Context;
use async_trait::async_trait;
use mas_jose::StaticKeystore;
use mas_jose::{JsonWebKeySet, StaticJwksStore, StaticKeystore};
use pkcs8::{DecodePrivateKey, EncodePrivateKey};
use rsa::{
pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey},
@ -43,13 +43,41 @@ pub struct KeyConfig {
key: String,
}
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum JwksOrJwksUri {
Jwks(JsonWebKeySet),
JwksUri(Url),
}
impl JwksOrJwksUri {
pub fn key_store(&self) -> StaticJwksStore {
let jwks = match self {
Self::Jwks(jwks) => jwks.clone(),
Self::JwksUri(_) => unimplemented!("jwks_uri are not implemented yet"),
};
StaticJwksStore::new(jwks)
}
}
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "client_auth_method", rename_all = "snake_case")]
pub enum OAuth2ClientAuthMethodConfig {
None,
ClientSecretBasic { client_secret: String },
ClientSecretPost { client_secret: String },
ClientSecretJwt { client_secret: String },
PrivateKeyJwt(JwksOrJwksUri),
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OAuth2ClientConfig {
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
#[serde(flatten)]
pub client_auth_method: OAuth2ClientAuthMethodConfig,
#[serde(default)]
pub redirect_uris: Vec<Url>,
@ -246,25 +274,55 @@ mod tests {
-----END PRIVATE KEY-----
issuer: https://example.com
clients:
- client_id: hello
- client_id: public
client_auth_method: none
redirect_uris:
- https://exemple.fr/callback
- client_id: world
- client_id: secret-basic
client_auth_method: client_secret_basic
client_secret: hello
- client_id: secret-post
client_auth_method: client_secret_post
client_secret: hello
- client_id: secret-jwk
client_auth_method: client_secret_jwt
client_secret: hello
- client_id: jwks
client_auth_method: private_key_jwt
jwks:
keys:
- kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
- kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
"#,
)?;
let config = OAuth2Config::load_from_file("config.yaml")?;
assert_eq!(config.issuer, "https://example.com".parse().unwrap());
assert_eq!(config.clients.len(), 2);
assert_eq!(config.clients.len(), 5);
assert_eq!(config.clients[0].client_id, "hello");
assert_eq!(config.clients[0].client_id, "public");
assert_eq!(
config.clients[0].redirect_uris,
vec!["https://exemple.fr/callback".parse().unwrap()]
);
assert_eq!(config.clients[1].client_id, "world");
assert_eq!(config.clients[1].client_id, "secret-basic");
assert_eq!(config.clients[1].redirect_uris, Vec::new());
Ok(())

View File

@ -15,6 +15,7 @@
use std::sync::Arc;
use mas_jose::{ExportJwks, StaticKeystore};
use mas_warp_utils::errors::WrapError;
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Reply>,)> {
@ -25,7 +26,7 @@ pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Re
}
async fn get(key_store: Arc<StaticKeystore>) -> Result<Box<dyn Reply>, Rejection> {
let jwks = key_store.export_jwks().await;
let jwks = key_store.export_jwks().await.wrap_error()?;
Ok(Box::new(warp::reply::json(&jwks)))
}

View File

@ -9,6 +9,7 @@ license = "Apache-2.0"
anyhow = "1.0.52"
async-trait = "0.1.52"
base64ct = { version = "1.0.1", features = ["std"] }
chrono = "0.4.19"
crypto-mac = { version = "0.11.1", features = ["std"] }
digest = "0.10.1"
ecdsa = { version = "0.13.3", features = ["sign", "verify", "pem", "pkcs8"] }
@ -19,6 +20,7 @@ pkcs1 = { version = "0.3.1", features = ["pem", "pkcs8"] }
pkcs8 = { version = "0.8.0", features = ["pem"] }
rand = "0.8.4"
rsa = { git = "https://github.com/sandhose/RSA.git", branch = "bump-pkcs" }
schemars = "0.8.8"
sec1 = "0.2.1"
serde = { version = "1.0.133", features = ["derive"] }
serde_json = "1.0.74"
@ -26,6 +28,6 @@ serde_with = { version = "1.11.0", features = ["base64"] }
sha2 = "0.10.0"
signature = "1.4.0"
thiserror = "1.0.30"
tokio = { version = "1.15.0", features = ["macros", "rt"] }
tokio = { version = "1.15.0", features = ["macros", "rt", "sync"] }
url = { version = "2.2.2", features = ["serde"] }
zeroize = "1.4.3"

25
crates/jose/src/claims.rs Normal file
View File

@ -0,0 +1,25 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
trait ClaimSet {
fn validate(&self) -> anyhow::Result<()>;
}
struct UnvalidatedClaim<T>(T);
impl<T> ClaimSet for UnvalidatedClaim<T> {
fn validate(&self) -> anyhow::Result<()> {
Ok(())
}
}

View File

@ -16,9 +16,10 @@
//!
//! <https://www.iana.org/assignments/jose/jose.xhtml>
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub enum JsonWebSignatureAlgorithm {
/// HMAC using SHA-256
#[serde(rename = "HS256")]
@ -157,7 +158,7 @@ pub enum JsonWebSignatureAlgorithm {
Es256K,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum JsonWebEncryptionAlgorithm {
/// AES_128_CBC_HMAC_SHA_256 authenticated encryption algorithm
#[serde(rename = "A128CBC-HS256")]
@ -184,14 +185,14 @@ pub enum JsonWebEncryptionAlgorithm {
A256Gcm,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum JsonWebEncryptionCompressionAlgorithm {
/// DEFLATE
#[serde(rename = "DEF")]
Def,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum JsonWebKeyType {
/// Elliptic Curve
#[serde(rename = "EC")]
@ -210,7 +211,7 @@ pub enum JsonWebKeyType {
Okp,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub enum JsonWebKeyEcEllipticCurve {
/// P-256 Curve
#[serde(rename = "P-256")]
@ -229,7 +230,7 @@ pub enum JsonWebKeyEcEllipticCurve {
Secp256K1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub enum JsonWebKeyOkpEllipticCurve {
/// Ed25519 signature algorithm key pairs
#[serde(rename = "Ed25519")]
@ -248,7 +249,7 @@ pub enum JsonWebKeyOkpEllipticCurve {
X448,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub enum JsonWebKeyUse {
/// Digital Signature or MAC
#[serde(rename = "sig")]
@ -259,7 +260,7 @@ pub enum JsonWebKeyUse {
Enc,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub enum JsonWebKeyOperation {
/// Compute digital signature or MAC
#[serde(rename = "sign")]

View File

@ -17,6 +17,7 @@
use anyhow::bail;
use p256::NistP256;
use rsa::{BigUint, PublicKeyParts};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::{
base64::{Base64, Standard, UrlSafe},
@ -25,14 +26,17 @@ use serde_with::{
};
use url::Url;
use crate::iana::{
JsonWebKeyEcEllipticCurve, JsonWebKeyOkpEllipticCurve, JsonWebKeyOperation, JsonWebKeyUse,
JsonWebSignatureAlgorithm,
use crate::{
iana::{
JsonWebKeyEcEllipticCurve, JsonWebKeyOkpEllipticCurve, JsonWebKeyOperation, JsonWebKeyUse,
JsonWebSignatureAlgorithm,
},
JsonWebKeyType,
};
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct JsonWebKey {
#[serde(flatten)]
parameters: JsonWebKeyParameters,
@ -49,17 +53,21 @@ pub struct JsonWebKey {
#[serde(default)]
kid: Option<String>,
#[schemars(with = "Option<String>")]
#[serde(default)]
x5u: Option<Url>,
#[schemars(with = "Vec<String>")]
#[serde(default)]
#[serde_as(as = "Option<Vec<Base64<Standard, Padded>>>")]
x5c: Option<Vec<Vec<u8>>>,
#[schemars(with = "Option<String>")]
#[serde(default)]
#[serde_as(as = "Option<Base64<UrlSafe, Unpadded>>")]
x5t: Option<Vec<u8>>,
#[schemars(with = "Option<String>")]
#[serde(default, rename = "x5t#S256")]
#[serde_as(as = "Option<Base64<UrlSafe, Unpadded>>")]
x5t_s256: Option<Vec<u8>>,
@ -104,13 +112,40 @@ impl JsonWebKey {
self.kid = Some(kid.into());
self
}
#[must_use]
pub fn kty(&self) -> JsonWebKeyType {
match self.parameters {
JsonWebKeyParameters::Ec { .. } => JsonWebKeyType::Ec,
JsonWebKeyParameters::Rsa { .. } => JsonWebKeyType::Rsa,
JsonWebKeyParameters::Okp { .. } => JsonWebKeyType::Okp,
}
}
#[must_use]
pub fn kid(&self) -> Option<&str> {
self.kid.as_deref()
}
#[must_use]
pub fn params(&self) -> &JsonWebKeyParameters {
&self.parameters
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct JsonWebKeySet {
keys: Vec<JsonWebKey>,
}
impl std::ops::Deref for JsonWebKeySet {
type Target = Vec<JsonWebKey>;
fn deref(&self) -> &Self::Target {
&self.keys
}
}
impl JsonWebKeySet {
#[must_use]
pub fn new(keys: Vec<JsonWebKey>) -> Self {
@ -119,27 +154,36 @@ impl JsonWebKeySet {
}
#[serde_as]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "kty")]
pub enum JsonWebKeyParameters {
#[serde(rename = "RSA")]
Rsa {
#[schemars(with = "String")]
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
n: Vec<u8>,
#[schemars(with = "String")]
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
e: Vec<u8>,
},
#[serde(rename = "EC")]
Ec {
crv: JsonWebKeyEcEllipticCurve,
#[schemars(with = "String")]
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
x: Vec<u8>,
#[schemars(with = "String")]
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
y: Vec<u8>,
},
#[serde(rename = "OKP")]
Okp {
crv: JsonWebKeyOkpEllipticCurve,
#[schemars(with = "String")]
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
x: Vec<u8>,
},

View File

@ -0,0 +1,278 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use anyhow::bail;
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use digest::Digest;
use rsa::{PublicKey, RsaPublicKey};
use sha2::{Sha256, Sha384, Sha512};
use signature::{Signature, Verifier};
use tokio::sync::RwLock;
use crate::{
ExportJwks, JsonWebKeySet, JsonWebKeyType, JsonWebSignatureAlgorithm, JwtHeader,
VerifyingKeystore,
};
pub struct StaticJwksStore {
key_set: JsonWebKeySet,
index: HashMap<(JsonWebKeyType, String), usize>,
}
impl StaticJwksStore {
#[must_use]
pub fn new(key_set: JsonWebKeySet) -> Self {
let index = key_set
.iter()
.enumerate()
.filter_map(|(index, key)| {
let kid = key.kid()?.to_string();
let kty = key.kty();
Some(((kty, kid), index))
})
.collect();
Self { key_set, index }
}
fn find_rsa_key(&self, kid: String) -> anyhow::Result<RsaPublicKey> {
let index = *self
.index
.get(&(JsonWebKeyType::Rsa, kid))
.ok_or_else(|| anyhow::anyhow!("key not found"))?;
let key = self
.key_set
.get(index)
.ok_or_else(|| anyhow::anyhow!("invalid index"))?;
let key = key.params().clone().try_into()?;
Ok(key)
}
fn find_ecdsa_key(&self, kid: String) -> anyhow::Result<ecdsa::VerifyingKey<p256::NistP256>> {
let index = *self
.index
.get(&(JsonWebKeyType::Ec, kid))
.ok_or_else(|| anyhow::anyhow!("key not found"))?;
let key = self
.key_set
.get(index)
.ok_or_else(|| anyhow::anyhow!("invalid index"))?;
let key = key.params().clone().try_into()?;
Ok(key)
}
}
#[async_trait]
impl VerifyingKeystore for &StaticJwksStore {
async fn verify(
self,
header: &JwtHeader,
payload: &[u8],
signature: &[u8],
) -> anyhow::Result<()> {
let kid = header
.kid()
.ok_or_else(|| anyhow::anyhow!("missing kid"))?
.to_string();
match header.alg() {
JsonWebSignatureAlgorithm::Rs256 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha256::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
&digest,
signature,
)?;
}
JsonWebSignatureAlgorithm::Rs384 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha384::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)),
&digest,
signature,
)?;
}
JsonWebSignatureAlgorithm::Rs512 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha512::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)),
&digest,
signature,
)?;
}
JsonWebSignatureAlgorithm::Es256 => {
let key = self.find_ecdsa_key(kid)?;
let signature = ecdsa::Signature::from_bytes(signature)?;
key.verify(payload, &signature)?;
}
_ => bail!("unsupported algorithm"),
};
Ok(())
}
}
enum RemoteKeySet {
Pending,
Errored {
at: DateTime<Utc>,
error: anyhow::Error,
},
Fulfilled {
at: DateTime<Utc>,
store: StaticJwksStore,
},
}
impl Default for RemoteKeySet {
fn default() -> Self {
Self::Pending
}
}
impl RemoteKeySet {
fn fullfill(&mut self, key_set: JsonWebKeySet) {
*self = Self::Fulfilled {
at: Utc::now(),
store: StaticJwksStore::new(key_set),
}
}
fn error(&mut self, error: anyhow::Error) {
*self = Self::Errored {
at: Utc::now(),
error,
}
}
fn should_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } if *at - now > Duration::minutes(5) => true,
Self::Fulfilled { at, .. } if *at - now > Duration::hours(1) => true,
_ => false,
}
}
fn should_force_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } | Self::Fulfilled { at, .. }
if *at - now > Duration::minutes(5) =>
{
true
}
_ => false,
}
}
}
pub struct JwksStore<T>
where
T: ExportJwks,
{
exporter: T,
cache: RwLock<RemoteKeySet>,
}
impl<T: ExportJwks> JwksStore<T> {
pub fn new(exporter: T) -> Self {
Self {
exporter,
cache: RwLock::default(),
}
}
async fn should_refresh(&self) -> bool {
let cache = self.cache.read().await;
cache.should_refresh()
}
async fn refresh(&self) {
let mut cache = self.cache.write().await;
if cache.should_force_refresh() {
let jwks = self.exporter.export_jwks().await;
match jwks {
Ok(jwks) => cache.fullfill(jwks),
Err(err) => cache.error(err),
}
}
}
}
#[async_trait]
impl<T: ExportJwks + Send + Sync> VerifyingKeystore for &JwksStore<T> {
async fn verify(
self,
header: &JwtHeader,
payload: &[u8],
signature: &[u8],
) -> anyhow::Result<()> {
if self.should_refresh().await {
self.refresh().await;
}
let cache = self.cache.read().await;
// TODO: we could bubble up the underlying error here
let store = match &*cache {
RemoteKeySet::Pending => bail!("inconsistent cache state"),
RemoteKeySet::Errored { error, .. } => bail!("cache in error state {}", error),
RemoteKeySet::Fulfilled { store, .. } => store,
};
store.verify(header, payload, signature).await?;
Ok(())
}
}

View File

@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod jwks;
mod shared_secret;
mod static_keystore;
mod traits;
pub use self::{
jwks::{JwksStore, StaticJwksStore},
shared_secret::SharedSecret,
static_keystore::StaticKeystore,
traits::{ExportJwks, SigningKeystore, VerifyingKeystore},

View File

@ -27,10 +27,7 @@ use sha2::{Sha256, Sha384, Sha512};
use signature::{Signature, Signer, Verifier};
use super::{ExportJwks, SigningKeystore, VerifyingKeystore};
use crate::{
iana::{JsonWebKeyOperation, JsonWebSignatureAlgorithm},
JsonWebKey, JsonWebKeySet, JwtHeader,
};
use crate::{iana::JsonWebSignatureAlgorithm, JsonWebKey, JsonWebKeySet, JwtHeader};
#[derive(Default)]
pub struct StaticKeystore {
@ -276,23 +273,13 @@ impl VerifyingKeystore for &StaticKeystore {
}
#[async_trait]
impl ExportJwks for &StaticKeystore {
async fn export_jwks(self) -> JsonWebKeySet {
let rsa = self.rsa_keys.iter().flat_map(|(kid, key)| {
impl ExportJwks for StaticKeystore {
async fn export_jwks(&self) -> anyhow::Result<JsonWebKeySet> {
let rsa = self.rsa_keys.iter().map(|(kid, key)| {
let pubkey = RsaPublicKey::from(key);
let basekey = JsonWebKey::new(pubkey.into())
JsonWebKey::new(pubkey.into())
.with_kid(kid)
.with_use(crate::JsonWebKeyUse::Sig)
.with_key_ops(vec![JsonWebKeyOperation::Sign]);
let algs = [
JsonWebSignatureAlgorithm::Rs256,
JsonWebSignatureAlgorithm::Rs384,
JsonWebSignatureAlgorithm::Rs512,
];
algs.into_iter()
.map(move |alg| basekey.clone().with_alg(alg))
});
let es256 = self.es256_keys.iter().map(|(kid, key)| {
@ -300,12 +287,11 @@ impl ExportJwks for &StaticKeystore {
JsonWebKey::new(pubkey.into())
.with_kid(kid)
.with_use(crate::JsonWebKeyUse::Sig)
.with_key_ops(vec![JsonWebKeyOperation::Sign])
.with_alg(JsonWebSignatureAlgorithm::Es256)
});
let keys = rsa.chain(es256).collect();
JsonWebKeySet::new(keys)
Ok(JsonWebKeySet::new(keys))
}
}

View File

@ -14,9 +14,7 @@
use async_trait::async_trait;
use crate::{
iana::JsonWebSignatureAlgorithm, JsonWebKeySet, JwtHeader,
};
use crate::{iana::JsonWebSignatureAlgorithm, JsonWebKeySet, JwtHeader};
#[async_trait]
pub trait SigningKeystore {
@ -32,6 +30,5 @@ pub trait VerifyingKeystore {
#[async_trait]
pub trait ExportJwks {
async fn export_jwks(self) -> JsonWebKeySet;
async fn export_jwks(&self) -> anyhow::Result<JsonWebKeySet>;
}

View File

@ -19,6 +19,7 @@
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::module_name_repetitions)]
mod claims;
pub(crate) mod iana;
pub(crate) mod jwk;
pub(crate) mod jwt;
@ -31,5 +32,8 @@ pub use self::{
},
jwk::{JsonWebKey, JsonWebKeySet},
jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader},
keystore::{ExportJwks, SharedSecret, SigningKeystore, StaticKeystore, VerifyingKeystore},
keystore::{
ExportJwks, JwksStore, SharedSecret, SigningKeystore, StaticJwksStore, StaticKeystore,
VerifyingKeystore,
},
};

View File

@ -15,7 +15,7 @@
//! Handle client authentication
use headers::{authorization::Basic, Authorization};
use mas_config::{OAuth2ClientConfig, OAuth2Config};
use mas_config::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config};
use mas_jose::{DecodedJsonWebToken, JsonWebTokenParts, SharedSecret};
use oauth2_types::requests::ClientAuthenticationMethod;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -72,17 +72,14 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
#[derive(Error, Debug)]
enum ClientAuthenticationError {
#[error("no client secret found for client {client_id:?}")]
NoClientSecret { client_id: String },
#[error("wrong client secret for client {client_id:?}")]
ClientSecretMismatch { client_id: String },
#[error("could not find client {client_id:?}")]
ClientNotFound { client_id: String },
#[error("client secret required for client {client_id:?}")]
ClientSecretRequired { client_id: String },
#[error("wrong client authentication method for client {client_id:?}")]
WrongAuthenticationMethod { client_id: String },
#[error("wrong audience in client assertion: expected {expected:?}, got {got:?}")]
AudienceMismatch { expected: String, got: String },
@ -113,12 +110,11 @@ async fn authenticate_client<T>(
credentials: ClientCredentials,
body: T,
) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
let auth_type = credentials.authentication_type();
let client = match credentials {
let (auth_method, client) = match credentials {
ClientCredentials::Pair {
client_id,
client_secret,
..
via,
} => {
let client = clients
.iter()
@ -127,17 +123,49 @@ async fn authenticate_client<T>(
client_id: client_id.to_string(),
})?;
match (client_secret, client.client_secret.as_ref()) {
(None, None) => Ok(client),
(Some(ref given), Some(expected)) if given == expected => Ok(client),
(Some(_), Some(_)) => {
Err(ClientAuthenticationError::ClientSecretMismatch { client_id })
let auth_method = match (&client.client_auth_method, client_secret, via) {
(OAuth2ClientAuthMethodConfig::None, None, _) => ClientAuthenticationMethod::None,
(
OAuth2ClientAuthMethodConfig::ClientSecretBasic {
client_secret: ref expected_client_secret,
},
Some(ref given_client_secret),
CredentialsVia::AuthorizationHeader,
) => {
if expected_client_secret != given_client_secret {
return Err(
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
);
}
ClientAuthenticationMethod::ClientSecretBasic
}
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }),
(None, Some(_)) => {
Err(ClientAuthenticationError::ClientSecretRequired { client_id })
(
OAuth2ClientAuthMethodConfig::ClientSecretPost {
client_secret: ref expected_client_secret,
},
Some(ref given_client_secret),
CredentialsVia::FormBody,
) => {
if expected_client_secret != given_client_secret {
return Err(
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
);
}
ClientAuthenticationMethod::ClientSecretPost
}
}
_ => {
return Err(
ClientAuthenticationError::WrongAuthenticationMethod { client_id }.into(),
)
}
};
(auth_method, client)
}
ClientCredentials::Assertion {
client_id,
@ -150,43 +178,61 @@ async fn authenticate_client<T>(
// client_id might have been passed as parameter. If not, it should be inferred
// from the token, as per rfc7521 sec. 4.2
let client_id = client_id.unwrap_or_else(|| decoded.claims().subject.clone());
let client_id = client_id
.as_ref()
.unwrap_or_else(|| &decoded.claims().subject);
let client = clients
.iter()
.find(|client| client.client_id == client_id)
.find(|client| &client.client_id == client_id)
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
client_id: client_id.to_string(),
})?;
if let Some(client_secret) = &client.client_secret {
let store = SharedSecret::new(client_secret);
token.verify(&decoded, &store).await.wrap_error()?;
let claims = decoded.claims();
// TODO: validate the times again
// rfc7523 sec. 3.3: the audience is the URL being called
if claims.audience != audience {
Err(ClientAuthenticationError::AudienceMismatch {
expected: audience,
got: claims.audience.clone(),
})
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
// match the client_id
} else if claims.issuer != claims.subject || claims.issuer != client_id {
Err(ClientAuthenticationError::InvalidAssertion)
} else {
Ok(client)
let auth_method = match &client.client_auth_method {
OAuth2ClientAuthMethodConfig::PrivateKeyJwt(jwks) => {
let store = jwks.key_store();
token.verify(&decoded, &store).await.wrap_error()?;
ClientAuthenticationMethod::PrivateKeyJwt
}
} else {
Err(ClientAuthenticationError::ClientSecretRequired {
client_id: client_id.to_string(),
})
}
}
}?;
Ok((auth_type, client.clone(), body))
OAuth2ClientAuthMethodConfig::ClientSecretJwt { client_secret } => {
let store = SharedSecret::new(client_secret);
token.verify(&decoded, &store).await.wrap_error()?;
ClientAuthenticationMethod::ClientSecretJwt
}
_ => {
return Err(ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client_id.clone(),
}
.into())
}
};
let claims = decoded.claims();
// TODO: validate the times again
// rfc7523 sec. 3.3: the audience is the URL being called
if claims.audience != audience {
return Err(ClientAuthenticationError::AudienceMismatch {
expected: audience,
got: claims.audience.clone(),
}
.into());
}
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
// match the client_id
if claims.issuer != claims.subject || &claims.issuer != client_id {
return Err(ClientAuthenticationError::InvalidAssertion.into());
}
(auth_method, client)
}
};
Ok((auth_method, client.clone(), body))
}
#[derive(Deserialize)]
@ -225,28 +271,6 @@ enum ClientCredentials {
},
}
impl ClientCredentials {
fn authentication_type(&self) -> ClientAuthenticationMethod {
match self {
ClientCredentials::Pair {
via: CredentialsVia::FormBody,
client_secret: None,
..
} => ClientAuthenticationMethod::None,
ClientCredentials::Pair {
via: CredentialsVia::FormBody,
client_secret: Some(_),
..
} => ClientAuthenticationMethod::ClientSecretPost,
ClientCredentials::Pair {
via: CredentialsVia::AuthorizationHeader,
..
} => ClientAuthenticationMethod::ClientSecretBasic,
ClientCredentials::Assertion { .. } => ClientAuthenticationMethod::ClientSecretJwt,
}
}
}
#[derive(Deserialize)]
struct ClientAuthForm<T> {
#[serde(flatten)]
@ -259,7 +283,7 @@ struct ClientAuthForm<T> {
#[cfg(test)]
mod tests {
use headers::authorization::Credentials;
use mas_config::ConfigurationSection;
use mas_config::{ConfigurationSection, OAuth2ClientAuthMethodConfig};
use mas_jose::{JsonWebSignatureAlgorithm, SigningKeystore};
use serde_json::json;
@ -272,17 +296,21 @@ mod tests {
let mut config = OAuth2Config::test();
config.clients.push(OAuth2ClientConfig {
client_id: "public".to_string(),
client_secret: None,
client_auth_method: OAuth2ClientAuthMethodConfig::None,
redirect_uris: Vec::new(),
});
config.clients.push(OAuth2ClientConfig {
client_id: "confidential".to_string(),
client_secret: Some(CLIENT_SECRET.to_string()),
client_id: "secret-basic".to_string(),
client_auth_method: OAuth2ClientAuthMethodConfig::ClientSecretBasic {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config.clients.push(OAuth2ClientConfig {
client_id: "confidential-2".to_string(),
client_secret: Some(CLIENT_SECRET.to_string()),
client_id: "secret-post".to_string(),
client_auth_method: OAuth2ClientAuthMethodConfig::ClientSecretPost {
client_secret: CLIENT_SECRET.to_string(),
},
redirect_uris: Vec::new(),
});
config
@ -395,7 +423,7 @@ mod tests {
)
.body(
serde_urlencoded::to_string(json!({
"client_id": "confidential",
"client_id": "secret-post",
"client_secret": CLIENT_SECRET,
"foo": "baz",
"bar": "foobar",
@ -407,7 +435,7 @@ mod tests {
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost);
assert_eq!(client.client_id, "confidential");
assert_eq!(client.client_id, "secret-post");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
@ -419,7 +447,7 @@ mod tests {
"https://example.com/token".to_string(),
);
let auth = Authorization::basic("confidential", CLIENT_SECRET);
let auth = Authorization::basic("secret-basic", CLIENT_SECRET);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
@ -439,7 +467,7 @@ mod tests {
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic);
assert_eq!(client.client_id, "confidential");
assert_eq!(client.client_id, "secret-basic");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}