1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Implement private_key_jwks client authentication

This involves a lot of things, including:
 - better VerifyingKeystore trait
 - better errors in the JOSE crate
 - getting rid of async_trait in some JOSE traits
This commit is contained in:
Quentin Gliech
2022-02-17 15:42:44 +01:00
parent c5858e6ed5
commit 035e2d7829
25 changed files with 1008 additions and 796 deletions

507
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -35,3 +35,8 @@ pem-rfc7468 = "0.3.1"
indoc = "1.0.3" indoc = "1.0.3"
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-http = { path = "../http" }
tower = { version = "0.4.11", features = ["util"] }
http = "0.2.6"
http-body = "0.4.4"
futures-util = "0.3.21"

View File

@@ -15,11 +15,15 @@
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use async_trait::async_trait; use async_trait::async_trait;
use mas_jose::{JsonWebKeySet, StaticJwksStore}; use futures_util::future::Either;
use http::Request;
use mas_http::HttpServiceExt;
use mas_jose::{DynamicJwksStore, JsonWebKeySet, StaticJwksStore, VerifyingKeystore};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use thiserror::Error; use thiserror::Error;
use tower::{BoxError, ServiceExt};
use url::Url; use url::Url;
use super::ConfigurationSection; use super::ConfigurationSection;
@@ -32,13 +36,37 @@ pub enum JwksOrJwksUri {
} }
impl JwksOrJwksUri { impl JwksOrJwksUri {
pub fn key_store(&self) -> StaticJwksStore { pub fn key_store(&self) -> Either<StaticJwksStore, DynamicJwksStore> {
let jwks = match self { // Assert that the output is both a VerifyingKeystore and Send
Self::Jwks(jwks) => jwks.clone(), fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
Self::JwksUri(_) => unimplemented!("jwks_uri are not implemented yet"), t
}
let inner = match self {
Self::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
Self::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.json::<JsonWebKeySet>()
.map_request(move |_: ()| {
Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
}; };
StaticJwksStore::new(jwks) assert(inner)
} }
} }

View File

@@ -63,6 +63,7 @@ mas-static-files = { path = "../static-files" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-templates = { path = "../templates" } mas-templates = { path = "../templates" }
mas-warp-utils = { path = "../warp-utils" } mas-warp-utils = { path = "../warp-utils" }
tower = "0.4.11"
[dev-dependencies] [dev-dependencies]
indoc = "1.0.3" indoc = "1.0.3"

View File

@@ -32,7 +32,7 @@ use warp::{filters::BoxedFilter, Filter, Reply};
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(super) fn filter( pub(super) fn filter(
key_store: impl SigningKeystore, key_store: &impl SigningKeystore,
http_config: &HttpConfig, http_config: &HttpConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> BoxedFilter<(Box<dyn Reply>,)> {
let builder = UrlBuilder::from(http_config); let builder = UrlBuilder::from(http_config);

View File

@@ -14,8 +14,9 @@
use std::sync::Arc; use std::sync::Arc;
use mas_jose::{ExportJwks, StaticKeystore}; use mas_jose::StaticKeystore;
use mas_warp_utils::{errors::WrapError, filters}; use mas_warp_utils::filters;
use tower::{Service, ServiceExt};
use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Reply>,)> { pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Reply>,)> {
@@ -27,7 +28,7 @@ pub(super) fn filter(key_store: &Arc<StaticKeystore>) -> BoxedFilter<(Box<dyn Re
} }
async fn get(key_store: Arc<StaticKeystore>) -> Result<Box<dyn Reply>, Rejection> { async fn get(key_store: Arc<StaticKeystore>) -> Result<Box<dyn Reply>, Rejection> {
let jwks = key_store.export_jwks().await.wrap_error()?; let mut key_store: &StaticKeystore = key_store.as_ref();
let jwks = key_store.ready().await?.call(()).await?;
Ok(Box::new(warp::reply::json(&jwks))) Ok(Box::new(warp::reply::json(&jwks)))
} }

View File

@@ -12,20 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::layers::{get::Get, json::Json}; use crate::layers::json::Json;
pub trait ServiceExt: Sized { pub trait ServiceExt: Sized {
fn json<T>(self) -> Json<Self, T>; fn json<T>(self) -> Json<Self, T>;
fn get(self) -> Get<Self>;
} }
impl<S> ServiceExt for S { impl<S> ServiceExt for S {
fn json<T>(self) -> Json<Self, T> { fn json<T>(self) -> Json<Self, T> {
Json::new(self) Json::new(self)
} }
fn get(self) -> Get<Self> {
Get::new(self)
}
} }

View File

@@ -54,14 +54,14 @@ pub type ClientResponse<B> = Response<
DecompressionBody<BoxBody<<B as http_body::Body>::Data, <B as http_body::Body>::Error>>, DecompressionBody<BoxBody<<B as http_body::Body>::Data, <B as http_body::Body>::Error>>,
>; >;
impl<ReqBody, ResBody, S> Layer<S> for ClientLayer<ReqBody> impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody>
where where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E> + Clone + Send + 'static,
ReqBody: http_body::Body + Default + Send + 'static, ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static, ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static, ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static, S::Future: Send + 'static,
S::Error: Into<BoxError>, E: Into<BoxError>,
{ {
type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>; type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>;

View File

@@ -1,66 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use http::{Request, Uri};
use tower::{Layer, Service};
pub struct Get<S> {
inner: S,
}
impl<S> Get<S> {
pub const fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S> Service<Uri> for Get<S>
where
S: Service<Request<http_body::Empty<()>>>,
{
type Error = S::Error;
type Response = S::Response;
type Future = S::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Uri) -> Self::Future {
let body = http_body::Empty::new();
let req = Request::builder()
.method("GET")
.uri(req)
.body(body)
.unwrap();
self.inner.call(req)
}
}
#[derive(Default, Clone, Copy)]
pub struct GetLayer;
impl<S> Layer<S> for GetLayer
where
S: Service<Request<http_body::Empty<()>>>,
{
type Service = Get<S>;
fn layer(&self, inner: S) -> Self::Service {
Get::new(inner)
}
}

View File

@@ -53,6 +53,7 @@ impl<S, B> Error<S, B> {
} }
} }
#[derive(Clone)]
pub struct Json<S, T> { pub struct Json<S, T> {
inner: S, inner: S,
_t: PhantomData<T>, _t: PhantomData<T>,

View File

@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
pub(crate) mod client; pub(crate) mod client;
pub(crate) mod get;
pub(crate) mod json; pub(crate) mod json;
pub(crate) mod server; pub(crate) mod server;
pub(crate) mod trace; pub(crate) mod trace;

View File

@@ -29,15 +29,16 @@ pub struct ServerLayer<ReqBody> {
_t: PhantomData<ReqBody>, _t: PhantomData<ReqBody>,
} }
impl<ReqBody, ResBody, S> Layer<S> for ServerLayer<ReqBody> impl<ReqBody, ResBody, S, E> Layer<S> for ServerLayer<ReqBody>
where where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E> + Clone + Send + 'static,
ReqBody: http_body::Body + 'static, ReqBody: http_body::Body + 'static,
ResBody: http_body::Body + Sync + Send + 'static, ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static, ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static, S::Future: Send + 'static,
S::Error: Into<BoxError>, E: Into<BoxError>,
{ {
#[allow(clippy::type_complexity)]
type Service = BoxCloneService< type Service = BoxCloneService<
Request<ReqBody>, Request<ReqBody>,
Response<CompressionBody<BoxBody<ResBody::Data, ResBody::Error>>>, Response<CompressionBody<BoxBody<ResBody::Data, ResBody::Error>>>,

View File

@@ -14,7 +14,9 @@ crypto-mac = { version = "0.11.1", features = ["std"] }
digest = "0.10.1" digest = "0.10.1"
ecdsa = { version = "0.13.4", features = ["sign", "verify", "pem", "pkcs8"] } ecdsa = { version = "0.13.4", features = ["sign", "verify", "pem", "pkcs8"] }
elliptic-curve = { version = "0.11.12", features = ["ecdh", "pem"] } elliptic-curve = { version = "0.11.12", features = ["ecdh", "pem"] }
futures-util = "0.3.21"
hmac = "0.12.0" hmac = "0.12.0"
http = "0.2.6"
p256 = { version = "0.10.1", features = ["ecdsa", "pem", "pkcs8"] } p256 = { version = "0.10.1", features = ["ecdsa", "pem", "pkcs8"] }
pkcs1 = { version = "0.3.3", features = ["pem", "pkcs8"] } pkcs1 = { version = "0.3.3", features = ["pem", "pkcs8"] }
pkcs8 = { version = "0.8.0", features = ["pem"] } pkcs8 = { version = "0.8.0", features = ["pem"] }
@@ -29,7 +31,11 @@ sha2 = "0.10.1"
signature = "1.4.0" signature = "1.4.0"
thiserror = "1.0.30" thiserror = "1.0.30"
tokio = { version = "1.16.1", features = ["macros", "rt", "sync"] } tokio = { version = "1.16.1", features = ["macros", "rt", "sync"] }
tower = "0.4.11"
url = { version = "2.2.2", features = ["serde"] } url = { version = "2.2.2", features = ["serde"] }
zeroize = "1.4.3" zeroize = "1.4.3"
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }
[dev-dependencies]
mas-http = { path = "../http" }

View File

@@ -30,7 +30,7 @@ use crate::{jwk::JsonWebKey, SigningKeystore, VerifyingKeystore};
#[serde_as] #[serde_as]
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct JwtHeader { pub struct JwtHeader {
alg: JsonWebSignatureAlg, alg: JsonWebSignatureAlg,
@@ -163,7 +163,7 @@ where
Ok(format!("{}.{}", header, payload)) Ok(format!("{}.{}", header, payload))
} }
pub async fn sign<S: SigningKeystore>(&self, store: S) -> anyhow::Result<JsonWebTokenParts> { pub async fn sign<S: SigningKeystore>(&self, store: &S) -> anyhow::Result<JsonWebTokenParts> {
let payload = self.serialize()?; let payload = self.serialize()?;
let signature = store.sign(&self.header, payload.as_bytes()).await?; let signature = store.sign(&self.header, payload.as_bytes()).await?;
Ok(JsonWebTokenParts { payload, signature }) Ok(JsonWebTokenParts { payload, signature })
@@ -205,22 +205,24 @@ impl JsonWebTokenParts {
Ok(decoded) Ok(decoded)
} }
pub async fn verify<T, S: VerifyingKeystore>( pub fn verify<T, S: VerifyingKeystore>(
&self, &self,
decoded: &DecodedJsonWebToken<T>, decoded: &DecodedJsonWebToken<T>,
store: S, store: &S,
) -> anyhow::Result<()> { ) -> S::Future
store where
.verify(&decoded.header, self.payload.as_bytes(), &self.signature) S::Error: std::error::Error + Send + Sync + 'static,
.await?; {
store.verify(&decoded.header, self.payload.as_bytes(), &self.signature)
Ok(())
} }
pub async fn decode_and_verify<T: DeserializeOwned, S: VerifyingKeystore>( pub async fn decode_and_verify<T: DeserializeOwned, S: VerifyingKeystore>(
&self, &self,
store: S, store: &S,
) -> anyhow::Result<DecodedJsonWebToken<T>> { ) -> anyhow::Result<DecodedJsonWebToken<T>>
where
S::Error: std::error::Error + Send + Sync + 'static,
{
let decoded = self.decode()?; let decoded = self.decode()?;
self.verify(&decoded, store).await?; self.verify(&decoded, store).await?;
Ok(decoded) Ok(decoded)

View File

@@ -1,276 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use anyhow::bail;
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use digest::Digest;
use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg};
use rsa::{PublicKey, RsaPublicKey};
use sha2::{Sha256, Sha384, Sha512};
use signature::{Signature, Verifier};
use tokio::sync::RwLock;
use crate::{ExportJwks, JsonWebKeySet, JwtHeader, VerifyingKeystore};
pub struct StaticJwksStore {
key_set: JsonWebKeySet,
index: HashMap<(JsonWebKeyType, String), usize>,
}
impl StaticJwksStore {
#[must_use]
pub fn new(key_set: JsonWebKeySet) -> Self {
let index = key_set
.iter()
.enumerate()
.filter_map(|(index, key)| {
let kid = key.kid()?.to_string();
let kty = key.kty();
Some(((kty, kid), index))
})
.collect();
Self { key_set, index }
}
fn find_rsa_key(&self, kid: String) -> anyhow::Result<RsaPublicKey> {
let index = *self
.index
.get(&(JsonWebKeyType::Rsa, kid))
.ok_or_else(|| anyhow::anyhow!("key not found"))?;
let key = self
.key_set
.get(index)
.ok_or_else(|| anyhow::anyhow!("invalid index"))?;
let key = key.params().clone().try_into()?;
Ok(key)
}
fn find_ecdsa_key(&self, kid: String) -> anyhow::Result<ecdsa::VerifyingKey<p256::NistP256>> {
let index = *self
.index
.get(&(JsonWebKeyType::Ec, kid))
.ok_or_else(|| anyhow::anyhow!("key not found"))?;
let key = self
.key_set
.get(index)
.ok_or_else(|| anyhow::anyhow!("invalid index"))?;
let key = key.params().clone().try_into()?;
Ok(key)
}
}
#[async_trait]
impl VerifyingKeystore for &StaticJwksStore {
async fn verify(
self,
header: &JwtHeader,
payload: &[u8],
signature: &[u8],
) -> anyhow::Result<()> {
let kid = header
.kid()
.ok_or_else(|| anyhow::anyhow!("missing kid"))?
.to_string();
match header.alg() {
JsonWebSignatureAlg::Rs256 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha256::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Rs384 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha384::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Rs512 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha512::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Es256 => {
let key = self.find_ecdsa_key(kid)?;
let signature = ecdsa::Signature::from_bytes(signature)?;
key.verify(payload, &signature)?;
}
_ => bail!("unsupported algorithm"),
};
Ok(())
}
}
enum RemoteKeySet {
Pending,
Errored {
at: DateTime<Utc>,
error: anyhow::Error,
},
Fulfilled {
at: DateTime<Utc>,
store: StaticJwksStore,
},
}
impl Default for RemoteKeySet {
fn default() -> Self {
Self::Pending
}
}
impl RemoteKeySet {
fn fullfill(&mut self, key_set: JsonWebKeySet) {
*self = Self::Fulfilled {
at: Utc::now(),
store: StaticJwksStore::new(key_set),
}
}
fn error(&mut self, error: anyhow::Error) {
*self = Self::Errored {
at: Utc::now(),
error,
}
}
fn should_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } if *at - now > Duration::minutes(5) => true,
Self::Fulfilled { at, .. } if *at - now > Duration::hours(1) => true,
_ => false,
}
}
fn should_force_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } | Self::Fulfilled { at, .. }
if *at - now > Duration::minutes(5) =>
{
true
}
_ => false,
}
}
}
pub struct JwksStore<T>
where
T: ExportJwks,
{
exporter: T,
cache: RwLock<RemoteKeySet>,
}
impl<T: ExportJwks> JwksStore<T> {
pub fn new(exporter: T) -> Self {
Self {
exporter,
cache: RwLock::default(),
}
}
async fn should_refresh(&self) -> bool {
let cache = self.cache.read().await;
cache.should_refresh()
}
async fn refresh(&self) {
let mut cache = self.cache.write().await;
if cache.should_force_refresh() {
let jwks = self.exporter.export_jwks().await;
match jwks {
Ok(jwks) => cache.fullfill(jwks),
Err(err) => cache.error(err),
}
}
}
}
#[async_trait]
impl<T: ExportJwks + Send + Sync> VerifyingKeystore for &JwksStore<T> {
async fn verify(
self,
header: &JwtHeader,
payload: &[u8],
signature: &[u8],
) -> anyhow::Result<()> {
if self.should_refresh().await {
self.refresh().await;
}
let cache = self.cache.read().await;
// TODO: we could bubble up the underlying error here
let store = match &*cache {
RemoteKeySet::Pending => bail!("inconsistent cache state"),
RemoteKeySet::Errored { error, .. } => bail!("cache in error state {}", error),
RemoteKeySet::Fulfilled { store, .. } => store,
};
store.verify(header, payload, signature).await?;
Ok(())
}
}

View File

@@ -0,0 +1,160 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use chrono::{DateTime, Duration, Utc};
use futures_util::future::BoxFuture;
use thiserror::Error;
use tokio::sync::RwLock;
use tower::{
util::{BoxCloneService, ServiceExt},
BoxError, Service,
};
use super::StaticJwksStore;
use crate::{JsonWebKeySet, JwtHeader, VerifyingKeystore};
#[derive(Debug, Error)]
pub enum Error {
#[error("cache in inconsistent state")]
InconsistentCache,
#[error(transparent)]
Cached(Arc<BoxError>),
#[error("todo")]
Todo,
#[error(transparent)]
Verification(#[from] super::static_store::Error),
}
enum State<E> {
Pending,
Errored {
at: DateTime<Utc>,
error: E,
},
Fulfilled {
at: DateTime<Utc>,
store: StaticJwksStore,
},
}
impl<E> Default for State<E> {
fn default() -> Self {
Self::Pending
}
}
impl<E> State<E> {
fn fullfill(&mut self, key_set: JsonWebKeySet) {
*self = Self::Fulfilled {
at: Utc::now(),
store: StaticJwksStore::new(key_set),
}
}
fn error(&mut self, error: E) {
*self = Self::Errored {
at: Utc::now(),
error,
}
}
fn should_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } if *at - now > Duration::minutes(5) => true,
Self::Fulfilled { at, .. } if *at - now > Duration::hours(1) => true,
_ => false,
}
}
fn should_force_refresh(&self) -> bool {
let now = Utc::now();
match self {
Self::Pending => true,
Self::Errored { at, .. } | Self::Fulfilled { at, .. }
if *at - now > Duration::minutes(5) =>
{
true
}
_ => false,
}
}
}
#[derive(Clone)]
pub struct DynamicJwksStore {
exporter: BoxCloneService<(), JsonWebKeySet, BoxError>,
cache: Arc<RwLock<State<Arc<BoxError>>>>,
}
impl DynamicJwksStore {
pub fn new<T>(exporter: T) -> Self
where
T: Service<(), Response = JsonWebKeySet, Error = BoxError> + Send + Clone + 'static,
T::Future: Send,
{
Self {
exporter: exporter.boxed_clone(),
cache: Arc::default(),
}
}
}
impl VerifyingKeystore for DynamicJwksStore {
type Error = Error;
type Future = BoxFuture<'static, Result<(), Self::Error>>;
fn verify(&self, header: &JwtHeader, payload: &[u8], signature: &[u8]) -> Self::Future {
let cache = self.cache.clone();
let exporter = self.exporter.clone();
let header = header.clone();
let payload = payload.to_owned();
let signature = signature.to_owned();
let fut = async move {
if cache.read().await.should_refresh() {
let mut cache = cache.write().await;
if cache.should_force_refresh() {
let jwks = async move { exporter.ready_oneshot().await?.call(()).await }.await;
match jwks {
Ok(jwks) => cache.fullfill(jwks),
Err(err) => cache.error(Arc::new(err)),
}
}
}
let cache = cache.read().await;
// TODO: we could bubble up the underlying error here
let store = match &*cache {
State::Pending => return Err(Error::InconsistentCache),
State::Errored { error, .. } => return Err(Error::Cached(error.clone())),
State::Fulfilled { store, .. } => store,
};
store.verify(&header, &payload, &signature).await?;
Ok(())
};
Box::pin(fut)
}
}

View File

@@ -0,0 +1,18 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod dynamic_store;
mod static_store;
pub use self::{dynamic_store::DynamicJwksStore, static_store::StaticJwksStore};

View File

@@ -0,0 +1,196 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, future::Ready};
use digest::Digest;
use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg};
use rsa::{PublicKey, RsaPublicKey};
use sha2::{Sha256, Sha384, Sha512};
use signature::{Signature, Verifier};
use thiserror::Error;
use crate::{JsonWebKeySet, JwtHeader, VerifyingKeystore};
#[derive(Debug, Error)]
pub enum Error {
#[error("key not found")]
KeyNotFound,
#[error("invalid index")]
InvalidIndex,
#[error(r#"missing "kid" field in header"#)]
MissingKid,
#[error(transparent)]
Rsa(#[from] rsa::errors::Error),
#[error("unsupported algorithm {alg}")]
UnsupportedAlgorithm { alg: JsonWebSignatureAlg },
#[error(transparent)]
Signature(#[from] signature::Error),
#[error("invalid {kty} key {kid}")]
InvalidKey {
kty: JsonWebKeyType,
kid: String,
source: anyhow::Error,
},
}
pub struct StaticJwksStore {
key_set: JsonWebKeySet,
index: HashMap<(JsonWebKeyType, String), usize>,
}
impl StaticJwksStore {
#[must_use]
pub fn new(key_set: JsonWebKeySet) -> Self {
let index = key_set
.iter()
.enumerate()
.filter_map(|(index, key)| {
let kid = key.kid()?.to_string();
let kty = key.kty();
Some(((kty, kid), index))
})
.collect();
Self { key_set, index }
}
fn find_rsa_key(&self, kid: String) -> Result<RsaPublicKey, Error> {
let index = *self
.index
.get(&(JsonWebKeyType::Rsa, kid.clone()))
.ok_or(Error::KeyNotFound)?;
let key = self.key_set.get(index).ok_or(Error::InvalidIndex)?;
let key = key
.params()
.clone()
.try_into()
.map_err(|source| Error::InvalidKey {
kty: JsonWebKeyType::Rsa,
kid,
source,
})?;
Ok(key)
}
fn find_ecdsa_key(&self, kid: String) -> Result<ecdsa::VerifyingKey<p256::NistP256>, Error> {
let index = *self
.index
.get(&(JsonWebKeyType::Ec, kid.clone()))
.ok_or(Error::KeyNotFound)?;
let key = self.key_set.get(index).ok_or(Error::InvalidIndex)?;
let key = key
.params()
.clone()
.try_into()
.map_err(|source| Error::InvalidKey {
kty: JsonWebKeyType::Ec,
kid,
source,
})?;
Ok(key)
}
fn verify_sync(
&self,
header: &JwtHeader,
payload: &[u8],
signature: &[u8],
) -> Result<(), Error> {
let kid = header.kid().ok_or(Error::MissingKid)?.to_string();
match header.alg() {
JsonWebSignatureAlg::Rs256 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha256::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Rs384 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha384::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Rs512 => {
let key = self.find_rsa_key(kid)?;
let digest = {
let mut digest = Sha512::new();
digest.update(&payload);
digest.finalize()
};
key.verify(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)),
&digest,
signature,
)?;
}
JsonWebSignatureAlg::Es256 => {
let key = self.find_ecdsa_key(kid)?;
let signature = ecdsa::Signature::from_bytes(signature)?;
key.verify(payload, &signature)?;
}
alg => return Err(Error::UnsupportedAlgorithm { alg }),
};
Ok(())
}
}
impl VerifyingKeystore for StaticJwksStore {
type Error = Error;
type Future = Ready<Result<(), Self::Error>>;
fn verify(&self, header: &JwtHeader, payload: &[u8], signature: &[u8]) -> Self::Future {
std::future::ready(self.verify_sync(header, payload, signature))
}
}

View File

@@ -18,8 +18,8 @@ mod static_keystore;
mod traits; mod traits;
pub use self::{ pub use self::{
jwks::{JwksStore, StaticJwksStore}, jwks::{DynamicJwksStore, StaticJwksStore},
shared_secret::SharedSecret, shared_secret::SharedSecret,
static_keystore::StaticKeystore, static_keystore::StaticKeystore,
traits::{ExportJwks, SigningKeystore, VerifyingKeystore}, traits::{SigningKeystore, VerifyingKeystore},
}; };

View File

@@ -12,17 +12,31 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::HashSet; use std::{collections::HashSet, future::Ready};
use anyhow::bail; use anyhow::bail;
use async_trait::async_trait; use async_trait::async_trait;
use digest::{InvalidLength, MacError};
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use sha2::{Sha256, Sha384, Sha512}; use sha2::{Sha256, Sha384, Sha512};
use thiserror::Error;
use super::{SigningKeystore, VerifyingKeystore}; use super::{SigningKeystore, VerifyingKeystore};
use crate::JwtHeader; use crate::JwtHeader;
#[derive(Debug, Error)]
pub enum Error {
#[error("invalid key")]
InvalidKey(#[from] InvalidLength),
#[error("unsupported algorithm {alg}")]
UnsupportedAlgorithm { alg: JsonWebSignatureAlg },
#[error("signature verification failed")]
Verification(#[from] MacError),
}
pub struct SharedSecret<'a> { pub struct SharedSecret<'a> {
inner: &'a [u8], inner: &'a [u8],
} }
@@ -33,11 +47,42 @@ impl<'a> SharedSecret<'a> {
inner: source.as_ref(), inner: source.as_ref(),
} }
} }
fn verify_sync(
&self,
header: &JwtHeader,
payload: &[u8],
signature: &[u8],
) -> Result<(), Error> {
match header.alg() {
JsonWebSignatureAlg::Hs256 => {
let mut mac = Hmac::<Sha256>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.into())?;
}
JsonWebSignatureAlg::Hs384 => {
let mut mac = Hmac::<Sha384>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.into())?;
}
JsonWebSignatureAlg::Hs512 => {
let mut mac = Hmac::<Sha512>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.into())?;
}
alg => return Err(Error::UnsupportedAlgorithm { alg }),
};
Ok(())
}
} }
#[async_trait] #[async_trait]
impl<'a> SigningKeystore for &SharedSecret<'a> { impl<'a> SigningKeystore for SharedSecret<'a> {
fn supported_algorithms(self) -> HashSet<JsonWebSignatureAlg> { fn supported_algorithms(&self) -> HashSet<JsonWebSignatureAlg> {
let mut algorithms = HashSet::with_capacity(3); let mut algorithms = HashSet::with_capacity(3);
algorithms.insert(JsonWebSignatureAlg::Hs256); algorithms.insert(JsonWebSignatureAlg::Hs256);
@@ -47,7 +92,7 @@ impl<'a> SigningKeystore for &SharedSecret<'a> {
algorithms algorithms
} }
async fn prepare_header(self, alg: JsonWebSignatureAlg) -> anyhow::Result<JwtHeader> { async fn prepare_header(&self, alg: JsonWebSignatureAlg) -> anyhow::Result<JwtHeader> {
if !matches!( if !matches!(
alg, alg,
JsonWebSignatureAlg::Hs256 | JsonWebSignatureAlg::Hs384 | JsonWebSignatureAlg::Hs512, JsonWebSignatureAlg::Hs256 | JsonWebSignatureAlg::Hs384 | JsonWebSignatureAlg::Hs512,
@@ -58,7 +103,7 @@ impl<'a> SigningKeystore for &SharedSecret<'a> {
Ok(JwtHeader::new(alg)) Ok(JwtHeader::new(alg))
} }
async fn sign(self, header: &JwtHeader, msg: &[u8]) -> anyhow::Result<Vec<u8>> { async fn sign(&self, header: &JwtHeader, msg: &[u8]) -> anyhow::Result<Vec<u8>> {
// TODO: do the signing in a blocking task // TODO: do the signing in a blocking task
// TODO: should we bail out if the key is too small? // TODO: should we bail out if the key is too small?
let signature = match header.alg() { let signature = match header.alg() {
@@ -87,38 +132,12 @@ impl<'a> SigningKeystore for &SharedSecret<'a> {
} }
} }
#[async_trait] impl<'a> VerifyingKeystore for SharedSecret<'a> {
impl<'a> VerifyingKeystore for &SharedSecret<'a> { type Error = Error;
async fn verify( type Future = Ready<Result<(), Self::Error>>;
self,
header: &JwtHeader,
payload: &[u8],
signature: &[u8],
) -> anyhow::Result<()> {
// TODO: do the verification in a blocking task
match header.alg() {
JsonWebSignatureAlg::Hs256 => {
let mut mac = Hmac::<Sha256>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.try_into()?)?;
}
JsonWebSignatureAlg::Hs384 => { fn verify(&self, header: &JwtHeader, payload: &[u8], signature: &[u8]) -> Self::Future {
let mut mac = Hmac::<Sha384>::new_from_slice(self.inner)?; std::future::ready(self.verify_sync(header, payload, signature))
mac.update(payload);
mac.verify(signature.try_into()?)?;
}
JsonWebSignatureAlg::Hs512 => {
let mut mac = Hmac::<Sha512>::new_from_slice(self.inner)?;
mac.update(payload);
mac.verify(signature.try_into()?)?;
}
_ => bail!("unsupported algorithm"),
};
Ok(())
} }
} }

View File

@@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::{HashMap, HashSet}; use std::{
collections::{HashMap, HashSet},
convert::Infallible,
future::Ready,
task::Poll,
};
use anyhow::bail; use anyhow::bail;
use async_trait::async_trait; use async_trait::async_trait;
@@ -26,8 +31,9 @@ use pkcs8::{DecodePrivateKey, EncodePublicKey};
use rsa::{PublicKey as _, RsaPrivateKey, RsaPublicKey}; use rsa::{PublicKey as _, RsaPrivateKey, RsaPublicKey};
use sha2::{Sha256, Sha384, Sha512}; use sha2::{Sha256, Sha384, Sha512};
use signature::{Signature, Signer, Verifier}; use signature::{Signature, Signer, Verifier};
use tower::Service;
use super::{ExportJwks, SigningKeystore, VerifyingKeystore}; use super::{SigningKeystore, VerifyingKeystore};
use crate::{JsonWebKey, JsonWebKeySet, JwtHeader}; use crate::{JsonWebKey, JsonWebKeySet, JwtHeader};
// Generate with // Generate with
@@ -123,135 +129,9 @@ impl StaticKeystore {
self.es256_keys.insert(kid, key); self.es256_keys.insert(kid, key);
Ok(()) Ok(())
} }
}
#[async_trait] fn verify_sync(
impl SigningKeystore for &StaticKeystore { &self,
fn supported_algorithms(self) -> HashSet<JsonWebSignatureAlg> {
let has_rsa = !self.rsa_keys.is_empty();
let has_es256 = !self.es256_keys.is_empty();
let capacity = (if has_rsa { 3 } else { 0 }) + (if has_es256 { 1 } else { 0 });
let mut algorithms = HashSet::with_capacity(capacity);
if has_rsa {
algorithms.insert(JsonWebSignatureAlg::Rs256);
algorithms.insert(JsonWebSignatureAlg::Rs384);
algorithms.insert(JsonWebSignatureAlg::Rs512);
}
if has_es256 {
algorithms.insert(JsonWebSignatureAlg::Es256);
}
algorithms
}
async fn prepare_header(self, alg: JsonWebSignatureAlg) -> anyhow::Result<JwtHeader> {
let header = JwtHeader::new(alg);
let kid = match alg {
JsonWebSignatureAlg::Rs256
| JsonWebSignatureAlg::Rs384
| JsonWebSignatureAlg::Rs512 => self
.rsa_keys
.keys()
.next()
.ok_or_else(|| anyhow::anyhow!("no RSA keys in keystore"))?,
JsonWebSignatureAlg::Es256 => self
.es256_keys
.keys()
.next()
.ok_or_else(|| anyhow::anyhow!("no ECDSA keys in keystore"))?,
_ => bail!("unsupported algorithm"),
};
Ok(header.with_kid(kid))
}
async fn sign(self, header: &JwtHeader, msg: &[u8]) -> anyhow::Result<Vec<u8>> {
let kid = header
.kid()
.ok_or_else(|| anyhow::anyhow!("missing kid from the JWT header"))?;
// TODO: do the signing in a blocking task
let signature = match header.alg() {
JsonWebSignatureAlg::Rs256 => {
let key = self
.rsa_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("RSA key not found in key store"))?;
let digest = {
let mut digest = Sha256::new();
digest.update(&msg);
digest.finalize()
};
key.sign(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
&digest,
)?
}
JsonWebSignatureAlg::Rs384 => {
let key = self
.rsa_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("RSA key not found in key store"))?;
let digest = {
let mut digest = Sha384::new();
digest.update(&msg);
digest.finalize()
};
key.sign(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)),
&digest,
)?
}
JsonWebSignatureAlg::Rs512 => {
let key = self
.rsa_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("RSA key not found in key store"))?;
let digest = {
let mut digest = Sha512::new();
digest.update(&msg);
digest.finalize()
};
key.sign(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)),
&digest,
)?
}
JsonWebSignatureAlg::Es256 => {
let key = self
.es256_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("ECDSA key not found in key store"))?;
let signature = key.try_sign(msg)?;
let signature: &[u8] = signature.as_ref();
signature.to_vec()
}
_ => bail!("Unsupported algorithm"),
};
Ok(signature)
}
}
#[async_trait]
impl VerifyingKeystore for &StaticKeystore {
async fn verify(
self,
header: &JwtHeader, header: &JwtHeader,
payload: &[u8], payload: &[u8],
signature: &[u8], signature: &[u8],
@@ -344,8 +224,147 @@ impl VerifyingKeystore for &StaticKeystore {
} }
#[async_trait] #[async_trait]
impl ExportJwks for StaticKeystore { impl SigningKeystore for StaticKeystore {
async fn export_jwks(&self) -> anyhow::Result<JsonWebKeySet> { fn supported_algorithms(&self) -> HashSet<JsonWebSignatureAlg> {
let has_rsa = !self.rsa_keys.is_empty();
let has_es256 = !self.es256_keys.is_empty();
let capacity = (if has_rsa { 3 } else { 0 }) + (if has_es256 { 1 } else { 0 });
let mut algorithms = HashSet::with_capacity(capacity);
if has_rsa {
algorithms.insert(JsonWebSignatureAlg::Rs256);
algorithms.insert(JsonWebSignatureAlg::Rs384);
algorithms.insert(JsonWebSignatureAlg::Rs512);
}
if has_es256 {
algorithms.insert(JsonWebSignatureAlg::Es256);
}
algorithms
}
async fn prepare_header(&self, alg: JsonWebSignatureAlg) -> anyhow::Result<JwtHeader> {
let header = JwtHeader::new(alg);
let kid = match alg {
JsonWebSignatureAlg::Rs256
| JsonWebSignatureAlg::Rs384
| JsonWebSignatureAlg::Rs512 => self
.rsa_keys
.keys()
.next()
.ok_or_else(|| anyhow::anyhow!("no RSA keys in keystore"))?,
JsonWebSignatureAlg::Es256 => self
.es256_keys
.keys()
.next()
.ok_or_else(|| anyhow::anyhow!("no ECDSA keys in keystore"))?,
_ => bail!("unsupported algorithm"),
};
Ok(header.with_kid(kid))
}
async fn sign(&self, header: &JwtHeader, msg: &[u8]) -> anyhow::Result<Vec<u8>> {
let kid = header
.kid()
.ok_or_else(|| anyhow::anyhow!("missing kid from the JWT header"))?;
// TODO: do the signing in a blocking task
let signature = match header.alg() {
JsonWebSignatureAlg::Rs256 => {
let key = self
.rsa_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("RSA key not found in key store"))?;
let digest = {
let mut digest = Sha256::new();
digest.update(&msg);
digest.finalize()
};
key.sign(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_256)),
&digest,
)?
}
JsonWebSignatureAlg::Rs384 => {
let key = self
.rsa_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("RSA key not found in key store"))?;
let digest = {
let mut digest = Sha384::new();
digest.update(&msg);
digest.finalize()
};
key.sign(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_384)),
&digest,
)?
}
JsonWebSignatureAlg::Rs512 => {
let key = self
.rsa_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("RSA key not found in key store"))?;
let digest = {
let mut digest = Sha512::new();
digest.update(&msg);
digest.finalize()
};
key.sign(
rsa::PaddingScheme::new_pkcs1v15_sign(Some(rsa::Hash::SHA2_512)),
&digest,
)?
}
JsonWebSignatureAlg::Es256 => {
let key = self
.es256_keys
.get(kid)
.ok_or_else(|| anyhow::anyhow!("ECDSA key not found in key store"))?;
let signature = key.try_sign(msg)?;
let signature: &[u8] = signature.as_ref();
signature.to_vec()
}
_ => bail!("Unsupported algorithm"),
};
Ok(signature)
}
}
impl VerifyingKeystore for StaticKeystore {
type Error = anyhow::Error;
type Future = Ready<Result<(), Self::Error>>;
fn verify(&self, header: &JwtHeader, msg: &[u8], signature: &[u8]) -> Self::Future {
std::future::ready(self.verify_sync(header, msg, signature))
}
}
impl Service<()> for &StaticKeystore {
type Future = Ready<Result<Self::Response, Self::Error>>;
type Response = JsonWebKeySet;
type Error = Infallible;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ()) -> Self::Future {
let rsa = self.rsa_keys.iter().map(|(kid, key)| { let rsa = self.rsa_keys.iter().map(|(kid, key)| {
let pubkey = RsaPublicKey::from(key); let pubkey = RsaPublicKey::from(key);
JsonWebKey::new(pubkey.into()) JsonWebKey::new(pubkey.into())
@@ -362,7 +381,7 @@ impl ExportJwks for StaticKeystore {
}); });
let keys = rsa.chain(es256).collect(); let keys = rsa.chain(es256).collect();
Ok(JsonWebKeySet::new(keys)) std::future::ready(Ok(JsonWebKeySet::new(keys)))
} }
} }

View File

@@ -12,28 +12,78 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::HashSet; use std::{collections::HashSet, future::Future, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::{
future::{Either, MapErr},
TryFutureExt,
};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use thiserror::Error;
use crate::{JsonWebKeySet, JwtHeader}; use crate::JwtHeader;
#[async_trait] #[async_trait]
pub trait SigningKeystore { pub trait SigningKeystore {
fn supported_algorithms(self) -> HashSet<JsonWebSignatureAlg>; fn supported_algorithms(&self) -> HashSet<JsonWebSignatureAlg>;
async fn prepare_header(self, alg: JsonWebSignatureAlg) -> anyhow::Result<JwtHeader>; async fn prepare_header(&self, alg: JsonWebSignatureAlg) -> anyhow::Result<JwtHeader>;
async fn sign(self, header: &JwtHeader, msg: &[u8]) -> anyhow::Result<Vec<u8>>; async fn sign(&self, header: &JwtHeader, msg: &[u8]) -> anyhow::Result<Vec<u8>>;
} }
#[async_trait]
pub trait VerifyingKeystore { pub trait VerifyingKeystore {
async fn verify(self, header: &JwtHeader, msg: &[u8], signature: &[u8]) -> anyhow::Result<()>; type Error;
type Future: Future<Output = Result<(), Self::Error>>;
fn verify(&self, header: &JwtHeader, msg: &[u8], signature: &[u8]) -> Self::Future;
} }
#[async_trait] #[derive(Debug, Error)]
pub trait ExportJwks { pub enum EitherError<A, B> {
async fn export_jwks(&self) -> anyhow::Result<JsonWebKeySet>; #[error(transparent)]
Left(A),
#[error(transparent)]
Right(B),
}
impl<L, R> VerifyingKeystore for Either<L, R>
where
L: VerifyingKeystore,
R: VerifyingKeystore,
{
type Error = EitherError<L::Error, R::Error>;
#[allow(clippy::type_complexity)]
type Future = Either<
MapErr<L::Future, fn(L::Error) -> Self::Error>,
MapErr<R::Future, fn(R::Error) -> Self::Error>,
>;
fn verify(&self, header: &JwtHeader, msg: &[u8], signature: &[u8]) -> Self::Future {
match self {
Either::Left(left) => Either::Left(
left.verify(header, msg, signature)
.map_err(EitherError::Left),
),
Either::Right(right) => Either::Right(
right
.verify(header, msg, signature)
.map_err(EitherError::Right),
),
}
}
}
impl<T> VerifyingKeystore for Arc<T>
where
T: VerifyingKeystore,
{
type Error = T::Error;
type Future = T::Future;
fn verify(&self, header: &JwtHeader, msg: &[u8], signature: &[u8]) -> Self::Future {
self.as_ref().verify(header, msg, signature)
}
} }

View File

@@ -26,7 +26,7 @@ pub use self::{
jwk::{JsonWebKey, JsonWebKeySet}, jwk::{JsonWebKey, JsonWebKeySet},
jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader}, jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader},
keystore::{ keystore::{
ExportJwks, JwksStore, SharedSecret, SigningKeystore, StaticJwksStore, StaticKeystore, DynamicJwksStore, SharedSecret, SigningKeystore, StaticJwksStore, StaticKeystore,
VerifyingKeystore, VerifyingKeystore,
}, },
}; };

View File

@@ -27,6 +27,7 @@ rand = "0.8.4"
mime = "0.3.16" mime = "0.3.16"
bincode = "1.3.3" bincode = "1.3.3"
crc = "2.1.0" crc = "2.1.0"
url = "2.2.2"
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-config = { path = "../config" } mas-config = { path = "../config" }
@@ -35,4 +36,6 @@ mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }
url = "2.2.2"
[dev-dependencies]
tower = { version = "0.4.11", features = ["util"] }

View File

@@ -95,6 +95,7 @@ enum ClientAuthenticationError {
impl Reject for ClientAuthenticationError {} impl Reject for ClientAuthenticationError {}
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
async fn authenticate_client<T>( async fn authenticate_client<T>(
clients_config: ClientsConfig, clients_config: ClientsConfig,
audience: String, audience: String,
@@ -204,7 +205,8 @@ async fn authenticate_client<T>(
let auth_method = match &client.client_auth_method { let auth_method = match &client.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(jwks) => { ClientAuthMethodConfig::PrivateKeyJwt(jwks) => {
let store = jwks.key_store(); let store = jwks.key_store();
token.verify(&decoded, &store).await.wrap_error()?; let fut = token.verify(&decoded, &store);
fut.await.wrap_error()?;
OAuthClientAuthenticationMethod::PrivateKeyJwt OAuthClientAuthenticationMethod::PrivateKeyJwt
} }
@@ -239,6 +241,8 @@ async fn authenticate_client<T>(
} }
}; };
tracing::Span::current().record("enduser.id", &client.client_id.as_str());
Ok((auth_method, client.clone(), body)) Ok((auth_method, client.clone(), body))
} }
@@ -291,8 +295,9 @@ struct ClientAuthForm<T> {
mod tests { mod tests {
use headers::authorization::Credentials; use headers::authorization::Credentials;
use mas_config::{ClientAuthMethodConfig, ConfigurationSection}; use mas_config::{ClientAuthMethodConfig, ConfigurationSection};
use mas_jose::{ExportJwks, SigningKeystore, StaticKeystore}; use mas_jose::{SigningKeystore, StaticKeystore};
use serde_json::json; use serde_json::json;
use tower::{Service, ServiceExt};
use super::*; use super::*;
@@ -343,7 +348,8 @@ mod tests {
}); });
let store = client_private_keystore(); let store = client_private_keystore();
let jwks = store.export_jwks().await.unwrap(); let jwks = (&store).ready().await.unwrap().call(()).await.unwrap();
//let jwks = store.export_jwks().await.unwrap();
config.push(ClientConfig { config.push(ClientConfig {
client_id: "private-key-jwt".to_string(), client_id: "private-key-jwt".to_string(),
client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.clone().into()), client_auth_method: ClientAuthMethodConfig::PrivateKeyJwt(jwks.clone().into()),