1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Axum migration: WIP client authentication

This commit is contained in:
Quentin Gliech
2022-04-04 10:16:40 +02:00
parent 9dad21475e
commit ed49624c3a
9 changed files with 638 additions and 11 deletions

View File

@ -15,14 +15,22 @@ data-encoding = "2.3.2"
futures-util = "0.3.21"
headers = "0.3.7"
http = "0.2.6"
mime = "0.3.16"
rand = "0.8.5"
serde = "1.0.136"
serde_with = "1.12.0"
serde_urlencoded = "0.7.1"
serde_json = "1.0.79"
sqlx = "0.5.11"
thiserror = "1.0.30"
tokio = "1.17.0"
tracing = "0.1.32"
url = "2.2.2"
# TODO: remove the config dependency by moving out the encrypter
mas-config = { path = "../config" }
mas-templates = { path = "../templates" }
mas-storage = { path = "../storage" }
mas-data-model = { path = "../data-model" }
mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" }

View File

@ -0,0 +1,560 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use async_trait::async_trait;
use axum::{
body::HttpBody,
extract::{
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
Form, FromRequest, RequestParts, TypedHeader,
},
response::IntoResponse,
};
use headers::{authorization::Basic, Authorization};
use mas_config::Encrypter;
use mas_data_model::{Client, StorageBackend};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebTokenParts, JwtHeader, SharedSecret,
StaticJwksStore,
};
use mas_storage::{
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
PostgresqlBackend,
};
use serde::{de::DeserializeOwned, Deserialize};
use serde_json::Value;
use sqlx::PgExecutor;
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
#[derive(Deserialize)]
struct AuthorizedForm<F = ()> {
client_id: Option<String>,
client_secret: Option<String>,
client_assertion_type: Option<String>,
client_assertion: Option<String>,
#[serde(flatten)]
inner: F,
}
#[derive(Debug, PartialEq, Eq)]
pub enum Credentials {
None {
client_id: String,
},
ClientSecretBasic {
client_id: String,
client_secret: String,
},
ClientSecretPost {
client_id: String,
client_secret: String,
},
ClientAssertionJwtBearer {
client_id: String,
jwt: JsonWebTokenParts,
header: Box<JwtHeader>,
claims: HashMap<String, Value>,
},
}
impl Credentials {
pub async fn fetch(
&self,
executor: impl PgExecutor<'_>,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let client_id = match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
};
lookup_client_by_client_id(executor, client_id).await
}
pub async fn verify<S: StorageBackend>(
&self,
encrypter: &Encrypter,
method: OAuthClientAuthenticationMethod,
client: &Client<S>,
) -> Result<(), CredentialsVerificationError> {
match (self, method) {
(Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
(
Credentials::ClientSecretPost { client_secret, .. },
OAuthClientAuthenticationMethod::ClientSecretPost,
)
| (
Credentials::ClientSecretBasic { client_secret, .. },
OAuthClientAuthenticationMethod::ClientSecretBasic,
) => {
// Decrypt the client_secret
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let decrypted_client_secret = encrypter
.decrypt_string(encrypted_client_secret)
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
// Check if the client_secret matches
if client_secret.as_bytes() != decrypted_client_secret {
return Err(CredentialsVerificationError::ClientSecretMismatch);
}
}
(
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
OAuthClientAuthenticationMethod::ClientSecretJwt,
) => {
// Get the client JWKS
let jwks = client
.jwks
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let store: Either<StaticJwksStore, DynamicJwksStore> = jwks_key_store(jwks);
jwt.verify(header, &store)
.await
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
(
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
OAuthClientAuthenticationMethod::PrivateKeyJwt,
) => {
// Decrypt the client_secret
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let decrypted_client_secret = encrypter
.decrypt_string(encrypted_client_secret)
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
let store = SharedSecret::new(&decrypted_client_secret);
jwt.verify(header, &store)
.await
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
(_, _) => {
return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
}
};
Ok(())
}
}
fn jwks_key_store(
_jwks: &mas_data_model::JwksOrJwksUri,
) -> Either<StaticJwksStore, DynamicJwksStore> {
todo!()
}
pub enum CredentialsVerificationError {
DecryptionError,
InvalidClientConfig,
ClientSecretMismatch,
AuthenticationMethodMismatch,
InvalidAssertionSignature,
}
#[derive(Debug, PartialEq, Eq)]
pub struct ClientAuthorization<F = ()> {
credentials: Credentials,
form: Option<F>,
}
#[derive(Debug)]
pub enum ClientAuthorizationError {
InvalidHeader,
BadForm(FailedToDeserializeQueryString),
ClientIdMismatch { credential: String, form: String },
UnsupportedClientAssertion { client_assertion_type: String },
MissingCredentials,
InvalidRequest,
InvalidAssertion,
InternalError(Box<dyn std::error::Error>),
}
impl IntoResponse for ClientAuthorizationError {
fn into_response(self) -> axum::response::Response {
todo!()
}
}
#[async_trait]
impl<B, F> FromRequest<B> for ClientAuthorization<F>
where
B: Send + HttpBody,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
F: DeserializeOwned,
{
type Rejection = ClientAuthorizationError;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let header = TypedHeader::<Authorization<Basic>>::from_request(req).await;
// Take the Authorization header
let credentials_from_header = match header {
Ok(header) => Some((header.username().to_string(), header.password().to_string())),
Err(err) => match err.reason() {
// If it's missing it is fine
TypedHeaderRejectionReason::Missing => None,
// If the header could not be parsed, return the error
TypedHeaderRejectionReason::Error(_) => {
return Err(ClientAuthorizationError::InvalidHeader)
}
},
};
// Take the form value
let (
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
form,
) = match Form::<AuthorizedForm<F>>::from_request(req).await {
Ok(Form(form)) => (
form.client_id,
form.client_secret,
form.client_assertion_type,
form.client_assertion,
Some(form.inner),
),
// If it is not a form, continue
Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
// If the form could not be read, return a Bad Request error
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
return Err(ClientAuthorizationError::BadForm(err))
}
// Other errors (body read twice, byte stream broke) return an internal error
Err(e) => return Err(ClientAuthorizationError::InternalError(Box::new(e))),
};
// And now, figure out the actual auth method
let credentials = match (
credentials_from_header,
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
) {
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
if let Some(client_id_from_form) = client_id_from_form {
// If the client_id was in the body, verify it matches with the header
if client_id != client_id_from_form {
return Err(ClientAuthorizationError::ClientIdMismatch {
credential: client_id,
form: client_id_from_form,
});
}
}
Credentials::ClientSecretBasic {
client_id,
client_secret,
}
}
(None, Some(client_id), Some(client_secret), None, None) => {
// Got both client_id and client_secret from the form
Credentials::ClientSecretPost {
client_id,
client_secret,
}
}
(None, Some(client_id), None, None, None) => {
// Only got a client_id in the form
Credentials::None { client_id }
}
(
None,
client_id_from_form,
None,
Some(client_assertion_type),
Some(client_assertion),
) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
// Got a JWT bearer client_assertion
let jwt: JsonWebTokenParts = client_assertion
.parse()
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
let decoded: DecodedJsonWebToken<HashMap<String, Value>> = jwt
.decode()
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
let (header, claims) = decoded.split();
let client_id = if let Some(Value::String(client_id)) = claims.get("sub") {
client_id.clone()
} else {
return Err(ClientAuthorizationError::InvalidAssertion);
};
if let Some(client_id_from_form) = client_id_from_form {
// If the client_id was in the body, verify it matches the one in the JWT
if client_id != client_id_from_form {
return Err(ClientAuthorizationError::ClientIdMismatch {
credential: client_id,
form: client_id_from_form,
});
}
}
Credentials::ClientAssertionJwtBearer {
client_id,
jwt,
header: Box::new(header),
claims,
}
}
(None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
// Got another unsupported client_assertion
return Err(ClientAuthorizationError::UnsupportedClientAssertion {
client_assertion_type,
});
}
(None, None, None, None, None) => {
// Special case when there are no credentials anywhere
return Err(ClientAuthorizationError::MissingCredentials);
}
_ => {
// Every other combination is an invalid request
return Err(ClientAuthorizationError::InvalidRequest);
}
};
Ok(ClientAuthorization { credentials, form })
}
}
#[cfg(test)]
mod tests {
use axum::body::{Bytes, Full};
use http::{Method, Request};
use super::*;
#[tokio::test]
async fn none_test() {
let mut req = RequestParts::new(
Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
.unwrap(),
);
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::None {
client_id: "client-id".to_string(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
#[tokio::test]
async fn client_secret_basic_test() {
let mut req = RequestParts::new(
Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Full::<Bytes>::new("foo=bar".into()))
.unwrap(),
);
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretBasic {
client_id: "client-id".to_string(),
client_secret: "client-secret".to_string(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
// client_id in both header and body
let mut req = RequestParts::new(
Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
.unwrap(),
);
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretBasic {
client_id: "client-id".to_string(),
client_secret: "client-secret".to_string(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
// client_id in both header and body mismatch
let mut req = RequestParts::new(
Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Full::<Bytes>::new("client_id=mismatch-id&foo=bar".into()))
.unwrap(),
);
assert!(matches!(
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
));
// Invalid header
let mut req = RequestParts::new(
Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(http::header::AUTHORIZATION, "Basic invalid")
.body(Full::<Bytes>::new("foo=bar".into()))
.unwrap(),
);
assert!(matches!(
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
Err(ClientAuthorizationError::InvalidHeader),
));
}
#[tokio::test]
async fn client_secret_post_test() {
let mut req = RequestParts::new(
Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(Full::<Bytes>::new(
"client_id=client-id&client_secret=client-secret&foo=bar".into(),
))
.unwrap(),
);
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretPost {
client_id: "client-id".to_string(),
client_secret: "client-secret".to_string(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
#[tokio::test]
async fn client_assertion_test() {
// Signed with client_secret = "client-secret"
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
let mut req = RequestParts::new(
Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(Full::<Bytes>::new(
format!(
"client_assertion_type={}&client_assertion={}&foo=bar",
JWT_BEARER_CLIENT_ASSERTION, jwt,
)
.into(),
))
.unwrap(),
);
let authz = ClientAuthorization::<serde_json::Value>::from_request(&mut req)
.await
.unwrap();
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
let (client_id, _jwt, _header, _claims) = if let Credentials::ClientAssertionJwtBearer {
client_id,
jwt,
header,
claims,
} = authz.credentials
{
(client_id, jwt, header, claims)
} else {
panic!("expected a JWT client_assertion");
};
assert_eq!(client_id, "client-id");
// TODO: test more things
}
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod client_authorization;
pub mod cookies;
pub mod csrf;
pub mod fancy_error;

View File

@ -280,26 +280,34 @@ where
) -> Result<Self, Self::Rejection> {
let header = TypedHeader::<Authorization<Bearer>>::from_request(req).await;
// Take the Authorization header
let token_from_header = match header {
Ok(header) => Some(header.token().to_string()),
Err(err) => match err.reason() {
// If it's missing it is fine
TypedHeaderRejectionReason::Missing => None,
// If the header could not be parsed, return the error
TypedHeaderRejectionReason::Error(_) => {
return Err(UserAuthorizationError::InvalidHeader)
}
},
};
// Take the form value
let (token_from_form, form) = match Form::<AuthorizedForm<F>>::from_request(req).await {
Ok(Form(form)) => (form.access_token, Some(form.inner)),
// If it is not a form, continue
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
// If the form could not be read, return a Bad Request error
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
return Err(UserAuthorizationError::BadForm(err))
}
// Other errors (body read twice, byte stream broke) return an internal error
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
};
let access_token = match (token_from_header, token_from_form) {
// Ensure the token should not be in both the form and the access token
(Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
(Some(t), None) => AccessToken::Header(t),
(None, Some(t)) => AccessToken::Form(t),

View File

@ -32,6 +32,7 @@ chacha20poly1305 = { version = "0.9.0", features = ["std"] }
elliptic-curve = { version = "0.11.12", features = ["pem", "pkcs8"] }
pem-rfc7468 = "0.3.1"
cookie = { version = "0.16.0", features = ["private", "key-expansion"] }
data-encoding = "2.3.2"
indoc = "1.0.4"

View File

@ -21,6 +21,7 @@ use chacha20poly1305::{
ChaCha20Poly1305,
};
use cookie::Key;
use data_encoding::BASE64;
use mas_jose::StaticKeystore;
use pkcs8::DecodePrivateKey;
use rsa::{
@ -42,6 +43,7 @@ pub struct Encrypter {
aead: Arc<ChaCha20Poly1305>,
}
// TODO: move this somewhere else
impl Encrypter {
/// Creates an [`Encrypter`] out of an encryption key
#[must_use]
@ -75,6 +77,41 @@ impl Encrypter {
let encrypted = self.aead.decrypt(nonce, encrypted)?;
Ok(encrypted)
}
/// Encrypt a payload to a self-contained base64-encoded string
///
/// # Errors
///
/// Will return `Err` when the payload failed to encrypt
pub fn encryt_to_string(&self, decrypted: &[u8]) -> anyhow::Result<String> {
let nonce = rand::random();
let encrypted = self.encrypt(&nonce, decrypted)?;
let encrypted = [&nonce[..], &encrypted].concat();
let encrypted = BASE64.encode(&encrypted);
Ok(encrypted)
}
/// Decrypt a payload from a self-contained base64-encoded string
///
/// # Errors
///
/// Will return `Err` when the payload failed to decrypt
pub fn decrypt_string(&self, encrypted: &str) -> anyhow::Result<Vec<u8>> {
let encrypted = BASE64.decode(encrypted.as_bytes())?;
let nonce: &[u8; 12] = encrypted
.get(0..12)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
.try_into()?;
let payload = encrypted
.get(12..)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
let decrypted_client_secret = self.decrypt(nonce, payload)?;
Ok(decrypted_client_secret)
}
}
impl From<Encrypter> for Key {

View File

@ -30,7 +30,7 @@ use crate::{jwk::JsonWebKey, SigningKeystore, VerifyingKeystore};
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct JwtHeader {
alg: JsonWebSignatureAlg,
@ -126,7 +126,7 @@ impl FromStr for JwtHeader {
}
}
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub struct JsonWebTokenParts {
payload: String,
signature: Vec<u8>,
@ -178,6 +178,14 @@ impl<T> DecodedJsonWebToken<T> {
pub fn claims(&self) -> &T {
&self.payload
}
pub fn header(&self) -> &JwtHeader {
&self.header
}
pub fn split(self) -> (JwtHeader, T) {
(self.header, self.payload)
}
}
impl<T> FromStr for DecodedJsonWebToken<T>
@ -205,15 +213,11 @@ impl JsonWebTokenParts {
Ok(decoded)
}
pub fn verify<T, S: VerifyingKeystore>(
&self,
decoded: &DecodedJsonWebToken<T>,
store: &S,
) -> S::Future
pub fn verify<S: VerifyingKeystore>(&self, header: &JwtHeader, store: &S) -> S::Future
where
S::Error: std::error::Error + Send + Sync + 'static,
{
store.verify(&decoded.header, self.payload.as_bytes(), &self.signature)
store.verify(header, self.payload.as_bytes(), &self.signature)
}
pub async fn decode_and_verify<T: DeserializeOwned, S: VerifyingKeystore>(
@ -224,7 +228,7 @@ impl JsonWebTokenParts {
S::Error: std::error::Error + Send + Sync + 'static,
{
let decoded = self.decode()?;
self.verify(&decoded, store).await?;
self.verify(&decoded.header, store).await?;
Ok(decoded)
}

View File

@ -318,7 +318,7 @@ async fn authenticate_client<T>(
})?;
let store = SharedSecret::new(&client_secret);
let fut = token.verify(&decoded, &store);
let fut = token.verify(decoded.header(), &store);
fut.await.wrap_error()?;
}
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
@ -329,7 +329,7 @@ async fn authenticate_client<T>(
})?;
let store = jwks_key_store(jwks);
let fut = token.verify(&decoded, &store);
let fut = token.verify(decoded.header(), &store);
fut.await.wrap_error()?;
}
_ => {