1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Get rid of legacy JWKS store

This commit is contained in:
Quentin Gliech
2022-08-29 18:29:22 +02:00
parent 84c793dae0
commit 2c400d4cc1
16 changed files with 328 additions and 889 deletions

View File

@@ -30,10 +30,7 @@ use mas_config::Encrypter;
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{
jwk::PublicJsonWebKeySet, DecodedJsonWebToken, DynamicJwksStore, Either,
JsonWebSignatureHeader, JsonWebTokenParts, SharedSecret, StaticJwksStore, VerifyingKeystore,
};
use mas_jose::{jwk::PublicJsonWebKeySet, Jwt};
use mas_storage::{
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
PostgresqlBackend,
@@ -72,9 +69,7 @@ pub enum Credentials {
},
ClientAssertionJwtBearer {
client_id: String,
jwt: JsonWebTokenParts,
header: Box<JsonWebSignatureHeader>,
claims: HashMap<String, Value>,
jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
},
}
@@ -128,7 +123,7 @@ impl Credentials {
}
(
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
Credentials::ClientAssertionJwtBearer { jwt, .. },
OAuthClientAuthenticationMethod::PrivateKeyJwt,
) => {
// Get the client JWKS
@@ -137,14 +132,16 @@ impl Credentials {
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let store: Either<StaticJwksStore, DynamicJwksStore> = jwks_key_store(jwks);
let fut = jwt.verify(header, &store);
fut.await
let jwks = fetch_jwks(jwks)
.await
.map_err(|_| CredentialsVerificationError::JwksFetchFailed)?;
jwt.verify_from_jwks(&jwks)
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
(
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
Credentials::ClientAssertionJwtBearer { jwt, .. },
OAuthClientAuthenticationMethod::ClientSecretJwt,
) => {
// Decrypt the client_secret
@@ -157,9 +154,7 @@ impl Credentials {
.decrypt_string(encrypted_client_secret)
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
let store = SharedSecret::new(&decrypted_client_secret);
let fut = jwt.verify(header, &store);
fut.await
jwt.verify_from_shared_secret(decrypted_client_secret)
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
@@ -171,38 +166,25 @@ impl Credentials {
}
}
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
// Assert that the output is both a VerifyingKeystore and Send
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
t
}
let inner = match jwks {
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
JwksOrJwksUri::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.response_body_to_bytes()
.json_response::<PublicJsonWebKeySet>()
.map_request(move |_: ()| {
http::Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result<PublicJsonWebKeySet, BoxError> {
let uri = match jwks {
JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
JwksOrJwksUri::JwksUri(u) => u,
};
assert(inner)
let request = http::Request::builder()
.uri(uri.as_str())
.body(http_body::Empty::new())
.unwrap();
let client = mas_http::client("fetch-jwks")
.response_body_to_bytes()
.json_response::<PublicJsonWebKeySet>()
.map_err(Box::new);
let response = client.oneshot(request).await?;
Ok(response.into_body())
}
#[derive(Debug, Error)]
@@ -221,6 +203,9 @@ pub enum CredentialsVerificationError {
#[error("invalid assertion signature")]
InvalidAssertionSignature,
#[error("failed to fetch jwks")]
JwksFetchFailed,
}
#[derive(Debug, PartialEq, Eq)]
@@ -344,16 +329,10 @@ where
Some(client_assertion),
) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
// Got a JWT bearer client_assertion
let jwt: JsonWebTokenParts = client_assertion
.parse()
let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
let decoded: DecodedJsonWebToken<HashMap<String, Value>> = jwt
.decode()
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
let (header, claims) = decoded.split();
let client_id = if let Some(Value::String(client_id)) = claims.get("sub") {
let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
client_id.clone()
} else {
return Err(ClientAuthorizationError::InvalidAssertion);
@@ -371,9 +350,7 @@ where
Credentials::ClientAssertionJwtBearer {
client_id,
jwt,
header: Box::new(header),
claims,
jwt: Box::new(jwt),
}
}
@@ -585,19 +562,15 @@ mod tests {
.unwrap();
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
let (client_id, _jwt, _header, _claims) = if let Credentials::ClientAssertionJwtBearer {
client_id,
jwt,
header,
claims,
} = authz.credentials
{
(client_id, jwt, header, claims)
let (client_id, jwt) =
if let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials {
(client_id, jwt)
} else {
panic!("expected a JWT client_assertion");
};
assert_eq!(client_id, "client-id");
// TODO: test more things
jwt.verify_from_shared_secret(b"client-secret".to_vec())
.unwrap();
}
}

View File

@@ -16,12 +16,10 @@ use std::{convert::Infallible, sync::Arc};
use axum::{extract::Extension, response::IntoResponse, Json};
use mas_jose::StaticKeystore;
use tower::{Service, ServiceExt};
pub(crate) async fn get(
Extension(key_store): Extension<Arc<StaticKeystore>>,
) -> Result<impl IntoResponse, Infallible> {
let mut key_store: &StaticKeystore = key_store.as_ref();
let jwks = key_store.ready().await?.call(()).await?;
let jwks = key_store.to_public_jwks();
Ok(Json(jwks))
}

View File

@@ -260,7 +260,8 @@ mod tests {
let jwks: PublicJsonWebKeySet = serde_json::from_value(jwks).unwrap();
// Both keys are RSA public keys
for jwk in &jwks.keys {
rsa::RsaPublicKey::try_from(jwk.parameters.clone()).unwrap();
let p = jwk.params().rsa().expect("an RSA key");
rsa::RsaPublicKey::try_from(p).unwrap();
}
let constraints = ConstraintSet::default()
@@ -394,12 +395,23 @@ mod tests {
let jwks: PublicJsonWebKeySet = serde_json::from_value(jwks).unwrap();
// The first 6 keys are RSA, 7th is P-256
let mut keys = jwks.keys.into_iter();
rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().parameters).unwrap();
ecdsa::VerifyingKey::try_from(keys.next().unwrap().parameters).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap();
rsa::RsaPublicKey::try_from(keys.next().unwrap().params().rsa().unwrap()).unwrap();
// 7th is P-256
ecdsa::VerifyingKey::<p256::NistP256>::try_from(
keys.next().unwrap().params().ec().unwrap(),
)
.unwrap();
// 8th is P-384
ecdsa::VerifyingKey::<p384::NistP384>::try_from(
keys.next().unwrap().params().ec().unwrap(),
)
.unwrap();
// 8th is P-521, but we don't support it yet
keys.next().unwrap().params().ec().unwrap();
}
}

View File

@@ -43,6 +43,40 @@ pub enum JsonWebKeyPrivateParameters {
Okp(OkpPrivateParameters),
}
impl JsonWebKeyPrivateParameters {
#[must_use]
pub const fn oct(&self) -> Option<&OctPrivateParameters> {
match self {
Self::Oct(params) => Some(params),
_ => None,
}
}
#[must_use]
pub const fn rsa(&self) -> Option<&RsaPrivateParameters> {
match self {
Self::Rsa(params) => Some(params),
_ => None,
}
}
#[must_use]
pub const fn ec(&self) -> Option<&EcPrivateParameters> {
match self {
Self::Ec(params) => Some(params),
_ => None,
}
}
#[must_use]
pub const fn okp(&self) -> Option<&OkpPrivateParameters> {
match self {
Self::Okp(params) => Some(params),
_ => None,
}
}
}
impl ParametersInfo for JsonWebKeyPrivateParameters {
fn kty(&self) -> JsonWebKeyType {
match self {

View File

@@ -39,6 +39,32 @@ pub enum JsonWebKeyPublicParameters {
Okp(OkpPublicParameters),
}
impl JsonWebKeyPublicParameters {
#[must_use]
pub const fn rsa(&self) -> Option<&RsaPublicParameters> {
match self {
Self::Rsa(params) => Some(params),
_ => None,
}
}
#[must_use]
pub const fn ec(&self) -> Option<&EcPublicParameters> {
match self {
Self::Ec(params) => Some(params),
_ => None,
}
}
#[must_use]
pub const fn okp(&self) -> Option<&OkpPublicParameters> {
match self {
Self::Okp(params) => Some(params),
_ => None,
}
}
}
impl ParametersInfo for JsonWebKeyPublicParameters {
fn kty(&self) -> JsonWebKeyType {
match self {
@@ -163,11 +189,38 @@ impl OkpPublicParameters {
mod rsa_impls {
use digest::DynDigest;
use rsa::{BigUint, RsaPublicKey};
use rsa::{BigUint, PublicKeyParts, RsaPublicKey};
use super::RsaPublicParameters;
use super::{JsonWebKeyPublicParameters, RsaPublicParameters};
use crate::jwa::rsa::RsaHashIdentifier;
impl From<RsaPublicKey> for JsonWebKeyPublicParameters {
fn from(key: RsaPublicKey) -> Self {
Self::from(&key)
}
}
impl From<&RsaPublicKey> for JsonWebKeyPublicParameters {
fn from(key: &RsaPublicKey) -> Self {
Self::Rsa(key.into())
}
}
impl From<RsaPublicKey> for RsaPublicParameters {
fn from(key: RsaPublicKey) -> Self {
Self::from(&key)
}
}
impl From<&RsaPublicKey> for RsaPublicParameters {
fn from(key: &RsaPublicKey) -> Self {
Self {
n: key.n().to_bytes_be(),
e: key.e().to_bytes_be(),
}
}
}
impl<H> TryFrom<RsaPublicParameters> for crate::jwa::rsa::pkcs1v15::VerifyingKey<H>
where
H: RsaHashIdentifier,
@@ -236,7 +289,56 @@ mod ec_impls {
AffinePoint, Curve, FieldBytes, FieldSize, ProjectiveArithmetic, PublicKey,
};
use super::{super::JwkEcCurve, EcPublicParameters};
use super::{super::JwkEcCurve, EcPublicParameters, JsonWebKeyPublicParameters};
impl<C> From<ecdsa::VerifyingKey<C>> for JsonWebKeyPublicParameters
where
C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldSize<C>: ModulusSize,
{
fn from(key: ecdsa::VerifyingKey<C>) -> Self {
Self::from(&key)
}
}
impl<C> From<&ecdsa::VerifyingKey<C>> for JsonWebKeyPublicParameters
where
C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldSize<C>: ModulusSize,
{
fn from(key: &ecdsa::VerifyingKey<C>) -> Self {
Self::Ec(key.into())
}
}
impl<C> From<ecdsa::VerifyingKey<C>> for EcPublicParameters
where
C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldSize<C>: ModulusSize,
{
fn from(key: ecdsa::VerifyingKey<C>) -> Self {
Self::from(&key)
}
}
impl<C> From<&ecdsa::VerifyingKey<C>> for EcPublicParameters
where
C: PrimeCurve + ProjectiveArithmetic + JwkEcCurve,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldSize<C>: ModulusSize,
{
fn from(key: &ecdsa::VerifyingKey<C>) -> Self {
let points = key.to_encoded_point(false);
EcPublicParameters {
x: points.x().unwrap().to_vec(),
y: points.y().unwrap().to_vec(),
crv: C::CRV,
}
}
}
impl<C> TryFrom<EcPublicParameters> for VerifyingKey<C>
where
@@ -311,71 +413,3 @@ mod ec_impls {
}
}
}
/// Some legacy implementations to remove
mod legacy {
use anyhow::bail;
use mas_iana::jose::JsonWebKeyEcEllipticCurve;
use p256::NistP256;
use rsa::{BigUint, PublicKeyParts};
use super::{EcPublicParameters, JsonWebKeyPublicParameters, RsaPublicParameters};
impl TryFrom<JsonWebKeyPublicParameters> for ecdsa::VerifyingKey<NistP256> {
type Error = anyhow::Error;
fn try_from(params: JsonWebKeyPublicParameters) -> Result<Self, Self::Error> {
let (x, y): ([u8; 32], [u8; 32]) = match params {
JsonWebKeyPublicParameters::Ec(EcPublicParameters {
x,
y,
crv: JsonWebKeyEcEllipticCurve::P256,
}) => (
x.try_into()
.map_err(|_| anyhow::anyhow!("invalid curve parameter x"))?,
y.try_into()
.map_err(|_| anyhow::anyhow!("invalid curve parameter y"))?,
),
_ => bail!("Wrong curve"),
};
let point = sec1::EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false);
let key = ecdsa::VerifyingKey::from_encoded_point(&point)?;
Ok(key)
}
}
impl From<ecdsa::VerifyingKey<NistP256>> for JsonWebKeyPublicParameters {
fn from(key: ecdsa::VerifyingKey<NistP256>) -> Self {
let points = key.to_encoded_point(false);
JsonWebKeyPublicParameters::Ec(EcPublicParameters {
x: points.x().unwrap().to_vec(),
y: points.y().unwrap().to_vec(),
crv: JsonWebKeyEcEllipticCurve::P256,
})
}
}
impl TryFrom<JsonWebKeyPublicParameters> for rsa::RsaPublicKey {
type Error = anyhow::Error;
fn try_from(params: JsonWebKeyPublicParameters) -> Result<Self, Self::Error> {
let (n, e) = match &params {
JsonWebKeyPublicParameters::Rsa(RsaPublicParameters { n, e }) => (n, e),
_ => bail!("Wrong key type"),
};
let n = BigUint::from_bytes_be(n);
let e = BigUint::from_bytes_be(e);
Ok(rsa::RsaPublicKey::new(n, e)?)
}
}
impl From<rsa::RsaPublicKey> for JsonWebKeyPublicParameters {
fn from(key: rsa::RsaPublicKey) -> Self {
JsonWebKeyPublicParameters::Rsa(RsaPublicParameters {
n: key.n().to_bytes_be(),
e: key.e().to_bytes_be(),
})
}
}
}

View File

@@ -190,32 +190,3 @@ impl JsonWebTokenParts {
format!("{}.{}", payload, signature)
}
}
#[cfg(test)]
mod tests {
use mas_iana::jose::JsonWebSignatureAlg;
use super::*;
use crate::SharedSecret;
#[tokio::test]
async fn decode_hs256() {
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
let jwt: JsonWebTokenParts = jwt.parse().unwrap();
let secret = "your-256-bit-secret";
let store = SharedSecret::new(&secret);
let jwt: DecodedJsonWebToken<serde_json::Value> =
jwt.decode_and_verify(&store).await.unwrap();
assert_eq!(jwt.header.typ(), Some("JWT"));
assert_eq!(jwt.header.alg(), JsonWebSignatureAlg::Hs256);
assert_eq!(
jwt.payload,
serde_json::json!({
"sub": "1234567890",
"name": "John Doe",
"iat": 1_516_239_022
})
);
}
}

View File

@@ -16,6 +16,7 @@ use std::{borrow::Cow, ops::Deref};
use thiserror::Error;
#[derive(Clone, PartialEq, Eq)]
pub struct RawJwt<'a> {
inner: Cow<'a, str>,
first_dot: usize,
@@ -54,6 +55,14 @@ impl<'a> RawJwt<'a> {
pub fn signed_part(&'a self) -> &'a str {
&self.inner[..self.second_dot]
}
pub fn into_owned(self) -> RawJwt<'static> {
RawJwt {
inner: self.inner.into_owned().into(),
first_dot: self.first_dot,
second_dot: self.second_dot,
}
}
}
impl<'a> Deref for RawJwt<'a> {
@@ -97,3 +106,25 @@ impl<'a> TryFrom<&'a str> for RawJwt<'a> {
})
}
}
impl TryFrom<String> for RawJwt<'static> {
type Error = DecodeError;
fn try_from(value: String) -> Result<Self, Self::Error> {
let mut indices = value
.char_indices()
.filter_map(|(idx, c)| (c == '.').then(|| idx));
let first_dot = indices.next().ok_or(DecodeError::NoDots)?;
let second_dot = indices.next().ok_or(DecodeError::OnlyOneDot)?;
if indices.next().is_some() {
return Err(DecodeError::TooManyDots);
}
Ok(Self {
inner: value.into(),
first_dot,
second_dot,
})
}
}

View File

@@ -18,7 +18,9 @@ use signature::{Signature, Signer, Verifier};
use thiserror::Error;
use super::{header::JsonWebSignatureHeader, raw::RawJwt};
use crate::{constraints::ConstraintSet, jwk::PublicJsonWebKeySet};
#[derive(Clone, PartialEq, Eq)]
pub struct Jwt<'a, T> {
raw: RawJwt<'a>,
header: JsonWebSignatureHeader,
@@ -107,14 +109,12 @@ impl JwtDecodeError {
}
}
impl<'a, T> TryFrom<&'a str> for Jwt<'a, T>
impl<'a, T> TryFrom<RawJwt<'a>> for Jwt<'a, T>
where
T: DeserializeOwned,
{
type Error = JwtDecodeError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
let raw = RawJwt::try_from(value)?;
fn try_from(raw: RawJwt<'a>) -> Result<Self, Self::Error> {
let header_reader =
base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.header().as_bytes())
.map_err(JwtDecodeError::decode_header)?;
@@ -139,6 +139,28 @@ where
}
}
impl<'a, T> TryFrom<&'a str> for Jwt<'a, T>
where
T: DeserializeOwned,
{
type Error = JwtDecodeError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
let raw = RawJwt::try_from(value)?;
Self::try_from(raw)
}
}
impl<T> TryFrom<String> for Jwt<'static, T>
where
T: DeserializeOwned,
{
type Error = JwtDecodeError;
fn try_from(value: String) -> Result<Self, Self::Error> {
let raw = RawJwt::try_from(value)?;
Self::try_from(raw)
}
}
#[derive(Debug, Error)]
pub enum JwtVerificationError {
#[error("failed to parse signature")]
@@ -164,6 +186,12 @@ impl JwtVerificationError {
}
}
#[derive(Debug, Error, Default)]
#[error("none of the keys worked")]
pub struct NoKeyWorked {
_inner: (),
}
impl<'a, T> Jwt<'a, T> {
pub fn header(&self) -> &JsonWebSignatureHeader {
&self.header
@@ -173,6 +201,15 @@ impl<'a, T> Jwt<'a, T> {
&self.payload
}
pub fn into_owned(self) -> Jwt<'static, T> {
Jwt {
raw: self.raw.into_owned(),
header: self.header,
payload: self.payload,
signature: self.signature,
}
}
pub fn verify<K, S>(&self, key: &K) -> Result<(), JwtVerificationError>
where
K: Verifier<S>,
@@ -185,6 +222,37 @@ impl<'a, T> Jwt<'a, T> {
.map_err(JwtVerificationError::verify)
}
pub fn verify_from_shared_secret(&self, secret: Vec<u8>) -> Result<(), NoKeyWorked> {
let verifier = crate::verifier::Verifier::for_oct_and_alg(secret, self.header().alg())
.map_err(|_| NoKeyWorked::default())?;
self.verify(&verifier).map_err(|_| NoKeyWorked::default())?;
Ok(())
}
pub fn verify_from_jwks(&self, jwks: &PublicJsonWebKeySet) -> Result<(), NoKeyWorked> {
let constraints = ConstraintSet::from(self.header());
let candidates = constraints.filter(&**jwks);
for candidate in candidates {
let verifier = match crate::verifier::Verifier::for_jwk_and_alg(
candidate.params(),
self.header().alg(),
) {
Ok(v) => v,
Err(_) => continue,
};
match self.verify(&verifier) {
Ok(_) => return Ok(()),
Err(_) => continue,
}
}
Err(NoKeyWorked::default())
}
pub fn as_str(&'a self) -> &'a str {
&self.raw
}

View File

@@ -1,165 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use chrono::{DateTime, Duration, Utc};
use futures_util::future::BoxFuture;
use thiserror::Error;
use tokio::sync::RwLock;
use tower::{
util::{BoxCloneService, ServiceExt},
BoxError, Service,
};
use super::StaticJwksStore;
use crate::{jwk::PublicJsonWebKeySet, JsonWebSignatureHeader, VerifyingKeystore};
#[derive(Debug, Error)]
pub enum Error {
#[error("cache in inconsistent state")]
InconsistentCache,
#[error(transparent)]
Cached(Arc<BoxError>),
#[error("todo")]
Todo,
#[error(transparent)]
Verification(#[from] super::static_store::Error),
}
enum State<E> {
Pending,
Errored {
at: DateTime<Utc>,
error: E,
},
Fulfilled {
at: DateTime<Utc>,
store: StaticJwksStore,
},
}
impl<E> Default for State<E> {
fn default() -> Self {
Self::Pending
}
}
impl<E> State<E> {
fn fullfill(&mut self, key_set: PublicJsonWebKeySet) {
*self = Self::Fulfilled {
at: Utc::now(),
store: StaticJwksStore::new(key_set),
}
}
fn error(&mut self, error: E) {
*self = Self::Errored {
at: Utc::now(),
error,
}
}
fn should_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } if *at - now > Duration::minutes(5) => true,
Self::Fulfilled { at, .. } if *at - now > Duration::hours(1) => true,
_ => false,
}
}
fn should_force_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } | Self::Fulfilled { at, .. }
if *at - now > Duration::minutes(5) =>
{
true
}
_ => false,
}
}
}
#[derive(Clone)]
pub struct DynamicJwksStore {
exporter: BoxCloneService<(), PublicJsonWebKeySet, BoxError>,
cache: Arc<RwLock<State<Arc<BoxError>>>>,
}
impl DynamicJwksStore {
pub fn new<T>(exporter: T) -> Self
where
T: Service<(), Response = PublicJsonWebKeySet, Error = BoxError> + Send + Clone + 'static,
T::Future: Send,
{
Self {
exporter: exporter.boxed_clone(),
cache: Arc::default(),
}
}
}
impl VerifyingKeystore for DynamicJwksStore {
type Error = Error;
type Future = BoxFuture<'static, Result<(), Self::Error>>;
fn verify(
&self,
header: &JsonWebSignatureHeader,
payload: &[u8],
signature: &[u8],
) -> Self::Future {
let cache = self.cache.clone();
let exporter = self.exporter.clone();
let header = header.clone();
let payload = payload.to_owned();
let signature = signature.to_owned();
let fut = async move {
if cache.read().await.should_refresh() {
let mut cache = cache.write().await;
if cache.should_force_refresh() {
let jwks = async move { exporter.ready_oneshot().await?.call(()).await }.await;
match jwks {
Ok(jwks) => cache.fullfill(jwks),
Err(err) => cache.error(Arc::new(err)),
}
}
}
let cache = cache.read().await;
// TODO: we could bubble up the underlying error here
let store = match &*cache {
State::Pending => return Err(Error::InconsistentCache),
State::Errored { error, .. } => return Err(Error::Cached(error.clone())),
State::Fulfilled { store, .. } => store,
};
store.verify(&header, &payload, &signature).await?;
Ok(())
};
Box::pin(fut)
}
}

View File

@@ -1,18 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod dynamic_store;
mod static_store;
pub use self::{dynamic_store::DynamicJwksStore, static_store::StaticJwksStore};

View File

@@ -1,245 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::future::Ready;
use digest::Digest;
use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg};
use rsa::{PublicKey, RsaPublicKey};
use sha2::{Sha256, Sha384, Sha512};
use signature::{Signature, Verifier};
use thiserror::Error;
use crate::{
constraints::Constrainable,
jwk::{PublicJsonWebKey, PublicJsonWebKeySet},
JsonWebSignatureHeader, VerifyingKeystore,
};
#[derive(Debug, Error)]
pub enum Error {
#[error("key not found")]
KeyNotFound,
#[error("multiple key matched")]
MultipleKeyMatched,
#[error(r#"missing "kid" field in header"#)]
MissingKid,
#[error(transparent)]
Rsa(#[from] rsa::errors::Error),
#[error("unsupported algorithm {alg}")]
UnsupportedAlgorithm { alg: JsonWebSignatureAlg },
#[error(transparent)]
Signature(#[from] signature::Error),
#[error("invalid {kty} key")]
InvalidKey {
kty: JsonWebKeyType,
source: anyhow::Error,
},
}
struct KeyConstraint<'a> {
kty: Option<JsonWebKeyType>,
alg: Option<JsonWebSignatureAlg>,
kid: Option<&'a str>,
}
impl<'a> KeyConstraint<'a> {
fn matches(&self, key: &'a PublicJsonWebKey) -> bool {
// If a specific KID was asked, match the key only if it has a matching kid
// field
if let Some(kid) = self.kid {
if key.kid() != Some(kid) {
return false;
}
}
if let Some(kty) = self.kty {
if key.kty() != kty {
return false;
}
}
if let Some(alg) = self.alg {
if key.alg() != None && key.alg() != Some(alg) {
return false;
}
}
true
}
fn find_keys(&self, key_set: &'a PublicJsonWebKeySet) -> Vec<&'a PublicJsonWebKey> {
key_set.iter().filter(|k| self.matches(k)).collect()
}
}
pub struct StaticJwksStore {
key_set: PublicJsonWebKeySet,
}
impl StaticJwksStore {
#[must_use]
pub fn new(key_set: PublicJsonWebKeySet) -> Self {
Self { key_set }
}
fn find_key<'a>(
&'a self,
constraint: &KeyConstraint<'a>,
) -> Result<&'a PublicJsonWebKey, Error> {
let keys = constraint.find_keys(&self.key_set);
match &keys[..] {
[one] => Ok(one),
[] => Err(Error::KeyNotFound),
_ => Err(Error::MultipleKeyMatched),
}
}
fn find_rsa_key(&self, kid: Option<&str>) -> Result<RsaPublicKey, Error> {
let constraint = KeyConstraint {
kty: Some(JsonWebKeyType::Rsa),
kid,
alg: None,
};
let key = self.find_key(&constraint)?;
let key = key
.params()
.clone()
.try_into()
.map_err(|source| Error::InvalidKey {
kty: JsonWebKeyType::Rsa,
source,
})?;
Ok(key)
}
fn find_ecdsa_key(
&self,
kid: Option<&str>,
) -> Result<ecdsa::VerifyingKey<p256::NistP256>, Error> {
let constraint = KeyConstraint {
kty: Some(JsonWebKeyType::Ec),
kid,
alg: None,
};
let key = self.find_key(&constraint)?;
let key = key
.params()
.clone()
.try_into()
.map_err(|source| Error::InvalidKey {
kty: JsonWebKeyType::Ec,
source,
})?;
Ok(key)
}
#[tracing::instrument(skip(self))]
fn verify_sync(
&self,
header: &JsonWebSignatureHeader,
payload: &[u8],
signature: &[u8],
) -> Result<(), Error> {
let kid = header.kid();
match header.alg() {
JsonWebSignatureAlg::Rs256 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha256::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Rs384 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha384::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Rs512 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha512::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Es256 => {
let key = self.find_ecdsa_key(kid)?;
let signature = ecdsa::Signature::from_bytes(signature)?;
key.verify(payload, &signature)?;
}
alg => return Err(Error::UnsupportedAlgorithm { alg }),
};
Ok(())
}
}
impl VerifyingKeystore for StaticJwksStore {
type Error = Error;
type Future = Ready<Result<(), Self::Error>>;
fn verify(
&self,
header: &JsonWebSignatureHeader,
payload: &[u8],
signature: &[u8],
) -> Self::Future {
std::future::ready(self.verify_sync(header, payload, signature))
}
}

View File

@@ -12,14 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod jwks;
mod shared_secret;
mod static_keystore;
mod traits;
pub use self::{
jwks::{DynamicJwksStore, StaticJwksStore},
shared_secret::SharedSecret,
static_keystore::StaticKeystore,
traits::{SigningKeystore, VerifyingKeystore},
};

View File

@@ -1,172 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, future::Ready};
use anyhow::bail;
use async_trait::async_trait;
use digest::{InvalidLength, MacError};
use hmac::{Hmac, Mac};
use mas_iana::jose::JsonWebSignatureAlg;
use sha2::{Sha256, Sha384, Sha512};
use thiserror::Error;
use super::{SigningKeystore, VerifyingKeystore};
use crate::JsonWebSignatureHeader;
#[derive(Debug, Error)]
pub enum Error {
#[error("invalid key")]
InvalidKey(#[from] InvalidLength),
#[error("unsupported algorithm {alg}")]
UnsupportedAlgorithm { alg: JsonWebSignatureAlg },
#[error("signature verification failed")]
Verification(#[from] MacError),
}
pub struct SharedSecret<'a> {
inner: &'a [u8],
}
impl<'a> SharedSecret<'a> {
pub fn new(source: &'a impl AsRef<[u8]>) -> Self {
Self {
inner: source.as_ref(),
}
}
fn verify_sync(
&self,
header: &JsonWebSignatureHeader,
payload: &[u8],
signature: &[u8],
) -> Result<(), Error> {
match header.alg() {
JsonWebSignatureAlg::Hs256 => {
let mut mac = Hmac::<Sha256>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.into())?;
}
JsonWebSignatureAlg::Hs384 => {
let mut mac = Hmac::<Sha384>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.into())?;
}
JsonWebSignatureAlg::Hs512 => {
let mut mac = Hmac::<Sha512>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.into())?;
}
alg => return Err(Error::UnsupportedAlgorithm { alg }),
};
Ok(())
}
}
#[async_trait]
impl<'a> SigningKeystore for SharedSecret<'a> {
fn supported_algorithms(&self) -> HashSet<JsonWebSignatureAlg> {
let mut algorithms = HashSet::with_capacity(3);
algorithms.insert(JsonWebSignatureAlg::Hs256);
algorithms.insert(JsonWebSignatureAlg::Hs384);
algorithms.insert(JsonWebSignatureAlg::Hs512);
algorithms
}
async fn prepare_header(
&self,
alg: JsonWebSignatureAlg,
) -> anyhow::Result<JsonWebSignatureHeader> {
if !matches!(
alg,
JsonWebSignatureAlg::Hs256 | JsonWebSignatureAlg::Hs384 | JsonWebSignatureAlg::Hs512,
) {
bail!("unsupported algorithm")
}
Ok(JsonWebSignatureHeader::new(alg))
}
async fn sign(&self, header: &JsonWebSignatureHeader, msg: &[u8]) -> anyhow::Result<Vec<u8>> {
// TODO: do the signing in a blocking task
// TODO: should we bail out if the key is too small?
let signature = match header.alg() {
JsonWebSignatureAlg::Hs256 => {
let mut mac = Hmac::<Sha256>::new_from_slice(self.inner)?;
mac.update(msg);
mac.finalize().into_bytes().to_vec()
}
JsonWebSignatureAlg::Hs384 => {
let mut mac = Hmac::<Sha384>::new_from_slice(self.inner)?;
mac.update(msg);
mac.finalize().into_bytes().to_vec()
}
JsonWebSignatureAlg::Hs512 => {
let mut mac = Hmac::<Sha512>::new_from_slice(self.inner)?;
mac.update(msg);
mac.finalize().into_bytes().to_vec()
}
_ => bail!("unsupported algorithm"),
};
Ok(signature)
}
}
impl<'a> VerifyingKeystore for SharedSecret<'a> {
type Error = Error;
type Future = Ready<Result<(), Self::Error>>;
fn verify(
&self,
header: &JsonWebSignatureHeader,
payload: &[u8],
signature: &[u8],
) -> Self::Future {
std::future::ready(self.verify_sync(header, payload, signature))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_shared_secret() {
let secret = "super-complicated-secret-that-should-be-big-enough-for-sha512";
let message = "this is the message to sign".as_bytes();
let store = SharedSecret::new(&secret);
for alg in [
JsonWebSignatureAlg::Hs256,
JsonWebSignatureAlg::Hs384,
JsonWebSignatureAlg::Hs512,
] {
let header = store.prepare_header(alg).await.unwrap();
assert_eq!(header.alg(), alg);
let signature = store.sign(&header, message).await.unwrap();
store.verify(&header, message, &signature).await.unwrap();
}
}
}

View File

@@ -14,9 +14,7 @@
use std::{
collections::{HashMap, HashSet},
convert::Infallible,
future::Ready,
task::Poll,
};
use anyhow::bail;
@@ -31,11 +29,10 @@ use pkcs8::{DecodePrivateKey, EncodePublicKey};
use rsa::{PublicKey as _, RsaPrivateKey, RsaPublicKey};
use sha2::{Sha256, Sha384, Sha512};
use signature::{Signature, Signer, Verifier};
use tower::Service;
use super::{SigningKeystore, VerifyingKeystore};
use crate::{
jwk::{JsonWebKey, JsonWebKeySet, PublicJsonWebKeySet},
jwk::{JsonWebKey, PublicJsonWebKeySet},
JsonWebSignatureHeader,
};
@@ -133,6 +130,27 @@ impl StaticKeystore {
Ok(())
}
#[must_use]
pub fn to_public_jwks(&self) -> PublicJsonWebKeySet {
let rsa = self.rsa_keys.iter().map(|(kid, key)| {
let pubkey = RsaPublicKey::from(key);
JsonWebKey::new(pubkey.into())
.with_kid(kid)
.with_use(JsonWebKeyUse::Sig)
});
let es256 = self.es256_keys.iter().map(|(kid, key)| {
let pubkey = ecdsa::VerifyingKey::from(key);
JsonWebKey::new(pubkey.into())
.with_kid(kid)
.with_use(JsonWebKeyUse::Sig)
.with_alg(JsonWebSignatureAlg::Es256)
});
let keys = rsa.chain(es256).collect();
PublicJsonWebKeySet::new(keys)
}
fn verify_sync(
&self,
header: &JsonWebSignatureHeader,
@@ -366,36 +384,6 @@ impl VerifyingKeystore for StaticKeystore {
}
}
impl Service<()> for &StaticKeystore {
type Future = Ready<Result<Self::Response, Self::Error>>;
type Response = PublicJsonWebKeySet;
type Error = Infallible;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ()) -> Self::Future {
let rsa = self.rsa_keys.iter().map(|(kid, key)| {
let pubkey = RsaPublicKey::from(key);
JsonWebKey::new(pubkey.into())
.with_kid(kid)
.with_use(JsonWebKeyUse::Sig)
});
let es256 = self.es256_keys.iter().map(|(kid, key)| {
let pubkey = ecdsa::VerifyingKey::from(key);
JsonWebKey::new(pubkey.into())
.with_kid(kid)
.with_use(JsonWebKeyUse::Sig)
.with_alg(JsonWebSignatureAlg::Es256)
});
let keys = rsa.chain(es256).collect();
std::future::ready(Ok(JsonWebKeySet::new(keys)))
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -12,15 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, future::Future, sync::Arc};
use std::{collections::HashSet, future::Future};
use async_trait::async_trait;
use futures_util::{
future::{Either, MapErr},
TryFutureExt,
};
use mas_iana::jose::JsonWebSignatureAlg;
use thiserror::Error;
use crate::JsonWebSignatureHeader;
@@ -43,61 +38,3 @@ pub trait VerifyingKeystore {
fn verify(&self, header: &JsonWebSignatureHeader, msg: &[u8], signature: &[u8])
-> Self::Future;
}
#[derive(Debug, Error)]
pub enum EitherError<A, B> {
#[error(transparent)]
Left(A),
#[error(transparent)]
Right(B),
}
impl<L, R> VerifyingKeystore for Either<L, R>
where
L: VerifyingKeystore,
R: VerifyingKeystore,
{
type Error = EitherError<L::Error, R::Error>;
#[allow(clippy::type_complexity)]
type Future = Either<
MapErr<L::Future, fn(L::Error) -> Self::Error>,
MapErr<R::Future, fn(R::Error) -> Self::Error>,
>;
fn verify(
&self,
header: &JsonWebSignatureHeader,
msg: &[u8],
signature: &[u8],
) -> Self::Future {
match self {
Either::Left(left) => Either::Left(
left.verify(header, msg, signature)
.map_err(EitherError::Left),
),
Either::Right(right) => Either::Right(
right
.verify(header, msg, signature)
.map_err(EitherError::Right),
),
}
}
}
impl<T> VerifyingKeystore for Arc<T>
where
T: VerifyingKeystore,
{
type Error = T::Error;
type Future = T::Future;
fn verify(
&self,
header: &JsonWebSignatureHeader,
msg: &[u8],
signature: &[u8],
) -> Self::Future {
self.as_ref().verify(header, msg, signature)
}
}

View File

@@ -30,8 +30,5 @@ pub use futures_util::future::Either;
pub use self::{
jwt::{DecodedJsonWebToken, JsonWebSignatureHeader, JsonWebTokenParts, Jwt, JwtSignatureError},
keystore::{
DynamicJwksStore, SharedSecret, SigningKeystore, StaticJwksStore, StaticKeystore,
VerifyingKeystore,
},
keystore::{SigningKeystore, StaticKeystore, VerifyingKeystore},
};