1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Implement client_secret_jwt authentication method

This commit is contained in:
Quentin Gliech
2021-11-05 17:18:21 +01:00
parent 6f9213c5f4
commit 16fe5a8d76
7 changed files with 505 additions and 91 deletions

View File

@@ -14,44 +14,32 @@
//! Handle client authentication //! Handle client authentication
use std::borrow::Cow;
use chrono::{Duration, Utc};
use headers::{authorization::Basic, Authorization}; use headers::{authorization::Basic, Authorization};
use serde::{de::DeserializeOwned, Deserialize}; use jwt_compact::{
alg::{Hs256, Hs256Key, Hs384, Hs384Key, Hs512, Hs512Key},
Algorithm, AlgorithmExt, AlgorithmSignature, TimeOptions, Token, UntrustedToken,
};
use oauth2_types::requests::ClientAuthenticationMethod;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_with::skip_serializing_none;
use thiserror::Error; use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection}; use warp::{reject::Reject, Filter, Rejection};
use super::headers::typed_header; use super::headers::typed_header;
use crate::config::{OAuth2ClientConfig, OAuth2Config}; use crate::{
config::{OAuth2ClientConfig, OAuth2Config},
/// Type of client authentication that succeeded errors::WrapError,
#[derive(Debug, PartialEq, Eq)] };
pub enum ClientAuthentication {
/// `client_secret_basic` authentication, where the `client_id` and
/// `client_secret` are sent through the `Authorization` header with
/// `Basic` authentication
ClientSecretBasic,
/// `client_secret_post` authentication, where the `client_id` and
/// `client_secret` are sent in the request body
ClientSecretPost,
/// `none` authentication for public clients, where only the `client_id` is
/// sent in the request body
None,
}
impl ClientAuthentication {
#[must_use]
/// Check if the authenticated client is public or not
pub fn public(&self) -> bool {
matches!(self, &Self::None)
}
}
/// Protect an enpoint with client authentication /// Protect an enpoint with client authentication
#[must_use] #[must_use]
pub fn client_authentication<T: DeserializeOwned + Send + 'static>( pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
oauth2_config: &OAuth2Config, oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (ClientAuthentication, OAuth2ClientConfig, T), Error = Rejection> audience: String,
) -> impl Filter<Extract = (ClientAuthenticationMethod, OAuth2ClientConfig, T), Error = Rejection>
+ Clone + Clone
+ Send + Send
+ Sync + Sync
@@ -64,25 +52,19 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
let client_id = auth.0.username().to_string(); let client_id = auth.0.username().to_string();
let client_secret = Some(auth.0.password().to_string()); let client_secret = Some(auth.0.password().to_string());
( (
ClientAuthentication::ClientSecretBasic, ClientCredentials::Pair {
client_id, via: CredentialsVia::AuthorizationHeader,
client_secret, client_id,
client_secret,
},
body, body,
) )
}) })
// Or from the form body // Or from the form body
.or(warp::body::form().map(|form: ClientAuthForm<T>| { .or(warp::body::form().map(|form: ClientAuthForm<T>| {
let ClientAuthForm { let ClientAuthForm { credentials, body } = form;
client_id,
client_secret, (credentials, body)
body,
} = form;
let auth_type = if client_secret.is_some() {
ClientAuthentication::ClientSecretPost
} else {
ClientAuthentication::None
};
(auth_type, client_id, client_secret, body)
})) }))
.unify() .unify()
.untuple_one(); .untuple_one();
@@ -90,6 +72,7 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
let clients = oauth2_config.clients.clone(); let clients = oauth2_config.clients.clone();
warp::any() warp::any()
.map(move || clients.clone()) .map(move || clients.clone())
.and(warp::any().map(move || audience.clone()))
.and(credentials) .and(credentials)
.and_then(authenticate_client) .and_then(authenticate_client)
.untuple_one() .untuple_one()
@@ -108,39 +91,257 @@ enum ClientAuthenticationError {
#[error("client secret required for client {client_id:?}")] #[error("client secret required for client {client_id:?}")]
ClientSecretRequired { client_id: String }, ClientSecretRequired { client_id: String },
#[error("wrong audience in client assertion: expected {expected:?}, got {got:?}")]
AudienceMismatch { expected: String, got: String },
#[error("invalid client assertion")]
InvalidAssertion,
} }
impl Reject for ClientAuthenticationError {} impl Reject for ClientAuthenticationError {}
#[skip_serializing_none]
#[derive(Serialize, Deserialize)]
struct ClientAssertionClaims {
#[serde(rename = "iss")]
issuer: String,
#[serde(rename = "sub")]
subject: String,
#[serde(rename = "aud")]
audience: String,
// TODO: use the JTI and ensure it is only used once
#[serde(default, rename = "jti")]
jwt_id: Option<String>,
}
struct UnsignedSignature(Vec<u8>);
impl AlgorithmSignature for UnsignedSignature {
fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self> {
Ok(Self(slice.to_vec()))
}
fn as_bytes(&self) -> std::borrow::Cow<'_, [u8]> {
Cow::Borrowed(&self.0)
}
}
struct Unsigned<'a>(&'a str);
impl<'a> Algorithm for Unsigned<'a> {
type SigningKey = ();
type VerifyingKey = ();
type Signature = UnsignedSignature;
fn name(&self) -> std::borrow::Cow<'static, str> {
Cow::Owned(self.0.to_string())
}
fn sign(&self, _signing_key: &Self::SigningKey, _message: &[u8]) -> Self::Signature {
UnsignedSignature(Vec::new())
}
fn verify_signature(
&self,
_signature: &Self::Signature,
_verifying_key: &Self::VerifyingKey,
_message: &[u8],
) -> bool {
true
}
}
fn verify_token(
untrusted_token: &UntrustedToken,
key: &str,
) -> anyhow::Result<Token<ClientAssertionClaims>> {
match untrusted_token.algorithm() {
"HS256" => {
let key = Hs256Key::new(key);
let token = Hs256.validate_integrity(untrusted_token, &key)?;
Ok(token)
}
"HS384" => {
let key = Hs384Key::new(key);
let token = Hs384.validate_integrity(untrusted_token, &key)?;
Ok(token)
}
"HS512" => {
let key = Hs512Key::new(key);
let token = Hs512.validate_integrity(untrusted_token, &key)?;
Ok(token)
}
alg => anyhow::bail!("unsupported signing algorithm {}", alg),
}
}
async fn authenticate_client<T>( async fn authenticate_client<T>(
clients: Vec<OAuth2ClientConfig>, clients: Vec<OAuth2ClientConfig>,
auth_type: ClientAuthentication, audience: String,
client_id: String, credentials: ClientCredentials,
client_secret: Option<String>,
body: T, body: T,
) -> Result<(ClientAuthentication, OAuth2ClientConfig, T), Rejection> { ) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
let client = clients let auth_type = credentials.authentication_type();
.iter() let client = match credentials {
.find(|client| client.client_id == client_id) ClientCredentials::Pair {
.ok_or_else(|| ClientAuthenticationError::ClientNotFound { client_id,
client_id: client_id.to_string(), client_secret,
})?; ..
} => {
let client = clients
.iter()
.find(|client| client.client_id == client_id)
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
client_id: client_id.to_string(),
})?;
let client = match (client_secret, client.client_secret.as_ref()) { match (client_secret, client.client_secret.as_ref()) {
(None, None) => Ok(client), (None, None) => Ok(client),
(Some(ref given), Some(expected)) if given == expected => Ok(client), (Some(ref given), Some(expected)) if given == expected => Ok(client),
(Some(_), Some(_)) => Err(ClientAuthenticationError::ClientSecretMismatch { client_id }), (Some(_), Some(_)) => {
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }), Err(ClientAuthenticationError::ClientSecretMismatch { client_id })
(None, Some(_)) => Err(ClientAuthenticationError::ClientSecretRequired { client_id }), }
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }),
(None, Some(_)) => {
Err(ClientAuthenticationError::ClientSecretRequired { client_id })
}
}
}
ClientCredentials::Assertion {
client_id,
client_assertion_type: ClientAssertionType::JwtBearer,
client_assertion,
} => {
let untrusted_token = UntrustedToken::new(&client_assertion).wrap_error()?;
// client_id might have been passed as parameter. If not, it should be inferred
// from the token, as per rfc7521 sec. 4.2
// TODO: this is not a pretty way to do it
let client_id = client_id
.ok_or(()) // Dumb error type
.or_else(|()| {
let alg = Unsigned(untrusted_token.algorithm());
// We need to deserialize the token once without verifying the signature to get
// the client_id
let token: Token<ClientAssertionClaims> =
alg.validate_integrity(&untrusted_token, &())?;
Ok::<_, anyhow::Error>(token.claims().custom.subject.clone())
})
.wrap_error()?;
let client = clients
.iter()
.find(|client| client.client_id == client_id)
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
client_id: client_id.to_string(),
})?;
if let Some(client_secret) = &client.client_secret {
let token = verify_token(&untrusted_token, client_secret).wrap_error()?;
let time_options = TimeOptions::new(Duration::minutes(1), Utc::now);
// rfc7523 sec. 3.4: expiration must be set and validated
let claims = token
.claims()
.validate_expiration(&time_options)
.wrap_error()?;
// rfc7523 sec. 3.5: "not before" can be set and must be validated if present
if claims.not_before.is_some() {
claims.validate_maturity(&time_options).wrap_error()?;
}
// rfc7523 sec. 3.3: the audience is the URL being called
if claims.custom.audience != audience {
Err(ClientAuthenticationError::AudienceMismatch {
expected: audience,
got: claims.custom.audience.clone(),
})
// rfc7523 sec. 3.1 & 3.2: both the issuer and the subject must
// match the client_id
} else if claims.custom.issuer != claims.custom.subject
|| claims.custom.issuer != client_id
{
Err(ClientAuthenticationError::InvalidAssertion)
} else {
Ok(client)
}
} else {
Err(ClientAuthenticationError::ClientSecretRequired {
client_id: client_id.to_string(),
})
}
}
}?; }?;
Ok((auth_type, client.clone(), body)) Ok((auth_type, client.clone(), body))
} }
#[derive(Deserialize)]
enum ClientAssertionType {
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
JwtBearer,
}
enum CredentialsVia {
FormBody,
AuthorizationHeader,
}
impl Default for CredentialsVia {
fn default() -> Self {
Self::FormBody
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ClientCredentials {
// Order here is important: serde tries to deserialize enum variants in order, so if "Pair"
// was before "Assertion", a client_assertion with a client_id would match the "Pair"
// variant first
Assertion {
client_id: Option<String>,
client_assertion_type: ClientAssertionType,
client_assertion: String,
},
Pair {
#[serde(skip)]
via: CredentialsVia,
client_id: String,
client_secret: Option<String>,
},
}
impl ClientCredentials {
fn authentication_type(&self) -> ClientAuthenticationMethod {
match self {
ClientCredentials::Pair {
via: CredentialsVia::FormBody,
client_secret: None,
..
} => ClientAuthenticationMethod::None,
ClientCredentials::Pair {
via: CredentialsVia::FormBody,
client_secret: Some(_),
..
} => ClientAuthenticationMethod::ClientSecretPost,
ClientCredentials::Pair {
via: CredentialsVia::AuthorizationHeader,
..
} => ClientAuthenticationMethod::ClientSecretBasic,
ClientCredentials::Assertion { .. } => ClientAuthenticationMethod::ClientSecretJwt,
}
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct ClientAuthForm<T> { struct ClientAuthForm<T> {
client_id: String, #[serde(flatten)]
client_secret: Option<String>, credentials: ClientCredentials,
#[serde(flatten)] #[serde(flatten)]
body: T, body: T,
@@ -148,10 +349,16 @@ struct ClientAuthForm<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use headers::authorization::Credentials;
use jwt_compact::{Claims, Header};
use mas_config::ConfigurationSection; use mas_config::ConfigurationSection;
use serde_json::json;
use super::*; use super::*;
// Long client_secret to support it as a HS512 key
const CLIENT_SECRET: &str = "leek2zaeyeb8thai7piehea3vah6ool9oanin9aeraThuci9EeghaekaiD1upe4Quoh7xeMae2meitohj0Waaveiwaorah1yazohr6Vae7iebeiRaWene5IeWeeciezu";
fn oauth2_config() -> OAuth2Config { fn oauth2_config() -> OAuth2Config {
let mut config = OAuth2Config::test(); let mut config = OAuth2Config::test();
config.clients.push(OAuth2ClientConfig { config.clients.push(OAuth2ClientConfig {
@@ -161,7 +368,12 @@ mod tests {
}); });
config.clients.push(OAuth2ClientConfig { config.clients.push(OAuth2ClientConfig {
client_id: "confidential".to_string(), client_id: "confidential".to_string(),
client_secret: Some("secret".to_string()), client_secret: Some(CLIENT_SECRET.to_string()),
redirect_uris: Vec::new(),
});
config.clients.push(OAuth2ClientConfig {
client_id: "confidential-2".to_string(),
client_secret: Some(CLIENT_SECRET.to_string()),
redirect_uris: Vec::new(), redirect_uris: Vec::new(),
}); });
config config
@@ -174,17 +386,126 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn client_secret_post() { async fn client_secret_jwt_hs256() {
let filter = client_authentication::<Form>(&oauth2_config()); client_secret_jwt::<'_, Hs256>().await;
}
#[tokio::test]
async fn client_secret_jwt_hs384() {
client_secret_jwt::<'_, Hs384>().await;
}
#[tokio::test]
async fn client_secret_jwt_hs512() {
client_secret_jwt::<'_, Hs512>().await;
}
async fn client_secret_jwt<'k, A>()
where
A: Algorithm + Default,
A::SigningKey: From<&'k [u8]>,
{
let audience = "https://example.com/token".to_string();
let filter = client_authentication::<Form>(&oauth2_config(), audience.clone());
let time_options = TimeOptions::default();
let key = A::SigningKey::from(CLIENT_SECRET.as_bytes());
let alg = A::default();
let header = Header::default();
let claims = Claims::new(ClientAssertionClaims {
issuer: "confidential".to_string(),
subject: "confidential".to_string(),
audience,
jwt_id: None,
})
.set_duration_and_issuance(&time_options, Duration::seconds(15));
// TODO: test failing cases
// - expired token
// - "not before" in the future
// - subject/issuer mismatch
// - audience mismatch
// - wrong secret/signature
let token = alg
.token(header, &claims, &key)
.expect("could not sign token");
let (auth, client, body) = warp::test::request() let (auth, client, body) = warp::test::request()
.method("POST") .method("POST")
.body("client_id=confidential&client_secret=secret&foo=baz&bar=foobar") .header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_id": "confidential",
"client_assertion": token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter) .filter(&filter)
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthentication::ClientSecretPost); assert_eq!(auth, ClientAuthenticationMethod::ClientSecretJwt);
assert_eq!(client.client_id, "confidential");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
// Without client_id
let res = warp::test::request()
.method("POST")
.header("Content-Type", mime::APPLICATION_WWW_FORM_URLENCODED.to_string())
.body(serde_urlencoded::to_string(json!({
"client_assertion": token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_ok());
// client_id mismatch
let res = warp::test::request()
.method("POST")
.body(serde_urlencoded::to_string(json!({
"client_id": "confidential-2",
"client_assertion": token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"foo": "baz",
"bar": "foobar",
})).unwrap())
.filter(&filter)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn client_secret_post() {
let filter = client_authentication::<Form>(
&oauth2_config(),
"https://example.com/token".to_string(),
);
let (auth, client, body) = warp::test::request()
.method("POST")
.header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.body(
serde_urlencoded::to_string(json!({
"client_id": "confidential",
"client_secret": CLIENT_SECRET,
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost);
assert_eq!(client.client_id, "confidential"); assert_eq!(client.client_id, "confidential");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");
@@ -192,17 +513,31 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn client_secret_basic() { async fn client_secret_basic() {
let filter = client_authentication::<Form>(&oauth2_config()); let filter = client_authentication::<Form>(
&oauth2_config(),
"https://example.com/token".to_string(),
);
let auth = Authorization::basic("confidential", CLIENT_SECRET);
let (auth, client, body) = warp::test::request() let (auth, client, body) = warp::test::request()
.method("POST") .method("POST")
.header("Authorization", "Basic Y29uZmlkZW50aWFsOnNlY3JldA==") .header(
.body("foo=baz&bar=foobar") "Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.header("Authorization", auth.0.encode())
.body(
serde_urlencoded::to_string(json!({
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter) .filter(&filter)
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthentication::ClientSecretBasic); assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic);
assert_eq!(client.client_id, "confidential"); assert_eq!(client.client_id, "confidential");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");
@@ -210,16 +545,30 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn none() { async fn none() {
let filter = client_authentication::<Form>(&oauth2_config()); let filter = client_authentication::<Form>(
&oauth2_config(),
"https://example.com/token".to_string(),
);
let (auth, client, body) = warp::test::request() let (auth, client, body) = warp::test::request()
.method("POST") .method("POST")
.body("client_id=public&foo=baz&bar=foobar") .header(
"Content-Type",
mime::APPLICATION_WWW_FORM_URLENCODED.to_string(),
)
.body(
serde_urlencoded::to_string(json!({
"client_id": "public",
"foo": "baz",
"bar": "foobar",
}))
.unwrap(),
)
.filter(&filter) .filter(&filter)
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthentication::None); assert_eq!(auth, ClientAuthenticationMethod::None);
assert_eq!(client.client_id, "public"); assert_eq!(client.client_id, "public");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");

View File

@@ -113,6 +113,15 @@ where
params: T, params: T,
} }
#[derive(Serialize)]
struct ParamsWithState<T> {
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
match response_mode { match response_mode {
ResponseMode::Query => { ResponseMode::Query => {
let existing: Option<HashMap<&str, &str>> = redirect_uri let existing: Option<HashMap<&str, &str>> = redirect_uri
@@ -159,7 +168,8 @@ where
))) )))
} }
ResponseMode::FormPost => { ResponseMode::FormPost => {
let ctx = FormPostContext::new(redirect_uri, params); let merged = ParamsWithState { state, params };
let ctx = FormPostContext::new(redirect_uri, merged);
let rendered = templates.render_form_post(&ctx)?; let rendered = templates.render_form_post(&ctx)?;
Ok(Box::new(html(rendered))) Ok(Box::new(html(rendered)))
} }

View File

@@ -17,7 +17,7 @@ use std::collections::HashSet;
use hyper::Method; use hyper::Method;
use mas_config::OAuth2Config; use mas_config::OAuth2Config;
use oauth2_types::{ use oauth2_types::{
oidc::Metadata, oidc::{Metadata, SigningAlgorithm},
pkce::CodeChallengeMethod, pkce::CodeChallengeMethod,
requests::{ClientAuthenticationMethod, GrantType, ResponseMode}, requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
}; };
@@ -61,10 +61,19 @@ pub(super) fn filter(
let mut s = HashSet::new(); let mut s = HashSet::new();
s.insert(ClientAuthenticationMethod::ClientSecretBasic); s.insert(ClientAuthenticationMethod::ClientSecretBasic);
s.insert(ClientAuthenticationMethod::ClientSecretPost); s.insert(ClientAuthenticationMethod::ClientSecretPost);
s.insert(ClientAuthenticationMethod::ClientSecretJwt);
s.insert(ClientAuthenticationMethod::None); s.insert(ClientAuthenticationMethod::None);
s s
}); });
let token_endpoint_auth_signing_alg_values_supported = Some({
let mut s = HashSet::new();
s.insert(SigningAlgorithm::Hs256);
s.insert(SigningAlgorithm::Hs384);
s.insert(SigningAlgorithm::Hs512);
s
});
let code_challenge_methods_supported = Some({ let code_challenge_methods_supported = Some({
let mut s = HashSet::new(); let mut s = HashSet::new();
s.insert(CodeChallengeMethod::Plain); s.insert(CodeChallengeMethod::Plain);
@@ -85,6 +94,7 @@ pub(super) fn filter(
response_modes_supported, response_modes_supported,
grant_types_supported, grant_types_supported,
token_endpoint_auth_methods_supported, token_endpoint_auth_methods_supported,
token_endpoint_auth_signing_alg_values_supported,
code_challenge_methods_supported, code_challenge_methods_supported,
}; };

View File

@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
use hyper::Method; use hyper::Method;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint}; use oauth2_types::requests::{
ClientAuthenticationMethod, IntrospectionRequest, IntrospectionResponse, TokenTypeHint,
};
use sqlx::{pool::PoolConnection, PgPool, Postgres}; use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info, warn}; use tracing::{info, warn};
use warp::{Filter, Rejection, Reply}; use warp::{Filter, Rejection, Reply};
@@ -21,11 +23,7 @@ use warp::{Filter, Rejection, Reply};
use crate::{ use crate::{
config::{OAuth2ClientConfig, OAuth2Config}, config::{OAuth2ClientConfig, OAuth2Config},
errors::WrapError, errors::WrapError,
filters::{ filters::{client::client_authentication, cors::cors, database::connection},
client::{client_authentication, ClientAuthentication},
cors::cors,
database::connection,
},
storage::oauth2::{ storage::oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
}, },
@@ -36,10 +34,16 @@ pub fn filter(
pool: &PgPool, pool: &PgPool,
oauth2_config: &OAuth2Config, oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static { ) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let audience = oauth2_config
.issuer
.join("/oauth2/introspect")
.unwrap()
.to_string();
warp::path!("oauth2" / "introspect").and( warp::path!("oauth2" / "introspect").and(
warp::post() warp::post()
.and(connection(pool)) .and(connection(pool))
.and(client_authentication(oauth2_config)) .and(client_authentication(oauth2_config, audience))
.and_then(introspect) .and_then(introspect)
.recover(recover) .recover(recover)
.with(cors().allow_method(Method::POST)), .with(cors().allow_method(Method::POST)),
@@ -63,7 +67,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
async fn introspect( async fn introspect(
mut conn: PoolConnection<Postgres>, mut conn: PoolConnection<Postgres>,
auth: ClientAuthentication, auth: ClientAuthenticationMethod,
client: OAuth2ClientConfig, client: OAuth2ClientConfig,
params: IntrospectionRequest, params: IntrospectionRequest,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {

View File

@@ -22,7 +22,8 @@ use mas_data_model::AuthorizationGrantStage;
use oauth2_types::{ use oauth2_types::{
errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient}, errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
requests::{ requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant, AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant,
ClientAuthenticationMethod, RefreshTokenGrant,
}, },
scope::OPENID, scope::OPENID,
}; };
@@ -42,12 +43,7 @@ use warp::{
use crate::{ use crate::{
config::{KeySet, OAuth2ClientConfig, OAuth2Config}, config::{KeySet, OAuth2ClientConfig, OAuth2Config},
errors::WrapError, errors::WrapError,
filters::{ filters::{client::client_authentication, cors::cors, database::connection, with_keys},
client::{client_authentication, ClientAuthentication},
cors::cors,
database::connection,
with_keys,
},
reply::with_typed_header, reply::with_typed_header,
storage::{ storage::{
oauth2::{ oauth2::{
@@ -97,10 +93,16 @@ pub fn filter(
pool: &PgPool, pool: &PgPool,
oauth2_config: &OAuth2Config, oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static { ) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let audience = oauth2_config
.issuer
.join("/oauth2/token")
.unwrap()
.to_string();
let issuer = oauth2_config.issuer.clone(); let issuer = oauth2_config.issuer.clone();
warp::path!("oauth2" / "token").and( warp::path!("oauth2" / "token").and(
warp::post() warp::post()
.and(client_authentication(oauth2_config)) .and(client_authentication(oauth2_config, audience))
.and(with_keys(oauth2_config)) .and(with_keys(oauth2_config))
.and(warp::any().map(move || issuer.clone())) .and(warp::any().map(move || issuer.clone()))
.and(connection(pool)) .and(connection(pool))
@@ -119,7 +121,7 @@ async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
} }
async fn token( async fn token(
_auth: ClientAuthentication, _auth: ClientAuthenticationMethod,
client: OAuth2ClientConfig, client: OAuth2ClientConfig,
req: AccessTokenRequest, req: AccessTokenRequest,
keys: KeySet, keys: KeySet,

View File

@@ -23,6 +23,28 @@ use crate::{
requests::{ClientAuthenticationMethod, GrantType, ResponseMode}, requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
}; };
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
#[serde(rename_all = "UPPERCASE")]
pub enum SigningAlgorithm {
#[serde(rename = "none")]
None,
Hs256,
Hs384,
Hs512,
Ps256,
Ps384,
Ps512,
Rs256,
Rs384,
Rs512,
Es256,
Es256K,
Es384,
Es512,
#[serde(rename = "EcDSA")]
EcDsa,
}
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2 // TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Clone)] #[derive(Serialize, Clone)]
@@ -65,6 +87,13 @@ pub struct Metadata {
/// by this token endpoint. /// by this token endpoint.
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>, pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
/// JSON array containing a list of the JWS signing algorithms supported by
/// the Token Endpoint for the signature on the JWT used to authenticate
/// the Client at the Token Endpoint for the private_key_jwt and
/// client_secret_jwt authentication methods. Servers SHOULD support
/// RS256. The value none MUST NOT be used.
pub token_endpoint_auth_signing_alg_values_supported: Option<HashSet<SigningAlgorithm>>,
/// PKCE code challenge methods supported by this authorization server /// PKCE code challenge methods supported by this authorization server
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>, pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,

View File

@@ -91,6 +91,16 @@ pub enum ClientAuthenticationMethod {
None, None,
ClientSecretPost, ClientSecretPost,
ClientSecretBasic, ClientSecretBasic,
ClientSecretJwt,
PrivateKeyJwt,
}
impl ClientAuthenticationMethod {
#[must_use]
/// Check if the authentication method is for public client or not
pub fn public(&self) -> bool {
matches!(self, &Self::None)
}
} }
#[derive( #[derive(