You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Axum migration: WIP client authentication
This commit is contained in:
@ -15,14 +15,22 @@ data-encoding = "2.3.2"
|
||||
futures-util = "0.3.21"
|
||||
headers = "0.3.7"
|
||||
http = "0.2.6"
|
||||
mime = "0.3.16"
|
||||
rand = "0.8.5"
|
||||
serde = "1.0.136"
|
||||
serde_with = "1.12.0"
|
||||
serde_urlencoded = "0.7.1"
|
||||
serde_json = "1.0.79"
|
||||
sqlx = "0.5.11"
|
||||
thiserror = "1.0.30"
|
||||
tokio = "1.17.0"
|
||||
tracing = "0.1.32"
|
||||
url = "2.2.2"
|
||||
|
||||
# TODO: remove the config dependency by moving out the encrypter
|
||||
mas-config = { path = "../config" }
|
||||
mas-templates = { path = "../templates" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
mas-jose = { path = "../jose" }
|
||||
mas-iana = { path = "../iana" }
|
||||
|
560
crates/axum-utils/src/client_authorization.rs
Normal file
560
crates/axum-utils/src/client_authorization.rs
Normal file
@ -0,0 +1,560 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::HttpBody,
|
||||
extract::{
|
||||
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
||||
Form, FromRequest, RequestParts, TypedHeader,
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use headers::{authorization::Basic, Authorization};
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::{Client, StorageBackend};
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::{
|
||||
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebTokenParts, JwtHeader, SharedSecret,
|
||||
StaticJwksStore,
|
||||
};
|
||||
use mas_storage::{
|
||||
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use serde_json::Value;
|
||||
use sqlx::PgExecutor;
|
||||
|
||||
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizedForm<F = ()> {
|
||||
client_id: Option<String>,
|
||||
client_secret: Option<String>,
|
||||
client_assertion_type: Option<String>,
|
||||
client_assertion: Option<String>,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: F,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Credentials {
|
||||
None {
|
||||
client_id: String,
|
||||
},
|
||||
ClientSecretBasic {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
},
|
||||
ClientSecretPost {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
},
|
||||
ClientAssertionJwtBearer {
|
||||
client_id: String,
|
||||
jwt: JsonWebTokenParts,
|
||||
header: Box<JwtHeader>,
|
||||
claims: HashMap<String, Value>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Credentials {
|
||||
pub async fn fetch(
|
||||
&self,
|
||||
executor: impl PgExecutor<'_>,
|
||||
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
|
||||
let client_id = match self {
|
||||
Credentials::None { client_id }
|
||||
| Credentials::ClientSecretBasic { client_id, .. }
|
||||
| Credentials::ClientSecretPost { client_id, .. }
|
||||
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
|
||||
};
|
||||
|
||||
lookup_client_by_client_id(executor, client_id).await
|
||||
}
|
||||
|
||||
pub async fn verify<S: StorageBackend>(
|
||||
&self,
|
||||
encrypter: &Encrypter,
|
||||
method: OAuthClientAuthenticationMethod,
|
||||
client: &Client<S>,
|
||||
) -> Result<(), CredentialsVerificationError> {
|
||||
match (self, method) {
|
||||
(Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
|
||||
|
||||
(
|
||||
Credentials::ClientSecretPost { client_secret, .. },
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost,
|
||||
)
|
||||
| (
|
||||
Credentials::ClientSecretBasic { client_secret, .. },
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic,
|
||||
) => {
|
||||
// Decrypt the client_secret
|
||||
let encrypted_client_secret = client
|
||||
.encrypted_client_secret
|
||||
.as_ref()
|
||||
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
|
||||
|
||||
let decrypted_client_secret = encrypter
|
||||
.decrypt_string(encrypted_client_secret)
|
||||
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
|
||||
|
||||
// Check if the client_secret matches
|
||||
if client_secret.as_bytes() != decrypted_client_secret {
|
||||
return Err(CredentialsVerificationError::ClientSecretMismatch);
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt,
|
||||
) => {
|
||||
// Get the client JWKS
|
||||
let jwks = client
|
||||
.jwks
|
||||
.as_ref()
|
||||
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
|
||||
|
||||
let store: Either<StaticJwksStore, DynamicJwksStore> = jwks_key_store(jwks);
|
||||
jwt.verify(header, &store)
|
||||
.await
|
||||
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
|
||||
}
|
||||
|
||||
(
|
||||
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
) => {
|
||||
// Decrypt the client_secret
|
||||
let encrypted_client_secret = client
|
||||
.encrypted_client_secret
|
||||
.as_ref()
|
||||
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
|
||||
|
||||
let decrypted_client_secret = encrypter
|
||||
.decrypt_string(encrypted_client_secret)
|
||||
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
|
||||
|
||||
let store = SharedSecret::new(&decrypted_client_secret);
|
||||
jwt.verify(header, &store)
|
||||
.await
|
||||
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
|
||||
}
|
||||
|
||||
(_, _) => {
|
||||
return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn jwks_key_store(
|
||||
_jwks: &mas_data_model::JwksOrJwksUri,
|
||||
) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub enum CredentialsVerificationError {
|
||||
DecryptionError,
|
||||
InvalidClientConfig,
|
||||
ClientSecretMismatch,
|
||||
AuthenticationMethodMismatch,
|
||||
InvalidAssertionSignature,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ClientAuthorization<F = ()> {
|
||||
credentials: Credentials,
|
||||
form: Option<F>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ClientAuthorizationError {
|
||||
InvalidHeader,
|
||||
BadForm(FailedToDeserializeQueryString),
|
||||
ClientIdMismatch { credential: String, form: String },
|
||||
UnsupportedClientAssertion { client_assertion_type: String },
|
||||
MissingCredentials,
|
||||
InvalidRequest,
|
||||
InvalidAssertion,
|
||||
InternalError(Box<dyn std::error::Error>),
|
||||
}
|
||||
|
||||
impl IntoResponse for ClientAuthorizationError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<B, F> FromRequest<B> for ClientAuthorization<F>
|
||||
where
|
||||
B: Send + HttpBody,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
F: DeserializeOwned,
|
||||
{
|
||||
type Rejection = ClientAuthorizationError;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let header = TypedHeader::<Authorization<Basic>>::from_request(req).await;
|
||||
|
||||
// Take the Authorization header
|
||||
let credentials_from_header = match header {
|
||||
Ok(header) => Some((header.username().to_string(), header.password().to_string())),
|
||||
Err(err) => match err.reason() {
|
||||
// If it's missing it is fine
|
||||
TypedHeaderRejectionReason::Missing => None,
|
||||
// If the header could not be parsed, return the error
|
||||
TypedHeaderRejectionReason::Error(_) => {
|
||||
return Err(ClientAuthorizationError::InvalidHeader)
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Take the form value
|
||||
let (
|
||||
client_id_from_form,
|
||||
client_secret_from_form,
|
||||
client_assertion_type,
|
||||
client_assertion,
|
||||
form,
|
||||
) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
||||
Ok(Form(form)) => (
|
||||
form.client_id,
|
||||
form.client_secret,
|
||||
form.client_assertion_type,
|
||||
form.client_assertion,
|
||||
Some(form.inner),
|
||||
),
|
||||
// If it is not a form, continue
|
||||
Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
|
||||
// If the form could not be read, return a Bad Request error
|
||||
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
||||
return Err(ClientAuthorizationError::BadForm(err))
|
||||
}
|
||||
// Other errors (body read twice, byte stream broke) return an internal error
|
||||
Err(e) => return Err(ClientAuthorizationError::InternalError(Box::new(e))),
|
||||
};
|
||||
|
||||
// And now, figure out the actual auth method
|
||||
let credentials = match (
|
||||
credentials_from_header,
|
||||
client_id_from_form,
|
||||
client_secret_from_form,
|
||||
client_assertion_type,
|
||||
client_assertion,
|
||||
) {
|
||||
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
|
||||
if let Some(client_id_from_form) = client_id_from_form {
|
||||
// If the client_id was in the body, verify it matches with the header
|
||||
if client_id != client_id_from_form {
|
||||
return Err(ClientAuthorizationError::ClientIdMismatch {
|
||||
credential: client_id,
|
||||
form: client_id_from_form,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Credentials::ClientSecretBasic {
|
||||
client_id,
|
||||
client_secret,
|
||||
}
|
||||
}
|
||||
|
||||
(None, Some(client_id), Some(client_secret), None, None) => {
|
||||
// Got both client_id and client_secret from the form
|
||||
Credentials::ClientSecretPost {
|
||||
client_id,
|
||||
client_secret,
|
||||
}
|
||||
}
|
||||
|
||||
(None, Some(client_id), None, None, None) => {
|
||||
// Only got a client_id in the form
|
||||
Credentials::None { client_id }
|
||||
}
|
||||
|
||||
(
|
||||
None,
|
||||
client_id_from_form,
|
||||
None,
|
||||
Some(client_assertion_type),
|
||||
Some(client_assertion),
|
||||
) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
|
||||
// Got a JWT bearer client_assertion
|
||||
|
||||
let jwt: JsonWebTokenParts = client_assertion
|
||||
.parse()
|
||||
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
|
||||
let decoded: DecodedJsonWebToken<HashMap<String, Value>> = jwt
|
||||
.decode()
|
||||
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
|
||||
let (header, claims) = decoded.split();
|
||||
|
||||
let client_id = if let Some(Value::String(client_id)) = claims.get("sub") {
|
||||
client_id.clone()
|
||||
} else {
|
||||
return Err(ClientAuthorizationError::InvalidAssertion);
|
||||
};
|
||||
|
||||
if let Some(client_id_from_form) = client_id_from_form {
|
||||
// If the client_id was in the body, verify it matches the one in the JWT
|
||||
if client_id != client_id_from_form {
|
||||
return Err(ClientAuthorizationError::ClientIdMismatch {
|
||||
credential: client_id,
|
||||
form: client_id_from_form,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Credentials::ClientAssertionJwtBearer {
|
||||
client_id,
|
||||
jwt,
|
||||
header: Box::new(header),
|
||||
claims,
|
||||
}
|
||||
}
|
||||
|
||||
(None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
|
||||
// Got another unsupported client_assertion
|
||||
return Err(ClientAuthorizationError::UnsupportedClientAssertion {
|
||||
client_assertion_type,
|
||||
});
|
||||
}
|
||||
|
||||
(None, None, None, None, None) => {
|
||||
// Special case when there are no credentials anywhere
|
||||
return Err(ClientAuthorizationError::MissingCredentials);
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Every other combination is an invalid request
|
||||
return Err(ClientAuthorizationError::InvalidRequest);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ClientAuthorization { credentials, form })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::body::{Bytes, Full};
|
||||
use http::{Method, Request};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn none_test() {
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
credentials: Credentials::None {
|
||||
client_id: "client-id".to_string(),
|
||||
},
|
||||
form: Some(serde_json::json!({"foo": "bar"})),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_basic_test() {
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
credentials: Credentials::ClientSecretBasic {
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: "client-secret".to_string(),
|
||||
},
|
||||
form: Some(serde_json::json!({"foo": "bar"})),
|
||||
}
|
||||
);
|
||||
|
||||
// client_id in both header and body
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
credentials: Credentials::ClientSecretBasic {
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: "client-secret".to_string(),
|
||||
},
|
||||
form: Some(serde_json::json!({"foo": "bar"})),
|
||||
}
|
||||
);
|
||||
|
||||
// client_id in both header and body mismatch
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=mismatch-id&foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
||||
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
|
||||
));
|
||||
|
||||
// Invalid header
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(http::header::AUTHORIZATION, "Basic invalid")
|
||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
||||
Err(ClientAuthorizationError::InvalidHeader),
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_post_test() {
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::<Bytes>::new(
|
||||
"client_id=client-id&client_secret=client-secret&foo=bar".into(),
|
||||
))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
credentials: Credentials::ClientSecretPost {
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: "client-secret".to_string(),
|
||||
},
|
||||
form: Some(serde_json::json!({"foo": "bar"})),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_assertion_test() {
|
||||
// Signed with client_secret = "client-secret"
|
||||
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::<Bytes>::new(
|
||||
format!(
|
||||
"client_assertion_type={}&client_assertion={}&foo=bar",
|
||||
JWT_BEARER_CLIENT_ASSERTION, jwt,
|
||||
)
|
||||
.into(),
|
||||
))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let authz = ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
|
||||
|
||||
let (client_id, _jwt, _header, _claims) = if let Credentials::ClientAssertionJwtBearer {
|
||||
client_id,
|
||||
jwt,
|
||||
header,
|
||||
claims,
|
||||
} = authz.credentials
|
||||
{
|
||||
(client_id, jwt, header, claims)
|
||||
} else {
|
||||
panic!("expected a JWT client_assertion");
|
||||
};
|
||||
|
||||
assert_eq!(client_id, "client-id");
|
||||
// TODO: test more things
|
||||
}
|
||||
}
|
@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod client_authorization;
|
||||
pub mod cookies;
|
||||
pub mod csrf;
|
||||
pub mod fancy_error;
|
||||
|
@ -280,26 +280,34 @@ where
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let header = TypedHeader::<Authorization<Bearer>>::from_request(req).await;
|
||||
|
||||
// Take the Authorization header
|
||||
let token_from_header = match header {
|
||||
Ok(header) => Some(header.token().to_string()),
|
||||
Err(err) => match err.reason() {
|
||||
// If it's missing it is fine
|
||||
TypedHeaderRejectionReason::Missing => None,
|
||||
// If the header could not be parsed, return the error
|
||||
TypedHeaderRejectionReason::Error(_) => {
|
||||
return Err(UserAuthorizationError::InvalidHeader)
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Take the form value
|
||||
let (token_from_form, form) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
||||
Ok(Form(form)) => (form.access_token, Some(form.inner)),
|
||||
// If it is not a form, continue
|
||||
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
|
||||
// If the form could not be read, return a Bad Request error
|
||||
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
||||
return Err(UserAuthorizationError::BadForm(err))
|
||||
}
|
||||
// Other errors (body read twice, byte stream broke) return an internal error
|
||||
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
||||
};
|
||||
|
||||
let access_token = match (token_from_header, token_from_form) {
|
||||
// Ensure the token should not be in both the form and the access token
|
||||
(Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
|
||||
(Some(t), None) => AccessToken::Header(t),
|
||||
(None, Some(t)) => AccessToken::Form(t),
|
||||
|
@ -32,6 +32,7 @@ chacha20poly1305 = { version = "0.9.0", features = ["std"] }
|
||||
elliptic-curve = { version = "0.11.12", features = ["pem", "pkcs8"] }
|
||||
pem-rfc7468 = "0.3.1"
|
||||
cookie = { version = "0.16.0", features = ["private", "key-expansion"] }
|
||||
data-encoding = "2.3.2"
|
||||
|
||||
indoc = "1.0.4"
|
||||
|
||||
|
@ -21,6 +21,7 @@ use chacha20poly1305::{
|
||||
ChaCha20Poly1305,
|
||||
};
|
||||
use cookie::Key;
|
||||
use data_encoding::BASE64;
|
||||
use mas_jose::StaticKeystore;
|
||||
use pkcs8::DecodePrivateKey;
|
||||
use rsa::{
|
||||
@ -42,6 +43,7 @@ pub struct Encrypter {
|
||||
aead: Arc<ChaCha20Poly1305>,
|
||||
}
|
||||
|
||||
// TODO: move this somewhere else
|
||||
impl Encrypter {
|
||||
/// Creates an [`Encrypter`] out of an encryption key
|
||||
#[must_use]
|
||||
@ -75,6 +77,41 @@ impl Encrypter {
|
||||
let encrypted = self.aead.decrypt(nonce, encrypted)?;
|
||||
Ok(encrypted)
|
||||
}
|
||||
|
||||
/// Encrypt a payload to a self-contained base64-encoded string
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Will return `Err` when the payload failed to encrypt
|
||||
pub fn encryt_to_string(&self, decrypted: &[u8]) -> anyhow::Result<String> {
|
||||
let nonce = rand::random();
|
||||
let encrypted = self.encrypt(&nonce, decrypted)?;
|
||||
let encrypted = [&nonce[..], &encrypted].concat();
|
||||
let encrypted = BASE64.encode(&encrypted);
|
||||
Ok(encrypted)
|
||||
}
|
||||
|
||||
/// Decrypt a payload from a self-contained base64-encoded string
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Will return `Err` when the payload failed to decrypt
|
||||
pub fn decrypt_string(&self, encrypted: &str) -> anyhow::Result<Vec<u8>> {
|
||||
let encrypted = BASE64.decode(encrypted.as_bytes())?;
|
||||
|
||||
let nonce: &[u8; 12] = encrypted
|
||||
.get(0..12)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
|
||||
.try_into()?;
|
||||
|
||||
let payload = encrypted
|
||||
.get(12..)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
|
||||
|
||||
let decrypted_client_secret = self.decrypt(nonce, payload)?;
|
||||
|
||||
Ok(decrypted_client_secret)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Encrypter> for Key {
|
||||
|
@ -30,7 +30,7 @@ use crate::{jwk::JsonWebKey, SigningKeystore, VerifyingKeystore};
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct JwtHeader {
|
||||
alg: JsonWebSignatureAlg,
|
||||
|
||||
@ -126,7 +126,7 @@ impl FromStr for JwtHeader {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct JsonWebTokenParts {
|
||||
payload: String,
|
||||
signature: Vec<u8>,
|
||||
@ -178,6 +178,14 @@ impl<T> DecodedJsonWebToken<T> {
|
||||
pub fn claims(&self) -> &T {
|
||||
&self.payload
|
||||
}
|
||||
|
||||
pub fn header(&self) -> &JwtHeader {
|
||||
&self.header
|
||||
}
|
||||
|
||||
pub fn split(self) -> (JwtHeader, T) {
|
||||
(self.header, self.payload)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FromStr for DecodedJsonWebToken<T>
|
||||
@ -205,15 +213,11 @@ impl JsonWebTokenParts {
|
||||
Ok(decoded)
|
||||
}
|
||||
|
||||
pub fn verify<T, S: VerifyingKeystore>(
|
||||
&self,
|
||||
decoded: &DecodedJsonWebToken<T>,
|
||||
store: &S,
|
||||
) -> S::Future
|
||||
pub fn verify<S: VerifyingKeystore>(&self, header: &JwtHeader, store: &S) -> S::Future
|
||||
where
|
||||
S::Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
store.verify(&decoded.header, self.payload.as_bytes(), &self.signature)
|
||||
store.verify(header, self.payload.as_bytes(), &self.signature)
|
||||
}
|
||||
|
||||
pub async fn decode_and_verify<T: DeserializeOwned, S: VerifyingKeystore>(
|
||||
@ -224,7 +228,7 @@ impl JsonWebTokenParts {
|
||||
S::Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let decoded = self.decode()?;
|
||||
self.verify(&decoded, store).await?;
|
||||
self.verify(&decoded.header, store).await?;
|
||||
Ok(decoded)
|
||||
}
|
||||
|
||||
|
@ -318,7 +318,7 @@ async fn authenticate_client<T>(
|
||||
})?;
|
||||
|
||||
let store = SharedSecret::new(&client_secret);
|
||||
let fut = token.verify(&decoded, &store);
|
||||
let fut = token.verify(decoded.header(), &store);
|
||||
fut.await.wrap_error()?;
|
||||
}
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
|
||||
@ -329,7 +329,7 @@ async fn authenticate_client<T>(
|
||||
})?;
|
||||
|
||||
let store = jwks_key_store(jwks);
|
||||
let fut = token.verify(&decoded, &store);
|
||||
let fut = token.verify(decoded.header(), &store);
|
||||
fut.await.wrap_error()?;
|
||||
}
|
||||
_ => {
|
||||
|
Reference in New Issue
Block a user