diff --git a/Cargo.lock b/Cargo.lock index 158e5921..dc5cb59e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1548,6 +1548,7 @@ version = "0.1.0" dependencies = [ "chrono", "crc", + "mas-iana", "oauth2-types", "rand", "serde", @@ -1670,6 +1671,7 @@ dependencies = [ "argon2", "chrono", "mas-data-model", + "mas-iana", "oauth2-types", "password-hash", "rand", @@ -1729,6 +1731,7 @@ dependencies = [ "hyper", "mas-config", "mas-data-model", + "mas-iana", "mas-jose", "mas-storage", "mas-templates", diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index e9b684f0..1215198a 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -13,4 +13,5 @@ url = { version = "2.2.2", features = ["serde"] } crc = "2.1.0" rand = "0.8.4" +mas-iana = { path = "../iana" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index a7ccf1f2..7cffc4e3 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -15,7 +15,8 @@ use std::num::NonZeroU32; use chrono::{DateTime, Duration, Utc}; -use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode}; +use mas_iana::oauth::PkceCodeChallengeMethod; +use oauth2_types::{pkce::CodeChallengeMethodExt, requests::ResponseMode}; use serde::Serialize; use thiserror::Error; use url::Url; @@ -25,13 +26,13 @@ use crate::{traits::StorageBackend, StorageBackendMarker}; #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Pkce { - pub challenge_method: CodeChallengeMethod, + pub challenge_method: PkceCodeChallengeMethod, pub challenge: String, } impl Pkce { #[must_use] - pub fn new(challenge_method: CodeChallengeMethod, challenge: String) -> Self { + pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self { Pkce { challenge_method, challenge, diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index d65abeef..2932e4f2 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -14,7 +14,7 @@ use chrono::{DateTime, Duration, Utc}; use crc::{Crc, CRC_32_ISO_HDLC}; -use oauth2_types::requests::TokenTypeHint; +use mas_iana::oauth::OAuthTokenTypeHint; use rand::{distributions::Alphanumeric, Rng}; use thiserror::Error; @@ -159,12 +159,12 @@ impl TokenType { } } -impl PartialEq for TokenType { - fn eq(&self, other: &TokenTypeHint) -> bool { +impl PartialEq for TokenType { + fn eq(&self, other: &OAuthTokenTypeHint) -> bool { matches!( (self, other), - (TokenType::AccessToken, TokenTypeHint::AccessToken) - | (TokenType::RefreshToken, TokenTypeHint::RefreshToken) + (TokenType::AccessToken, OAuthTokenTypeHint::AccessToken) + | (TokenType::RefreshToken, OAuthTokenTypeHint::RefreshToken) ) } } diff --git a/crates/handlers/src/oauth2/authorization.rs b/crates/handlers/src/oauth2/authorization.rs index 7c658335..52b10801 100644 --- a/crates/handlers/src/oauth2/authorization.rs +++ b/crates/handlers/src/oauth2/authorization.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use chrono::Duration; use hyper::{ @@ -25,6 +25,7 @@ use mas_data_model::{ Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Pkce, StorageBackend, TokenType, }; +use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; use mas_storage::{ oauth2::{ access_token::add_access_token, @@ -50,9 +51,9 @@ use oauth2_types::{ RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported, }, pkce, + prelude::*, requests::{ AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, Prompt, ResponseMode, - ResponseType, }, scope::ScopeToken, }; @@ -191,16 +192,15 @@ struct Params { /// figure out what response mode must be used, and emit an error if the /// suggested response mode isn't allowed for the given response types. fn resolve_response_mode( - response_type: &HashSet, + response_type: OAuthAuthorizationEndpointResponseType, suggested_response_mode: Option, ) -> anyhow::Result { use ResponseMode as M; - use ResponseType as T; // If the response type includes either "token" or "id_token", the default // response mode is "fragment" and the response mode "query" must not be // used - if response_type.contains(&T::Token) || response_type.contains(&T::IdToken) { + if response_type.has_token() || response_type.has_id_token() { match suggested_response_mode { None => Ok(M::Fragment), Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")), @@ -345,11 +345,11 @@ async fn get( let redirect_uri = client .resolve_redirect_uri(¶ms.auth.redirect_uri) .wrap_error()?; - let response_type = ¶ms.auth.response_type; + let response_type = params.auth.response_type; let response_mode = resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?; - let code: Option = if response_type.contains(&ResponseType::Code) { + let code: Option = if response_type.has_code() { // 32 random alphanumeric characters, about 190bit of entropy let code: String = thread_rng() .sample_iter(&Alphanumeric) @@ -400,8 +400,8 @@ async fn get( params.auth.max_age, None, response_mode, - response_type.contains(&ResponseType::Token), - response_type.contains(&ResponseType::IdToken), + response_type.has_token(), + response_type.has_id_token(), ) .await .wrap_error()?; diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index 87b7bb40..45e0597b 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -15,12 +15,17 @@ use std::collections::HashSet; use mas_config::OAuth2Config; -use mas_iana::jose::JsonWebSignatureAlg; +use mas_iana::{ + jose::JsonWebSignatureAlg, + oauth::{ + OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod, + PkceCodeChallengeMethod, + }, +}; use mas_jose::SigningKeystore; use oauth2_types::{ oidc::{ClaimType, Metadata, SubjectType}, - pkce::CodeChallengeMethod, - requests::{ClientAuthenticationMethod, Display, GrantType, ResponseMode}, + requests::{Display, GrantType, ResponseMode}, }; use warp::{filters::BoxedFilter, Filter, Reply}; @@ -34,11 +39,11 @@ pub(super) fn filter( // This is how clients can authenticate let client_auth_methods_supported = Some({ let mut s = HashSet::new(); - s.insert(ClientAuthenticationMethod::ClientSecretBasic); - s.insert(ClientAuthenticationMethod::ClientSecretPost); - s.insert(ClientAuthenticationMethod::ClientSecretJwt); - s.insert(ClientAuthenticationMethod::PrivateKeyJwt); - s.insert(ClientAuthenticationMethod::None); + s.insert(OAuthClientAuthenticationMethod::ClientSecretBasic); + s.insert(OAuthClientAuthenticationMethod::ClientSecretPost); + s.insert(OAuthClientAuthenticationMethod::ClientSecretJwt); + s.insert(OAuthClientAuthenticationMethod::PrivateKeyJwt); + s.insert(OAuthClientAuthenticationMethod::None); s }); @@ -72,13 +77,13 @@ pub(super) fn filter( let response_types_supported = Some({ let mut s = HashSet::new(); - s.insert("code".to_string()); - s.insert("token".to_string()); - s.insert("id_token".to_string()); - s.insert("code token".to_string()); - s.insert("code id_token".to_string()); - s.insert("token id_token".to_string()); - s.insert("code token id_token".to_string()); + s.insert(OAuthAuthorizationEndpointResponseType::Code); + s.insert(OAuthAuthorizationEndpointResponseType::Token); + s.insert(OAuthAuthorizationEndpointResponseType::IdToken); + s.insert(OAuthAuthorizationEndpointResponseType::CodeToken); + s.insert(OAuthAuthorizationEndpointResponseType::CodeIdToken); + s.insert(OAuthAuthorizationEndpointResponseType::IdTokenToken); + s.insert(OAuthAuthorizationEndpointResponseType::CodeIdToken); s }); @@ -107,8 +112,8 @@ pub(super) fn filter( let code_challenge_methods_supported = Some({ let mut s = HashSet::new(); - s.insert(CodeChallengeMethod::Plain); - s.insert(CodeChallengeMethod::S256); + s.insert(PkceCodeChallengeMethod::Plain); + s.insert(PkceCodeChallengeMethod::S256); s }); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 8a043b6d..4dda8634 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -14,6 +14,7 @@ use mas_config::{OAuth2ClientConfig, OAuth2Config}; use mas_data_model::TokenType; +use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_storage::oauth2::{ access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, }; @@ -21,9 +22,7 @@ use mas_warp_utils::{ errors::WrapError, filters::{client::client_authentication, database::connection}, }; -use oauth2_types::requests::{ - ClientAuthenticationMethod, IntrospectionRequest, IntrospectionResponse, TokenTypeHint, -}; +use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse}; use sqlx::{pool::PoolConnection, PgPool, Postgres}; use tracing::{info, warn}; use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; @@ -64,12 +63,12 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse { async fn introspect( mut conn: PoolConnection, - auth: ClientAuthenticationMethod, + auth: OAuthClientAuthenticationMethod, client: OAuth2ClientConfig, params: IntrospectionRequest, ) -> Result, Rejection> { // Token introspection is only allowed by confidential clients - if auth.public() { + if auth == OAuthClientAuthenticationMethod::None { warn!(?client, "Client tried to introspect"); // TODO: have a nice error here return Ok(Box::new(warp::reply::json(&INACTIVE))); @@ -96,7 +95,7 @@ async fn introspect( scope: Some(session.scope), client_id: Some(session.client.client_id), username: Some(session.browser_session.user.username), - token_type: Some(TokenTypeHint::AccessToken), + token_type: Some(OAuthTokenTypeHint::AccessToken), exp: Some(exp), iat: Some(token.created_at), nbf: Some(token.created_at), @@ -116,7 +115,7 @@ async fn introspect( scope: Some(session.scope), client_id: Some(session.client.client_id), username: Some(session.browser_session.user.username), - token_type: Some(TokenTypeHint::RefreshToken), + token_type: Some(OAuthTokenTypeHint::RefreshToken), exp: None, iat: Some(token.created_at), nbf: Some(token.created_at), diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 128798e5..0d9f4b30 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -21,7 +21,7 @@ use headers::{CacheControl, Pragma}; use hyper::StatusCode; use mas_config::{OAuth2ClientConfig, OAuth2Config}; use mas_data_model::{AuthorizationGrantStage, TokenType}; -use mas_iana::jose::JsonWebSignatureAlg; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_jose::{ claims::{AT_HASH, AUD, AUTH_TIME, C_HASH, EXP, IAT, ISS, NONCE, SUB}, DecodedJsonWebToken, SigningKeystore, StaticKeystore, @@ -42,8 +42,7 @@ use mas_warp_utils::{ use oauth2_types::{ errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient}, requests::{ - AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, - ClientAuthenticationMethod, RefreshTokenGrant, + AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant, }, scope::OPENID, }; @@ -131,7 +130,7 @@ async fn recover(rejection: Rejection) -> Result, Rejection> { } async fn token( - _auth: ClientAuthenticationMethod, + _auth: OAuthClientAuthenticationMethod, client: OAuth2ClientConfig, req: AccessTokenRequest, key_store: Arc, diff --git a/crates/iana-codegen/src/main.rs b/crates/iana-codegen/src/main.rs index abf4f0f5..02fb3d79 100644 --- a/crates/iana-codegen/src/main.rs +++ b/crates/iana-codegen/src/main.rs @@ -182,10 +182,12 @@ async fn generate_oauth(client: &Arc, path: PathBuf) -> anyhow::Result<( "https://www.iana.org/assignments/jose/jose.xhtml", client.clone(), ) - .load::() + .load::() .await? .load::() .await? + .load::() + .await? .load::() .await? .load::() diff --git a/crates/iana-codegen/src/oauth.rs b/crates/iana-codegen/src/oauth.rs index 13a893b7..e21b6768 100644 --- a/crates/iana-codegen/src/oauth.rs +++ b/crates/iana-codegen/src/oauth.rs @@ -21,22 +21,25 @@ use crate::{ #[allow(dead_code)] #[derive(Debug, Deserialize)] -pub struct TokenTypeHint { - #[serde(rename = "Hint Value")] +pub struct AccessTokenType { + #[serde(rename = "Name")] name: String, + #[serde(rename = "Additional Token Endpoint Response Parameters")] + additional_parameters: String, + #[serde(rename = "HTTP Authentication Scheme(s)")] + http_schemes: String, #[serde(rename = "Change Controller")] change_controller: String, #[serde(rename = "Reference")] reference: String, } -impl EnumEntry for TokenTypeHint { - const URL: &'static str = - "https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv"; - const SECTIONS: &'static [Section] = &[s("OAuthTokenTypeHint", "OAuth Token Type Hint")]; +impl EnumEntry for AccessTokenType { + const URL: &'static str = "https://www.iana.org/assignments/oauth-parameters/token-types.csv"; + const SECTIONS: &'static [Section] = &[s("OAuthAccessTokenType", "OAuth Access Token Type")]; fn key(&self) -> Option<&'static str> { - Some("OAuthTokenTypeHint") + Some("OAuthAccessTokenType") } fn name(&self) -> &str { @@ -82,16 +85,41 @@ pub struct TokenEndpointAuthenticationMethod { reference: String, } +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct TokenTypeHint { + #[serde(rename = "Hint Value")] + name: String, + #[serde(rename = "Change Controller")] + change_controller: String, + #[serde(rename = "Reference")] + reference: String, +} + +impl EnumEntry for TokenTypeHint { + const URL: &'static str = + "https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv"; + const SECTIONS: &'static [Section] = &[s("OAuthTokenTypeHint", "OAuth Token Type Hint")]; + + fn key(&self) -> Option<&'static str> { + Some("OAuthTokenTypeHint") + } + + fn name(&self) -> &str { + &self.name + } +} + impl EnumEntry for TokenEndpointAuthenticationMethod { const URL: &'static str = "https://www.iana.org/assignments/oauth-parameters/token-endpoint-auth-method.csv"; const SECTIONS: &'static [Section] = &[s( - "OAuthTokenEndpointAuthenticationMethod", + "OAuthClientAuthenticationMethod", "OAuth Token Endpoint Authentication Method", )]; fn key(&self) -> Option<&'static str> { - Some("OAuthTokenEndpointAuthenticationMethod") + Some("OAuthClientAuthenticationMethod") } fn name(&self) -> &str { diff --git a/crates/iana-codegen/src/traits.rs b/crates/iana-codegen/src/traits.rs index cc66f152..6fec86e7 100644 --- a/crates/iana-codegen/src/traits.rs +++ b/crates/iana-codegen/src/traits.rs @@ -61,7 +61,11 @@ pub trait EnumEntry: DeserializeOwned + Send + Sync { None } fn enum_name(&self) -> String { - self.name().replace('+', "_").to_case(Case::Pascal) + // Do the case transformation twice to have "N_A" turned to "Na" instead of "NA" + self.name() + .replace('+', "_") + .to_case(Case::Pascal) + .to_case(Case::Pascal) } async fn fetch(client: &Client) -> anyhow::Result> { diff --git a/crates/iana/src/oauth.rs b/crates/iana/src/oauth.rs index e3e70ada..8b64294e 100644 --- a/crates/iana/src/oauth.rs +++ b/crates/iana/src/oauth.rs @@ -21,24 +21,24 @@ use parse_display::{Display, FromStr}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -/// OAuth Token Type Hint +/// OAuth Access Token Type /// -/// Source: +/// Source: #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema, )] -pub enum OAuthTokenTypeHint { - #[serde(rename = "access_token")] - #[display("access_token")] - AccessToken, +pub enum OAuthAccessTokenType { + #[serde(rename = "Bearer")] + #[display("Bearer")] + Bearer, - #[serde(rename = "refresh_token")] - #[display("refresh_token")] - RefreshToken, + #[serde(rename = "N_A")] + #[display("N_A")] + Na, - #[serde(rename = "pct")] - #[display("pct")] - Pct, + #[serde(rename = "PoP")] + #[display("PoP")] + PoP, } /// OAuth Authorization Endpoint Response Type @@ -81,13 +81,33 @@ pub enum OAuthAuthorizationEndpointResponseType { Token, } +/// OAuth Token Type Hint +/// +/// Source: +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema, +)] +pub enum OAuthTokenTypeHint { + #[serde(rename = "access_token")] + #[display("access_token")] + AccessToken, + + #[serde(rename = "refresh_token")] + #[display("refresh_token")] + RefreshToken, + + #[serde(rename = "pct")] + #[display("pct")] + Pct, +} + /// OAuth Token Endpoint Authentication Method /// /// Source: #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema, )] -pub enum OAuthTokenEndpointAuthenticationMethod { +pub enum OAuthClientAuthenticationMethod { #[serde(rename = "none")] #[display("none")] None, diff --git a/crates/oauth2-types/src/lib.rs b/crates/oauth2-types/src/lib.rs index 09d52b4b..67b84e62 100644 --- a/crates/oauth2-types/src/lib.rs +++ b/crates/oauth2-types/src/lib.rs @@ -16,11 +16,46 @@ #![deny(clippy::all)] #![warn(clippy::pedantic)] +use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; + +pub trait ResponseTypeExt { + fn has_code(&self) -> bool; + fn has_token(&self) -> bool; + fn has_id_token(&self) -> bool; +} + +impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType { + fn has_code(&self) -> bool { + matches!( + self, + Self::Code | Self::CodeToken | Self::CodeIdToken | Self::CodeIdTokenToken + ) + } + + fn has_token(&self) -> bool { + matches!( + self, + Self::Token | Self::CodeToken | Self::IdTokenToken | Self::CodeIdTokenToken + ) + } + + fn has_id_token(&self) -> bool { + matches!( + self, + Self::IdToken | Self::IdTokenToken | Self::CodeIdToken | Self::CodeIdTokenToken + ) + } +} + pub mod errors; pub mod oidc; pub mod pkce; pub mod requests; pub mod scope; +pub mod prelude { + pub use crate::{pkce::CodeChallengeMethodExt, ResponseTypeExt}; +} + #[cfg(test)] mod test_utils; diff --git a/crates/oauth2-types/src/oidc.rs b/crates/oauth2-types/src/oidc.rs index 9baef27c..9a376dd5 100644 --- a/crates/oauth2-types/src/oidc.rs +++ b/crates/oauth2-types/src/oidc.rs @@ -14,15 +14,18 @@ use std::collections::HashSet; -use mas_iana::jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}; +use mas_iana::{ + jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}, + oauth::{ + OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod, + PkceCodeChallengeMethod, + }, +}; use serde::Serialize; use serde_with::skip_serializing_none; use url::Url; -use crate::{ - pkce::CodeChallengeMethod, - requests::{ClientAuthenticationMethod, Display, GrantType, ResponseMode}, -}; +use crate::requests::{Display, GrantType, ResponseMode}; #[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)] #[serde(rename_all = "lowercase")] @@ -66,7 +69,7 @@ pub struct Metadata { /// JSON array containing a list of the OAuth 2.0 "response_type" values /// that this authorization server supports. - pub response_types_supported: Option>, + pub response_types_supported: Option>, /// JSON array containing a list of the OAuth 2.0 "response_mode" values /// that this authorization server supports. @@ -78,7 +81,7 @@ pub struct Metadata { /// JSON array containing a list of client authentication methods supported /// by this token endpoint. - pub token_endpoint_auth_methods_supported: Option>, + pub token_endpoint_auth_methods_supported: Option>, /// JSON array containing a list of the JWS signing algorithms supported by /// the token endpoint for the signature on the JWT used to authenticate the @@ -109,7 +112,8 @@ pub struct Metadata { /// JSON array containing a list of client authentication methods supported /// by this revocation endpoint. - pub revocation_endpoint_auth_methods_supported: Option>, + pub revocation_endpoint_auth_methods_supported: + Option>, /// JSON array containing a list of the JWS signing algorithms supported by /// the revocation endpoint for the signature on the JWT used to @@ -121,7 +125,8 @@ pub struct Metadata { /// JSON array containing a list of client authentication methods supported /// by this introspection endpoint. - pub introspection_endpoint_auth_methods_supported: Option>, + pub introspection_endpoint_auth_methods_supported: + Option>, /// JSON array containing a list of the JWS signing algorithms supported by /// the introspection endpoint for the signature on the JWT used to @@ -130,7 +135,7 @@ pub struct Metadata { Option>, /// PKCE code challenge methods supported by this authorization server. - pub code_challenge_methods_supported: Option>, + pub code_challenge_methods_supported: Option>, /// URL of the OP's UserInfo Endpoint. pub userinfo_endpoint: Option, diff --git a/crates/oauth2-types/src/pkce.rs b/crates/oauth2-types/src/pkce.rs index 76bbc238..980f4a4f 100644 --- a/crates/oauth2-types/src/pkce.rs +++ b/crates/oauth2-types/src/pkce.rs @@ -15,40 +15,23 @@ use std::borrow::Cow; use data_encoding::BASE64URL_NOPAD; -use parse_display::{Display, FromStr}; +use mas_iana::oauth::PkceCodeChallengeMethod; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -#[derive( - Debug, - Hash, - PartialEq, - Eq, - PartialOrd, - Ord, - Clone, - Copy, - Display, - FromStr, - Serialize, - Deserialize, -)] -pub enum CodeChallengeMethod { - #[serde(rename = "plain")] - #[display("plain")] - Plain, +pub trait CodeChallengeMethodExt { + #[must_use] + fn compute_challenge(self, verifier: &str) -> Cow<'_, str>; - #[serde(rename = "S256")] - #[display("S256")] - S256, + #[must_use] + fn verify(self, challenge: &str, verifier: &str) -> bool; } -impl CodeChallengeMethod { - #[must_use] - pub fn compute_challenge(self, verifier: &str) -> Cow<'_, str> { +impl CodeChallengeMethodExt for PkceCodeChallengeMethod { + fn compute_challenge(self, verifier: &str) -> Cow<'_, str> { match self { - CodeChallengeMethod::Plain => verifier.into(), - CodeChallengeMethod::S256 => { + Self::Plain => verifier.into(), + Self::S256 => { let mut hasher = Sha256::new(); hasher.update(verifier.as_bytes()); let hash = hasher.finalize(); @@ -58,15 +41,14 @@ impl CodeChallengeMethod { } } - #[must_use] - pub fn verify(self, challenge: &str, verifier: &str) -> bool { + fn verify(self, challenge: &str, verifier: &str) -> bool { self.compute_challenge(verifier) == challenge } } #[derive(Serialize, Deserialize)] pub struct AuthorizationRequest { - pub code_challenge_method: CodeChallengeMethod, + pub code_challenge_method: PkceCodeChallengeMethod, pub code_challenge: String, } diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index a81d3c46..71f97a20 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -16,6 +16,9 @@ use std::{collections::HashSet, hash::Hash, num::NonZeroU32}; use chrono::{DateTime, Duration, Utc}; use language_tags::LanguageTag; +use mas_iana::oauth::{ + OAuthAccessTokenType, OAuthAuthorizationEndpointResponseType, OAuthTokenTypeHint, +}; use parse_display::{Display, FromStr}; use serde::{Deserialize, Serialize}; use serde_with::{ @@ -28,29 +31,6 @@ use crate::scope::Scope; // ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml -#[derive( - Debug, - Hash, - PartialEq, - Eq, - PartialOrd, - Ord, - Clone, - Copy, - Display, - FromStr, - Serialize, - Deserialize, -)] -#[display(style = "snake_case")] -#[serde(rename_all = "snake_case")] -pub enum ResponseType { - Code, - IdToken, - Token, - None, -} - #[derive( Debug, Hash, @@ -72,37 +52,6 @@ pub enum ResponseMode { FormPost, } -#[derive( - Debug, - Hash, - PartialEq, - Eq, - PartialOrd, - Ord, - Clone, - Copy, - Display, - FromStr, - Serialize, - Deserialize, -)] -#[serde(rename_all = "snake_case")] -pub enum ClientAuthenticationMethod { - None, - ClientSecretPost, - ClientSecretBasic, - ClientSecretJwt, - PrivateKeyJwt, -} - -impl ClientAuthenticationMethod { - #[must_use] - /// Check if the authentication method is for public client or not - pub fn public(&self) -> bool { - matches!(self, &Self::None) - } -} - #[derive( Debug, Hash, @@ -151,8 +100,7 @@ pub enum Prompt { #[serde_as] #[derive(Serialize, Deserialize)] pub struct AuthorizationRequest { - #[serde_as(as = "StringWithSeparator::")] - pub response_type: HashSet, + pub response_type: OAuthAuthorizationEndpointResponseType, pub client_id: String, @@ -200,25 +148,6 @@ pub struct AuthorizationResponse { pub response: R, } -#[derive( - Debug, - Hash, - PartialEq, - Eq, - PartialOrd, - Ord, - Clone, - Copy, - Display, - FromStr, - Serialize, - Deserialize, -)] -#[serde(rename_all = "snake_case")] -pub enum TokenType { - Bearer, -} - #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct AuthorizationCodeGrant { @@ -285,7 +214,7 @@ pub struct AccessTokenResponse { // TODO: this should be somewhere else id_token: Option, - token_type: TokenType, + token_type: OAuthAccessTokenType, #[serde_as(as = "Option>")] expires_in: Option, @@ -300,7 +229,7 @@ impl AccessTokenResponse { access_token, refresh_token: None, id_token: None, - token_type: TokenType::Bearer, + token_type: OAuthAccessTokenType::Bearer, expires_in: None, scope: None, } @@ -331,20 +260,13 @@ impl AccessTokenResponse { } } -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum TokenTypeHint { - AccessToken, - RefreshToken, -} - #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct IntrospectionRequest { pub token: String, #[serde(default)] - pub token_type_hint: Option, + pub token_type_hint: Option, } #[serde_as] @@ -359,7 +281,7 @@ pub struct IntrospectionResponse { pub username: Option, - pub token_type: Option, + pub token_type: Option, #[serde_as(as = "Option")] pub exp: Option>, diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 2840438c..4909af26 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -23,3 +23,4 @@ url = { version = "2.2.2", features = ["serde"] } oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } +mas-iana = { path = "../iana" } diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index e996308e..54dc92a1 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -22,7 +22,8 @@ use mas_data_model::{ Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Client, Pkce, Session, User, }; -use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode, scope::Scope}; +use mas_iana::oauth::PkceCodeChallengeMethod; +use oauth2_types::{requests::ResponseMode, scope::Scope}; use sqlx::PgExecutor; use url::Url; @@ -237,12 +238,12 @@ impl TryInto> for GrantLookup { let pkce = match (self.grant_code_challenge, self.grant_code_challenge_method) { (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { Some(Pkce { - challenge_method: CodeChallengeMethod::Plain, + challenge_method: PkceCodeChallengeMethod::Plain, challenge, }) } (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { - challenge_method: CodeChallengeMethod::S256, + challenge_method: PkceCodeChallengeMethod::S256, challenge, }), (None, None) => None, diff --git a/crates/warp-utils/Cargo.toml b/crates/warp-utils/Cargo.toml index d5ea353c..cec59cc2 100644 --- a/crates/warp-utils/Cargo.toml +++ b/crates/warp-utils/Cargo.toml @@ -35,3 +35,4 @@ mas-templates = { path = "../templates" } mas-data-model = { path = "../data-model" } mas-storage = { path = "../storage" } mas-jose = { path = "../jose" } +mas-iana = { path = "../iana" } diff --git a/crates/warp-utils/src/filters/client.rs b/crates/warp-utils/src/filters/client.rs index d07041af..a73e4eba 100644 --- a/crates/warp-utils/src/filters/client.rs +++ b/crates/warp-utils/src/filters/client.rs @@ -18,11 +18,11 @@ use std::collections::HashMap; use headers::{authorization::Basic, Authorization}; use mas_config::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config}; +use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{ claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB}, DecodedJsonWebToken, JsonWebTokenParts, SharedSecret, }; -use oauth2_types::requests::ClientAuthenticationMethod; use serde::{de::DeserializeOwned, Deserialize}; use thiserror::Error; use warp::{reject::Reject, Filter, Rejection}; @@ -35,7 +35,7 @@ use crate::errors::WrapError; pub fn client_authentication( oauth2_config: &OAuth2Config, audience: String, -) -> impl Filter +) -> impl Filter + Clone + Send + Sync @@ -99,7 +99,7 @@ async fn authenticate_client( audience: String, credentials: ClientCredentials, body: T, -) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> { +) -> Result<(OAuthClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> { let (auth_method, client) = match credentials { ClientCredentials::Pair { client_id, @@ -114,7 +114,9 @@ async fn authenticate_client( })?; let auth_method = match (&client.client_auth_method, client_secret, via) { - (OAuth2ClientAuthMethodConfig::None, None, _) => ClientAuthenticationMethod::None, + (OAuth2ClientAuthMethodConfig::None, None, _) => { + OAuthClientAuthenticationMethod::None + } ( OAuth2ClientAuthMethodConfig::ClientSecretBasic { @@ -129,7 +131,7 @@ async fn authenticate_client( ); } - ClientAuthenticationMethod::ClientSecretBasic + OAuthClientAuthenticationMethod::ClientSecretBasic } ( @@ -145,7 +147,7 @@ async fn authenticate_client( ); } - ClientAuthenticationMethod::ClientSecretPost + OAuthClientAuthenticationMethod::ClientSecretPost } _ => { @@ -204,13 +206,13 @@ async fn authenticate_client( OAuth2ClientAuthMethodConfig::PrivateKeyJwt(jwks) => { let store = jwks.key_store(); token.verify(&decoded, &store).await.wrap_error()?; - ClientAuthenticationMethod::PrivateKeyJwt + OAuthClientAuthenticationMethod::PrivateKeyJwt } OAuth2ClientAuthMethodConfig::ClientSecretJwt { client_secret } => { let store = SharedSecret::new(client_secret); token.verify(&decoded, &store).await.wrap_error()?; - ClientAuthenticationMethod::ClientSecretJwt + OAuthClientAuthenticationMethod::ClientSecretJwt } _ => { @@ -428,7 +430,7 @@ mod tests { .await .unwrap(); - assert_eq!(auth, ClientAuthenticationMethod::ClientSecretJwt); + assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretJwt); assert_eq!(client.client_id, "secret-jwt"); assert_eq!(body.foo, "baz"); assert_eq!(body.bar, "foobar"); @@ -515,7 +517,7 @@ mod tests { .await .unwrap(); - assert_eq!(auth, ClientAuthenticationMethod::PrivateKeyJwt); + assert_eq!(auth, OAuthClientAuthenticationMethod::PrivateKeyJwt); assert_eq!(client.client_id, "private-key-jwt"); assert_eq!(body.foo, "baz"); assert_eq!(body.bar, "foobar"); @@ -575,7 +577,7 @@ mod tests { .await .unwrap(); - assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost); + assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretPost); assert_eq!(client.client_id, "secret-post"); assert_eq!(body.foo, "baz"); assert_eq!(body.bar, "foobar"); @@ -607,7 +609,7 @@ mod tests { .await .unwrap(); - assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic); + assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretBasic); assert_eq!(client.client_id, "secret-basic"); assert_eq!(body.foo, "baz"); assert_eq!(body.bar, "foobar"); @@ -638,7 +640,7 @@ mod tests { .await .unwrap(); - assert_eq!(auth, ClientAuthenticationMethod::None); + assert_eq!(auth, OAuthClientAuthenticationMethod::None); assert_eq!(client.client_id, "public"); assert_eq!(body.foo, "baz"); assert_eq!(body.bar, "foobar");