1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Use iana generated types in more places

This commit is contained in:
Quentin Gliech
2022-01-12 12:22:54 +01:00
parent 2844706bb1
commit 5b9c35a079
20 changed files with 222 additions and 211 deletions

3
Cargo.lock generated
View File

@ -1548,6 +1548,7 @@ version = "0.1.0"
dependencies = [
"chrono",
"crc",
"mas-iana",
"oauth2-types",
"rand",
"serde",
@ -1670,6 +1671,7 @@ dependencies = [
"argon2",
"chrono",
"mas-data-model",
"mas-iana",
"oauth2-types",
"password-hash",
"rand",
@ -1729,6 +1731,7 @@ dependencies = [
"hyper",
"mas-config",
"mas-data-model",
"mas-iana",
"mas-jose",
"mas-storage",
"mas-templates",

View File

@ -13,4 +13,5 @@ url = { version = "2.2.2", features = ["serde"] }
crc = "2.1.0"
rand = "0.8.4"
mas-iana = { path = "../iana" }
oauth2-types = { path = "../oauth2-types" }

View File

@ -15,7 +15,8 @@
use std::num::NonZeroU32;
use chrono::{DateTime, Duration, Utc};
use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode};
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{pkce::CodeChallengeMethodExt, requests::ResponseMode};
use serde::Serialize;
use thiserror::Error;
use url::Url;
@ -25,13 +26,13 @@ use crate::{traits::StorageBackend, StorageBackendMarker};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Pkce {
pub challenge_method: CodeChallengeMethod,
pub challenge_method: PkceCodeChallengeMethod,
pub challenge: String,
}
impl Pkce {
#[must_use]
pub fn new(challenge_method: CodeChallengeMethod, challenge: String) -> Self {
pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
Pkce {
challenge_method,
challenge,

View File

@ -14,7 +14,7 @@
use chrono::{DateTime, Duration, Utc};
use crc::{Crc, CRC_32_ISO_HDLC};
use oauth2_types::requests::TokenTypeHint;
use mas_iana::oauth::OAuthTokenTypeHint;
use rand::{distributions::Alphanumeric, Rng};
use thiserror::Error;
@ -159,12 +159,12 @@ impl TokenType {
}
}
impl PartialEq<TokenTypeHint> for TokenType {
fn eq(&self, other: &TokenTypeHint) -> bool {
impl PartialEq<OAuthTokenTypeHint> for TokenType {
fn eq(&self, other: &OAuthTokenTypeHint) -> bool {
matches!(
(self, other),
(TokenType::AccessToken, TokenTypeHint::AccessToken)
| (TokenType::RefreshToken, TokenTypeHint::RefreshToken)
(TokenType::AccessToken, OAuthTokenTypeHint::AccessToken)
| (TokenType::RefreshToken, OAuthTokenTypeHint::RefreshToken)
)
}
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use chrono::Duration;
use hyper::{
@ -25,6 +25,7 @@ use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Pkce, StorageBackend, TokenType,
};
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use mas_storage::{
oauth2::{
access_token::add_access_token,
@ -50,9 +51,9 @@ use oauth2_types::{
RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported,
},
pkce,
prelude::*,
requests::{
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, Prompt, ResponseMode,
ResponseType,
},
scope::ScopeToken,
};
@ -191,16 +192,15 @@ struct Params {
/// figure out what response mode must be used, and emit an error if the
/// suggested response mode isn't allowed for the given response types.
fn resolve_response_mode(
response_type: &HashSet<ResponseType>,
response_type: OAuthAuthorizationEndpointResponseType,
suggested_response_mode: Option<ResponseMode>,
) -> anyhow::Result<ResponseMode> {
use ResponseMode as M;
use ResponseType as T;
// If the response type includes either "token" or "id_token", the default
// response mode is "fragment" and the response mode "query" must not be
// used
if response_type.contains(&T::Token) || response_type.contains(&T::IdToken) {
if response_type.has_token() || response_type.has_id_token() {
match suggested_response_mode {
None => Ok(M::Fragment),
Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")),
@ -345,11 +345,11 @@ async fn get(
let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?;
let response_type = &params.auth.response_type;
let response_type = params.auth.response_type;
let response_mode =
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
let code: Option<AuthorizationCode> = if response_type.contains(&ResponseType::Code) {
let code: Option<AuthorizationCode> = if response_type.has_code() {
// 32 random alphanumeric characters, about 190bit of entropy
let code: String = thread_rng()
.sample_iter(&Alphanumeric)
@ -400,8 +400,8 @@ async fn get(
params.auth.max_age,
None,
response_mode,
response_type.contains(&ResponseType::Token),
response_type.contains(&ResponseType::IdToken),
response_type.has_token(),
response_type.has_id_token(),
)
.await
.wrap_error()?;

View File

@ -15,12 +15,17 @@
use std::collections::HashSet;
use mas_config::OAuth2Config;
use mas_iana::jose::JsonWebSignatureAlg;
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{
OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod,
PkceCodeChallengeMethod,
},
};
use mas_jose::SigningKeystore;
use oauth2_types::{
oidc::{ClaimType, Metadata, SubjectType},
pkce::CodeChallengeMethod,
requests::{ClientAuthenticationMethod, Display, GrantType, ResponseMode},
requests::{Display, GrantType, ResponseMode},
};
use warp::{filters::BoxedFilter, Filter, Reply};
@ -34,11 +39,11 @@ pub(super) fn filter(
// This is how clients can authenticate
let client_auth_methods_supported = Some({
let mut s = HashSet::new();
s.insert(ClientAuthenticationMethod::ClientSecretBasic);
s.insert(ClientAuthenticationMethod::ClientSecretPost);
s.insert(ClientAuthenticationMethod::ClientSecretJwt);
s.insert(ClientAuthenticationMethod::PrivateKeyJwt);
s.insert(ClientAuthenticationMethod::None);
s.insert(OAuthClientAuthenticationMethod::ClientSecretBasic);
s.insert(OAuthClientAuthenticationMethod::ClientSecretPost);
s.insert(OAuthClientAuthenticationMethod::ClientSecretJwt);
s.insert(OAuthClientAuthenticationMethod::PrivateKeyJwt);
s.insert(OAuthClientAuthenticationMethod::None);
s
});
@ -72,13 +77,13 @@ pub(super) fn filter(
let response_types_supported = Some({
let mut s = HashSet::new();
s.insert("code".to_string());
s.insert("token".to_string());
s.insert("id_token".to_string());
s.insert("code token".to_string());
s.insert("code id_token".to_string());
s.insert("token id_token".to_string());
s.insert("code token id_token".to_string());
s.insert(OAuthAuthorizationEndpointResponseType::Code);
s.insert(OAuthAuthorizationEndpointResponseType::Token);
s.insert(OAuthAuthorizationEndpointResponseType::IdToken);
s.insert(OAuthAuthorizationEndpointResponseType::CodeToken);
s.insert(OAuthAuthorizationEndpointResponseType::CodeIdToken);
s.insert(OAuthAuthorizationEndpointResponseType::IdTokenToken);
s.insert(OAuthAuthorizationEndpointResponseType::CodeIdToken);
s
});
@ -107,8 +112,8 @@ pub(super) fn filter(
let code_challenge_methods_supported = Some({
let mut s = HashSet::new();
s.insert(CodeChallengeMethod::Plain);
s.insert(CodeChallengeMethod::S256);
s.insert(PkceCodeChallengeMethod::Plain);
s.insert(PkceCodeChallengeMethod::S256);
s
});

View File

@ -14,6 +14,7 @@
use mas_config::{OAuth2ClientConfig, OAuth2Config};
use mas_data_model::TokenType;
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_storage::oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
};
@ -21,9 +22,7 @@ use mas_warp_utils::{
errors::WrapError,
filters::{client::client_authentication, database::connection},
};
use oauth2_types::requests::{
ClientAuthenticationMethod, IntrospectionRequest, IntrospectionResponse, TokenTypeHint,
};
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info, warn};
use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
@ -64,12 +63,12 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
async fn introspect(
mut conn: PoolConnection<Postgres>,
auth: ClientAuthenticationMethod,
auth: OAuthClientAuthenticationMethod,
client: OAuth2ClientConfig,
params: IntrospectionRequest,
) -> Result<Box<dyn Reply>, Rejection> {
// Token introspection is only allowed by confidential clients
if auth.public() {
if auth == OAuthClientAuthenticationMethod::None {
warn!(?client, "Client tried to introspect");
// TODO: have a nice error here
return Ok(Box::new(warp::reply::json(&INACTIVE)));
@ -96,7 +95,7 @@ async fn introspect(
scope: Some(session.scope),
client_id: Some(session.client.client_id),
username: Some(session.browser_session.user.username),
token_type: Some(TokenTypeHint::AccessToken),
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: Some(exp),
iat: Some(token.created_at),
nbf: Some(token.created_at),
@ -116,7 +115,7 @@ async fn introspect(
scope: Some(session.scope),
client_id: Some(session.client.client_id),
username: Some(session.browser_session.user.username),
token_type: Some(TokenTypeHint::RefreshToken),
token_type: Some(OAuthTokenTypeHint::RefreshToken),
exp: None,
iat: Some(token.created_at),
nbf: Some(token.created_at),

View File

@ -21,7 +21,7 @@ use headers::{CacheControl, Pragma};
use hyper::StatusCode;
use mas_config::{OAuth2ClientConfig, OAuth2Config};
use mas_data_model::{AuthorizationGrantStage, TokenType};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{
claims::{AT_HASH, AUD, AUTH_TIME, C_HASH, EXP, IAT, ISS, NONCE, SUB},
DecodedJsonWebToken, SigningKeystore, StaticKeystore,
@ -42,8 +42,7 @@ use mas_warp_utils::{
use oauth2_types::{
errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant,
ClientAuthenticationMethod, RefreshTokenGrant,
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
},
scope::OPENID,
};
@ -131,7 +130,7 @@ async fn recover(rejection: Rejection) -> Result<Box<dyn Reply>, Rejection> {
}
async fn token(
_auth: ClientAuthenticationMethod,
_auth: OAuthClientAuthenticationMethod,
client: OAuth2ClientConfig,
req: AccessTokenRequest,
key_store: Arc<StaticKeystore>,

View File

@ -182,10 +182,12 @@ async fn generate_oauth(client: &Arc<Client>, path: PathBuf) -> anyhow::Result<(
"https://www.iana.org/assignments/jose/jose.xhtml",
client.clone(),
)
.load::<TokenTypeHint>()
.load::<AccessTokenType>()
.await?
.load::<AuthorizationEndpointResponseType>()
.await?
.load::<TokenTypeHint>()
.await?
.load::<TokenEndpointAuthenticationMethod>()
.await?
.load::<PkceCodeChallengeMethod>()

View File

@ -21,22 +21,25 @@ use crate::{
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct TokenTypeHint {
#[serde(rename = "Hint Value")]
pub struct AccessTokenType {
#[serde(rename = "Name")]
name: String,
#[serde(rename = "Additional Token Endpoint Response Parameters")]
additional_parameters: String,
#[serde(rename = "HTTP Authentication Scheme(s)")]
http_schemes: String,
#[serde(rename = "Change Controller")]
change_controller: String,
#[serde(rename = "Reference")]
reference: String,
}
impl EnumEntry for TokenTypeHint {
const URL: &'static str =
"https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv";
const SECTIONS: &'static [Section] = &[s("OAuthTokenTypeHint", "OAuth Token Type Hint")];
impl EnumEntry for AccessTokenType {
const URL: &'static str = "https://www.iana.org/assignments/oauth-parameters/token-types.csv";
const SECTIONS: &'static [Section] = &[s("OAuthAccessTokenType", "OAuth Access Token Type")];
fn key(&self) -> Option<&'static str> {
Some("OAuthTokenTypeHint")
Some("OAuthAccessTokenType")
}
fn name(&self) -> &str {
@ -82,16 +85,41 @@ pub struct TokenEndpointAuthenticationMethod {
reference: String,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct TokenTypeHint {
#[serde(rename = "Hint Value")]
name: String,
#[serde(rename = "Change Controller")]
change_controller: String,
#[serde(rename = "Reference")]
reference: String,
}
impl EnumEntry for TokenTypeHint {
const URL: &'static str =
"https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv";
const SECTIONS: &'static [Section] = &[s("OAuthTokenTypeHint", "OAuth Token Type Hint")];
fn key(&self) -> Option<&'static str> {
Some("OAuthTokenTypeHint")
}
fn name(&self) -> &str {
&self.name
}
}
impl EnumEntry for TokenEndpointAuthenticationMethod {
const URL: &'static str =
"https://www.iana.org/assignments/oauth-parameters/token-endpoint-auth-method.csv";
const SECTIONS: &'static [Section] = &[s(
"OAuthTokenEndpointAuthenticationMethod",
"OAuthClientAuthenticationMethod",
"OAuth Token Endpoint Authentication Method",
)];
fn key(&self) -> Option<&'static str> {
Some("OAuthTokenEndpointAuthenticationMethod")
Some("OAuthClientAuthenticationMethod")
}
fn name(&self) -> &str {

View File

@ -61,7 +61,11 @@ pub trait EnumEntry: DeserializeOwned + Send + Sync {
None
}
fn enum_name(&self) -> String {
self.name().replace('+', "_").to_case(Case::Pascal)
// Do the case transformation twice to have "N_A" turned to "Na" instead of "NA"
self.name()
.replace('+', "_")
.to_case(Case::Pascal)
.to_case(Case::Pascal)
}
async fn fetch(client: &Client) -> anyhow::Result<Vec<(&'static str, EnumMember)>> {

View File

@ -21,24 +21,24 @@ use parse_display::{Display, FromStr};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
/// OAuth Token Type Hint
/// OAuth Access Token Type
///
/// Source: <https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv>
/// Source: <https://www.iana.org/assignments/oauth-parameters/token-types.csv>
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema,
)]
pub enum OAuthTokenTypeHint {
#[serde(rename = "access_token")]
#[display("access_token")]
AccessToken,
pub enum OAuthAccessTokenType {
#[serde(rename = "Bearer")]
#[display("Bearer")]
Bearer,
#[serde(rename = "refresh_token")]
#[display("refresh_token")]
RefreshToken,
#[serde(rename = "N_A")]
#[display("N_A")]
Na,
#[serde(rename = "pct")]
#[display("pct")]
Pct,
#[serde(rename = "PoP")]
#[display("PoP")]
PoP,
}
/// OAuth Authorization Endpoint Response Type
@ -81,13 +81,33 @@ pub enum OAuthAuthorizationEndpointResponseType {
Token,
}
/// OAuth Token Type Hint
///
/// Source: <https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv>
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema,
)]
pub enum OAuthTokenTypeHint {
#[serde(rename = "access_token")]
#[display("access_token")]
AccessToken,
#[serde(rename = "refresh_token")]
#[display("refresh_token")]
RefreshToken,
#[serde(rename = "pct")]
#[display("pct")]
Pct,
}
/// OAuth Token Endpoint Authentication Method
///
/// Source: <https://www.iana.org/assignments/oauth-parameters/token-endpoint-auth-method.csv>
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema,
)]
pub enum OAuthTokenEndpointAuthenticationMethod {
pub enum OAuthClientAuthenticationMethod {
#[serde(rename = "none")]
#[display("none")]
None,

View File

@ -16,11 +16,46 @@
#![deny(clippy::all)]
#![warn(clippy::pedantic)]
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
pub trait ResponseTypeExt {
fn has_code(&self) -> bool;
fn has_token(&self) -> bool;
fn has_id_token(&self) -> bool;
}
impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType {
fn has_code(&self) -> bool {
matches!(
self,
Self::Code | Self::CodeToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
fn has_token(&self) -> bool {
matches!(
self,
Self::Token | Self::CodeToken | Self::IdTokenToken | Self::CodeIdTokenToken
)
}
fn has_id_token(&self) -> bool {
matches!(
self,
Self::IdToken | Self::IdTokenToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
}
pub mod errors;
pub mod oidc;
pub mod pkce;
pub mod requests;
pub mod scope;
pub mod prelude {
pub use crate::{pkce::CodeChallengeMethodExt, ResponseTypeExt};
}
#[cfg(test)]
mod test_utils;

View File

@ -14,15 +14,18 @@
use std::collections::HashSet;
use mas_iana::jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg};
use mas_iana::{
jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg},
oauth::{
OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod,
PkceCodeChallengeMethod,
},
};
use serde::Serialize;
use serde_with::skip_serializing_none;
use url::Url;
use crate::{
pkce::CodeChallengeMethod,
requests::{ClientAuthenticationMethod, Display, GrantType, ResponseMode},
};
use crate::requests::{Display, GrantType, ResponseMode};
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
@ -66,7 +69,7 @@ pub struct Metadata {
/// JSON array containing a list of the OAuth 2.0 "response_type" values
/// that this authorization server supports.
pub response_types_supported: Option<HashSet<String>>,
pub response_types_supported: Option<HashSet<OAuthAuthorizationEndpointResponseType>>,
/// JSON array containing a list of the OAuth 2.0 "response_mode" values
/// that this authorization server supports.
@ -78,7 +81,7 @@ pub struct Metadata {
/// JSON array containing a list of client authentication methods supported
/// by this token endpoint.
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
pub token_endpoint_auth_methods_supported: Option<HashSet<OAuthClientAuthenticationMethod>>,
/// JSON array containing a list of the JWS signing algorithms supported by
/// the token endpoint for the signature on the JWT used to authenticate the
@ -109,7 +112,8 @@ pub struct Metadata {
/// JSON array containing a list of client authentication methods supported
/// by this revocation endpoint.
pub revocation_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
pub revocation_endpoint_auth_methods_supported:
Option<HashSet<OAuthClientAuthenticationMethod>>,
/// JSON array containing a list of the JWS signing algorithms supported by
/// the revocation endpoint for the signature on the JWT used to
@ -121,7 +125,8 @@ pub struct Metadata {
/// JSON array containing a list of client authentication methods supported
/// by this introspection endpoint.
pub introspection_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
pub introspection_endpoint_auth_methods_supported:
Option<HashSet<OAuthClientAuthenticationMethod>>,
/// JSON array containing a list of the JWS signing algorithms supported by
/// the introspection endpoint for the signature on the JWT used to
@ -130,7 +135,7 @@ pub struct Metadata {
Option<HashSet<JsonWebSignatureAlg>>,
/// PKCE code challenge methods supported by this authorization server.
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
pub code_challenge_methods_supported: Option<HashSet<PkceCodeChallengeMethod>>,
/// URL of the OP's UserInfo Endpoint.
pub userinfo_endpoint: Option<Url>,

View File

@ -15,40 +15,23 @@
use std::borrow::Cow;
use data_encoding::BASE64URL_NOPAD;
use parse_display::{Display, FromStr};
use mas_iana::oauth::PkceCodeChallengeMethod;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
pub enum CodeChallengeMethod {
#[serde(rename = "plain")]
#[display("plain")]
Plain,
pub trait CodeChallengeMethodExt {
#[must_use]
fn compute_challenge(self, verifier: &str) -> Cow<'_, str>;
#[serde(rename = "S256")]
#[display("S256")]
S256,
#[must_use]
fn verify(self, challenge: &str, verifier: &str) -> bool;
}
impl CodeChallengeMethod {
#[must_use]
pub fn compute_challenge(self, verifier: &str) -> Cow<'_, str> {
impl CodeChallengeMethodExt for PkceCodeChallengeMethod {
fn compute_challenge(self, verifier: &str) -> Cow<'_, str> {
match self {
CodeChallengeMethod::Plain => verifier.into(),
CodeChallengeMethod::S256 => {
Self::Plain => verifier.into(),
Self::S256 => {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
@ -58,15 +41,14 @@ impl CodeChallengeMethod {
}
}
#[must_use]
pub fn verify(self, challenge: &str, verifier: &str) -> bool {
fn verify(self, challenge: &str, verifier: &str) -> bool {
self.compute_challenge(verifier) == challenge
}
}
#[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest {
pub code_challenge_method: CodeChallengeMethod,
pub code_challenge_method: PkceCodeChallengeMethod,
pub code_challenge: String,
}

View File

@ -16,6 +16,9 @@ use std::{collections::HashSet, hash::Hash, num::NonZeroU32};
use chrono::{DateTime, Duration, Utc};
use language_tags::LanguageTag;
use mas_iana::oauth::{
OAuthAccessTokenType, OAuthAuthorizationEndpointResponseType, OAuthTokenTypeHint,
};
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
use serde_with::{
@ -28,29 +31,6 @@ use crate::scope::Scope;
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[display(style = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum ResponseType {
Code,
IdToken,
Token,
None,
}
#[derive(
Debug,
Hash,
@ -72,37 +52,6 @@ pub enum ResponseMode {
FormPost,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthenticationMethod {
None,
ClientSecretPost,
ClientSecretBasic,
ClientSecretJwt,
PrivateKeyJwt,
}
impl ClientAuthenticationMethod {
#[must_use]
/// Check if the authentication method is for public client or not
pub fn public(&self) -> bool {
matches!(self, &Self::None)
}
}
#[derive(
Debug,
Hash,
@ -151,8 +100,7 @@ pub enum Prompt {
#[serde_as]
#[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest {
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
pub response_type: HashSet<ResponseType>,
pub response_type: OAuthAuthorizationEndpointResponseType,
pub client_id: String,
@ -200,25 +148,6 @@ pub struct AuthorizationResponse<R> {
pub response: R,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
Bearer,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AuthorizationCodeGrant {
@ -285,7 +214,7 @@ pub struct AccessTokenResponse {
// TODO: this should be somewhere else
id_token: Option<String>,
token_type: TokenType,
token_type: OAuthAccessTokenType,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
expires_in: Option<Duration>,
@ -300,7 +229,7 @@ impl AccessTokenResponse {
access_token,
refresh_token: None,
id_token: None,
token_type: TokenType::Bearer,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}
@ -331,20 +260,13 @@ impl AccessTokenResponse {
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TokenTypeHint {
AccessToken,
RefreshToken,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct IntrospectionRequest {
pub token: String,
#[serde(default)]
pub token_type_hint: Option<TokenTypeHint>,
pub token_type_hint: Option<OAuthTokenTypeHint>,
}
#[serde_as]
@ -359,7 +281,7 @@ pub struct IntrospectionResponse {
pub username: Option<String>,
pub token_type: Option<TokenTypeHint>,
pub token_type: Option<OAuthTokenTypeHint>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub exp: Option<DateTime<Utc>>,

View File

@ -23,3 +23,4 @@ url = { version = "2.2.2", features = ["serde"] }
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-iana = { path = "../iana" }

View File

@ -22,7 +22,8 @@ use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Client, Pkce, Session, User,
};
use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode, scope::Scope};
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope};
use sqlx::PgExecutor;
use url::Url;
@ -237,12 +238,12 @@ impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
let pkce = match (self.grant_code_challenge, self.grant_code_challenge_method) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce {
challenge_method: CodeChallengeMethod::Plain,
challenge_method: PkceCodeChallengeMethod::Plain,
challenge,
})
}
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
challenge_method: CodeChallengeMethod::S256,
challenge_method: PkceCodeChallengeMethod::S256,
challenge,
}),
(None, None) => None,

View File

@ -35,3 +35,4 @@ mas-templates = { path = "../templates" }
mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" }
mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" }

View File

@ -18,11 +18,11 @@ use std::collections::HashMap;
use headers::{authorization::Basic, Authorization};
use mas_config::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
DecodedJsonWebToken, JsonWebTokenParts, SharedSecret,
};
use oauth2_types::requests::ClientAuthenticationMethod;
use serde::{de::DeserializeOwned, Deserialize};
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection};
@ -35,7 +35,7 @@ use crate::errors::WrapError;
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
oauth2_config: &OAuth2Config,
audience: String,
) -> impl Filter<Extract = (ClientAuthenticationMethod, OAuth2ClientConfig, T), Error = Rejection>
) -> impl Filter<Extract = (OAuthClientAuthenticationMethod, OAuth2ClientConfig, T), Error = Rejection>
+ Clone
+ Send
+ Sync
@ -99,7 +99,7 @@ async fn authenticate_client<T>(
audience: String,
credentials: ClientCredentials,
body: T,
) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
) -> Result<(OAuthClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
let (auth_method, client) = match credentials {
ClientCredentials::Pair {
client_id,
@ -114,7 +114,9 @@ async fn authenticate_client<T>(
})?;
let auth_method = match (&client.client_auth_method, client_secret, via) {
(OAuth2ClientAuthMethodConfig::None, None, _) => ClientAuthenticationMethod::None,
(OAuth2ClientAuthMethodConfig::None, None, _) => {
OAuthClientAuthenticationMethod::None
}
(
OAuth2ClientAuthMethodConfig::ClientSecretBasic {
@ -129,7 +131,7 @@ async fn authenticate_client<T>(
);
}
ClientAuthenticationMethod::ClientSecretBasic
OAuthClientAuthenticationMethod::ClientSecretBasic
}
(
@ -145,7 +147,7 @@ async fn authenticate_client<T>(
);
}
ClientAuthenticationMethod::ClientSecretPost
OAuthClientAuthenticationMethod::ClientSecretPost
}
_ => {
@ -204,13 +206,13 @@ async fn authenticate_client<T>(
OAuth2ClientAuthMethodConfig::PrivateKeyJwt(jwks) => {
let store = jwks.key_store();
token.verify(&decoded, &store).await.wrap_error()?;
ClientAuthenticationMethod::PrivateKeyJwt
OAuthClientAuthenticationMethod::PrivateKeyJwt
}
OAuth2ClientAuthMethodConfig::ClientSecretJwt { client_secret } => {
let store = SharedSecret::new(client_secret);
token.verify(&decoded, &store).await.wrap_error()?;
ClientAuthenticationMethod::ClientSecretJwt
OAuthClientAuthenticationMethod::ClientSecretJwt
}
_ => {
@ -428,7 +430,7 @@ mod tests {
.await
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretJwt);
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretJwt);
assert_eq!(client.client_id, "secret-jwt");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
@ -515,7 +517,7 @@ mod tests {
.await
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::PrivateKeyJwt);
assert_eq!(auth, OAuthClientAuthenticationMethod::PrivateKeyJwt);
assert_eq!(client.client_id, "private-key-jwt");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
@ -575,7 +577,7 @@ mod tests {
.await
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost);
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretPost);
assert_eq!(client.client_id, "secret-post");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
@ -607,7 +609,7 @@ mod tests {
.await
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic);
assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretBasic);
assert_eq!(client.client_id, "secret-basic");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
@ -638,7 +640,7 @@ mod tests {
.await
.unwrap();
assert_eq!(auth, ClientAuthenticationMethod::None);
assert_eq!(auth, OAuthClientAuthenticationMethod::None);
assert_eq!(client.client_id, "public");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");