You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Implement client_secret_jwt authentication method
This commit is contained in:
@@ -14,44 +14,32 @@
|
|||||||
|
|
||||||
//! Handle client authentication
|
//! Handle client authentication
|
||||||
|
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
use chrono::{Duration, Utc};
|
||||||
use headers::{authorization::Basic, Authorization};
|
use headers::{authorization::Basic, Authorization};
|
||||||
use serde::{de::DeserializeOwned, Deserialize};
|
use jwt_compact::{
|
||||||
|
alg::{Hs256, Hs256Key, Hs384, Hs384Key, Hs512, Hs512Key},
|
||||||
|
Algorithm, AlgorithmExt, AlgorithmSignature, TimeOptions, Token, UntrustedToken,
|
||||||
|
};
|
||||||
|
use oauth2_types::requests::ClientAuthenticationMethod;
|
||||||
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
|
use serde_with::skip_serializing_none;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use warp::{reject::Reject, Filter, Rejection};
|
use warp::{reject::Reject, Filter, Rejection};
|
||||||
|
|
||||||
use super::headers::typed_header;
|
use super::headers::typed_header;
|
||||||
use crate::config::{OAuth2ClientConfig, OAuth2Config};
|
use crate::{
|
||||||
|
config::{OAuth2ClientConfig, OAuth2Config},
|
||||||
/// Type of client authentication that succeeded
|
errors::WrapError,
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
};
|
||||||
pub enum ClientAuthentication {
|
|
||||||
/// `client_secret_basic` authentication, where the `client_id` and
|
|
||||||
/// `client_secret` are sent through the `Authorization` header with
|
|
||||||
/// `Basic` authentication
|
|
||||||
ClientSecretBasic,
|
|
||||||
|
|
||||||
/// `client_secret_post` authentication, where the `client_id` and
|
|
||||||
/// `client_secret` are sent in the request body
|
|
||||||
ClientSecretPost,
|
|
||||||
|
|
||||||
/// `none` authentication for public clients, where only the `client_id` is
|
|
||||||
/// sent in the request body
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ClientAuthentication {
|
|
||||||
#[must_use]
|
|
||||||
/// Check if the authenticated client is public or not
|
|
||||||
pub fn public(&self) -> bool {
|
|
||||||
matches!(self, &Self::None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Protect an enpoint with client authentication
|
/// Protect an enpoint with client authentication
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
||||||
oauth2_config: &OAuth2Config,
|
oauth2_config: &OAuth2Config,
|
||||||
) -> impl Filter<Extract = (ClientAuthentication, OAuth2ClientConfig, T), Error = Rejection>
|
audience: String,
|
||||||
|
) -> impl Filter<Extract = (ClientAuthenticationMethod, OAuth2ClientConfig, T), Error = Rejection>
|
||||||
+ Clone
|
+ Clone
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync
|
+ Sync
|
||||||
@@ -64,25 +52,19 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
|||||||
let client_id = auth.0.username().to_string();
|
let client_id = auth.0.username().to_string();
|
||||||
let client_secret = Some(auth.0.password().to_string());
|
let client_secret = Some(auth.0.password().to_string());
|
||||||
(
|
(
|
||||||
ClientAuthentication::ClientSecretBasic,
|
ClientCredentials::Pair {
|
||||||
client_id,
|
via: CredentialsVia::AuthorizationHeader,
|
||||||
client_secret,
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
},
|
||||||
body,
|
body,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
// Or from the form body
|
// Or from the form body
|
||||||
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
|
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
|
||||||
let ClientAuthForm {
|
let ClientAuthForm { credentials, body } = form;
|
||||||
client_id,
|
|
||||||
client_secret,
|
(credentials, body)
|
||||||
body,
|
|
||||||
} = form;
|
|
||||||
let auth_type = if client_secret.is_some() {
|
|
||||||
ClientAuthentication::ClientSecretPost
|
|
||||||
} else {
|
|
||||||
ClientAuthentication::None
|
|
||||||
};
|
|
||||||
(auth_type, client_id, client_secret, body)
|
|
||||||
}))
|
}))
|
||||||
.unify()
|
.unify()
|
||||||
.untuple_one();
|
.untuple_one();
|
||||||
@@ -90,6 +72,7 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
|||||||
let clients = oauth2_config.clients.clone();
|
let clients = oauth2_config.clients.clone();
|
||||||
warp::any()
|
warp::any()
|
||||||
.map(move || clients.clone())
|
.map(move || clients.clone())
|
||||||
|
.and(warp::any().map(move || audience.clone()))
|
||||||
.and(credentials)
|
.and(credentials)
|
||||||
.and_then(authenticate_client)
|
.and_then(authenticate_client)
|
||||||
.untuple_one()
|
.untuple_one()
|
||||||
@@ -108,39 +91,257 @@ enum ClientAuthenticationError {
|
|||||||
|
|
||||||
#[error("client secret required for client {client_id:?}")]
|
#[error("client secret required for client {client_id:?}")]
|
||||||
ClientSecretRequired { client_id: String },
|
ClientSecretRequired { client_id: String },
|
||||||
|
|
||||||
|
#[error("wrong audience in client assertion: expected {expected:?}, got {got:?}")]
|
||||||
|
AudienceMismatch { expected: String, got: String },
|
||||||
|
|
||||||
|
#[error("invalid client assertion")]
|
||||||
|
InvalidAssertion,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Reject for ClientAuthenticationError {}
|
impl Reject for ClientAuthenticationError {}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct ClientAssertionClaims {
|
||||||
|
#[serde(rename = "iss")]
|
||||||
|
issuer: String,
|
||||||
|
#[serde(rename = "sub")]
|
||||||
|
subject: String,
|
||||||
|
#[serde(rename = "aud")]
|
||||||
|
audience: String,
|
||||||
|
// TODO: use the JTI and ensure it is only used once
|
||||||
|
#[serde(default, rename = "jti")]
|
||||||
|
jwt_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct UnsignedSignature(Vec<u8>);
|
||||||
|
impl AlgorithmSignature for UnsignedSignature {
|
||||||
|
fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self> {
|
||||||
|
Ok(Self(slice.to_vec()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_bytes(&self) -> std::borrow::Cow<'_, [u8]> {
|
||||||
|
Cow::Borrowed(&self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Unsigned<'a>(&'a str);
|
||||||
|
impl<'a> Algorithm for Unsigned<'a> {
|
||||||
|
type SigningKey = ();
|
||||||
|
|
||||||
|
type VerifyingKey = ();
|
||||||
|
|
||||||
|
type Signature = UnsignedSignature;
|
||||||
|
|
||||||
|
fn name(&self) -> std::borrow::Cow<'static, str> {
|
||||||
|
Cow::Owned(self.0.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sign(&self, _signing_key: &Self::SigningKey, _message: &[u8]) -> Self::Signature {
|
||||||
|
UnsignedSignature(Vec::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_signature(
|
||||||
|
&self,
|
||||||
|
_signature: &Self::Signature,
|
||||||
|
_verifying_key: &Self::VerifyingKey,
|
||||||
|
_message: &[u8],
|
||||||
|
) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_token(
|
||||||
|
untrusted_token: &UntrustedToken,
|
||||||
|
key: &str,
|
||||||
|
) -> anyhow::Result<Token<ClientAssertionClaims>> {
|
||||||
|
match untrusted_token.algorithm() {
|
||||||
|
"HS256" => {
|
||||||
|
let key = Hs256Key::new(key);
|
||||||
|
let token = Hs256.validate_integrity(untrusted_token, &key)?;
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
"HS384" => {
|
||||||
|
let key = Hs384Key::new(key);
|
||||||
|
let token = Hs384.validate_integrity(untrusted_token, &key)?;
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
"HS512" => {
|
||||||
|
let key = Hs512Key::new(key);
|
||||||
|
let token = Hs512.validate_integrity(untrusted_token, &key)?;
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
alg => anyhow::bail!("unsupported signing algorithm {}", alg),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn authenticate_client<T>(
|
async fn authenticate_client<T>(
|
||||||
clients: Vec<OAuth2ClientConfig>,
|
clients: Vec<OAuth2ClientConfig>,
|
||||||
auth_type: ClientAuthentication,
|
audience: String,
|
||||||
client_id: String,
|
credentials: ClientCredentials,
|
||||||
client_secret: Option<String>,
|
|
||||||
body: T,
|
body: T,
|
||||||
) -> Result<(ClientAuthentication, OAuth2ClientConfig, T), Rejection> {
|
) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
|
||||||
let client = clients
|
let auth_type = credentials.authentication_type();
|
||||||
.iter()
|
let client = match credentials {
|
||||||
.find(|client| client.client_id == client_id)
|
ClientCredentials::Pair {
|
||||||
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
client_id,
|
||||||
client_id: client_id.to_string(),
|
client_secret,
|
||||||
})?;
|
..
|
||||||
|
} => {
|
||||||
|
let client = clients
|
||||||
|
.iter()
|
||||||
|
.find(|client| client.client_id == client_id)
|
||||||
|
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
||||||
|
client_id: client_id.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
let client = match (client_secret, client.client_secret.as_ref()) {
|
match (client_secret, client.client_secret.as_ref()) {
|
||||||
(None, None) => Ok(client),
|
(None, None) => Ok(client),
|
||||||
(Some(ref given), Some(expected)) if given == expected => Ok(client),
|
(Some(ref given), Some(expected)) if given == expected => Ok(client),
|
||||||
(Some(_), Some(_)) => Err(ClientAuthenticationError::ClientSecretMismatch { client_id }),
|
(Some(_), Some(_)) => {
|
||||||
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }),
|
Err(ClientAuthenticationError::ClientSecretMismatch { client_id })
|
||||||
(None, Some(_)) => Err(ClientAuthenticationError::ClientSecretRequired { client_id }),
|
}
|
||||||
|
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }),
|
||||||
|
(None, Some(_)) => {
|
||||||
|
Err(ClientAuthenticationError::ClientSecretRequired { client_id })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ClientCredentials::Assertion {
|
||||||
|
client_id,
|
||||||
|
client_assertion_type: ClientAssertionType::JwtBearer,
|
||||||
|
client_assertion,
|
||||||
|
} => {
|
||||||
|
let untrusted_token = UntrustedToken::new(&client_assertion).wrap_error()?;
|
||||||
|
|
||||||
|
// client_id might have been passed as parameter. If not, it should be inferred
|
||||||
|
// from the token, as per rfc7521 sec. 4.2
|
||||||
|
// TODO: this is not a pretty way to do it
|
||||||
|
let client_id = client_id
|
||||||
|
.ok_or(()) // Dumb error type
|
||||||
|
.or_else(|()| {
|
||||||
|
let alg = Unsigned(untrusted_token.algorithm());
|
||||||
|
// We need to deserialize the token once without verifying the signature to get
|
||||||
|
// the client_id
|
||||||
|
let token: Token<ClientAssertionClaims> =
|
||||||
|
alg.validate_integrity(&untrusted_token, &())?;
|
||||||
|
|
||||||
|
Ok::<_, anyhow::Error>(token.claims().custom.subject.clone())
|
||||||
|
})
|
||||||
|
.wrap_error()?;
|
||||||
|
|
||||||
|
let client = clients
|
||||||
|
.iter()
|
||||||
|
.find(|client| client.client_id == client_id)
|
||||||
|
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
||||||
|
client_id: client_id.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if let Some(client_secret) = &client.client_secret {
|
||||||
|
let token = verify_token(&untrusted_token, client_secret).wrap_error()?;
|
||||||
|
|
||||||
|
let time_options = TimeOptions::new(Duration::minutes(1), Utc::now);
|
||||||
|
|
||||||
|
// rfc7523 sec. 3.4: expiration must be set and validated
|
||||||
|
let claims = token
|
||||||
|
.claims()
|
||||||
|
.validate_expiration(&time_options)
|
||||||
|
.wrap_error()?;
|
||||||
|
|
||||||
|
// rfc7523 sec. 3.5: "not before" can be set and must be validated if present
|
||||||
|
if claims.not_before.is_some() {
|
||||||
|
claims.validate_maturity(&time_options).wrap_error()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// rfc7523 sec. 3.3: the audience is the URL being called
|
||||||
|
if claims.custom.audience != audience {
|
||||||
|
Err(ClientAuthenticationError::AudienceMismatch {
|
||||||
|
expected: audience,
|
||||||
|
got: claims.custom.audience.clone(),
|
||||||
|
})
|
||||||
|
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
|
||||||
|
// match the client_id
|
||||||
|
} else if claims.custom.issuer != claims.custom.subject
|
||||||
|
|| claims.custom.issuer != client_id
|
||||||
|
{
|
||||||
|
Err(ClientAuthenticationError::InvalidAssertion)
|
||||||
|
} else {
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(ClientAuthenticationError::ClientSecretRequired {
|
||||||
|
client_id: client_id.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
Ok((auth_type, client.clone(), body))
|
Ok((auth_type, client.clone(), body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
enum ClientAssertionType {
|
||||||
|
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
|
||||||
|
JwtBearer,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum CredentialsVia {
|
||||||
|
FormBody,
|
||||||
|
AuthorizationHeader,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CredentialsVia {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::FormBody
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum ClientCredentials {
|
||||||
|
// Order here is important: serde tries to deserialize enum variants in order, so if "Pair"
|
||||||
|
// was before "Assertion", a client_assertion with a client_id would match the "Pair"
|
||||||
|
// variant first
|
||||||
|
Assertion {
|
||||||
|
client_id: Option<String>,
|
||||||
|
client_assertion_type: ClientAssertionType,
|
||||||
|
client_assertion: String,
|
||||||
|
},
|
||||||
|
Pair {
|
||||||
|
#[serde(skip)]
|
||||||
|
via: CredentialsVia,
|
||||||
|
client_id: String,
|
||||||
|
client_secret: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientCredentials {
|
||||||
|
fn authentication_type(&self) -> ClientAuthenticationMethod {
|
||||||
|
match self {
|
||||||
|
ClientCredentials::Pair {
|
||||||
|
via: CredentialsVia::FormBody,
|
||||||
|
client_secret: None,
|
||||||
|
..
|
||||||
|
} => ClientAuthenticationMethod::None,
|
||||||
|
ClientCredentials::Pair {
|
||||||
|
via: CredentialsVia::FormBody,
|
||||||
|
client_secret: Some(_),
|
||||||
|
..
|
||||||
|
} => ClientAuthenticationMethod::ClientSecretPost,
|
||||||
|
ClientCredentials::Pair {
|
||||||
|
via: CredentialsVia::AuthorizationHeader,
|
||||||
|
..
|
||||||
|
} => ClientAuthenticationMethod::ClientSecretBasic,
|
||||||
|
ClientCredentials::Assertion { .. } => ClientAuthenticationMethod::ClientSecretJwt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct ClientAuthForm<T> {
|
struct ClientAuthForm<T> {
|
||||||
client_id: String,
|
#[serde(flatten)]
|
||||||
client_secret: Option<String>,
|
credentials: ClientCredentials,
|
||||||
|
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
body: T,
|
body: T,
|
||||||
@@ -148,10 +349,16 @@ struct ClientAuthForm<T> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use headers::authorization::Credentials;
|
||||||
|
use jwt_compact::{Claims, Header};
|
||||||
use mas_config::ConfigurationSection;
|
use mas_config::ConfigurationSection;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
// Long client_secret to support it as a HS512 key
|
||||||
|
const CLIENT_SECRET: &str = "leek2zaeyeb8thai7piehea3vah6ool9oanin9aeraThuci9EeghaekaiD1upe4Quoh7xeMae2meitohj0Waaveiwaorah1yazohr6Vae7iebeiRaWene5IeWeeciezu";
|
||||||
|
|
||||||
fn oauth2_config() -> OAuth2Config {
|
fn oauth2_config() -> OAuth2Config {
|
||||||
let mut config = OAuth2Config::test();
|
let mut config = OAuth2Config::test();
|
||||||
config.clients.push(OAuth2ClientConfig {
|
config.clients.push(OAuth2ClientConfig {
|
||||||
@@ -161,7 +368,12 @@ mod tests {
|
|||||||
});
|
});
|
||||||
config.clients.push(OAuth2ClientConfig {
|
config.clients.push(OAuth2ClientConfig {
|
||||||
client_id: "confidential".to_string(),
|
client_id: "confidential".to_string(),
|
||||||
client_secret: Some("secret".to_string()),
|
client_secret: Some(CLIENT_SECRET.to_string()),
|
||||||
|
redirect_uris: Vec::new(),
|
||||||
|
});
|
||||||
|
config.clients.push(OAuth2ClientConfig {
|
||||||
|
client_id: "confidential-2".to_string(),
|
||||||
|
client_secret: Some(CLIENT_SECRET.to_string()),
|
||||||
redirect_uris: Vec::new(),
|
redirect_uris: Vec::new(),
|
||||||
});
|
});
|
||||||
config
|
config
|
||||||
@@ -174,17 +386,126 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn client_secret_post() {
|
async fn client_secret_jwt_hs256() {
|
||||||
let filter = client_authentication::<Form>(&oauth2_config());
|
client_secret_jwt::<'_, Hs256>().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_secret_jwt_hs384() {
|
||||||
|
client_secret_jwt::<'_, Hs384>().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_secret_jwt_hs512() {
|
||||||
|
client_secret_jwt::<'_, Hs512>().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn client_secret_jwt<'k, A>()
|
||||||
|
where
|
||||||
|
A: Algorithm + Default,
|
||||||
|
A::SigningKey: From<&'k [u8]>,
|
||||||
|
{
|
||||||
|
let audience = "https://example.com/token".to_string();
|
||||||
|
let filter = client_authentication::<Form>(&oauth2_config(), audience.clone());
|
||||||
|
let time_options = TimeOptions::default();
|
||||||
|
|
||||||
|
let key = A::SigningKey::from(CLIENT_SECRET.as_bytes());
|
||||||
|
let alg = A::default();
|
||||||
|
let header = Header::default();
|
||||||
|
let claims = Claims::new(ClientAssertionClaims {
|
||||||
|
issuer: "confidential".to_string(),
|
||||||
|
subject: "confidential".to_string(),
|
||||||
|
audience,
|
||||||
|
jwt_id: None,
|
||||||
|
})
|
||||||
|
.set_duration_and_issuance(&time_options, Duration::seconds(15));
|
||||||
|
|
||||||
|
// TODO: test failing cases
|
||||||
|
// - expired token
|
||||||
|
// - "not before" in the future
|
||||||
|
// - subject/issuer mismatch
|
||||||
|
// - audience mismatch
|
||||||
|
// - wrong secret/signature
|
||||||
|
|
||||||
|
let token = alg
|
||||||
|
.token(header, &claims, &key)
|
||||||
|
.expect("could not sign token");
|
||||||
|
|
||||||
let (auth, client, body) = warp::test::request()
|
let (auth, client, body) = warp::test::request()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.body("client_id=confidential&client_secret=secret&foo=baz&bar=foobar")
|
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||||
|
.body(serde_urlencoded::to_string(json!({
|
||||||
|
"client_id": "confidential",
|
||||||
|
"client_assertion": token,
|
||||||
|
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||||
|
"foo": "baz",
|
||||||
|
"bar": "foobar",
|
||||||
|
})).unwrap())
|
||||||
.filter(&filter)
|
.filter(&filter)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(auth, ClientAuthentication::ClientSecretPost);
|
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretJwt);
|
||||||
|
assert_eq!(client.client_id, "confidential");
|
||||||
|
assert_eq!(body.foo, "baz");
|
||||||
|
assert_eq!(body.bar, "foobar");
|
||||||
|
|
||||||
|
// Without client_id
|
||||||
|
let res = warp::test::request()
|
||||||
|
.method("POST")
|
||||||
|
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
|
||||||
|
.body(serde_urlencoded::to_string(json!({
|
||||||
|
"client_assertion": token,
|
||||||
|
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||||
|
"foo": "baz",
|
||||||
|
"bar": "foobar",
|
||||||
|
})).unwrap())
|
||||||
|
.filter(&filter)
|
||||||
|
.await;
|
||||||
|
assert!(res.is_ok());
|
||||||
|
|
||||||
|
// client_id mismatch
|
||||||
|
let res = warp::test::request()
|
||||||
|
.method("POST")
|
||||||
|
.body(serde_urlencoded::to_string(json!({
|
||||||
|
"client_id": "confidential-2",
|
||||||
|
"client_assertion": token,
|
||||||
|
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||||
|
"foo": "baz",
|
||||||
|
"bar": "foobar",
|
||||||
|
})).unwrap())
|
||||||
|
.filter(&filter)
|
||||||
|
.await;
|
||||||
|
assert!(res.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_secret_post() {
|
||||||
|
let filter = client_authentication::<Form>(
|
||||||
|
&oauth2_config(),
|
||||||
|
"https://example.com/token".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let (auth, client, body) = warp::test::request()
|
||||||
|
.method("POST")
|
||||||
|
.header(
|
||||||
|
"Content-Type",
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||||
|
)
|
||||||
|
.body(
|
||||||
|
serde_urlencoded::to_string(json!({
|
||||||
|
"client_id": "confidential",
|
||||||
|
"client_secret": CLIENT_SECRET,
|
||||||
|
"foo": "baz",
|
||||||
|
"bar": "foobar",
|
||||||
|
}))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.filter(&filter)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost);
|
||||||
assert_eq!(client.client_id, "confidential");
|
assert_eq!(client.client_id, "confidential");
|
||||||
assert_eq!(body.foo, "baz");
|
assert_eq!(body.foo, "baz");
|
||||||
assert_eq!(body.bar, "foobar");
|
assert_eq!(body.bar, "foobar");
|
||||||
@@ -192,17 +513,31 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn client_secret_basic() {
|
async fn client_secret_basic() {
|
||||||
let filter = client_authentication::<Form>(&oauth2_config());
|
let filter = client_authentication::<Form>(
|
||||||
|
&oauth2_config(),
|
||||||
|
"https://example.com/token".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let auth = Authorization::basic("confidential", CLIENT_SECRET);
|
||||||
let (auth, client, body) = warp::test::request()
|
let (auth, client, body) = warp::test::request()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.header("Authorization", "Basic Y29uZmlkZW50aWFsOnNlY3JldA==")
|
.header(
|
||||||
.body("foo=baz&bar=foobar")
|
"Content-Type",
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||||
|
)
|
||||||
|
.header("Authorization", auth.0.encode())
|
||||||
|
.body(
|
||||||
|
serde_urlencoded::to_string(json!({
|
||||||
|
"foo": "baz",
|
||||||
|
"bar": "foobar",
|
||||||
|
}))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
.filter(&filter)
|
.filter(&filter)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(auth, ClientAuthentication::ClientSecretBasic);
|
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic);
|
||||||
assert_eq!(client.client_id, "confidential");
|
assert_eq!(client.client_id, "confidential");
|
||||||
assert_eq!(body.foo, "baz");
|
assert_eq!(body.foo, "baz");
|
||||||
assert_eq!(body.bar, "foobar");
|
assert_eq!(body.bar, "foobar");
|
||||||
@@ -210,16 +545,30 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn none() {
|
async fn none() {
|
||||||
let filter = client_authentication::<Form>(&oauth2_config());
|
let filter = client_authentication::<Form>(
|
||||||
|
&oauth2_config(),
|
||||||
|
"https://example.com/token".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
let (auth, client, body) = warp::test::request()
|
let (auth, client, body) = warp::test::request()
|
||||||
.method("POST")
|
.method("POST")
|
||||||
.body("client_id=public&foo=baz&bar=foobar")
|
.header(
|
||||||
|
"Content-Type",
|
||||||
|
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
|
||||||
|
)
|
||||||
|
.body(
|
||||||
|
serde_urlencoded::to_string(json!({
|
||||||
|
"client_id": "public",
|
||||||
|
"foo": "baz",
|
||||||
|
"bar": "foobar",
|
||||||
|
}))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
.filter(&filter)
|
.filter(&filter)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(auth, ClientAuthentication::None);
|
assert_eq!(auth, ClientAuthenticationMethod::None);
|
||||||
assert_eq!(client.client_id, "public");
|
assert_eq!(client.client_id, "public");
|
||||||
assert_eq!(body.foo, "baz");
|
assert_eq!(body.foo, "baz");
|
||||||
assert_eq!(body.bar, "foobar");
|
assert_eq!(body.bar, "foobar");
|
||||||
|
@@ -113,6 +113,15 @@ where
|
|||||||
params: T,
|
params: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ParamsWithState<T> {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
state: Option<String>,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
params: T,
|
||||||
|
}
|
||||||
|
|
||||||
match response_mode {
|
match response_mode {
|
||||||
ResponseMode::Query => {
|
ResponseMode::Query => {
|
||||||
let existing: Option<HashMap<&str, &str>> = redirect_uri
|
let existing: Option<HashMap<&str, &str>> = redirect_uri
|
||||||
@@ -159,7 +168,8 @@ where
|
|||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
ResponseMode::FormPost => {
|
ResponseMode::FormPost => {
|
||||||
let ctx = FormPostContext::new(redirect_uri, params);
|
let merged = ParamsWithState { state, params };
|
||||||
|
let ctx = FormPostContext::new(redirect_uri, merged);
|
||||||
let rendered = templates.render_form_post(&ctx)?;
|
let rendered = templates.render_form_post(&ctx)?;
|
||||||
Ok(Box::new(html(rendered)))
|
Ok(Box::new(html(rendered)))
|
||||||
}
|
}
|
||||||
|
@@ -17,7 +17,7 @@ use std::collections::HashSet;
|
|||||||
use hyper::Method;
|
use hyper::Method;
|
||||||
use mas_config::OAuth2Config;
|
use mas_config::OAuth2Config;
|
||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
oidc::Metadata,
|
oidc::{Metadata, SigningAlgorithm},
|
||||||
pkce::CodeChallengeMethod,
|
pkce::CodeChallengeMethod,
|
||||||
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
||||||
};
|
};
|
||||||
@@ -61,10 +61,19 @@ pub(super) fn filter(
|
|||||||
let mut s = HashSet::new();
|
let mut s = HashSet::new();
|
||||||
s.insert(ClientAuthenticationMethod::ClientSecretBasic);
|
s.insert(ClientAuthenticationMethod::ClientSecretBasic);
|
||||||
s.insert(ClientAuthenticationMethod::ClientSecretPost);
|
s.insert(ClientAuthenticationMethod::ClientSecretPost);
|
||||||
|
s.insert(ClientAuthenticationMethod::ClientSecretJwt);
|
||||||
s.insert(ClientAuthenticationMethod::None);
|
s.insert(ClientAuthenticationMethod::None);
|
||||||
s
|
s
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let token_endpoint_auth_signing_alg_values_supported = Some({
|
||||||
|
let mut s = HashSet::new();
|
||||||
|
s.insert(SigningAlgorithm::Hs256);
|
||||||
|
s.insert(SigningAlgorithm::Hs384);
|
||||||
|
s.insert(SigningAlgorithm::Hs512);
|
||||||
|
s
|
||||||
|
});
|
||||||
|
|
||||||
let code_challenge_methods_supported = Some({
|
let code_challenge_methods_supported = Some({
|
||||||
let mut s = HashSet::new();
|
let mut s = HashSet::new();
|
||||||
s.insert(CodeChallengeMethod::Plain);
|
s.insert(CodeChallengeMethod::Plain);
|
||||||
@@ -85,6 +94,7 @@ pub(super) fn filter(
|
|||||||
response_modes_supported,
|
response_modes_supported,
|
||||||
grant_types_supported,
|
grant_types_supported,
|
||||||
token_endpoint_auth_methods_supported,
|
token_endpoint_auth_methods_supported,
|
||||||
|
token_endpoint_auth_signing_alg_values_supported,
|
||||||
code_challenge_methods_supported,
|
code_challenge_methods_supported,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -13,7 +13,9 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use hyper::Method;
|
use hyper::Method;
|
||||||
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
|
use oauth2_types::requests::{
|
||||||
|
ClientAuthenticationMethod, IntrospectionRequest, IntrospectionResponse, TokenTypeHint,
|
||||||
|
};
|
||||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use warp::{Filter, Rejection, Reply};
|
use warp::{Filter, Rejection, Reply};
|
||||||
@@ -21,11 +23,7 @@ use warp::{Filter, Rejection, Reply};
|
|||||||
use crate::{
|
use crate::{
|
||||||
config::{OAuth2ClientConfig, OAuth2Config},
|
config::{OAuth2ClientConfig, OAuth2Config},
|
||||||
errors::WrapError,
|
errors::WrapError,
|
||||||
filters::{
|
filters::{client::client_authentication, cors::cors, database::connection},
|
||||||
client::{client_authentication, ClientAuthentication},
|
|
||||||
cors::cors,
|
|
||||||
database::connection,
|
|
||||||
},
|
|
||||||
storage::oauth2::{
|
storage::oauth2::{
|
||||||
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
|
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
|
||||||
},
|
},
|
||||||
@@ -36,10 +34,16 @@ pub fn filter(
|
|||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
oauth2_config: &OAuth2Config,
|
oauth2_config: &OAuth2Config,
|
||||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||||
|
let audience = oauth2_config
|
||||||
|
.issuer
|
||||||
|
.join("/oauth2/introspect")
|
||||||
|
.unwrap()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
warp::path!("oauth2" / "introspect").and(
|
warp::path!("oauth2" / "introspect").and(
|
||||||
warp::post()
|
warp::post()
|
||||||
.and(connection(pool))
|
.and(connection(pool))
|
||||||
.and(client_authentication(oauth2_config))
|
.and(client_authentication(oauth2_config, audience))
|
||||||
.and_then(introspect)
|
.and_then(introspect)
|
||||||
.recover(recover)
|
.recover(recover)
|
||||||
.with(cors().allow_method(Method::POST)),
|
.with(cors().allow_method(Method::POST)),
|
||||||
@@ -63,7 +67,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
|||||||
|
|
||||||
async fn introspect(
|
async fn introspect(
|
||||||
mut conn: PoolConnection<Postgres>,
|
mut conn: PoolConnection<Postgres>,
|
||||||
auth: ClientAuthentication,
|
auth: ClientAuthenticationMethod,
|
||||||
client: OAuth2ClientConfig,
|
client: OAuth2ClientConfig,
|
||||||
params: IntrospectionRequest,
|
params: IntrospectionRequest,
|
||||||
) -> Result<impl Reply, Rejection> {
|
) -> Result<impl Reply, Rejection> {
|
||||||
|
@@ -22,7 +22,8 @@ use mas_data_model::AuthorizationGrantStage;
|
|||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
|
errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
|
||||||
requests::{
|
requests::{
|
||||||
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
|
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant,
|
||||||
|
ClientAuthenticationMethod, RefreshTokenGrant,
|
||||||
},
|
},
|
||||||
scope::OPENID,
|
scope::OPENID,
|
||||||
};
|
};
|
||||||
@@ -42,12 +43,7 @@ use warp::{
|
|||||||
use crate::{
|
use crate::{
|
||||||
config::{KeySet, OAuth2ClientConfig, OAuth2Config},
|
config::{KeySet, OAuth2ClientConfig, OAuth2Config},
|
||||||
errors::WrapError,
|
errors::WrapError,
|
||||||
filters::{
|
filters::{client::client_authentication, cors::cors, database::connection, with_keys},
|
||||||
client::{client_authentication, ClientAuthentication},
|
|
||||||
cors::cors,
|
|
||||||
database::connection,
|
|
||||||
with_keys,
|
|
||||||
},
|
|
||||||
reply::with_typed_header,
|
reply::with_typed_header,
|
||||||
storage::{
|
storage::{
|
||||||
oauth2::{
|
oauth2::{
|
||||||
@@ -97,10 +93,16 @@ pub fn filter(
|
|||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
oauth2_config: &OAuth2Config,
|
oauth2_config: &OAuth2Config,
|
||||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||||
|
let audience = oauth2_config
|
||||||
|
.issuer
|
||||||
|
.join("/oauth2/token")
|
||||||
|
.unwrap()
|
||||||
|
.to_string();
|
||||||
let issuer = oauth2_config.issuer.clone();
|
let issuer = oauth2_config.issuer.clone();
|
||||||
|
|
||||||
warp::path!("oauth2" / "token").and(
|
warp::path!("oauth2" / "token").and(
|
||||||
warp::post()
|
warp::post()
|
||||||
.and(client_authentication(oauth2_config))
|
.and(client_authentication(oauth2_config, audience))
|
||||||
.and(with_keys(oauth2_config))
|
.and(with_keys(oauth2_config))
|
||||||
.and(warp::any().map(move || issuer.clone()))
|
.and(warp::any().map(move || issuer.clone()))
|
||||||
.and(connection(pool))
|
.and(connection(pool))
|
||||||
@@ -119,7 +121,7 @@ async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn token(
|
async fn token(
|
||||||
_auth: ClientAuthentication,
|
_auth: ClientAuthenticationMethod,
|
||||||
client: OAuth2ClientConfig,
|
client: OAuth2ClientConfig,
|
||||||
req: AccessTokenRequest,
|
req: AccessTokenRequest,
|
||||||
keys: KeySet,
|
keys: KeySet,
|
||||||
|
@@ -23,6 +23,28 @@ use crate::{
|
|||||||
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
#[serde(rename_all = "UPPERCASE")]
|
||||||
|
pub enum SigningAlgorithm {
|
||||||
|
#[serde(rename = "none")]
|
||||||
|
None,
|
||||||
|
Hs256,
|
||||||
|
Hs384,
|
||||||
|
Hs512,
|
||||||
|
Ps256,
|
||||||
|
Ps384,
|
||||||
|
Ps512,
|
||||||
|
Rs256,
|
||||||
|
Rs384,
|
||||||
|
Rs512,
|
||||||
|
Es256,
|
||||||
|
Es256K,
|
||||||
|
Es384,
|
||||||
|
Es512,
|
||||||
|
#[serde(rename = "EcDSA")]
|
||||||
|
EcDsa,
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
|
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Serialize, Clone)]
|
#[derive(Serialize, Clone)]
|
||||||
@@ -65,6 +87,13 @@ pub struct Metadata {
|
|||||||
/// by this token endpoint.
|
/// by this token endpoint.
|
||||||
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
|
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
|
||||||
|
|
||||||
|
/// JSON array containing a list of the JWS signing algorithms supported by
|
||||||
|
/// the Token Endpoint for the signature on the JWT used to authenticate
|
||||||
|
/// the Client at the Token Endpoint for the private_key_jwt and
|
||||||
|
/// client_secret_jwt authentication methods. Servers SHOULD support
|
||||||
|
/// RS256. The value none MUST NOT be used.
|
||||||
|
pub token_endpoint_auth_signing_alg_values_supported: Option<HashSet<SigningAlgorithm>>,
|
||||||
|
|
||||||
/// PKCE code challenge methods supported by this authorization server
|
/// PKCE code challenge methods supported by this authorization server
|
||||||
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
|
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
|
||||||
|
|
||||||
|
@@ -91,6 +91,16 @@ pub enum ClientAuthenticationMethod {
|
|||||||
None,
|
None,
|
||||||
ClientSecretPost,
|
ClientSecretPost,
|
||||||
ClientSecretBasic,
|
ClientSecretBasic,
|
||||||
|
ClientSecretJwt,
|
||||||
|
PrivateKeyJwt,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientAuthenticationMethod {
|
||||||
|
#[must_use]
|
||||||
|
/// Check if the authentication method is for public client or not
|
||||||
|
pub fn public(&self) -> bool {
|
||||||
|
matches!(self, &Self::None)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
|
Reference in New Issue
Block a user