diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index dedf81cf..94abc9f2 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -30,10 +30,7 @@ use mas_config::Encrypter; use mas_data_model::{Client, JwksOrJwksUri, StorageBackend}; use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; -use mas_jose::{ - jwk::PublicJsonWebKeySet, DecodedJsonWebToken, DynamicJwksStore, Either, - JsonWebSignatureHeader, JsonWebTokenParts, SharedSecret, StaticJwksStore, VerifyingKeystore, -}; +use mas_jose::{jwk::PublicJsonWebKeySet, Jwt}; use mas_storage::{ oauth2::client::{lookup_client_by_client_id, ClientFetchError}, PostgresqlBackend, @@ -72,9 +69,7 @@ pub enum Credentials { }, ClientAssertionJwtBearer { client_id: String, - jwt: JsonWebTokenParts, - header: Box, - claims: HashMap, + jwt: Box>>, }, } @@ -128,7 +123,7 @@ impl Credentials { } ( - Credentials::ClientAssertionJwtBearer { jwt, header, .. }, + Credentials::ClientAssertionJwtBearer { jwt, .. }, OAuthClientAuthenticationMethod::PrivateKeyJwt, ) => { // Get the client JWKS @@ -137,14 +132,16 @@ impl Credentials { .as_ref() .ok_or(CredentialsVerificationError::InvalidClientConfig)?; - let store: Either = jwks_key_store(jwks); - let fut = jwt.verify(header, &store); - fut.await + let jwks = fetch_jwks(jwks) + .await + .map_err(|_| CredentialsVerificationError::JwksFetchFailed)?; + + jwt.verify_from_jwks(&jwks) .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?; } ( - Credentials::ClientAssertionJwtBearer { jwt, header, .. }, + Credentials::ClientAssertionJwtBearer { jwt, .. }, OAuthClientAuthenticationMethod::ClientSecretJwt, ) => { // Decrypt the client_secret @@ -157,9 +154,7 @@ impl Credentials { .decrypt_string(encrypted_client_secret) .map_err(|_e| CredentialsVerificationError::DecryptionError)?; - let store = SharedSecret::new(&decrypted_client_secret); - let fut = jwt.verify(header, &store); - fut.await + jwt.verify_from_shared_secret(decrypted_client_secret) .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?; } @@ -171,38 +166,25 @@ impl Credentials { } } -fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either { - // Assert that the output is both a VerifyingKeystore and Send - fn assert(t: T) -> T { - t - } - - let inner = match jwks { - JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())), - JwksOrJwksUri::JwksUri(uri) => { - let uri = uri.clone(); - - // TODO: get the client from somewhere else? - let exporter = mas_http::client("fetch-jwks") - .response_body_to_bytes() - .json_response::() - .map_request(move |_: ()| { - http::Request::builder() - .method("GET") - // TODO: change the Uri type in config to avoid reparsing here - .uri(uri.to_string()) - .body(http_body::Empty::new()) - .unwrap() - }) - .map_response(http::Response::into_body) - .map_err(BoxError::from) - .boxed_clone(); - - Either::Right(DynamicJwksStore::new(exporter)) - } +async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result { + let uri = match jwks { + JwksOrJwksUri::Jwks(j) => return Ok(j.clone()), + JwksOrJwksUri::JwksUri(u) => u, }; - assert(inner) + let request = http::Request::builder() + .uri(uri.as_str()) + .body(http_body::Empty::new()) + .unwrap(); + + let client = mas_http::client("fetch-jwks") + .response_body_to_bytes() + .json_response::() + .map_err(Box::new); + + let response = client.oneshot(request).await?; + + Ok(response.into_body()) } #[derive(Debug, Error)] @@ -221,6 +203,9 @@ pub enum CredentialsVerificationError { #[error("invalid assertion signature")] InvalidAssertionSignature, + + #[error("failed to fetch jwks")] + JwksFetchFailed, } #[derive(Debug, PartialEq, Eq)] @@ -344,16 +329,10 @@ where Some(client_assertion), ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => { // Got a JWT bearer client_assertion - - let jwt: JsonWebTokenParts = client_assertion - .parse() + let jwt: Jwt<'static, HashMap> = Jwt::try_from(client_assertion) .map_err(|_| ClientAuthorizationError::InvalidAssertion)?; - let decoded: DecodedJsonWebToken> = jwt - .decode() - .map_err(|_| ClientAuthorizationError::InvalidAssertion)?; - let (header, claims) = decoded.split(); - let client_id = if let Some(Value::String(client_id)) = claims.get("sub") { + let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") { client_id.clone() } else { return Err(ClientAuthorizationError::InvalidAssertion); @@ -371,9 +350,7 @@ where Credentials::ClientAssertionJwtBearer { client_id, - jwt, - header: Box::new(header), - claims, + jwt: Box::new(jwt), } } @@ -585,19 +562,15 @@ mod tests { .unwrap(); assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"}))); - let (client_id, _jwt, _header, _claims) = if let Credentials::ClientAssertionJwtBearer { - client_id, - jwt, - header, - claims, - } = authz.credentials - { - (client_id, jwt, header, claims) - } else { - panic!("expected a JWT client_assertion"); - }; + let (client_id, jwt) = + if let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials { + (client_id, jwt) + } else { + panic!("expected a JWT client_assertion"); + }; assert_eq!(client_id, "client-id"); - // TODO: test more things + jwt.verify_from_shared_secret(b"client-secret".to_vec()) + .unwrap(); } } diff --git a/crates/handlers/src/oauth2/keys.rs b/crates/handlers/src/oauth2/keys.rs index b9d50cea..a883eddc 100644 --- a/crates/handlers/src/oauth2/keys.rs +++ b/crates/handlers/src/oauth2/keys.rs @@ -16,12 +16,10 @@ use std::{convert::Infallible, sync::Arc}; use axum::{extract::Extension, response::IntoResponse, Json}; use mas_jose::StaticKeystore; -use tower::{Service, ServiceExt}; pub(crate) async fn get( Extension(key_store): Extension>, ) -> Result { - let mut key_store: &StaticKeystore = key_store.as_ref(); - let jwks = key_store.ready().await?.call(()).await?; + let jwks = key_store.to_public_jwks(); Ok(Json(jwks)) } diff --git a/crates/jose/src/jwk/mod.rs b/crates/jose/src/jwk/mod.rs index dfb59b95..bdcc89c3 100644 --- a/crates/jose/src/jwk/mod.rs +++ b/crates/jose/src/jwk/mod.rs @@ -260,7 +260,8 @@ mod tests { let jwks: PublicJsonWebKeySet = serde_json::from_value(jwks).unwrap(); // Both keys are RSA public keys for jwk in &jwks.keys { - rsa::RsaPublicKey::try_from(jwk.parameters.clone()).unwrap(); + let p = jwk.params().rsa().expect("an RSA key"); + rsa::RsaPublicKey::try_from(p).unwrap(); } let constraints = ConstraintSet::default() @@ -394,12 +395,23 @@ mod tests { let jwks: PublicJsonWebKeySet = serde_json::from_value(jwks).unwrap(); // The first 6 keys are RSA, 7th is P-256 let mut keys = jwks.keys.into_iter(); - rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap(); - rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap(); - rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap(); - rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap(); - rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap(); - rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap(); - ecdsa::VerifyingKey::try_from(keys.next().unwrap().parameters).unwrap(); + rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap(); + rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap(); + rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap(); + rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap(); + rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap(); + rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap(); + // 7th is P-256 + ecdsa::VerifyingKey::::try_from( + keys.next().unwrap().params().ec().unwrap(), + ) + .unwrap(); + // 8th is P-384 + ecdsa::VerifyingKey::::try_from( + keys.next().unwrap().params().ec().unwrap(), + ) + .unwrap(); + // 8th is P-521, but we don't support it yet + keys.next().unwrap().params().ec().unwrap(); } } diff --git a/crates/jose/src/jwk/private_parameters.rs b/crates/jose/src/jwk/private_parameters.rs index 2f182abe..47408e75 100644 --- a/crates/jose/src/jwk/private_parameters.rs +++ b/crates/jose/src/jwk/private_parameters.rs @@ -43,6 +43,40 @@ pub enum JsonWebKeyPrivateParameters { Okp(OkpPrivateParameters), } +impl JsonWebKeyPrivateParameters { + #[must_use] + pub const fn oct(&self) -> Option<&OctPrivateParameters> { + match self { + Self::Oct(params) => Some(params), + _ => None, + } + } + + #[must_use] + pub const fn rsa(&self) -> Option<&RsaPrivateParameters> { + match self { + Self::Rsa(params) => Some(params), + _ => None, + } + } + + #[must_use] + pub const fn ec(&self) -> Option<&EcPrivateParameters> { + match self { + Self::Ec(params) => Some(params), + _ => None, + } + } + + #[must_use] + pub const fn okp(&self) -> Option<&OkpPrivateParameters> { + match self { + Self::Okp(params) => Some(params), + _ => None, + } + } +} + impl ParametersInfo for JsonWebKeyPrivateParameters { fn kty(&self) -> JsonWebKeyType { match self { diff --git a/crates/jose/src/jwk/public_parameters.rs b/crates/jose/src/jwk/public_parameters.rs index 2771ca4d..ab3a875a 100644 --- a/crates/jose/src/jwk/public_parameters.rs +++ b/crates/jose/src/jwk/public_parameters.rs @@ -39,6 +39,32 @@ pub enum JsonWebKeyPublicParameters { Okp(OkpPublicParameters), } +impl JsonWebKeyPublicParameters { + #[must_use] + pub const fn rsa(&self) -> Option<&RsaPublicParameters> { + match self { + Self::Rsa(params) => Some(params), + _ => None, + } + } + + #[must_use] + pub const fn ec(&self) -> Option<&EcPublicParameters> { + match self { + Self::Ec(params) => Some(params), + _ => None, + } + } + + #[must_use] + pub const fn okp(&self) -> Option<&OkpPublicParameters> { + match self { + Self::Okp(params) => Some(params), + _ => None, + } + } +} + impl ParametersInfo for JsonWebKeyPublicParameters { fn kty(&self) -> JsonWebKeyType { match self { @@ -163,11 +189,38 @@ impl OkpPublicParameters { mod rsa_impls { use digest::DynDigest; - use rsa::{BigUint, RsaPublicKey}; + use rsa::{BigUint, PublicKeyParts, RsaPublicKey}; - use super::RsaPublicParameters; + use super::{JsonWebKeyPublicParameters, RsaPublicParameters}; use crate::jwa::rsa::RsaHashIdentifier; + impl From for JsonWebKeyPublicParameters { + fn from(key: RsaPublicKey) -> Self { + Self::from(&key) + } + } + + impl From<&RsaPublicKey> for JsonWebKeyPublicParameters { + fn from(key: &RsaPublicKey) -> Self { + Self::Rsa(key.into()) + } + } + + impl From for RsaPublicParameters { + fn from(key: RsaPublicKey) -> Self { + Self::from(&key) + } + } + + impl From<&RsaPublicKey> for RsaPublicParameters { + fn from(key: &RsaPublicKey) -> Self { + Self { + n: key.n().to_bytes_be(), + e: key.e().to_bytes_be(), + } + } + } + impl TryFrom for crate::jwa::rsa::pkcs1v15::VerifyingKey where H: RsaHashIdentifier, @@ -236,7 +289,56 @@ mod ec_impls { AffinePoint, Curve, FieldBytes, FieldSize, ProjectiveArithmetic, PublicKey, }; - use super::{super::JwkEcCurve, EcPublicParameters}; + use super::{super::JwkEcCurve, EcPublicParameters, JsonWebKeyPublicParameters}; + + impl From> for JsonWebKeyPublicParameters + where + C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldSize: ModulusSize, + { + fn from(key: ecdsa::VerifyingKey) -> Self { + Self::from(&key) + } + } + + impl From<&ecdsa::VerifyingKey> for JsonWebKeyPublicParameters + where + C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldSize: ModulusSize, + { + fn from(key: &ecdsa::VerifyingKey) -> Self { + Self::Ec(key.into()) + } + } + + impl From> for EcPublicParameters + where + C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldSize: ModulusSize, + { + fn from(key: ecdsa::VerifyingKey) -> Self { + Self::from(&key) + } + } + + impl From<&ecdsa::VerifyingKey> for EcPublicParameters + where + C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldSize: ModulusSize, + { + fn from(key: &ecdsa::VerifyingKey) -> Self { + let points = key.to_encoded_point(false); + EcPublicParameters { + x: points.x().unwrap().to_vec(), + y: points.y().unwrap().to_vec(), + crv: C::CRV, + } + } + } impl TryFrom for VerifyingKey where @@ -311,71 +413,3 @@ mod ec_impls { } } } - -/// Some legacy implementations to remove -mod legacy { - use anyhow::bail; - use mas_iana::jose::JsonWebKeyEcEllipticCurve; - use p256::NistP256; - use rsa::{BigUint, PublicKeyParts}; - - use super::{EcPublicParameters, JsonWebKeyPublicParameters, RsaPublicParameters}; - - impl TryFrom for ecdsa::VerifyingKey { - type Error = anyhow::Error; - - fn try_from(params: JsonWebKeyPublicParameters) -> Result { - let (x, y): ([u8; 32], [u8; 32]) = match params { - JsonWebKeyPublicParameters::Ec(EcPublicParameters { - x, - y, - crv: JsonWebKeyEcEllipticCurve::P256, - }) => ( - x.try_into() - .map_err(|_| anyhow::anyhow!("invalid curve parameter x"))?, - y.try_into() - .map_err(|_| anyhow::anyhow!("invalid curve parameter y"))?, - ), - _ => bail!("Wrong curve"), - }; - - let point = sec1::EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false); - let key = ecdsa::VerifyingKey::from_encoded_point(&point)?; - Ok(key) - } - } - - impl From> for JsonWebKeyPublicParameters { - fn from(key: ecdsa::VerifyingKey) -> Self { - let points = key.to_encoded_point(false); - JsonWebKeyPublicParameters::Ec(EcPublicParameters { - x: points.x().unwrap().to_vec(), - y: points.y().unwrap().to_vec(), - crv: JsonWebKeyEcEllipticCurve::P256, - }) - } - } - - impl TryFrom for rsa::RsaPublicKey { - type Error = anyhow::Error; - - fn try_from(params: JsonWebKeyPublicParameters) -> Result { - let (n, e) = match ¶ms { - JsonWebKeyPublicParameters::Rsa(RsaPublicParameters { n, e }) => (n, e), - _ => bail!("Wrong key type"), - }; - let n = BigUint::from_bytes_be(n); - let e = BigUint::from_bytes_be(e); - Ok(rsa::RsaPublicKey::new(n, e)?) - } - } - - impl From for JsonWebKeyPublicParameters { - fn from(key: rsa::RsaPublicKey) -> Self { - JsonWebKeyPublicParameters::Rsa(RsaPublicParameters { - n: key.n().to_bytes_be(), - e: key.e().to_bytes_be(), - }) - } - } -} diff --git a/crates/jose/src/jwt/mod.rs b/crates/jose/src/jwt/mod.rs index 3d2459cc..7d4abbaf 100644 --- a/crates/jose/src/jwt/mod.rs +++ b/crates/jose/src/jwt/mod.rs @@ -190,32 +190,3 @@ impl JsonWebTokenParts { format!("{}.{}", payload, signature) } } - -#[cfg(test)] -mod tests { - use mas_iana::jose::JsonWebSignatureAlg; - - use super::*; - use crate::SharedSecret; - - #[tokio::test] - async fn decode_hs256() { - let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"; - let jwt: JsonWebTokenParts = jwt.parse().unwrap(); - let secret = "your-256-bit-secret"; - let store = SharedSecret::new(&secret); - let jwt: DecodedJsonWebToken = - jwt.decode_and_verify(&store).await.unwrap(); - - assert_eq!(jwt.header.typ(), Some("JWT")); - assert_eq!(jwt.header.alg(), JsonWebSignatureAlg::Hs256); - assert_eq!( - jwt.payload, - serde_json::json!({ - "sub": "1234567890", - "name": "John Doe", - "iat": 1_516_239_022 - }) - ); - } -} diff --git a/crates/jose/src/jwt/raw.rs b/crates/jose/src/jwt/raw.rs index 1c29a56e..45289a64 100644 --- a/crates/jose/src/jwt/raw.rs +++ b/crates/jose/src/jwt/raw.rs @@ -16,6 +16,7 @@ use std::{borrow::Cow, ops::Deref}; use thiserror::Error; +#[derive(Clone, PartialEq, Eq)] pub struct RawJwt<'a> { inner: Cow<'a, str>, first_dot: usize, @@ -54,6 +55,14 @@ impl<'a> RawJwt<'a> { pub fn signed_part(&'a self) -> &'a str { &self.inner[..self.second_dot] } + + pub fn into_owned(self) -> RawJwt<'static> { + RawJwt { + inner: self.inner.into_owned().into(), + first_dot: self.first_dot, + second_dot: self.second_dot, + } + } } impl<'a> Deref for RawJwt<'a> { @@ -97,3 +106,25 @@ impl<'a> TryFrom<&'a str> for RawJwt<'a> { }) } } + +impl TryFrom for RawJwt<'static> { + type Error = DecodeError; + fn try_from(value: String) -> Result { + let mut indices = value + .char_indices() + .filter_map(|(idx, c)| (c == '.').then(|| idx)); + + let first_dot = indices.next().ok_or(DecodeError::NoDots)?; + let second_dot = indices.next().ok_or(DecodeError::OnlyOneDot)?; + + if indices.next().is_some() { + return Err(DecodeError::TooManyDots); + } + + Ok(Self { + inner: value.into(), + first_dot, + second_dot, + }) + } +} diff --git a/crates/jose/src/jwt/signed.rs b/crates/jose/src/jwt/signed.rs index 97b50a6d..f9fc788a 100644 --- a/crates/jose/src/jwt/signed.rs +++ b/crates/jose/src/jwt/signed.rs @@ -18,7 +18,9 @@ use signature::{Signature, Signer, Verifier}; use thiserror::Error; use super::{header::JsonWebSignatureHeader, raw::RawJwt}; +use crate::{constraints::ConstraintSet, jwk::PublicJsonWebKeySet}; +#[derive(Clone, PartialEq, Eq)] pub struct Jwt<'a, T> { raw: RawJwt<'a>, header: JsonWebSignatureHeader, @@ -107,14 +109,12 @@ impl JwtDecodeError { } } -impl<'a, T> TryFrom<&'a str> for Jwt<'a, T> +impl<'a, T> TryFrom> for Jwt<'a, T> where T: DeserializeOwned, { type Error = JwtDecodeError; - fn try_from(value: &'a str) -> Result { - let raw = RawJwt::try_from(value)?; - + fn try_from(raw: RawJwt<'a>) -> Result { let header_reader = base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.header().as_bytes()) .map_err(JwtDecodeError::decode_header)?; @@ -139,6 +139,28 @@ where } } +impl<'a, T> TryFrom<&'a str> for Jwt<'a, T> +where + T: DeserializeOwned, +{ + type Error = JwtDecodeError; + fn try_from(value: &'a str) -> Result { + let raw = RawJwt::try_from(value)?; + Self::try_from(raw) + } +} + +impl TryFrom for Jwt<'static, T> +where + T: DeserializeOwned, +{ + type Error = JwtDecodeError; + fn try_from(value: String) -> Result { + let raw = RawJwt::try_from(value)?; + Self::try_from(raw) + } +} + #[derive(Debug, Error)] pub enum JwtVerificationError { #[error("failed to parse signature")] @@ -164,6 +186,12 @@ impl JwtVerificationError { } } +#[derive(Debug, Error, Default)] +#[error("none of the keys worked")] +pub struct NoKeyWorked { + _inner: (), +} + impl<'a, T> Jwt<'a, T> { pub fn header(&self) -> &JsonWebSignatureHeader { &self.header @@ -173,6 +201,15 @@ impl<'a, T> Jwt<'a, T> { &self.payload } + pub fn into_owned(self) -> Jwt<'static, T> { + Jwt { + raw: self.raw.into_owned(), + header: self.header, + payload: self.payload, + signature: self.signature, + } + } + pub fn verify(&self, key: &K) -> Result<(), JwtVerificationError> where K: Verifier, @@ -185,6 +222,37 @@ impl<'a, T> Jwt<'a, T> { .map_err(JwtVerificationError::verify) } + pub fn verify_from_shared_secret(&self, secret: Vec) -> Result<(), NoKeyWorked> { + let verifier = crate::verifier::Verifier::for_oct_and_alg(secret, self.header().alg()) + .map_err(|_| NoKeyWorked::default())?; + + self.verify(&verifier).map_err(|_| NoKeyWorked::default())?; + + Ok(()) + } + + pub fn verify_from_jwks(&self, jwks: &PublicJsonWebKeySet) -> Result<(), NoKeyWorked> { + let constraints = ConstraintSet::from(self.header()); + let candidates = constraints.filter(&**jwks); + + for candidate in candidates { + let verifier = match crate::verifier::Verifier::for_jwk_and_alg( + candidate.params(), + self.header().alg(), + ) { + Ok(v) => v, + Err(_) => continue, + }; + + match self.verify(&verifier) { + Ok(_) => return Ok(()), + Err(_) => continue, + } + } + + Err(NoKeyWorked::default()) + } + pub fn as_str(&'a self) -> &'a str { &self.raw } diff --git a/crates/jose/src/keystore/jwks/dynamic_store.rs b/crates/jose/src/keystore/jwks/dynamic_store.rs deleted file mode 100644 index 1448c583..00000000 --- a/crates/jose/src/keystore/jwks/dynamic_store.rs +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use chrono::{DateTime, Duration, Utc}; -use futures_util::future::BoxFuture; -use thiserror::Error; -use tokio::sync::RwLock; -use tower::{ - util::{BoxCloneService, ServiceExt}, - BoxError, Service, -}; - -use super::StaticJwksStore; -use crate::{jwk::PublicJsonWebKeySet, JsonWebSignatureHeader, VerifyingKeystore}; - -#[derive(Debug, Error)] -pub enum Error { - #[error("cache in inconsistent state")] - InconsistentCache, - - #[error(transparent)] - Cached(Arc), - - #[error("todo")] - Todo, - - #[error(transparent)] - Verification(#[from] super::static_store::Error), -} - -enum State { - Pending, - Errored { - at: DateTime, - error: E, - }, - Fulfilled { - at: DateTime, - store: StaticJwksStore, - }, -} - -impl Default for State { - fn default() -> Self { - Self::Pending - } -} - -impl State { - fn fullfill(&mut self, key_set: PublicJsonWebKeySet) { - *self = Self::Fulfilled { - at: Utc::now(), - store: StaticJwksStore::new(key_set), - } - } - - fn error(&mut self, error: E) { - *self = Self::Errored { - at: Utc::now(), - error, - } - } - - fn should_refresh(&self) -> bool { - let now = Utc::now(); - match self { - Self::Pending => true, - Self::Errored { at, .. } if *at - now > Duration::minutes(5) => true, - Self::Fulfilled { at, .. } if *at - now > Duration::hours(1) => true, - _ => false, - } - } - - fn should_force_refresh(&self) -> bool { - let now = Utc::now(); - match self { - Self::Pending => true, - Self::Errored { at, .. } | Self::Fulfilled { at, .. } - if *at - now > Duration::minutes(5) => - { - true - } - _ => false, - } - } -} - -#[derive(Clone)] -pub struct DynamicJwksStore { - exporter: BoxCloneService<(), PublicJsonWebKeySet, BoxError>, - cache: Arc>>>, -} - -impl DynamicJwksStore { - pub fn new(exporter: T) -> Self - where - T: Service<(), Response = PublicJsonWebKeySet, Error = BoxError> + Send + Clone + 'static, - T::Future: Send, - { - Self { - exporter: exporter.boxed_clone(), - cache: Arc::default(), - } - } -} - -impl VerifyingKeystore for DynamicJwksStore { - type Error = Error; - type Future = BoxFuture<'static, Result<(), Self::Error>>; - - fn verify( - &self, - header: &JsonWebSignatureHeader, - payload: &[u8], - signature: &[u8], - ) -> Self::Future { - let cache = self.cache.clone(); - let exporter = self.exporter.clone(); - let header = header.clone(); - let payload = payload.to_owned(); - let signature = signature.to_owned(); - - let fut = async move { - if cache.read().await.should_refresh() { - let mut cache = cache.write().await; - - if cache.should_force_refresh() { - let jwks = async move { exporter.ready_oneshot().await?.call(()).await }.await; - - match jwks { - Ok(jwks) => cache.fullfill(jwks), - Err(err) => cache.error(Arc::new(err)), - } - } - } - - let cache = cache.read().await; - // TODO: we could bubble up the underlying error here - let store = match &*cache { - State::Pending => return Err(Error::InconsistentCache), - State::Errored { error, .. } => return Err(Error::Cached(error.clone())), - State::Fulfilled { store, .. } => store, - }; - - store.verify(&header, &payload, &signature).await?; - - Ok(()) - }; - - Box::pin(fut) - } -} diff --git a/crates/jose/src/keystore/jwks/mod.rs b/crates/jose/src/keystore/jwks/mod.rs deleted file mode 100644 index 93849d86..00000000 --- a/crates/jose/src/keystore/jwks/mod.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -mod dynamic_store; -mod static_store; - -pub use self::{dynamic_store::DynamicJwksStore, static_store::StaticJwksStore}; diff --git a/crates/jose/src/keystore/jwks/static_store.rs b/crates/jose/src/keystore/jwks/static_store.rs deleted file mode 100644 index 607cbe36..00000000 --- a/crates/jose/src/keystore/jwks/static_store.rs +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::future::Ready; - -use digest::Digest; -use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg}; -use rsa::{PublicKey, RsaPublicKey}; -use sha2::{Sha256, Sha384, Sha512}; -use signature::{Signature, Verifier}; -use thiserror::Error; - -use crate::{ - constraints::Constrainable, - jwk::{PublicJsonWebKey, PublicJsonWebKeySet}, - JsonWebSignatureHeader, VerifyingKeystore, -}; - -#[derive(Debug, Error)] -pub enum Error { - #[error("key not found")] - KeyNotFound, - - #[error("multiple key matched")] - MultipleKeyMatched, - - #[error(r#"missing "kid" field in header"#)] - MissingKid, - - #[error(transparent)] - Rsa(#[from] rsa::errors::Error), - - #[error("unsupported algorithm {alg}")] - UnsupportedAlgorithm { alg: JsonWebSignatureAlg }, - - #[error(transparent)] - Signature(#[from] signature::Error), - - #[error("invalid {kty} key")] - InvalidKey { - kty: JsonWebKeyType, - source: anyhow::Error, - }, -} - -struct KeyConstraint<'a> { - kty: Option, - alg: Option, - kid: Option<&'a str>, -} - -impl<'a> KeyConstraint<'a> { - fn matches(&self, key: &'a PublicJsonWebKey) -> bool { - // If a specific KID was asked, match the key only if it has a matching kid - // field - if let Some(kid) = self.kid { - if key.kid() != Some(kid) { - return false; - } - } - - if let Some(kty) = self.kty { - if key.kty() != kty { - return false; - } - } - - if let Some(alg) = self.alg { - if key.alg() != None && key.alg() != Some(alg) { - return false; - } - } - - true - } - - fn find_keys(&self, key_set: &'a PublicJsonWebKeySet) -> Vec<&'a PublicJsonWebKey> { - key_set.iter().filter(|k| self.matches(k)).collect() - } -} - -pub struct StaticJwksStore { - key_set: PublicJsonWebKeySet, -} - -impl StaticJwksStore { - #[must_use] - pub fn new(key_set: PublicJsonWebKeySet) -> Self { - Self { key_set } - } - - fn find_key<'a>( - &'a self, - constraint: &KeyConstraint<'a>, - ) -> Result<&'a PublicJsonWebKey, Error> { - let keys = constraint.find_keys(&self.key_set); - - match &keys[..] { - [one] => Ok(one), - [] => Err(Error::KeyNotFound), - _ => Err(Error::MultipleKeyMatched), - } - } - - fn find_rsa_key(&self, kid: Option<&str>) -> Result { - let constraint = KeyConstraint { - kty: Some(JsonWebKeyType::Rsa), - kid, - alg: None, - }; - - let key = self.find_key(&constraint)?; - - let key = key - .params() - .clone() - .try_into() - .map_err(|source| Error::InvalidKey { - kty: JsonWebKeyType::Rsa, - source, - })?; - - Ok(key) - } - - fn find_ecdsa_key( - &self, - kid: Option<&str>, - ) -> Result, Error> { - let constraint = KeyConstraint { - kty: Some(JsonWebKeyType::Ec), - kid, - alg: None, - }; - - let key = self.find_key(&constraint)?; - - let key = key - .params() - .clone() - .try_into() - .map_err(|source| Error::InvalidKey { - kty: JsonWebKeyType::Ec, - source, - })?; - - Ok(key) - } - - #[tracing::instrument(skip(self))] - fn verify_sync( - &self, - header: &JsonWebSignatureHeader, - payload: &[u8], - signature: &[u8], - ) -> Result<(), Error> { - let kid = header.kid(); - match header.alg() { - JsonWebSignatureAlg::Rs256 => { - let key = self.find_rsa_key(kid)?; - - let digest = { - let mut digest = Sha256::new(); - digest.update(&payload); - digest.finalize() - }; - - key.verify( - rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)), - &digest, - signature, - )?; - } - - JsonWebSignatureAlg::Rs384 => { - let key = self.find_rsa_key(kid)?; - - let digest = { - let mut digest = Sha384::new(); - digest.update(&payload); - digest.finalize() - }; - - key.verify( - rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)), - &digest, - signature, - )?; - } - - JsonWebSignatureAlg::Rs512 => { - let key = self.find_rsa_key(kid)?; - - let digest = { - let mut digest = Sha512::new(); - digest.update(&payload); - digest.finalize() - }; - - key.verify( - rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)), - &digest, - signature, - )?; - } - - JsonWebSignatureAlg::Es256 => { - let key = self.find_ecdsa_key(kid)?; - - let signature = ecdsa::Signature::from_bytes(signature)?; - - key.verify(payload, &signature)?; - } - - alg => return Err(Error::UnsupportedAlgorithm { alg }), - }; - - Ok(()) - } -} - -impl VerifyingKeystore for StaticJwksStore { - type Error = Error; - type Future = Ready>; - - fn verify( - &self, - header: &JsonWebSignatureHeader, - payload: &[u8], - signature: &[u8], - ) -> Self::Future { - std::future::ready(self.verify_sync(header, payload, signature)) - } -} diff --git a/crates/jose/src/keystore/mod.rs b/crates/jose/src/keystore/mod.rs index d030af50..21a8c326 100644 --- a/crates/jose/src/keystore/mod.rs +++ b/crates/jose/src/keystore/mod.rs @@ -12,14 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod jwks; -mod shared_secret; mod static_keystore; mod traits; pub use self::{ - jwks::{DynamicJwksStore, StaticJwksStore}, - shared_secret::SharedSecret, static_keystore::StaticKeystore, traits::{SigningKeystore, VerifyingKeystore}, }; diff --git a/crates/jose/src/keystore/shared_secret.rs b/crates/jose/src/keystore/shared_secret.rs deleted file mode 100644 index 8b0023ef..00000000 --- a/crates/jose/src/keystore/shared_secret.rs +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{collections::HashSet, future::Ready}; - -use anyhow::bail; -use async_trait::async_trait; -use digest::{InvalidLength, MacError}; -use hmac::{Hmac, Mac}; -use mas_iana::jose::JsonWebSignatureAlg; -use sha2::{Sha256, Sha384, Sha512}; -use thiserror::Error; - -use super::{SigningKeystore, VerifyingKeystore}; -use crate::JsonWebSignatureHeader; - -#[derive(Debug, Error)] -pub enum Error { - #[error("invalid key")] - InvalidKey(#[from] InvalidLength), - - #[error("unsupported algorithm {alg}")] - UnsupportedAlgorithm { alg: JsonWebSignatureAlg }, - - #[error("signature verification failed")] - Verification(#[from] MacError), -} - -pub struct SharedSecret<'a> { - inner: &'a [u8], -} - -impl<'a> SharedSecret<'a> { - pub fn new(source: &'a impl AsRef<[u8]>) -> Self { - Self { - inner: source.as_ref(), - } - } - - fn verify_sync( - &self, - header: &JsonWebSignatureHeader, - payload: &[u8], - signature: &[u8], - ) -> Result<(), Error> { - match header.alg() { - JsonWebSignatureAlg::Hs256 => { - let mut mac = Hmac::::new_from_slice(self.inner)?; - mac.update(payload); - mac.verify(signature.into())?; - } - - JsonWebSignatureAlg::Hs384 => { - let mut mac = Hmac::::new_from_slice(self.inner)?; - mac.update(payload); - mac.verify(signature.into())?; - } - - JsonWebSignatureAlg::Hs512 => { - let mut mac = Hmac::::new_from_slice(self.inner)?; - mac.update(payload); - mac.verify(signature.into())?; - } - - alg => return Err(Error::UnsupportedAlgorithm { alg }), - }; - - Ok(()) - } -} - -#[async_trait] -impl<'a> SigningKeystore for SharedSecret<'a> { - fn supported_algorithms(&self) -> HashSet { - let mut algorithms = HashSet::with_capacity(3); - - algorithms.insert(JsonWebSignatureAlg::Hs256); - algorithms.insert(JsonWebSignatureAlg::Hs384); - algorithms.insert(JsonWebSignatureAlg::Hs512); - - algorithms - } - - async fn prepare_header( - &self, - alg: JsonWebSignatureAlg, - ) -> anyhow::Result { - if !matches!( - alg, - JsonWebSignatureAlg::Hs256 | JsonWebSignatureAlg::Hs384 | JsonWebSignatureAlg::Hs512, - ) { - bail!("unsupported algorithm") - } - - Ok(JsonWebSignatureHeader::new(alg)) - } - - async fn sign(&self, header: &JsonWebSignatureHeader, msg: &[u8]) -> anyhow::Result> { - // TODO: do the signing in a blocking task - // TODO: should we bail out if the key is too small? - let signature = match header.alg() { - JsonWebSignatureAlg::Hs256 => { - let mut mac = Hmac::::new_from_slice(self.inner)?; - mac.update(msg); - mac.finalize().into_bytes().to_vec() - } - - JsonWebSignatureAlg::Hs384 => { - let mut mac = Hmac::::new_from_slice(self.inner)?; - mac.update(msg); - mac.finalize().into_bytes().to_vec() - } - - JsonWebSignatureAlg::Hs512 => { - let mut mac = Hmac::::new_from_slice(self.inner)?; - mac.update(msg); - mac.finalize().into_bytes().to_vec() - } - - _ => bail!("unsupported algorithm"), - }; - - Ok(signature) - } -} - -impl<'a> VerifyingKeystore for SharedSecret<'a> { - type Error = Error; - type Future = Ready>; - - fn verify( - &self, - header: &JsonWebSignatureHeader, - payload: &[u8], - signature: &[u8], - ) -> Self::Future { - std::future::ready(self.verify_sync(header, payload, signature)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_shared_secret() { - let secret = "super-complicated-secret-that-should-be-big-enough-for-sha512"; - let message = "this is the message to sign".as_bytes(); - let store = SharedSecret::new(&secret); - for alg in [ - JsonWebSignatureAlg::Hs256, - JsonWebSignatureAlg::Hs384, - JsonWebSignatureAlg::Hs512, - ] { - let header = store.prepare_header(alg).await.unwrap(); - assert_eq!(header.alg(), alg); - let signature = store.sign(&header, message).await.unwrap(); - store.verify(&header, message, &signature).await.unwrap(); - } - } -} diff --git a/crates/jose/src/keystore/static_keystore.rs b/crates/jose/src/keystore/static_keystore.rs index 09697f63..dc2f6432 100644 --- a/crates/jose/src/keystore/static_keystore.rs +++ b/crates/jose/src/keystore/static_keystore.rs @@ -14,9 +14,7 @@ use std::{ collections::{HashMap, HashSet}, - convert::Infallible, future::Ready, - task::Poll, }; use anyhow::bail; @@ -31,11 +29,10 @@ use pkcs8::{DecodePrivateKey, EncodePublicKey}; use rsa::{PublicKey as _, RsaPrivateKey, RsaPublicKey}; use sha2::{Sha256, Sha384, Sha512}; use signature::{Signature, Signer, Verifier}; -use tower::Service; use super::{SigningKeystore, VerifyingKeystore}; use crate::{ - jwk::{JsonWebKey, JsonWebKeySet, PublicJsonWebKeySet}, + jwk::{JsonWebKey, PublicJsonWebKeySet}, JsonWebSignatureHeader, }; @@ -133,6 +130,27 @@ impl StaticKeystore { Ok(()) } + #[must_use] + pub fn to_public_jwks(&self) -> PublicJsonWebKeySet { + let rsa = self.rsa_keys.iter().map(|(kid, key)| { + let pubkey = RsaPublicKey::from(key); + JsonWebKey::new(pubkey.into()) + .with_kid(kid) + .with_use(JsonWebKeyUse::Sig) + }); + + let es256 = self.es256_keys.iter().map(|(kid, key)| { + let pubkey = ecdsa::VerifyingKey::from(key); + JsonWebKey::new(pubkey.into()) + .with_kid(kid) + .with_use(JsonWebKeyUse::Sig) + .with_alg(JsonWebSignatureAlg::Es256) + }); + + let keys = rsa.chain(es256).collect(); + PublicJsonWebKeySet::new(keys) + } + fn verify_sync( &self, header: &JsonWebSignatureHeader, @@ -366,36 +384,6 @@ impl VerifyingKeystore for StaticKeystore { } } -impl Service<()> for &StaticKeystore { - type Future = Ready>; - type Response = PublicJsonWebKeySet; - type Error = Infallible; - - fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _req: ()) -> Self::Future { - let rsa = self.rsa_keys.iter().map(|(kid, key)| { - let pubkey = RsaPublicKey::from(key); - JsonWebKey::new(pubkey.into()) - .with_kid(kid) - .with_use(JsonWebKeyUse::Sig) - }); - - let es256 = self.es256_keys.iter().map(|(kid, key)| { - let pubkey = ecdsa::VerifyingKey::from(key); - JsonWebKey::new(pubkey.into()) - .with_kid(kid) - .with_use(JsonWebKeyUse::Sig) - .with_alg(JsonWebSignatureAlg::Es256) - }); - - let keys = rsa.chain(es256).collect(); - std::future::ready(Ok(JsonWebKeySet::new(keys))) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/jose/src/keystore/traits.rs b/crates/jose/src/keystore/traits.rs index ae678753..34f1bb4c 100644 --- a/crates/jose/src/keystore/traits.rs +++ b/crates/jose/src/keystore/traits.rs @@ -12,15 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashSet, future::Future, sync::Arc}; +use std::{collections::HashSet, future::Future}; use async_trait::async_trait; -use futures_util::{ - future::{Either, MapErr}, - TryFutureExt, -}; use mas_iana::jose::JsonWebSignatureAlg; -use thiserror::Error; use crate::JsonWebSignatureHeader; @@ -43,61 +38,3 @@ pub trait VerifyingKeystore { fn verify(&self, header: &JsonWebSignatureHeader, msg: &[u8], signature: &[u8]) -> Self::Future; } - -#[derive(Debug, Error)] -pub enum EitherError { - #[error(transparent)] - Left(A), - #[error(transparent)] - Right(B), -} - -impl VerifyingKeystore for Either -where - L: VerifyingKeystore, - R: VerifyingKeystore, -{ - type Error = EitherError; - - #[allow(clippy::type_complexity)] - type Future = Either< - MapErr Self::Error>, - MapErr Self::Error>, - >; - - fn verify( - &self, - header: &JsonWebSignatureHeader, - msg: &[u8], - signature: &[u8], - ) -> Self::Future { - match self { - Either::Left(left) => Either::Left( - left.verify(header, msg, signature) - .map_err(EitherError::Left), - ), - Either::Right(right) => Either::Right( - right - .verify(header, msg, signature) - .map_err(EitherError::Right), - ), - } - } -} - -impl VerifyingKeystore for Arc -where - T: VerifyingKeystore, -{ - type Error = T::Error; - type Future = T::Future; - - fn verify( - &self, - header: &JsonWebSignatureHeader, - msg: &[u8], - signature: &[u8], - ) -> Self::Future { - self.as_ref().verify(header, msg, signature) - } -} diff --git a/crates/jose/src/lib.rs b/crates/jose/src/lib.rs index 033874a3..8d51434d 100644 --- a/crates/jose/src/lib.rs +++ b/crates/jose/src/lib.rs @@ -30,8 +30,5 @@ pub use futures_util::future::Either; pub use self::{ jwt::{DecodedJsonWebToken, JsonWebSignatureHeader, JsonWebTokenParts, Jwt, JwtSignatureError}, - keystore::{ - DynamicJwksStore, SharedSecret, SigningKeystore, StaticJwksStore, StaticKeystore, - VerifyingKeystore, - }, + keystore::{SigningKeystore, StaticKeystore, VerifyingKeystore}, };