You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-06 06:02:40 +03:00
Axum migration: WIP client authentication
This commit is contained in:
8
Cargo.lock
generated
8
Cargo.lock
generated
@@ -1940,14 +1940,21 @@ dependencies = [
|
|||||||
"futures-util",
|
"futures-util",
|
||||||
"headers",
|
"headers",
|
||||||
"http",
|
"http",
|
||||||
|
"mas-config",
|
||||||
"mas-data-model",
|
"mas-data-model",
|
||||||
|
"mas-iana",
|
||||||
|
"mas-jose",
|
||||||
"mas-storage",
|
"mas-storage",
|
||||||
"mas-templates",
|
"mas-templates",
|
||||||
|
"mime",
|
||||||
"rand",
|
"rand",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_urlencoded",
|
||||||
"serde_with",
|
"serde_with",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
@@ -2004,6 +2011,7 @@ dependencies = [
|
|||||||
"chacha20poly1305",
|
"chacha20poly1305",
|
||||||
"chrono",
|
"chrono",
|
||||||
"cookie",
|
"cookie",
|
||||||
|
"data-encoding",
|
||||||
"elliptic-curve",
|
"elliptic-curve",
|
||||||
"figment",
|
"figment",
|
||||||
"indoc",
|
"indoc",
|
||||||
|
@@ -15,14 +15,22 @@ data-encoding = "2.3.2"
|
|||||||
futures-util = "0.3.21"
|
futures-util = "0.3.21"
|
||||||
headers = "0.3.7"
|
headers = "0.3.7"
|
||||||
http = "0.2.6"
|
http = "0.2.6"
|
||||||
|
mime = "0.3.16"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
serde = "1.0.136"
|
serde = "1.0.136"
|
||||||
serde_with = "1.12.0"
|
serde_with = "1.12.0"
|
||||||
|
serde_urlencoded = "0.7.1"
|
||||||
|
serde_json = "1.0.79"
|
||||||
sqlx = "0.5.11"
|
sqlx = "0.5.11"
|
||||||
thiserror = "1.0.30"
|
thiserror = "1.0.30"
|
||||||
|
tokio = "1.17.0"
|
||||||
tracing = "0.1.32"
|
tracing = "0.1.32"
|
||||||
url = "2.2.2"
|
url = "2.2.2"
|
||||||
|
|
||||||
|
# TODO: remove the config dependency by moving out the encrypter
|
||||||
|
mas-config = { path = "../config" }
|
||||||
mas-templates = { path = "../templates" }
|
mas-templates = { path = "../templates" }
|
||||||
mas-storage = { path = "../storage" }
|
mas-storage = { path = "../storage" }
|
||||||
mas-data-model = { path = "../data-model" }
|
mas-data-model = { path = "../data-model" }
|
||||||
|
mas-jose = { path = "../jose" }
|
||||||
|
mas-iana = { path = "../iana" }
|
||||||
|
560
crates/axum-utils/src/client_authorization.rs
Normal file
560
crates/axum-utils/src/client_authorization.rs
Normal file
@@ -0,0 +1,560 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::{
|
||||||
|
body::HttpBody,
|
||||||
|
extract::{
|
||||||
|
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
||||||
|
Form, FromRequest, RequestParts, TypedHeader,
|
||||||
|
},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
use headers::{authorization::Basic, Authorization};
|
||||||
|
use mas_config::Encrypter;
|
||||||
|
use mas_data_model::{Client, StorageBackend};
|
||||||
|
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||||
|
use mas_jose::{
|
||||||
|
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebTokenParts, JwtHeader, SharedSecret,
|
||||||
|
StaticJwksStore,
|
||||||
|
};
|
||||||
|
use mas_storage::{
|
||||||
|
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
|
||||||
|
PostgresqlBackend,
|
||||||
|
};
|
||||||
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use sqlx::PgExecutor;
|
||||||
|
|
||||||
|
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct AuthorizedForm<F = ()> {
|
||||||
|
client_id: Option<String>,
|
||||||
|
client_secret: Option<String>,
|
||||||
|
client_assertion_type: Option<String>,
|
||||||
|
client_assertion: Option<String>,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
inner: F,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum Credentials {
|
||||||
|
None {
|
||||||
|
client_id: String,
|
||||||
|
},
|
||||||
|
ClientSecretBasic {
|
||||||
|
client_id: String,
|
||||||
|
client_secret: String,
|
||||||
|
},
|
||||||
|
ClientSecretPost {
|
||||||
|
client_id: String,
|
||||||
|
client_secret: String,
|
||||||
|
},
|
||||||
|
ClientAssertionJwtBearer {
|
||||||
|
client_id: String,
|
||||||
|
jwt: JsonWebTokenParts,
|
||||||
|
header: Box<JwtHeader>,
|
||||||
|
claims: HashMap<String, Value>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Credentials {
|
||||||
|
pub async fn fetch(
|
||||||
|
&self,
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
|
||||||
|
let client_id = match self {
|
||||||
|
Credentials::None { client_id }
|
||||||
|
| Credentials::ClientSecretBasic { client_id, .. }
|
||||||
|
| Credentials::ClientSecretPost { client_id, .. }
|
||||||
|
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
|
||||||
|
};
|
||||||
|
|
||||||
|
lookup_client_by_client_id(executor, client_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn verify<S: StorageBackend>(
|
||||||
|
&self,
|
||||||
|
encrypter: &Encrypter,
|
||||||
|
method: OAuthClientAuthenticationMethod,
|
||||||
|
client: &Client<S>,
|
||||||
|
) -> Result<(), CredentialsVerificationError> {
|
||||||
|
match (self, method) {
|
||||||
|
(Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
|
||||||
|
|
||||||
|
(
|
||||||
|
Credentials::ClientSecretPost { client_secret, .. },
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretPost,
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
Credentials::ClientSecretBasic { client_secret, .. },
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretBasic,
|
||||||
|
) => {
|
||||||
|
// Decrypt the client_secret
|
||||||
|
let encrypted_client_secret = client
|
||||||
|
.encrypted_client_secret
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
|
||||||
|
|
||||||
|
let decrypted_client_secret = encrypter
|
||||||
|
.decrypt_string(encrypted_client_secret)
|
||||||
|
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
|
||||||
|
|
||||||
|
// Check if the client_secret matches
|
||||||
|
if client_secret.as_bytes() != decrypted_client_secret {
|
||||||
|
return Err(CredentialsVerificationError::ClientSecretMismatch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretJwt,
|
||||||
|
) => {
|
||||||
|
// Get the client JWKS
|
||||||
|
let jwks = client
|
||||||
|
.jwks
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
|
||||||
|
|
||||||
|
let store: Either<StaticJwksStore, DynamicJwksStore> = jwks_key_store(jwks);
|
||||||
|
jwt.verify(header, &store)
|
||||||
|
.await
|
||||||
|
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
|
||||||
|
OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||||
|
) => {
|
||||||
|
// Decrypt the client_secret
|
||||||
|
let encrypted_client_secret = client
|
||||||
|
.encrypted_client_secret
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
|
||||||
|
|
||||||
|
let decrypted_client_secret = encrypter
|
||||||
|
.decrypt_string(encrypted_client_secret)
|
||||||
|
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
|
||||||
|
|
||||||
|
let store = SharedSecret::new(&decrypted_client_secret);
|
||||||
|
jwt.verify(header, &store)
|
||||||
|
.await
|
||||||
|
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
(_, _) => {
|
||||||
|
return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn jwks_key_store(
|
||||||
|
_jwks: &mas_data_model::JwksOrJwksUri,
|
||||||
|
) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum CredentialsVerificationError {
|
||||||
|
DecryptionError,
|
||||||
|
InvalidClientConfig,
|
||||||
|
ClientSecretMismatch,
|
||||||
|
AuthenticationMethodMismatch,
|
||||||
|
InvalidAssertionSignature,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct ClientAuthorization<F = ()> {
|
||||||
|
credentials: Credentials,
|
||||||
|
form: Option<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ClientAuthorizationError {
|
||||||
|
InvalidHeader,
|
||||||
|
BadForm(FailedToDeserializeQueryString),
|
||||||
|
ClientIdMismatch { credential: String, form: String },
|
||||||
|
UnsupportedClientAssertion { client_assertion_type: String },
|
||||||
|
MissingCredentials,
|
||||||
|
InvalidRequest,
|
||||||
|
InvalidAssertion,
|
||||||
|
InternalError(Box<dyn std::error::Error>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for ClientAuthorizationError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<B, F> FromRequest<B> for ClientAuthorization<F>
|
||||||
|
where
|
||||||
|
B: Send + HttpBody,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: std::error::Error + Send + Sync + 'static,
|
||||||
|
F: DeserializeOwned,
|
||||||
|
{
|
||||||
|
type Rejection = ClientAuthorizationError;
|
||||||
|
|
||||||
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
let header = TypedHeader::<Authorization<Basic>>::from_request(req).await;
|
||||||
|
|
||||||
|
// Take the Authorization header
|
||||||
|
let credentials_from_header = match header {
|
||||||
|
Ok(header) => Some((header.username().to_string(), header.password().to_string())),
|
||||||
|
Err(err) => match err.reason() {
|
||||||
|
// If it's missing it is fine
|
||||||
|
TypedHeaderRejectionReason::Missing => None,
|
||||||
|
// If the header could not be parsed, return the error
|
||||||
|
TypedHeaderRejectionReason::Error(_) => {
|
||||||
|
return Err(ClientAuthorizationError::InvalidHeader)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Take the form value
|
||||||
|
let (
|
||||||
|
client_id_from_form,
|
||||||
|
client_secret_from_form,
|
||||||
|
client_assertion_type,
|
||||||
|
client_assertion,
|
||||||
|
form,
|
||||||
|
) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
||||||
|
Ok(Form(form)) => (
|
||||||
|
form.client_id,
|
||||||
|
form.client_secret,
|
||||||
|
form.client_assertion_type,
|
||||||
|
form.client_assertion,
|
||||||
|
Some(form.inner),
|
||||||
|
),
|
||||||
|
// If it is not a form, continue
|
||||||
|
Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
|
||||||
|
// If the form could not be read, return a Bad Request error
|
||||||
|
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
||||||
|
return Err(ClientAuthorizationError::BadForm(err))
|
||||||
|
}
|
||||||
|
// Other errors (body read twice, byte stream broke) return an internal error
|
||||||
|
Err(e) => return Err(ClientAuthorizationError::InternalError(Box::new(e))),
|
||||||
|
};
|
||||||
|
|
||||||
|
// And now, figure out the actual auth method
|
||||||
|
let credentials = match (
|
||||||
|
credentials_from_header,
|
||||||
|
client_id_from_form,
|
||||||
|
client_secret_from_form,
|
||||||
|
client_assertion_type,
|
||||||
|
client_assertion,
|
||||||
|
) {
|
||||||
|
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
|
||||||
|
if let Some(client_id_from_form) = client_id_from_form {
|
||||||
|
// If the client_id was in the body, verify it matches with the header
|
||||||
|
if client_id != client_id_from_form {
|
||||||
|
return Err(ClientAuthorizationError::ClientIdMismatch {
|
||||||
|
credential: client_id,
|
||||||
|
form: client_id_from_form,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Credentials::ClientSecretBasic {
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(None, Some(client_id), Some(client_secret), None, None) => {
|
||||||
|
// Got both client_id and client_secret from the form
|
||||||
|
Credentials::ClientSecretPost {
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(None, Some(client_id), None, None, None) => {
|
||||||
|
// Only got a client_id in the form
|
||||||
|
Credentials::None { client_id }
|
||||||
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
client_id_from_form,
|
||||||
|
None,
|
||||||
|
Some(client_assertion_type),
|
||||||
|
Some(client_assertion),
|
||||||
|
) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
|
||||||
|
// Got a JWT bearer client_assertion
|
||||||
|
|
||||||
|
let jwt: JsonWebTokenParts = client_assertion
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
|
||||||
|
let decoded: DecodedJsonWebToken<HashMap<String, Value>> = jwt
|
||||||
|
.decode()
|
||||||
|
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
|
||||||
|
let (header, claims) = decoded.split();
|
||||||
|
|
||||||
|
let client_id = if let Some(Value::String(client_id)) = claims.get("sub") {
|
||||||
|
client_id.clone()
|
||||||
|
} else {
|
||||||
|
return Err(ClientAuthorizationError::InvalidAssertion);
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(client_id_from_form) = client_id_from_form {
|
||||||
|
// If the client_id was in the body, verify it matches the one in the JWT
|
||||||
|
if client_id != client_id_from_form {
|
||||||
|
return Err(ClientAuthorizationError::ClientIdMismatch {
|
||||||
|
credential: client_id,
|
||||||
|
form: client_id_from_form,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Credentials::ClientAssertionJwtBearer {
|
||||||
|
client_id,
|
||||||
|
jwt,
|
||||||
|
header: Box::new(header),
|
||||||
|
claims,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
|
||||||
|
// Got another unsupported client_assertion
|
||||||
|
return Err(ClientAuthorizationError::UnsupportedClientAssertion {
|
||||||
|
client_assertion_type,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
(None, None, None, None, None) => {
|
||||||
|
// Special case when there are no credentials anywhere
|
||||||
|
return Err(ClientAuthorizationError::MissingCredentials);
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
// Every other combination is an invalid request
|
||||||
|
return Err(ClientAuthorizationError::InvalidRequest);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ClientAuthorization { credentials, form })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use axum::body::{Bytes, Full};
|
||||||
|
use http::{Method, Request};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn none_test() {
|
||||||
|
let mut req = RequestParts::new(
|
||||||
|
Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(
|
||||||
|
http::header::CONTENT_TYPE,
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
|
)
|
||||||
|
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
ClientAuthorization {
|
||||||
|
credentials: Credentials::None {
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
},
|
||||||
|
form: Some(serde_json::json!({"foo": "bar"})),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_secret_basic_test() {
|
||||||
|
let mut req = RequestParts::new(
|
||||||
|
Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(
|
||||||
|
http::header::CONTENT_TYPE,
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
http::header::AUTHORIZATION,
|
||||||
|
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||||
|
)
|
||||||
|
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
ClientAuthorization {
|
||||||
|
credentials: Credentials::ClientSecretBasic {
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: "client-secret".to_string(),
|
||||||
|
},
|
||||||
|
form: Some(serde_json::json!({"foo": "bar"})),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// client_id in both header and body
|
||||||
|
let mut req = RequestParts::new(
|
||||||
|
Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(
|
||||||
|
http::header::CONTENT_TYPE,
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
http::header::AUTHORIZATION,
|
||||||
|
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||||
|
)
|
||||||
|
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
ClientAuthorization {
|
||||||
|
credentials: Credentials::ClientSecretBasic {
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: "client-secret".to_string(),
|
||||||
|
},
|
||||||
|
form: Some(serde_json::json!({"foo": "bar"})),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// client_id in both header and body mismatch
|
||||||
|
let mut req = RequestParts::new(
|
||||||
|
Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(
|
||||||
|
http::header::CONTENT_TYPE,
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
http::header::AUTHORIZATION,
|
||||||
|
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||||
|
)
|
||||||
|
.body(Full::<Bytes>::new("client_id=mismatch-id&foo=bar".into()))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
||||||
|
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Invalid header
|
||||||
|
let mut req = RequestParts::new(
|
||||||
|
Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(
|
||||||
|
http::header::CONTENT_TYPE,
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
|
)
|
||||||
|
.header(http::header::AUTHORIZATION, "Basic invalid")
|
||||||
|
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
||||||
|
Err(ClientAuthorizationError::InvalidHeader),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_secret_post_test() {
|
||||||
|
let mut req = RequestParts::new(
|
||||||
|
Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(
|
||||||
|
http::header::CONTENT_TYPE,
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
|
)
|
||||||
|
.body(Full::<Bytes>::new(
|
||||||
|
"client_id=client-id&client_secret=client-secret&foo=bar".into(),
|
||||||
|
))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
ClientAuthorization {
|
||||||
|
credentials: Credentials::ClientSecretPost {
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: "client-secret".to_string(),
|
||||||
|
},
|
||||||
|
form: Some(serde_json::json!({"foo": "bar"})),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_assertion_test() {
|
||||||
|
// Signed with client_secret = "client-secret"
|
||||||
|
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
|
||||||
|
let mut req = RequestParts::new(
|
||||||
|
Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.header(
|
||||||
|
http::header::CONTENT_TYPE,
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
|
)
|
||||||
|
.body(Full::<Bytes>::new(
|
||||||
|
format!(
|
||||||
|
"client_assertion_type={}&client_assertion={}&foo=bar",
|
||||||
|
JWT_BEARER_CLIENT_ASSERTION, jwt,
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let authz = ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
|
||||||
|
|
||||||
|
let (client_id, _jwt, _header, _claims) = if let Credentials::ClientAssertionJwtBearer {
|
||||||
|
client_id,
|
||||||
|
jwt,
|
||||||
|
header,
|
||||||
|
claims,
|
||||||
|
} = authz.credentials
|
||||||
|
{
|
||||||
|
(client_id, jwt, header, claims)
|
||||||
|
} else {
|
||||||
|
panic!("expected a JWT client_assertion");
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(client_id, "client-id");
|
||||||
|
// TODO: test more things
|
||||||
|
}
|
||||||
|
}
|
@@ -12,6 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
pub mod client_authorization;
|
||||||
pub mod cookies;
|
pub mod cookies;
|
||||||
pub mod csrf;
|
pub mod csrf;
|
||||||
pub mod fancy_error;
|
pub mod fancy_error;
|
||||||
|
@@ -280,26 +280,34 @@ where
|
|||||||
) -> Result<Self, Self::Rejection> {
|
) -> Result<Self, Self::Rejection> {
|
||||||
let header = TypedHeader::<Authorization<Bearer>>::from_request(req).await;
|
let header = TypedHeader::<Authorization<Bearer>>::from_request(req).await;
|
||||||
|
|
||||||
|
// Take the Authorization header
|
||||||
let token_from_header = match header {
|
let token_from_header = match header {
|
||||||
Ok(header) => Some(header.token().to_string()),
|
Ok(header) => Some(header.token().to_string()),
|
||||||
Err(err) => match err.reason() {
|
Err(err) => match err.reason() {
|
||||||
|
// If it's missing it is fine
|
||||||
TypedHeaderRejectionReason::Missing => None,
|
TypedHeaderRejectionReason::Missing => None,
|
||||||
|
// If the header could not be parsed, return the error
|
||||||
TypedHeaderRejectionReason::Error(_) => {
|
TypedHeaderRejectionReason::Error(_) => {
|
||||||
return Err(UserAuthorizationError::InvalidHeader)
|
return Err(UserAuthorizationError::InvalidHeader)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Take the form value
|
||||||
let (token_from_form, form) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
let (token_from_form, form) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
||||||
Ok(Form(form)) => (form.access_token, Some(form.inner)),
|
Ok(Form(form)) => (form.access_token, Some(form.inner)),
|
||||||
|
// If it is not a form, continue
|
||||||
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
|
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
|
||||||
|
// If the form could not be read, return a Bad Request error
|
||||||
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
||||||
return Err(UserAuthorizationError::BadForm(err))
|
return Err(UserAuthorizationError::BadForm(err))
|
||||||
}
|
}
|
||||||
|
// Other errors (body read twice, byte stream broke) return an internal error
|
||||||
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
||||||
};
|
};
|
||||||
|
|
||||||
let access_token = match (token_from_header, token_from_form) {
|
let access_token = match (token_from_header, token_from_form) {
|
||||||
|
// Ensure the token should not be in both the form and the access token
|
||||||
(Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
|
(Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
|
||||||
(Some(t), None) => AccessToken::Header(t),
|
(Some(t), None) => AccessToken::Header(t),
|
||||||
(None, Some(t)) => AccessToken::Form(t),
|
(None, Some(t)) => AccessToken::Form(t),
|
||||||
|
@@ -32,6 +32,7 @@ chacha20poly1305 = { version = "0.9.0", features = ["std"] }
|
|||||||
elliptic-curve = { version = "0.11.12", features = ["pem", "pkcs8"] }
|
elliptic-curve = { version = "0.11.12", features = ["pem", "pkcs8"] }
|
||||||
pem-rfc7468 = "0.3.1"
|
pem-rfc7468 = "0.3.1"
|
||||||
cookie = { version = "0.16.0", features = ["private", "key-expansion"] }
|
cookie = { version = "0.16.0", features = ["private", "key-expansion"] }
|
||||||
|
data-encoding = "2.3.2"
|
||||||
|
|
||||||
indoc = "1.0.4"
|
indoc = "1.0.4"
|
||||||
|
|
||||||
|
@@ -21,6 +21,7 @@ use chacha20poly1305::{
|
|||||||
ChaCha20Poly1305,
|
ChaCha20Poly1305,
|
||||||
};
|
};
|
||||||
use cookie::Key;
|
use cookie::Key;
|
||||||
|
use data_encoding::BASE64;
|
||||||
use mas_jose::StaticKeystore;
|
use mas_jose::StaticKeystore;
|
||||||
use pkcs8::DecodePrivateKey;
|
use pkcs8::DecodePrivateKey;
|
||||||
use rsa::{
|
use rsa::{
|
||||||
@@ -42,6 +43,7 @@ pub struct Encrypter {
|
|||||||
aead: Arc<ChaCha20Poly1305>,
|
aead: Arc<ChaCha20Poly1305>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: move this somewhere else
|
||||||
impl Encrypter {
|
impl Encrypter {
|
||||||
/// Creates an [`Encrypter`] out of an encryption key
|
/// Creates an [`Encrypter`] out of an encryption key
|
||||||
#[must_use]
|
#[must_use]
|
||||||
@@ -75,6 +77,41 @@ impl Encrypter {
|
|||||||
let encrypted = self.aead.decrypt(nonce, encrypted)?;
|
let encrypted = self.aead.decrypt(nonce, encrypted)?;
|
||||||
Ok(encrypted)
|
Ok(encrypted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Encrypt a payload to a self-contained base64-encoded string
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Will return `Err` when the payload failed to encrypt
|
||||||
|
pub fn encryt_to_string(&self, decrypted: &[u8]) -> anyhow::Result<String> {
|
||||||
|
let nonce = rand::random();
|
||||||
|
let encrypted = self.encrypt(&nonce, decrypted)?;
|
||||||
|
let encrypted = [&nonce[..], &encrypted].concat();
|
||||||
|
let encrypted = BASE64.encode(&encrypted);
|
||||||
|
Ok(encrypted)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt a payload from a self-contained base64-encoded string
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Will return `Err` when the payload failed to decrypt
|
||||||
|
pub fn decrypt_string(&self, encrypted: &str) -> anyhow::Result<Vec<u8>> {
|
||||||
|
let encrypted = BASE64.decode(encrypted.as_bytes())?;
|
||||||
|
|
||||||
|
let nonce: &[u8; 12] = encrypted
|
||||||
|
.get(0..12)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
|
||||||
|
.try_into()?;
|
||||||
|
|
||||||
|
let payload = encrypted
|
||||||
|
.get(12..)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
|
||||||
|
|
||||||
|
let decrypted_client_secret = self.decrypt(nonce, payload)?;
|
||||||
|
|
||||||
|
Ok(decrypted_client_secret)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Encrypter> for Key {
|
impl From<Encrypter> for Key {
|
||||||
|
@@ -30,7 +30,7 @@ use crate::{jwk::JsonWebKey, SigningKeystore, VerifyingKeystore};
|
|||||||
|
|
||||||
#[serde_as]
|
#[serde_as]
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||||
pub struct JwtHeader {
|
pub struct JwtHeader {
|
||||||
alg: JsonWebSignatureAlg,
|
alg: JsonWebSignatureAlg,
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ impl FromStr for JwtHeader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct JsonWebTokenParts {
|
pub struct JsonWebTokenParts {
|
||||||
payload: String,
|
payload: String,
|
||||||
signature: Vec<u8>,
|
signature: Vec<u8>,
|
||||||
@@ -178,6 +178,14 @@ impl<T> DecodedJsonWebToken<T> {
|
|||||||
pub fn claims(&self) -> &T {
|
pub fn claims(&self) -> &T {
|
||||||
&self.payload
|
&self.payload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn header(&self) -> &JwtHeader {
|
||||||
|
&self.header
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split(self) -> (JwtHeader, T) {
|
||||||
|
(self.header, self.payload)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> FromStr for DecodedJsonWebToken<T>
|
impl<T> FromStr for DecodedJsonWebToken<T>
|
||||||
@@ -205,15 +213,11 @@ impl JsonWebTokenParts {
|
|||||||
Ok(decoded)
|
Ok(decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify<T, S: VerifyingKeystore>(
|
pub fn verify<S: VerifyingKeystore>(&self, header: &JwtHeader, store: &S) -> S::Future
|
||||||
&self,
|
|
||||||
decoded: &DecodedJsonWebToken<T>,
|
|
||||||
store: &S,
|
|
||||||
) -> S::Future
|
|
||||||
where
|
where
|
||||||
S::Error: std::error::Error + Send + Sync + 'static,
|
S::Error: std::error::Error + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
store.verify(&decoded.header, self.payload.as_bytes(), &self.signature)
|
store.verify(header, self.payload.as_bytes(), &self.signature)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn decode_and_verify<T: DeserializeOwned, S: VerifyingKeystore>(
|
pub async fn decode_and_verify<T: DeserializeOwned, S: VerifyingKeystore>(
|
||||||
@@ -224,7 +228,7 @@ impl JsonWebTokenParts {
|
|||||||
S::Error: std::error::Error + Send + Sync + 'static,
|
S::Error: std::error::Error + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
let decoded = self.decode()?;
|
let decoded = self.decode()?;
|
||||||
self.verify(&decoded, store).await?;
|
self.verify(&decoded.header, store).await?;
|
||||||
Ok(decoded)
|
Ok(decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -318,7 +318,7 @@ async fn authenticate_client<T>(
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
let store = SharedSecret::new(&client_secret);
|
let store = SharedSecret::new(&client_secret);
|
||||||
let fut = token.verify(&decoded, &store);
|
let fut = token.verify(decoded.header(), &store);
|
||||||
fut.await.wrap_error()?;
|
fut.await.wrap_error()?;
|
||||||
}
|
}
|
||||||
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
|
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
|
||||||
@@ -329,7 +329,7 @@ async fn authenticate_client<T>(
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
let store = jwks_key_store(jwks);
|
let store = jwks_key_store(jwks);
|
||||||
let fut = token.verify(&decoded, &store);
|
let fut = token.verify(decoded.header(), &store);
|
||||||
fut.await.wrap_error()?;
|
fut.await.wrap_error()?;
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
Reference in New Issue
Block a user