You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Support private_key_jwt client auth
Which includes having a verifying keystore out of JWKS (and soon out of a JWKS URI)
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1570,6 +1570,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64ct",
|
"base64ct",
|
||||||
|
"chrono",
|
||||||
"crypto-mac",
|
"crypto-mac",
|
||||||
"digest 0.10.1",
|
"digest 0.10.1",
|
||||||
"ecdsa",
|
"ecdsa",
|
||||||
@ -1580,6 +1581,7 @@ dependencies = [
|
|||||||
"pkcs8",
|
"pkcs8",
|
||||||
"rand",
|
"rand",
|
||||||
"rsa",
|
"rsa",
|
||||||
|
"schemars",
|
||||||
"sec1",
|
"sec1",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
@ -38,7 +38,7 @@ pub use self::{
|
|||||||
csrf::CsrfConfig,
|
csrf::CsrfConfig,
|
||||||
database::DatabaseConfig,
|
database::DatabaseConfig,
|
||||||
http::HttpConfig,
|
http::HttpConfig,
|
||||||
oauth2::{OAuth2ClientConfig, OAuth2Config},
|
oauth2::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config},
|
||||||
telemetry::{
|
telemetry::{
|
||||||
MetricsConfig, MetricsExporterConfig, Propagator, TelemetryConfig, TracingConfig,
|
MetricsConfig, MetricsExporterConfig, Propagator, TelemetryConfig, TracingConfig,
|
||||||
TracingExporterConfig,
|
TracingExporterConfig,
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use mas_jose::StaticKeystore;
|
use mas_jose::{JsonWebKeySet, StaticJwksStore, StaticKeystore};
|
||||||
use pkcs8::{DecodePrivateKey, EncodePrivateKey};
|
use pkcs8::{DecodePrivateKey, EncodePrivateKey};
|
||||||
use rsa::{
|
use rsa::{
|
||||||
pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey},
|
pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey},
|
||||||
@ -43,13 +43,41 @@ pub struct KeyConfig {
|
|||||||
key: String,
|
key: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum JwksOrJwksUri {
|
||||||
|
Jwks(JsonWebKeySet),
|
||||||
|
JwksUri(Url),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JwksOrJwksUri {
|
||||||
|
pub fn key_store(&self) -> StaticJwksStore {
|
||||||
|
let jwks = match self {
|
||||||
|
Self::Jwks(jwks) => jwks.clone(),
|
||||||
|
Self::JwksUri(_) => unimplemented!("jwks_uri are not implemented yet"),
|
||||||
|
};
|
||||||
|
|
||||||
|
StaticJwksStore::new(jwks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
|
||||||
|
#[serde(tag = "client_auth_method", rename_all = "snake_case")]
|
||||||
|
pub enum OAuth2ClientAuthMethodConfig {
|
||||||
|
None,
|
||||||
|
ClientSecretBasic { client_secret: String },
|
||||||
|
ClientSecretPost { client_secret: String },
|
||||||
|
ClientSecretJwt { client_secret: String },
|
||||||
|
PrivateKeyJwt(JwksOrJwksUri),
|
||||||
|
}
|
||||||
|
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||||
pub struct OAuth2ClientConfig {
|
pub struct OAuth2ClientConfig {
|
||||||
pub client_id: String,
|
pub client_id: String,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(flatten)]
|
||||||
pub client_secret: Option<String>,
|
pub client_auth_method: OAuth2ClientAuthMethodConfig,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub redirect_uris: Vec<Url>,
|
pub redirect_uris: Vec<Url>,
|
||||||
@ -246,25 +274,55 @@ mod tests {
|
|||||||
-----END PRIVATE KEY-----
|
-----END PRIVATE KEY-----
|
||||||
issuer: https://example.com
|
issuer: https://example.com
|
||||||
clients:
|
clients:
|
||||||
- client_id: hello
|
- client_id: public
|
||||||
|
client_auth_method: none
|
||||||
redirect_uris:
|
redirect_uris:
|
||||||
- https://exemple.fr/callback
|
- https://exemple.fr/callback
|
||||||
- client_id: world
|
|
||||||
|
- client_id: secret-basic
|
||||||
|
client_auth_method: client_secret_basic
|
||||||
|
client_secret: hello
|
||||||
|
|
||||||
|
- client_id: secret-post
|
||||||
|
client_auth_method: client_secret_post
|
||||||
|
client_secret: hello
|
||||||
|
|
||||||
|
- client_id: secret-jwk
|
||||||
|
client_auth_method: client_secret_jwt
|
||||||
|
client_secret: hello
|
||||||
|
|
||||||
|
- client_id: jwks
|
||||||
|
client_auth_method: private_key_jwt
|
||||||
|
jwks:
|
||||||
|
keys:
|
||||||
|
- kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
|
||||||
|
kty: "RSA"
|
||||||
|
alg: "RS256"
|
||||||
|
use: "sig"
|
||||||
|
e: "AQAB"
|
||||||
|
n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
|
||||||
|
|
||||||
|
- kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
|
||||||
|
kty: "RSA"
|
||||||
|
alg: "RS256"
|
||||||
|
use: "sig"
|
||||||
|
e: "AQAB"
|
||||||
|
n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
|
||||||
"#,
|
"#,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let config = OAuth2Config::load_from_file("config.yaml")?;
|
let config = OAuth2Config::load_from_file("config.yaml")?;
|
||||||
|
|
||||||
assert_eq!(config.issuer, "https://example.com".parse().unwrap());
|
assert_eq!(config.issuer, "https://example.com".parse().unwrap());
|
||||||
assert_eq!(config.clients.len(), 2);
|
assert_eq!(config.clients.len(), 5);
|
||||||
|
|
||||||
assert_eq!(config.clients[0].client_id, "hello");
|
assert_eq!(config.clients[0].client_id, "public");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.clients[0].redirect_uris,
|
config.clients[0].redirect_uris,
|
||||||
vec!["https://exemple.fr/callback".parse().unwrap()]
|
vec!["https://exemple.fr/callback".parse().unwrap()]
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(config.clients[1].client_id, "world");
|
assert_eq!(config.clients[1].client_id, "secret-basic");
|
||||||
assert_eq!(config.clients[1].redirect_uris, Vec::new());
|
assert_eq!(config.clients[1].redirect_uris, Vec::new());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use mas_jose::{ExportJwks, StaticKeystore};
|
use mas_jose::{ExportJwks, StaticKeystore};
|
||||||
|
use mas_warp_utils::errors::WrapError;
|
||||||
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
|
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
|
||||||
|
|
||||||
pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Reply>,)> {
|
pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Reply>,)> {
|
||||||
@ -25,7 +26,7 @@ pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn get(key_store: Arc<StaticKeystore>) -> Result<Box<dyn Reply>, Rejection> {
|
async fn get(key_store: Arc<StaticKeystore>) -> Result<Box<dyn Reply>, Rejection> {
|
||||||
let jwks = key_store.export_jwks().await;
|
let jwks = key_store.export_jwks().await.wrap_error()?;
|
||||||
|
|
||||||
Ok(Box::new(warp::reply::json(&jwks)))
|
Ok(Box::new(warp::reply::json(&jwks)))
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ license = "Apache-2.0"
|
|||||||
anyhow = "1.0.52"
|
anyhow = "1.0.52"
|
||||||
async-trait = "0.1.52"
|
async-trait = "0.1.52"
|
||||||
base64ct = { version = "1.0.1", features = ["std"] }
|
base64ct = { version = "1.0.1", features = ["std"] }
|
||||||
|
chrono = "0.4.19"
|
||||||
crypto-mac = { version = "0.11.1", features = ["std"] }
|
crypto-mac = { version = "0.11.1", features = ["std"] }
|
||||||
digest = "0.10.1"
|
digest = "0.10.1"
|
||||||
ecdsa = { version = "0.13.3", features = ["sign", "verify", "pem", "pkcs8"] }
|
ecdsa = { version = "0.13.3", features = ["sign", "verify", "pem", "pkcs8"] }
|
||||||
@ -19,6 +20,7 @@ pkcs1 = { version = "0.3.1", features = ["pem", "pkcs8"] }
|
|||||||
pkcs8 = { version = "0.8.0", features = ["pem"] }
|
pkcs8 = { version = "0.8.0", features = ["pem"] }
|
||||||
rand = "0.8.4"
|
rand = "0.8.4"
|
||||||
rsa = { git = "https://github.com/sandhose/RSA.git", branch = "bump-pkcs" }
|
rsa = { git = "https://github.com/sandhose/RSA.git", branch = "bump-pkcs" }
|
||||||
|
schemars = "0.8.8"
|
||||||
sec1 = "0.2.1"
|
sec1 = "0.2.1"
|
||||||
serde = { version = "1.0.133", features = ["derive"] }
|
serde = { version = "1.0.133", features = ["derive"] }
|
||||||
serde_json = "1.0.74"
|
serde_json = "1.0.74"
|
||||||
@ -26,6 +28,6 @@ serde_with = { version = "1.11.0", features = ["base64"] }
|
|||||||
sha2 = "0.10.0"
|
sha2 = "0.10.0"
|
||||||
signature = "1.4.0"
|
signature = "1.4.0"
|
||||||
thiserror = "1.0.30"
|
thiserror = "1.0.30"
|
||||||
tokio = { version = "1.15.0", features = ["macros", "rt"] }
|
tokio = { version = "1.15.0", features = ["macros", "rt", "sync"] }
|
||||||
url = { version = "2.2.2", features = ["serde"] }
|
url = { version = "2.2.2", features = ["serde"] }
|
||||||
zeroize = "1.4.3"
|
zeroize = "1.4.3"
|
||||||
|
25
crates/jose/src/claims.rs
Normal file
25
crates/jose/src/claims.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
trait ClaimSet {
|
||||||
|
fn validate(&self) -> anyhow::Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct UnvalidatedClaim<T>(T);
|
||||||
|
|
||||||
|
impl<T> ClaimSet for UnvalidatedClaim<T> {
|
||||||
|
fn validate(&self) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -16,9 +16,10 @@
|
|||||||
//!
|
//!
|
||||||
//! <https://www.iana.org/assignments/jose/jose.xhtml>
|
//! <https://www.iana.org/assignments/jose/jose.xhtml>
|
||||||
|
|
||||||
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
pub enum JsonWebSignatureAlgorithm {
|
pub enum JsonWebSignatureAlgorithm {
|
||||||
/// HMAC using SHA-256
|
/// HMAC using SHA-256
|
||||||
#[serde(rename = "HS256")]
|
#[serde(rename = "HS256")]
|
||||||
@ -157,7 +158,7 @@ pub enum JsonWebSignatureAlgorithm {
|
|||||||
Es256K,
|
Es256K,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
pub enum JsonWebEncryptionAlgorithm {
|
pub enum JsonWebEncryptionAlgorithm {
|
||||||
/// AES_128_CBC_HMAC_SHA_256 authenticated encryption algorithm
|
/// AES_128_CBC_HMAC_SHA_256 authenticated encryption algorithm
|
||||||
#[serde(rename = "A128CBC-HS256")]
|
#[serde(rename = "A128CBC-HS256")]
|
||||||
@ -184,14 +185,14 @@ pub enum JsonWebEncryptionAlgorithm {
|
|||||||
A256Gcm,
|
A256Gcm,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
pub enum JsonWebEncryptionCompressionAlgorithm {
|
pub enum JsonWebEncryptionCompressionAlgorithm {
|
||||||
/// DEFLATE
|
/// DEFLATE
|
||||||
#[serde(rename = "DEF")]
|
#[serde(rename = "DEF")]
|
||||||
Def,
|
Def,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
pub enum JsonWebKeyType {
|
pub enum JsonWebKeyType {
|
||||||
/// Elliptic Curve
|
/// Elliptic Curve
|
||||||
#[serde(rename = "EC")]
|
#[serde(rename = "EC")]
|
||||||
@ -210,7 +211,7 @@ pub enum JsonWebKeyType {
|
|||||||
Okp,
|
Okp,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
pub enum JsonWebKeyEcEllipticCurve {
|
pub enum JsonWebKeyEcEllipticCurve {
|
||||||
/// P-256 Curve
|
/// P-256 Curve
|
||||||
#[serde(rename = "P-256")]
|
#[serde(rename = "P-256")]
|
||||||
@ -229,7 +230,7 @@ pub enum JsonWebKeyEcEllipticCurve {
|
|||||||
Secp256K1,
|
Secp256K1,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
pub enum JsonWebKeyOkpEllipticCurve {
|
pub enum JsonWebKeyOkpEllipticCurve {
|
||||||
/// Ed25519 signature algorithm key pairs
|
/// Ed25519 signature algorithm key pairs
|
||||||
#[serde(rename = "Ed25519")]
|
#[serde(rename = "Ed25519")]
|
||||||
@ -248,7 +249,7 @@ pub enum JsonWebKeyOkpEllipticCurve {
|
|||||||
X448,
|
X448,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
pub enum JsonWebKeyUse {
|
pub enum JsonWebKeyUse {
|
||||||
/// Digital Signature or MAC
|
/// Digital Signature or MAC
|
||||||
#[serde(rename = "sig")]
|
#[serde(rename = "sig")]
|
||||||
@ -259,7 +260,7 @@ pub enum JsonWebKeyUse {
|
|||||||
Enc,
|
Enc,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
pub enum JsonWebKeyOperation {
|
pub enum JsonWebKeyOperation {
|
||||||
/// Compute digital signature or MAC
|
/// Compute digital signature or MAC
|
||||||
#[serde(rename = "sign")]
|
#[serde(rename = "sign")]
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
use p256::NistP256;
|
use p256::NistP256;
|
||||||
use rsa::{BigUint, PublicKeyParts};
|
use rsa::{BigUint, PublicKeyParts};
|
||||||
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_with::{
|
use serde_with::{
|
||||||
base64::{Base64, Standard, UrlSafe},
|
base64::{Base64, Standard, UrlSafe},
|
||||||
@ -25,14 +26,17 @@ use serde_with::{
|
|||||||
};
|
};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use crate::iana::{
|
use crate::{
|
||||||
|
iana::{
|
||||||
JsonWebKeyEcEllipticCurve, JsonWebKeyOkpEllipticCurve, JsonWebKeyOperation, JsonWebKeyUse,
|
JsonWebKeyEcEllipticCurve, JsonWebKeyOkpEllipticCurve, JsonWebKeyOperation, JsonWebKeyUse,
|
||||||
JsonWebSignatureAlgorithm,
|
JsonWebSignatureAlgorithm,
|
||||||
|
},
|
||||||
|
JsonWebKeyType,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[serde_as]
|
#[serde_as]
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||||
pub struct JsonWebKey {
|
pub struct JsonWebKey {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
parameters: JsonWebKeyParameters,
|
parameters: JsonWebKeyParameters,
|
||||||
@ -49,17 +53,21 @@ pub struct JsonWebKey {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
kid: Option<String>,
|
kid: Option<String>,
|
||||||
|
|
||||||
|
#[schemars(with = "Option<String>")]
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
x5u: Option<Url>,
|
x5u: Option<Url>,
|
||||||
|
|
||||||
|
#[schemars(with = "Vec<String>")]
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[serde_as(as = "Option<Vec<Base64<Standard, Padded>>>")]
|
#[serde_as(as = "Option<Vec<Base64<Standard, Padded>>>")]
|
||||||
x5c: Option<Vec<Vec<u8>>>,
|
x5c: Option<Vec<Vec<u8>>>,
|
||||||
|
|
||||||
|
#[schemars(with = "Option<String>")]
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[serde_as(as = "Option<Base64<UrlSafe, Unpadded>>")]
|
#[serde_as(as = "Option<Base64<UrlSafe, Unpadded>>")]
|
||||||
x5t: Option<Vec<u8>>,
|
x5t: Option<Vec<u8>>,
|
||||||
|
|
||||||
|
#[schemars(with = "Option<String>")]
|
||||||
#[serde(default, rename = "x5t#S256")]
|
#[serde(default, rename = "x5t#S256")]
|
||||||
#[serde_as(as = "Option<Base64<UrlSafe, Unpadded>>")]
|
#[serde_as(as = "Option<Base64<UrlSafe, Unpadded>>")]
|
||||||
x5t_s256: Option<Vec<u8>>,
|
x5t_s256: Option<Vec<u8>>,
|
||||||
@ -104,13 +112,40 @@ impl JsonWebKey {
|
|||||||
self.kid = Some(kid.into());
|
self.kid = Some(kid.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn kty(&self) -> JsonWebKeyType {
|
||||||
|
match self.parameters {
|
||||||
|
JsonWebKeyParameters::Ec { .. } => JsonWebKeyType::Ec,
|
||||||
|
JsonWebKeyParameters::Rsa { .. } => JsonWebKeyType::Rsa,
|
||||||
|
JsonWebKeyParameters::Okp { .. } => JsonWebKeyType::Okp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn kid(&self) -> Option<&str> {
|
||||||
|
self.kid.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn params(&self) -> &JsonWebKeyParameters {
|
||||||
|
&self.parameters
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||||
pub struct JsonWebKeySet {
|
pub struct JsonWebKeySet {
|
||||||
keys: Vec<JsonWebKey>,
|
keys: Vec<JsonWebKey>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for JsonWebKeySet {
|
||||||
|
type Target = Vec<JsonWebKey>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.keys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl JsonWebKeySet {
|
impl JsonWebKeySet {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(keys: Vec<JsonWebKey>) -> Self {
|
pub fn new(keys: Vec<JsonWebKey>) -> Self {
|
||||||
@ -119,27 +154,36 @@ impl JsonWebKeySet {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[serde_as]
|
#[serde_as]
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||||
#[serde(tag = "kty")]
|
#[serde(tag = "kty")]
|
||||||
pub enum JsonWebKeyParameters {
|
pub enum JsonWebKeyParameters {
|
||||||
#[serde(rename = "RSA")]
|
#[serde(rename = "RSA")]
|
||||||
Rsa {
|
Rsa {
|
||||||
|
#[schemars(with = "String")]
|
||||||
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
||||||
n: Vec<u8>,
|
n: Vec<u8>,
|
||||||
|
|
||||||
|
#[schemars(with = "String")]
|
||||||
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
||||||
e: Vec<u8>,
|
e: Vec<u8>,
|
||||||
},
|
},
|
||||||
#[serde(rename = "EC")]
|
#[serde(rename = "EC")]
|
||||||
Ec {
|
Ec {
|
||||||
crv: JsonWebKeyEcEllipticCurve,
|
crv: JsonWebKeyEcEllipticCurve,
|
||||||
|
|
||||||
|
#[schemars(with = "String")]
|
||||||
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
||||||
x: Vec<u8>,
|
x: Vec<u8>,
|
||||||
|
|
||||||
|
#[schemars(with = "String")]
|
||||||
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
||||||
y: Vec<u8>,
|
y: Vec<u8>,
|
||||||
},
|
},
|
||||||
#[serde(rename = "OKP")]
|
#[serde(rename = "OKP")]
|
||||||
Okp {
|
Okp {
|
||||||
crv: JsonWebKeyOkpEllipticCurve,
|
crv: JsonWebKeyOkpEllipticCurve,
|
||||||
|
|
||||||
|
#[schemars(with = "String")]
|
||||||
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
#[serde_as(as = "Base64<UrlSafe, Unpadded>")]
|
||||||
x: Vec<u8>,
|
x: Vec<u8>,
|
||||||
},
|
},
|
||||||
|
278
crates/jose/src/keystore/jwks.rs
Normal file
278
crates/jose/src/keystore/jwks.rs
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use anyhow::bail;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use chrono::{DateTime, Duration, Utc};
|
||||||
|
use digest::Digest;
|
||||||
|
use rsa::{PublicKey, RsaPublicKey};
|
||||||
|
use sha2::{Sha256, Sha384, Sha512};
|
||||||
|
use signature::{Signature, Verifier};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
ExportJwks, JsonWebKeySet, JsonWebKeyType, JsonWebSignatureAlgorithm, JwtHeader,
|
||||||
|
VerifyingKeystore,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct StaticJwksStore {
|
||||||
|
key_set: JsonWebKeySet,
|
||||||
|
index: HashMap<(JsonWebKeyType, String), usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StaticJwksStore {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(key_set: JsonWebKeySet) -> Self {
|
||||||
|
let index = key_set
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(index, key)| {
|
||||||
|
let kid = key.kid()?.to_string();
|
||||||
|
let kty = key.kty();
|
||||||
|
|
||||||
|
Some(((kty, kid), index))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self { key_set, index }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_rsa_key(&self, kid: String) -> anyhow::Result<RsaPublicKey> {
|
||||||
|
let index = *self
|
||||||
|
.index
|
||||||
|
.get(&(JsonWebKeyType::Rsa, kid))
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("key not found"))?;
|
||||||
|
|
||||||
|
let key = self
|
||||||
|
.key_set
|
||||||
|
.get(index)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("invalid index"))?;
|
||||||
|
|
||||||
|
let key = key.params().clone().try_into()?;
|
||||||
|
|
||||||
|
Ok(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_ecdsa_key(&self, kid: String) -> anyhow::Result<ecdsa::VerifyingKey<p256::NistP256>> {
|
||||||
|
let index = *self
|
||||||
|
.index
|
||||||
|
.get(&(JsonWebKeyType::Ec, kid))
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("key not found"))?;
|
||||||
|
|
||||||
|
let key = self
|
||||||
|
.key_set
|
||||||
|
.get(index)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("invalid index"))?;
|
||||||
|
|
||||||
|
let key = key.params().clone().try_into()?;
|
||||||
|
|
||||||
|
Ok(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl VerifyingKeystore for &StaticJwksStore {
|
||||||
|
async fn verify(
|
||||||
|
self,
|
||||||
|
header: &JwtHeader,
|
||||||
|
payload: &[u8],
|
||||||
|
signature: &[u8],
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let kid = header
|
||||||
|
.kid()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("missing kid"))?
|
||||||
|
.to_string();
|
||||||
|
match header.alg() {
|
||||||
|
JsonWebSignatureAlgorithm::Rs256 => {
|
||||||
|
let key = self.find_rsa_key(kid)?;
|
||||||
|
|
||||||
|
let digest = {
|
||||||
|
let mut digest = Sha256::new();
|
||||||
|
digest.update(&payload);
|
||||||
|
digest.finalize()
|
||||||
|
};
|
||||||
|
|
||||||
|
key.verify(
|
||||||
|
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
|
||||||
|
&digest,
|
||||||
|
signature,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
JsonWebSignatureAlgorithm::Rs384 => {
|
||||||
|
let key = self.find_rsa_key(kid)?;
|
||||||
|
|
||||||
|
let digest = {
|
||||||
|
let mut digest = Sha384::new();
|
||||||
|
digest.update(&payload);
|
||||||
|
digest.finalize()
|
||||||
|
};
|
||||||
|
|
||||||
|
key.verify(
|
||||||
|
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)),
|
||||||
|
&digest,
|
||||||
|
signature,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
JsonWebSignatureAlgorithm::Rs512 => {
|
||||||
|
let key = self.find_rsa_key(kid)?;
|
||||||
|
|
||||||
|
let digest = {
|
||||||
|
let mut digest = Sha512::new();
|
||||||
|
digest.update(&payload);
|
||||||
|
digest.finalize()
|
||||||
|
};
|
||||||
|
|
||||||
|
key.verify(
|
||||||
|
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)),
|
||||||
|
&digest,
|
||||||
|
signature,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
JsonWebSignatureAlgorithm::Es256 => {
|
||||||
|
let key = self.find_ecdsa_key(kid)?;
|
||||||
|
|
||||||
|
let signature = ecdsa::Signature::from_bytes(signature)?;
|
||||||
|
|
||||||
|
key.verify(payload, &signature)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => bail!("unsupported algorithm"),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum RemoteKeySet {
|
||||||
|
Pending,
|
||||||
|
Errored {
|
||||||
|
at: DateTime<Utc>,
|
||||||
|
error: anyhow::Error,
|
||||||
|
},
|
||||||
|
Fulfilled {
|
||||||
|
at: DateTime<Utc>,
|
||||||
|
store: StaticJwksStore,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RemoteKeySet {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteKeySet {
|
||||||
|
fn fullfill(&mut self, key_set: JsonWebKeySet) {
|
||||||
|
*self = Self::Fulfilled {
|
||||||
|
at: Utc::now(),
|
||||||
|
store: StaticJwksStore::new(key_set),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error(&mut self, error: anyhow::Error) {
|
||||||
|
*self = Self::Errored {
|
||||||
|
at: Utc::now(),
|
||||||
|
error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_refresh(&self) -> bool {
|
||||||
|
let now = Utc::now();
|
||||||
|
match self {
|
||||||
|
Self::Pending => true,
|
||||||
|
Self::Errored { at, .. } if *at - now > Duration::minutes(5) => true,
|
||||||
|
Self::Fulfilled { at, .. } if *at - now > Duration::hours(1) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_force_refresh(&self) -> bool {
|
||||||
|
let now = Utc::now();
|
||||||
|
match self {
|
||||||
|
Self::Pending => true,
|
||||||
|
Self::Errored { at, .. } | Self::Fulfilled { at, .. }
|
||||||
|
if *at - now > Duration::minutes(5) =>
|
||||||
|
{
|
||||||
|
true
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct JwksStore<T>
|
||||||
|
where
|
||||||
|
T: ExportJwks,
|
||||||
|
{
|
||||||
|
exporter: T,
|
||||||
|
cache: RwLock<RemoteKeySet>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: ExportJwks> JwksStore<T> {
|
||||||
|
pub fn new(exporter: T) -> Self {
|
||||||
|
Self {
|
||||||
|
exporter,
|
||||||
|
cache: RwLock::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn should_refresh(&self) -> bool {
|
||||||
|
let cache = self.cache.read().await;
|
||||||
|
cache.should_refresh()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn refresh(&self) {
|
||||||
|
let mut cache = self.cache.write().await;
|
||||||
|
|
||||||
|
if cache.should_force_refresh() {
|
||||||
|
let jwks = self.exporter.export_jwks().await;
|
||||||
|
|
||||||
|
match jwks {
|
||||||
|
Ok(jwks) => cache.fullfill(jwks),
|
||||||
|
Err(err) => cache.error(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T: ExportJwks + Send + Sync> VerifyingKeystore for &JwksStore<T> {
|
||||||
|
async fn verify(
|
||||||
|
self,
|
||||||
|
header: &JwtHeader,
|
||||||
|
payload: &[u8],
|
||||||
|
signature: &[u8],
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
if self.should_refresh().await {
|
||||||
|
self.refresh().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cache = self.cache.read().await;
|
||||||
|
// TODO: we could bubble up the underlying error here
|
||||||
|
let store = match &*cache {
|
||||||
|
RemoteKeySet::Pending => bail!("inconsistent cache state"),
|
||||||
|
RemoteKeySet::Errored { error, .. } => bail!("cache in error state {}", error),
|
||||||
|
RemoteKeySet::Fulfilled { store, .. } => store,
|
||||||
|
};
|
||||||
|
|
||||||
|
store.verify(header, payload, signature).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -12,11 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
mod jwks;
|
||||||
mod shared_secret;
|
mod shared_secret;
|
||||||
mod static_keystore;
|
mod static_keystore;
|
||||||
mod traits;
|
mod traits;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
|
jwks::{JwksStore, StaticJwksStore},
|
||||||
shared_secret::SharedSecret,
|
shared_secret::SharedSecret,
|
||||||
static_keystore::StaticKeystore,
|
static_keystore::StaticKeystore,
|
||||||
traits::{ExportJwks, SigningKeystore, VerifyingKeystore},
|
traits::{ExportJwks, SigningKeystore, VerifyingKeystore},
|
||||||
|
@ -27,10 +27,7 @@ use sha2::{Sha256, Sha384, Sha512};
|
|||||||
use signature::{Signature, Signer, Verifier};
|
use signature::{Signature, Signer, Verifier};
|
||||||
|
|
||||||
use super::{ExportJwks, SigningKeystore, VerifyingKeystore};
|
use super::{ExportJwks, SigningKeystore, VerifyingKeystore};
|
||||||
use crate::{
|
use crate::{iana::JsonWebSignatureAlgorithm, JsonWebKey, JsonWebKeySet, JwtHeader};
|
||||||
iana::{JsonWebKeyOperation, JsonWebSignatureAlgorithm},
|
|
||||||
JsonWebKey, JsonWebKeySet, JwtHeader,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct StaticKeystore {
|
pub struct StaticKeystore {
|
||||||
@ -276,23 +273,13 @@ impl VerifyingKeystore for &StaticKeystore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ExportJwks for &StaticKeystore {
|
impl ExportJwks for StaticKeystore {
|
||||||
async fn export_jwks(self) -> JsonWebKeySet {
|
async fn export_jwks(&self) -> anyhow::Result<JsonWebKeySet> {
|
||||||
let rsa = self.rsa_keys.iter().flat_map(|(kid, key)| {
|
let rsa = self.rsa_keys.iter().map(|(kid, key)| {
|
||||||
let pubkey = RsaPublicKey::from(key);
|
let pubkey = RsaPublicKey::from(key);
|
||||||
let basekey = JsonWebKey::new(pubkey.into())
|
JsonWebKey::new(pubkey.into())
|
||||||
.with_kid(kid)
|
.with_kid(kid)
|
||||||
.with_use(crate::JsonWebKeyUse::Sig)
|
.with_use(crate::JsonWebKeyUse::Sig)
|
||||||
.with_key_ops(vec![JsonWebKeyOperation::Sign]);
|
|
||||||
|
|
||||||
let algs = [
|
|
||||||
JsonWebSignatureAlgorithm::Rs256,
|
|
||||||
JsonWebSignatureAlgorithm::Rs384,
|
|
||||||
JsonWebSignatureAlgorithm::Rs512,
|
|
||||||
];
|
|
||||||
|
|
||||||
algs.into_iter()
|
|
||||||
.map(move |alg| basekey.clone().with_alg(alg))
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let es256 = self.es256_keys.iter().map(|(kid, key)| {
|
let es256 = self.es256_keys.iter().map(|(kid, key)| {
|
||||||
@ -300,12 +287,11 @@ impl ExportJwks for &StaticKeystore {
|
|||||||
JsonWebKey::new(pubkey.into())
|
JsonWebKey::new(pubkey.into())
|
||||||
.with_kid(kid)
|
.with_kid(kid)
|
||||||
.with_use(crate::JsonWebKeyUse::Sig)
|
.with_use(crate::JsonWebKeyUse::Sig)
|
||||||
.with_key_ops(vec![JsonWebKeyOperation::Sign])
|
|
||||||
.with_alg(JsonWebSignatureAlgorithm::Es256)
|
.with_alg(JsonWebSignatureAlgorithm::Es256)
|
||||||
});
|
});
|
||||||
|
|
||||||
let keys = rsa.chain(es256).collect();
|
let keys = rsa.chain(es256).collect();
|
||||||
JsonWebKeySet::new(keys)
|
Ok(JsonWebKeySet::new(keys))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,9 +14,7 @@
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use crate::{
|
use crate::{iana::JsonWebSignatureAlgorithm, JsonWebKeySet, JwtHeader};
|
||||||
iana::JsonWebSignatureAlgorithm, JsonWebKeySet, JwtHeader,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait SigningKeystore {
|
pub trait SigningKeystore {
|
||||||
@ -32,6 +30,5 @@ pub trait VerifyingKeystore {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ExportJwks {
|
pub trait ExportJwks {
|
||||||
async fn export_jwks(self) -> JsonWebKeySet;
|
async fn export_jwks(&self) -> anyhow::Result<JsonWebKeySet>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
#![allow(clippy::missing_errors_doc)]
|
#![allow(clippy::missing_errors_doc)]
|
||||||
#![allow(clippy::module_name_repetitions)]
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
|
mod claims;
|
||||||
pub(crate) mod iana;
|
pub(crate) mod iana;
|
||||||
pub(crate) mod jwk;
|
pub(crate) mod jwk;
|
||||||
pub(crate) mod jwt;
|
pub(crate) mod jwt;
|
||||||
@ -31,5 +32,8 @@ pub use self::{
|
|||||||
},
|
},
|
||||||
jwk::{JsonWebKey, JsonWebKeySet},
|
jwk::{JsonWebKey, JsonWebKeySet},
|
||||||
jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader},
|
jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader},
|
||||||
keystore::{ExportJwks, SharedSecret, SigningKeystore, StaticKeystore, VerifyingKeystore},
|
keystore::{
|
||||||
|
ExportJwks, JwksStore, SharedSecret, SigningKeystore, StaticJwksStore, StaticKeystore,
|
||||||
|
VerifyingKeystore,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
//! Handle client authentication
|
//! Handle client authentication
|
||||||
|
|
||||||
use headers::{authorization::Basic, Authorization};
|
use headers::{authorization::Basic, Authorization};
|
||||||
use mas_config::{OAuth2ClientConfig, OAuth2Config};
|
use mas_config::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config};
|
||||||
use mas_jose::{DecodedJsonWebToken, JsonWebTokenParts, SharedSecret};
|
use mas_jose::{DecodedJsonWebToken, JsonWebTokenParts, SharedSecret};
|
||||||
use oauth2_types::requests::ClientAuthenticationMethod;
|
use oauth2_types::requests::ClientAuthenticationMethod;
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
@ -72,17 +72,14 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
|||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
enum ClientAuthenticationError {
|
enum ClientAuthenticationError {
|
||||||
#[error("no client secret found for client {client_id:?}")]
|
|
||||||
NoClientSecret { client_id: String },
|
|
||||||
|
|
||||||
#[error("wrong client secret for client {client_id:?}")]
|
#[error("wrong client secret for client {client_id:?}")]
|
||||||
ClientSecretMismatch { client_id: String },
|
ClientSecretMismatch { client_id: String },
|
||||||
|
|
||||||
#[error("could not find client {client_id:?}")]
|
#[error("could not find client {client_id:?}")]
|
||||||
ClientNotFound { client_id: String },
|
ClientNotFound { client_id: String },
|
||||||
|
|
||||||
#[error("client secret required for client {client_id:?}")]
|
#[error("wrong client authentication method for client {client_id:?}")]
|
||||||
ClientSecretRequired { client_id: String },
|
WrongAuthenticationMethod { client_id: String },
|
||||||
|
|
||||||
#[error("wrong audience in client assertion: expected {expected:?}, got {got:?}")]
|
#[error("wrong audience in client assertion: expected {expected:?}, got {got:?}")]
|
||||||
AudienceMismatch { expected: String, got: String },
|
AudienceMismatch { expected: String, got: String },
|
||||||
@ -113,12 +110,11 @@ async fn authenticate_client<T>(
|
|||||||
credentials: ClientCredentials,
|
credentials: ClientCredentials,
|
||||||
body: T,
|
body: T,
|
||||||
) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
|
) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
|
||||||
let auth_type = credentials.authentication_type();
|
let (auth_method, client) = match credentials {
|
||||||
let client = match credentials {
|
|
||||||
ClientCredentials::Pair {
|
ClientCredentials::Pair {
|
||||||
client_id,
|
client_id,
|
||||||
client_secret,
|
client_secret,
|
||||||
..
|
via,
|
||||||
} => {
|
} => {
|
||||||
let client = clients
|
let client = clients
|
||||||
.iter()
|
.iter()
|
||||||
@ -127,17 +123,49 @@ async fn authenticate_client<T>(
|
|||||||
client_id: client_id.to_string(),
|
client_id: client_id.to_string(),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
match (client_secret, client.client_secret.as_ref()) {
|
let auth_method = match (&client.client_auth_method, client_secret, via) {
|
||||||
(None, None) => Ok(client),
|
(OAuth2ClientAuthMethodConfig::None, None, _) => ClientAuthenticationMethod::None,
|
||||||
(Some(ref given), Some(expected)) if given == expected => Ok(client),
|
|
||||||
(Some(_), Some(_)) => {
|
(
|
||||||
Err(ClientAuthenticationError::ClientSecretMismatch { client_id })
|
OAuth2ClientAuthMethodConfig::ClientSecretBasic {
|
||||||
|
client_secret: ref expected_client_secret,
|
||||||
|
},
|
||||||
|
Some(ref given_client_secret),
|
||||||
|
CredentialsVia::AuthorizationHeader,
|
||||||
|
) => {
|
||||||
|
if expected_client_secret != given_client_secret {
|
||||||
|
return Err(
|
||||||
|
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }),
|
|
||||||
(None, Some(_)) => {
|
ClientAuthenticationMethod::ClientSecretBasic
|
||||||
Err(ClientAuthenticationError::ClientSecretRequired { client_id })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
OAuth2ClientAuthMethodConfig::ClientSecretPost {
|
||||||
|
client_secret: ref expected_client_secret,
|
||||||
|
},
|
||||||
|
Some(ref given_client_secret),
|
||||||
|
CredentialsVia::FormBody,
|
||||||
|
) => {
|
||||||
|
if expected_client_secret != given_client_secret {
|
||||||
|
return Err(
|
||||||
|
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ClientAuthenticationMethod::ClientSecretPost
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
return Err(
|
||||||
|
ClientAuthenticationError::WrongAuthenticationMethod { client_id }.into(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
(auth_method, client)
|
||||||
}
|
}
|
||||||
ClientCredentials::Assertion {
|
ClientCredentials::Assertion {
|
||||||
client_id,
|
client_id,
|
||||||
@ -150,43 +178,61 @@ async fn authenticate_client<T>(
|
|||||||
|
|
||||||
// client_id might have been passed as parameter. If not, it should be inferred
|
// client_id might have been passed as parameter. If not, it should be inferred
|
||||||
// from the token, as per rfc7521 sec. 4.2
|
// from the token, as per rfc7521 sec. 4.2
|
||||||
let client_id = client_id.unwrap_or_else(|| decoded.claims().subject.clone());
|
let client_id = client_id
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or_else(|| &decoded.claims().subject);
|
||||||
|
|
||||||
let client = clients
|
let client = clients
|
||||||
.iter()
|
.iter()
|
||||||
.find(|client| client.client_id == client_id)
|
.find(|client| &client.client_id == client_id)
|
||||||
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
||||||
client_id: client_id.to_string(),
|
client_id: client_id.to_string(),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if let Some(client_secret) = &client.client_secret {
|
let auth_method = match &client.client_auth_method {
|
||||||
|
OAuth2ClientAuthMethodConfig::PrivateKeyJwt(jwks) => {
|
||||||
|
let store = jwks.key_store();
|
||||||
|
token.verify(&decoded, &store).await.wrap_error()?;
|
||||||
|
ClientAuthenticationMethod::PrivateKeyJwt
|
||||||
|
}
|
||||||
|
|
||||||
|
OAuth2ClientAuthMethodConfig::ClientSecretJwt { client_secret } => {
|
||||||
let store = SharedSecret::new(client_secret);
|
let store = SharedSecret::new(client_secret);
|
||||||
token.verify(&decoded, &store).await.wrap_error()?;
|
token.verify(&decoded, &store).await.wrap_error()?;
|
||||||
|
ClientAuthenticationMethod::ClientSecretJwt
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
return Err(ClientAuthenticationError::WrongAuthenticationMethod {
|
||||||
|
client_id: client_id.clone(),
|
||||||
|
}
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let claims = decoded.claims();
|
let claims = decoded.claims();
|
||||||
// TODO: validate the times again
|
// TODO: validate the times again
|
||||||
|
|
||||||
// rfc7523 sec. 3.3: the audience is the URL being called
|
// rfc7523 sec. 3.3: the audience is the URL being called
|
||||||
if claims.audience != audience {
|
if claims.audience != audience {
|
||||||
Err(ClientAuthenticationError::AudienceMismatch {
|
return Err(ClientAuthenticationError::AudienceMismatch {
|
||||||
expected: audience,
|
expected: audience,
|
||||||
got: claims.audience.clone(),
|
got: claims.audience.clone(),
|
||||||
})
|
}
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
|
||||||
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
|
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
|
||||||
// match the client_id
|
// match the client_id
|
||||||
} else if claims.issuer != claims.subject || claims.issuer != client_id {
|
if claims.issuer != claims.subject || &claims.issuer != client_id {
|
||||||
Err(ClientAuthenticationError::InvalidAssertion)
|
return Err(ClientAuthenticationError::InvalidAssertion.into());
|
||||||
} else {
|
|
||||||
Ok(client)
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
Err(ClientAuthenticationError::ClientSecretRequired {
|
|
||||||
client_id: client_id.to_string(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}?;
|
|
||||||
|
|
||||||
Ok((auth_type, client.clone(), body))
|
(auth_method, client)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((auth_method, client.clone(), body))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -225,28 +271,6 @@ enum ClientCredentials {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientCredentials {
|
|
||||||
fn authentication_type(&self) -> ClientAuthenticationMethod {
|
|
||||||
match self {
|
|
||||||
ClientCredentials::Pair {
|
|
||||||
via: CredentialsVia::FormBody,
|
|
||||||
client_secret: None,
|
|
||||||
..
|
|
||||||
} => ClientAuthenticationMethod::None,
|
|
||||||
ClientCredentials::Pair {
|
|
||||||
via: CredentialsVia::FormBody,
|
|
||||||
client_secret: Some(_),
|
|
||||||
..
|
|
||||||
} => ClientAuthenticationMethod::ClientSecretPost,
|
|
||||||
ClientCredentials::Pair {
|
|
||||||
via: CredentialsVia::AuthorizationHeader,
|
|
||||||
..
|
|
||||||
} => ClientAuthenticationMethod::ClientSecretBasic,
|
|
||||||
ClientCredentials::Assertion { .. } => ClientAuthenticationMethod::ClientSecretJwt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct ClientAuthForm<T> {
|
struct ClientAuthForm<T> {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
@ -259,7 +283,7 @@ struct ClientAuthForm<T> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use headers::authorization::Credentials;
|
use headers::authorization::Credentials;
|
||||||
use mas_config::ConfigurationSection;
|
use mas_config::{ConfigurationSection, OAuth2ClientAuthMethodConfig};
|
||||||
use mas_jose::{JsonWebSignatureAlgorithm, SigningKeystore};
|
use mas_jose::{JsonWebSignatureAlgorithm, SigningKeystore};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
@ -272,17 +296,21 @@ mod tests {
|
|||||||
let mut config = OAuth2Config::test();
|
let mut config = OAuth2Config::test();
|
||||||
config.clients.push(OAuth2ClientConfig {
|
config.clients.push(OAuth2ClientConfig {
|
||||||
client_id: "public".to_string(),
|
client_id: "public".to_string(),
|
||||||
client_secret: None,
|
client_auth_method: OAuth2ClientAuthMethodConfig::None,
|
||||||
redirect_uris: Vec::new(),
|
redirect_uris: Vec::new(),
|
||||||
});
|
});
|
||||||
config.clients.push(OAuth2ClientConfig {
|
config.clients.push(OAuth2ClientConfig {
|
||||||
client_id: "confidential".to_string(),
|
client_id: "secret-basic".to_string(),
|
||||||
client_secret: Some(CLIENT_SECRET.to_string()),
|
client_auth_method: OAuth2ClientAuthMethodConfig::ClientSecretBasic {
|
||||||
|
client_secret: CLIENT_SECRET.to_string(),
|
||||||
|
},
|
||||||
redirect_uris: Vec::new(),
|
redirect_uris: Vec::new(),
|
||||||
});
|
});
|
||||||
config.clients.push(OAuth2ClientConfig {
|
config.clients.push(OAuth2ClientConfig {
|
||||||
client_id: "confidential-2".to_string(),
|
client_id: "secret-post".to_string(),
|
||||||
client_secret: Some(CLIENT_SECRET.to_string()),
|
client_auth_method: OAuth2ClientAuthMethodConfig::ClientSecretPost {
|
||||||
|
client_secret: CLIENT_SECRET.to_string(),
|
||||||
|
},
|
||||||
redirect_uris: Vec::new(),
|
redirect_uris: Vec::new(),
|
||||||
});
|
});
|
||||||
config
|
config
|
||||||
@ -395,7 +423,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.body(
|
.body(
|
||||||
serde_urlencoded::to_string(json!({
|
serde_urlencoded::to_string(json!({
|
||||||
"client_id": "confidential",
|
"client_id": "secret-post",
|
||||||
"client_secret": CLIENT_SECRET,
|
"client_secret": CLIENT_SECRET,
|
||||||
"foo": "baz",
|
"foo": "baz",
|
||||||
"bar": "foobar",
|
"bar": "foobar",
|
||||||
@ -407,7 +435,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost);
|
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost);
|
||||||
assert_eq!(client.client_id, "confidential");
|
assert_eq!(client.client_id, "secret-post");
|
||||||
assert_eq!(body.foo, "baz");
|
assert_eq!(body.foo, "baz");
|
||||||
assert_eq!(body.bar, "foobar");
|
assert_eq!(body.bar, "foobar");
|
||||||
}
|
}
|
||||||
@ -419,7 +447,7 @@ mod tests {
|
|||||||
"https://example.com/token".to_string(),
|
"https://example.com/token".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let auth = Authorization::basic("confidential", CLIENT_SECRET);
|
let auth = Authorization::basic("secret-basic", CLIENT_SECRET);
|
||||||
let (auth, client, body) = warp::test::request()
|
let (auth, client, body) = warp::test::request()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.header(
|
.header(
|
||||||
@ -439,7 +467,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic);
|
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic);
|
||||||
assert_eq!(client.client_id, "confidential");
|
assert_eq!(client.client_id, "secret-basic");
|
||||||
assert_eq!(body.foo, "baz");
|
assert_eq!(body.foo, "baz");
|
||||||
assert_eq!(body.bar, "foobar");
|
assert_eq!(body.bar, "foobar");
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user