1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Use iana generated types in more places

This commit is contained in:
Quentin Gliech
2022-01-12 12:22:54 +01:00
parent 2844706bb1
commit 5b9c35a079
20 changed files with 222 additions and 211 deletions

3
Cargo.lock generated
View File

@ -1548,6 +1548,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"crc", "crc",
"mas-iana",
"oauth2-types", "oauth2-types",
"rand", "rand",
"serde", "serde",
@ -1670,6 +1671,7 @@ dependencies = [
"argon2", "argon2",
"chrono", "chrono",
"mas-data-model", "mas-data-model",
"mas-iana",
"oauth2-types", "oauth2-types",
"password-hash", "password-hash",
"rand", "rand",
@ -1729,6 +1731,7 @@ dependencies = [
"hyper", "hyper",
"mas-config", "mas-config",
"mas-data-model", "mas-data-model",
"mas-iana",
"mas-jose", "mas-jose",
"mas-storage", "mas-storage",
"mas-templates", "mas-templates",

View File

@ -13,4 +13,5 @@ url = { version = "2.2.2", features = ["serde"] }
crc = "2.1.0" crc = "2.1.0"
rand = "0.8.4" rand = "0.8.4"
mas-iana = { path = "../iana" }
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }

View File

@ -15,7 +15,8 @@
use std::num::NonZeroU32; use std::num::NonZeroU32;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode}; use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{pkce::CodeChallengeMethodExt, requests::ResponseMode};
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
use url::Url; use url::Url;
@ -25,13 +26,13 @@ use crate::{traits::StorageBackend, StorageBackendMarker};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Pkce { pub struct Pkce {
pub challenge_method: CodeChallengeMethod, pub challenge_method: PkceCodeChallengeMethod,
pub challenge: String, pub challenge: String,
} }
impl Pkce { impl Pkce {
#[must_use] #[must_use]
pub fn new(challenge_method: CodeChallengeMethod, challenge: String) -> Self { pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
Pkce { Pkce {
challenge_method, challenge_method,
challenge, challenge,

View File

@ -14,7 +14,7 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use crc::{Crc, CRC_32_ISO_HDLC}; use crc::{Crc, CRC_32_ISO_HDLC};
use oauth2_types::requests::TokenTypeHint; use mas_iana::oauth::OAuthTokenTypeHint;
use rand::{distributions::Alphanumeric, Rng}; use rand::{distributions::Alphanumeric, Rng};
use thiserror::Error; use thiserror::Error;
@ -159,12 +159,12 @@ impl TokenType {
} }
} }
impl PartialEq<TokenTypeHint> for TokenType { impl PartialEq<OAuthTokenTypeHint> for TokenType {
fn eq(&self, other: &TokenTypeHint) -> bool { fn eq(&self, other: &OAuthTokenTypeHint) -> bool {
matches!( matches!(
(self, other), (self, other),
(TokenType::AccessToken, TokenTypeHint::AccessToken) (TokenType::AccessToken, OAuthTokenTypeHint::AccessToken)
| (TokenType::RefreshToken, TokenTypeHint::RefreshToken) | (TokenType::RefreshToken, OAuthTokenTypeHint::RefreshToken)
) )
} }
} }

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use chrono::Duration; use chrono::Duration;
use hyper::{ use hyper::{
@ -25,6 +25,7 @@ use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Pkce, StorageBackend, TokenType, Pkce, StorageBackend, TokenType,
}; };
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use mas_storage::{ use mas_storage::{
oauth2::{ oauth2::{
access_token::add_access_token, access_token::add_access_token,
@ -50,9 +51,9 @@ use oauth2_types::{
RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported, RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported,
}, },
pkce, pkce,
prelude::*,
requests::{ requests::{
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, Prompt, ResponseMode, AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, Prompt, ResponseMode,
ResponseType,
}, },
scope::ScopeToken, scope::ScopeToken,
}; };
@ -191,16 +192,15 @@ struct Params {
/// figure out what response mode must be used, and emit an error if the /// figure out what response mode must be used, and emit an error if the
/// suggested response mode isn't allowed for the given response types. /// suggested response mode isn't allowed for the given response types.
fn resolve_response_mode( fn resolve_response_mode(
response_type: &HashSet<ResponseType>, response_type: OAuthAuthorizationEndpointResponseType,
suggested_response_mode: Option<ResponseMode>, suggested_response_mode: Option<ResponseMode>,
) -> anyhow::Result<ResponseMode> { ) -> anyhow::Result<ResponseMode> {
use ResponseMode as M; use ResponseMode as M;
use ResponseType as T;
// If the response type includes either "token" or "id_token", the default // If the response type includes either "token" or "id_token", the default
// response mode is "fragment" and the response mode "query" must not be // response mode is "fragment" and the response mode "query" must not be
// used // used
if response_type.contains(&T::Token) || response_type.contains(&T::IdToken) { if response_type.has_token() || response_type.has_id_token() {
match suggested_response_mode { match suggested_response_mode {
None => Ok(M::Fragment), None => Ok(M::Fragment),
Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")), Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")),
@ -345,11 +345,11 @@ async fn get(
let redirect_uri = client let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri) .resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?; .wrap_error()?;
let response_type = &params.auth.response_type; let response_type = params.auth.response_type;
let response_mode = let response_mode =
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?; resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
let code: Option<AuthorizationCode> = if response_type.contains(&ResponseType::Code) { let code: Option<AuthorizationCode> = if response_type.has_code() {
// 32 random alphanumeric characters, about 190bit of entropy // 32 random alphanumeric characters, about 190bit of entropy
let code: String = thread_rng() let code: String = thread_rng()
.sample_iter(&Alphanumeric) .sample_iter(&Alphanumeric)
@ -400,8 +400,8 @@ async fn get(
params.auth.max_age, params.auth.max_age,
None, None,
response_mode, response_mode,
response_type.contains(&ResponseType::Token), response_type.has_token(),
response_type.contains(&ResponseType::IdToken), response_type.has_id_token(),
) )
.await .await
.wrap_error()?; .wrap_error()?;

View File

@ -15,12 +15,17 @@
use std::collections::HashSet; use std::collections::HashSet;
use mas_config::OAuth2Config; use mas_config::OAuth2Config;
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{
OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod,
PkceCodeChallengeMethod,
},
};
use mas_jose::SigningKeystore; use mas_jose::SigningKeystore;
use oauth2_types::{ use oauth2_types::{
oidc::{ClaimType, Metadata, SubjectType}, oidc::{ClaimType, Metadata, SubjectType},
pkce::CodeChallengeMethod, requests::{Display, GrantType, ResponseMode},
requests::{ClientAuthenticationMethod, Display, GrantType, ResponseMode},
}; };
use warp::{filters::BoxedFilter, Filter, Reply}; use warp::{filters::BoxedFilter, Filter, Reply};
@ -34,11 +39,11 @@ pub(super) fn filter(
// This is how clients can authenticate // This is how clients can authenticate
let client_auth_methods_supported = Some({ let client_auth_methods_supported = Some({
let mut s = HashSet::new(); let mut s = HashSet::new();
s.insert(ClientAuthenticationMethod::ClientSecretBasic); s.insert(OAuthClientAuthenticationMethod::ClientSecretBasic);
s.insert(ClientAuthenticationMethod::ClientSecretPost); s.insert(OAuthClientAuthenticationMethod::ClientSecretPost);
s.insert(ClientAuthenticationMethod::ClientSecretJwt); s.insert(OAuthClientAuthenticationMethod::ClientSecretJwt);
s.insert(ClientAuthenticationMethod::PrivateKeyJwt); s.insert(OAuthClientAuthenticationMethod::PrivateKeyJwt);
s.insert(ClientAuthenticationMethod::None); s.insert(OAuthClientAuthenticationMethod::None);
s s
}); });
@ -72,13 +77,13 @@ pub(super) fn filter(
let response_types_supported = Some({ let response_types_supported = Some({
let mut s = HashSet::new(); let mut s = HashSet::new();
s.insert("code".to_string()); s.insert(OAuthAuthorizationEndpointResponseType::Code);
s.insert("token".to_string()); s.insert(OAuthAuthorizationEndpointResponseType::Token);
s.insert("id_token".to_string()); s.insert(OAuthAuthorizationEndpointResponseType::IdToken);
s.insert("code token".to_string()); s.insert(OAuthAuthorizationEndpointResponseType::CodeToken);
s.insert("code id_token".to_string()); s.insert(OAuthAuthorizationEndpointResponseType::CodeIdToken);
s.insert("token id_token".to_string()); s.insert(OAuthAuthorizationEndpointResponseType::IdTokenToken);
s.insert("code token id_token".to_string()); s.insert(OAuthAuthorizationEndpointResponseType::CodeIdToken);
s s
}); });
@ -107,8 +112,8 @@ pub(super) fn filter(
let code_challenge_methods_supported = Some({ let code_challenge_methods_supported = Some({
let mut s = HashSet::new(); let mut s = HashSet::new();
s.insert(CodeChallengeMethod::Plain); s.insert(PkceCodeChallengeMethod::Plain);
s.insert(CodeChallengeMethod::S256); s.insert(PkceCodeChallengeMethod::S256);
s s
}); });

View File

@ -14,6 +14,7 @@
use mas_config::{OAuth2ClientConfig, OAuth2Config}; use mas_config::{OAuth2ClientConfig, OAuth2Config};
use mas_data_model::TokenType; use mas_data_model::TokenType;
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_storage::oauth2::{ use mas_storage::oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
}; };
@ -21,9 +22,7 @@ use mas_warp_utils::{
errors::WrapError, errors::WrapError,
filters::{client::client_authentication, database::connection}, filters::{client::client_authentication, database::connection},
}; };
use oauth2_types::requests::{ use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
ClientAuthenticationMethod, IntrospectionRequest, IntrospectionResponse, TokenTypeHint,
};
use sqlx::{pool::PoolConnection, PgPool, Postgres}; use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info, warn}; use tracing::{info, warn};
use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
@ -64,12 +63,12 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
async fn introspect( async fn introspect(
mut conn: PoolConnection<Postgres>, mut conn: PoolConnection<Postgres>,
auth: ClientAuthenticationMethod, auth: OAuthClientAuthenticationMethod,
client: OAuth2ClientConfig, client: OAuth2ClientConfig,
params: IntrospectionRequest, params: IntrospectionRequest,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
// Token introspection is only allowed by confidential clients // Token introspection is only allowed by confidential clients
if auth.public() { if auth == OAuthClientAuthenticationMethod::None {
warn!(?client, "Client tried to introspect"); warn!(?client, "Client tried to introspect");
// TODO: have a nice error here // TODO: have a nice error here
return Ok(Box::new(warp::reply::json(&INACTIVE))); return Ok(Box::new(warp::reply::json(&INACTIVE)));
@ -96,7 +95,7 @@ async fn introspect(
scope: Some(session.scope), scope: Some(session.scope),
client_id: Some(session.client.client_id), client_id: Some(session.client.client_id),
username: Some(session.browser_session.user.username), username: Some(session.browser_session.user.username),
token_type: Some(TokenTypeHint::AccessToken), token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: Some(exp), exp: Some(exp),
iat: Some(token.created_at), iat: Some(token.created_at),
nbf: Some(token.created_at), nbf: Some(token.created_at),
@ -116,7 +115,7 @@ async fn introspect(
scope: Some(session.scope), scope: Some(session.scope),
client_id: Some(session.client.client_id), client_id: Some(session.client.client_id),
username: Some(session.browser_session.user.username), username: Some(session.browser_session.user.username),
token_type: Some(TokenTypeHint::RefreshToken), token_type: Some(OAuthTokenTypeHint::RefreshToken),
exp: None, exp: None,
iat: Some(token.created_at), iat: Some(token.created_at),
nbf: Some(token.created_at), nbf: Some(token.created_at),

View File

@ -21,7 +21,7 @@ use headers::{CacheControl, Pragma};
use hyper::StatusCode; use hyper::StatusCode;
use mas_config::{OAuth2ClientConfig, OAuth2Config}; use mas_config::{OAuth2ClientConfig, OAuth2Config};
use mas_data_model::{AuthorizationGrantStage, TokenType}; use mas_data_model::{AuthorizationGrantStage, TokenType};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{ use mas_jose::{
claims::{AT_HASH, AUD, AUTH_TIME, C_HASH, EXP, IAT, ISS, NONCE, SUB}, claims::{AT_HASH, AUD, AUTH_TIME, C_HASH, EXP, IAT, ISS, NONCE, SUB},
DecodedJsonWebToken, SigningKeystore, StaticKeystore, DecodedJsonWebToken, SigningKeystore, StaticKeystore,
@ -42,8 +42,7 @@ use mas_warp_utils::{
use oauth2_types::{ use oauth2_types::{
errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient}, errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
requests::{ requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
ClientAuthenticationMethod, RefreshTokenGrant,
}, },
scope::OPENID, scope::OPENID,
}; };
@ -131,7 +130,7 @@ async fn recover(rejection: Rejection) -> Result<Box<dyn Reply>, Rejection> {
} }
async fn token( async fn token(
_auth: ClientAuthenticationMethod, _auth: OAuthClientAuthenticationMethod,
client: OAuth2ClientConfig, client: OAuth2ClientConfig,
req: AccessTokenRequest, req: AccessTokenRequest,
key_store: Arc<StaticKeystore>, key_store: Arc<StaticKeystore>,

View File

@ -182,10 +182,12 @@ async fn generate_oauth(client: &Arc<Client>, path: PathBuf) -> anyhow::Result<(
"https://www.iana.org/assignments/jose/jose.xhtml", "https://www.iana.org/assignments/jose/jose.xhtml",
client.clone(), client.clone(),
) )
.load::<TokenTypeHint>() .load::<AccessTokenType>()
.await? .await?
.load::<AuthorizationEndpointResponseType>() .load::<AuthorizationEndpointResponseType>()
.await? .await?
.load::<TokenTypeHint>()
.await?
.load::<TokenEndpointAuthenticationMethod>() .load::<TokenEndpointAuthenticationMethod>()
.await? .await?
.load::<PkceCodeChallengeMethod>() .load::<PkceCodeChallengeMethod>()

View File

@ -21,22 +21,25 @@ use crate::{
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct TokenTypeHint { pub struct AccessTokenType {
#[serde(rename = "Hint Value")] #[serde(rename = "Name")]
name: String, name: String,
#[serde(rename = "Additional Token Endpoint Response Parameters")]
additional_parameters: String,
#[serde(rename = "HTTP Authentication Scheme(s)")]
http_schemes: String,
#[serde(rename = "Change Controller")] #[serde(rename = "Change Controller")]
change_controller: String, change_controller: String,
#[serde(rename = "Reference")] #[serde(rename = "Reference")]
reference: String, reference: String,
} }
impl EnumEntry for TokenTypeHint { impl EnumEntry for AccessTokenType {
const URL: &'static str = const URL: &'static str = "https://www.iana.org/assignments/oauth-parameters/token-types.csv";
"https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv"; const SECTIONS: &'static [Section] = &[s("OAuthAccessTokenType", "OAuth Access Token Type")];
const SECTIONS: &'static [Section] = &[s("OAuthTokenTypeHint", "OAuth Token Type Hint")];
fn key(&self) -> Option<&'static str> { fn key(&self) -> Option<&'static str> {
Some("OAuthTokenTypeHint") Some("OAuthAccessTokenType")
} }
fn name(&self) -> &str { fn name(&self) -> &str {
@ -82,16 +85,41 @@ pub struct TokenEndpointAuthenticationMethod {
reference: String, reference: String,
} }
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct TokenTypeHint {
#[serde(rename = "Hint Value")]
name: String,
#[serde(rename = "Change Controller")]
change_controller: String,
#[serde(rename = "Reference")]
reference: String,
}
impl EnumEntry for TokenTypeHint {
const URL: &'static str =
"https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv";
const SECTIONS: &'static [Section] = &[s("OAuthTokenTypeHint", "OAuth Token Type Hint")];
fn key(&self) -> Option<&'static str> {
Some("OAuthTokenTypeHint")
}
fn name(&self) -> &str {
&self.name
}
}
impl EnumEntry for TokenEndpointAuthenticationMethod { impl EnumEntry for TokenEndpointAuthenticationMethod {
const URL: &'static str = const URL: &'static str =
"https://www.iana.org/assignments/oauth-parameters/token-endpoint-auth-method.csv"; "https://www.iana.org/assignments/oauth-parameters/token-endpoint-auth-method.csv";
const SECTIONS: &'static [Section] = &[s( const SECTIONS: &'static [Section] = &[s(
"OAuthTokenEndpointAuthenticationMethod", "OAuthClientAuthenticationMethod",
"OAuth Token Endpoint Authentication Method", "OAuth Token Endpoint Authentication Method",
)]; )];
fn key(&self) -> Option<&'static str> { fn key(&self) -> Option<&'static str> {
Some("OAuthTokenEndpointAuthenticationMethod") Some("OAuthClientAuthenticationMethod")
} }
fn name(&self) -> &str { fn name(&self) -> &str {

View File

@ -61,7 +61,11 @@ pub trait EnumEntry: DeserializeOwned + Send + Sync {
None None
} }
fn enum_name(&self) -> String { fn enum_name(&self) -> String {
self.name().replace('+', "_").to_case(Case::Pascal) // Do the case transformation twice to have "N_A" turned to "Na" instead of "NA"
self.name()
.replace('+', "_")
.to_case(Case::Pascal)
.to_case(Case::Pascal)
} }
async fn fetch(client: &Client) -> anyhow::Result<Vec<(&'static str, EnumMember)>> { async fn fetch(client: &Client) -> anyhow::Result<Vec<(&'static str, EnumMember)>> {

View File

@ -21,24 +21,24 @@ use parse_display::{Display, FromStr};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// OAuth Token Type Hint /// OAuth Access Token Type
/// ///
/// Source: <https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv> /// Source: <https://www.iana.org/assignments/oauth-parameters/token-types.csv>
#[derive( #[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema, Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema,
)] )]
pub enum OAuthTokenTypeHint { pub enum OAuthAccessTokenType {
#[serde(rename = "access_token")] #[serde(rename = "Bearer")]
#[display("access_token")] #[display("Bearer")]
AccessToken, Bearer,
#[serde(rename = "refresh_token")] #[serde(rename = "N_A")]
#[display("refresh_token")] #[display("N_A")]
RefreshToken, Na,
#[serde(rename = "pct")] #[serde(rename = "PoP")]
#[display("pct")] #[display("PoP")]
Pct, PoP,
} }
/// OAuth Authorization Endpoint Response Type /// OAuth Authorization Endpoint Response Type
@ -81,13 +81,33 @@ pub enum OAuthAuthorizationEndpointResponseType {
Token, Token,
} }
/// OAuth Token Type Hint
///
/// Source: <https://www.iana.org/assignments/oauth-parameters/token-type-hint.csv>
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema,
)]
pub enum OAuthTokenTypeHint {
#[serde(rename = "access_token")]
#[display("access_token")]
AccessToken,
#[serde(rename = "refresh_token")]
#[display("refresh_token")]
RefreshToken,
#[serde(rename = "pct")]
#[display("pct")]
Pct,
}
/// OAuth Token Endpoint Authentication Method /// OAuth Token Endpoint Authentication Method
/// ///
/// Source: <https://www.iana.org/assignments/oauth-parameters/token-endpoint-auth-method.csv> /// Source: <https://www.iana.org/assignments/oauth-parameters/token-endpoint-auth-method.csv>
#[derive( #[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema, Debug, Clone, Copy, PartialEq, Eq, Hash, Display, FromStr, Serialize, Deserialize, JsonSchema,
)] )]
pub enum OAuthTokenEndpointAuthenticationMethod { pub enum OAuthClientAuthenticationMethod {
#[serde(rename = "none")] #[serde(rename = "none")]
#[display("none")] #[display("none")]
None, None,

View File

@ -16,11 +16,46 @@
#![deny(clippy::all)] #![deny(clippy::all)]
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
pub trait ResponseTypeExt {
fn has_code(&self) -> bool;
fn has_token(&self) -> bool;
fn has_id_token(&self) -> bool;
}
impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType {
fn has_code(&self) -> bool {
matches!(
self,
Self::Code | Self::CodeToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
fn has_token(&self) -> bool {
matches!(
self,
Self::Token | Self::CodeToken | Self::IdTokenToken | Self::CodeIdTokenToken
)
}
fn has_id_token(&self) -> bool {
matches!(
self,
Self::IdToken | Self::IdTokenToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
}
pub mod errors; pub mod errors;
pub mod oidc; pub mod oidc;
pub mod pkce; pub mod pkce;
pub mod requests; pub mod requests;
pub mod scope; pub mod scope;
pub mod prelude {
pub use crate::{pkce::CodeChallengeMethodExt, ResponseTypeExt};
}
#[cfg(test)] #[cfg(test)]
mod test_utils; mod test_utils;

View File

@ -14,15 +14,18 @@
use std::collections::HashSet; use std::collections::HashSet;
use mas_iana::jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}; use mas_iana::{
jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg},
oauth::{
OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod,
PkceCodeChallengeMethod,
},
};
use serde::Serialize; use serde::Serialize;
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use url::Url; use url::Url;
use crate::{ use crate::requests::{Display, GrantType, ResponseMode};
pkce::CodeChallengeMethod,
requests::{ClientAuthenticationMethod, Display, GrantType, ResponseMode},
};
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
@ -66,7 +69,7 @@ pub struct Metadata {
/// JSON array containing a list of the OAuth 2.0 "response_type" values /// JSON array containing a list of the OAuth 2.0 "response_type" values
/// that this authorization server supports. /// that this authorization server supports.
pub response_types_supported: Option<HashSet<String>>, pub response_types_supported: Option<HashSet<OAuthAuthorizationEndpointResponseType>>,
/// JSON array containing a list of the OAuth 2.0 "response_mode" values /// JSON array containing a list of the OAuth 2.0 "response_mode" values
/// that this authorization server supports. /// that this authorization server supports.
@ -78,7 +81,7 @@ pub struct Metadata {
/// JSON array containing a list of client authentication methods supported /// JSON array containing a list of client authentication methods supported
/// by this token endpoint. /// by this token endpoint.
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>, pub token_endpoint_auth_methods_supported: Option<HashSet<OAuthClientAuthenticationMethod>>,
/// JSON array containing a list of the JWS signing algorithms supported by /// JSON array containing a list of the JWS signing algorithms supported by
/// the token endpoint for the signature on the JWT used to authenticate the /// the token endpoint for the signature on the JWT used to authenticate the
@ -109,7 +112,8 @@ pub struct Metadata {
/// JSON array containing a list of client authentication methods supported /// JSON array containing a list of client authentication methods supported
/// by this revocation endpoint. /// by this revocation endpoint.
pub revocation_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>, pub revocation_endpoint_auth_methods_supported:
Option<HashSet<OAuthClientAuthenticationMethod>>,
/// JSON array containing a list of the JWS signing algorithms supported by /// JSON array containing a list of the JWS signing algorithms supported by
/// the revocation endpoint for the signature on the JWT used to /// the revocation endpoint for the signature on the JWT used to
@ -121,7 +125,8 @@ pub struct Metadata {
/// JSON array containing a list of client authentication methods supported /// JSON array containing a list of client authentication methods supported
/// by this introspection endpoint. /// by this introspection endpoint.
pub introspection_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>, pub introspection_endpoint_auth_methods_supported:
Option<HashSet<OAuthClientAuthenticationMethod>>,
/// JSON array containing a list of the JWS signing algorithms supported by /// JSON array containing a list of the JWS signing algorithms supported by
/// the introspection endpoint for the signature on the JWT used to /// the introspection endpoint for the signature on the JWT used to
@ -130,7 +135,7 @@ pub struct Metadata {
Option<HashSet<JsonWebSignatureAlg>>, Option<HashSet<JsonWebSignatureAlg>>,
/// PKCE code challenge methods supported by this authorization server. /// PKCE code challenge methods supported by this authorization server.
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>, pub code_challenge_methods_supported: Option<HashSet<PkceCodeChallengeMethod>>,
/// URL of the OP's UserInfo Endpoint. /// URL of the OP's UserInfo Endpoint.
pub userinfo_endpoint: Option<Url>, pub userinfo_endpoint: Option<Url>,

View File

@ -15,40 +15,23 @@
use std::borrow::Cow; use std::borrow::Cow;
use data_encoding::BASE64URL_NOPAD; use data_encoding::BASE64URL_NOPAD;
use parse_display::{Display, FromStr}; use mas_iana::oauth::PkceCodeChallengeMethod;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
#[derive( pub trait CodeChallengeMethodExt {
Debug, #[must_use]
Hash, fn compute_challenge(self, verifier: &str) -> Cow<'_, str>;
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
pub enum CodeChallengeMethod {
#[serde(rename = "plain")]
#[display("plain")]
Plain,
#[serde(rename = "S256")] #[must_use]
#[display("S256")] fn verify(self, challenge: &str, verifier: &str) -> bool;
S256,
} }
impl CodeChallengeMethod { impl CodeChallengeMethodExt for PkceCodeChallengeMethod {
#[must_use] fn compute_challenge(self, verifier: &str) -> Cow<'_, str> {
pub fn compute_challenge(self, verifier: &str) -> Cow<'_, str> {
match self { match self {
CodeChallengeMethod::Plain => verifier.into(), Self::Plain => verifier.into(),
CodeChallengeMethod::S256 => { Self::S256 => {
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes()); hasher.update(verifier.as_bytes());
let hash = hasher.finalize(); let hash = hasher.finalize();
@ -58,15 +41,14 @@ impl CodeChallengeMethod {
} }
} }
#[must_use] fn verify(self, challenge: &str, verifier: &str) -> bool {
pub fn verify(self, challenge: &str, verifier: &str) -> bool {
self.compute_challenge(verifier) == challenge self.compute_challenge(verifier) == challenge
} }
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest { pub struct AuthorizationRequest {
pub code_challenge_method: CodeChallengeMethod, pub code_challenge_method: PkceCodeChallengeMethod,
pub code_challenge: String, pub code_challenge: String,
} }

View File

@ -16,6 +16,9 @@ use std::{collections::HashSet, hash::Hash, num::NonZeroU32};
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use language_tags::LanguageTag; use language_tags::LanguageTag;
use mas_iana::oauth::{
OAuthAccessTokenType, OAuthAuthorizationEndpointResponseType, OAuthTokenTypeHint,
};
use parse_display::{Display, FromStr}; use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{ use serde_with::{
@ -28,29 +31,6 @@ use crate::scope::Scope;
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml // ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[display(style = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum ResponseType {
Code,
IdToken,
Token,
None,
}
#[derive( #[derive(
Debug, Debug,
Hash, Hash,
@ -72,37 +52,6 @@ pub enum ResponseMode {
FormPost, FormPost,
} }
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthenticationMethod {
None,
ClientSecretPost,
ClientSecretBasic,
ClientSecretJwt,
PrivateKeyJwt,
}
impl ClientAuthenticationMethod {
#[must_use]
/// Check if the authentication method is for public client or not
pub fn public(&self) -> bool {
matches!(self, &Self::None)
}
}
#[derive( #[derive(
Debug, Debug,
Hash, Hash,
@ -151,8 +100,7 @@ pub enum Prompt {
#[serde_as] #[serde_as]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest { pub struct AuthorizationRequest {
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")] pub response_type: OAuthAuthorizationEndpointResponseType,
pub response_type: HashSet<ResponseType>,
pub client_id: String, pub client_id: String,
@ -200,25 +148,6 @@ pub struct AuthorizationResponse<R> {
pub response: R, pub response: R,
} }
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
Bearer,
}
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AuthorizationCodeGrant { pub struct AuthorizationCodeGrant {
@ -285,7 +214,7 @@ pub struct AccessTokenResponse {
// TODO: this should be somewhere else // TODO: this should be somewhere else
id_token: Option<String>, id_token: Option<String>,
token_type: TokenType, token_type: OAuthAccessTokenType,
#[serde_as(as = "Option<DurationSeconds<i64>>")] #[serde_as(as = "Option<DurationSeconds<i64>>")]
expires_in: Option<Duration>, expires_in: Option<Duration>,
@ -300,7 +229,7 @@ impl AccessTokenResponse {
access_token, access_token,
refresh_token: None, refresh_token: None,
id_token: None, id_token: None,
token_type: TokenType::Bearer, token_type: OAuthAccessTokenType::Bearer,
expires_in: None, expires_in: None,
scope: None, scope: None,
} }
@ -331,20 +260,13 @@ impl AccessTokenResponse {
} }
} }
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TokenTypeHint {
AccessToken,
RefreshToken,
}
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct IntrospectionRequest { pub struct IntrospectionRequest {
pub token: String, pub token: String,
#[serde(default)] #[serde(default)]
pub token_type_hint: Option<TokenTypeHint>, pub token_type_hint: Option<OAuthTokenTypeHint>,
} }
#[serde_as] #[serde_as]
@ -359,7 +281,7 @@ pub struct IntrospectionResponse {
pub username: Option<String>, pub username: Option<String>,
pub token_type: Option<TokenTypeHint>, pub token_type: Option<OAuthTokenTypeHint>,
#[serde_as(as = "Option<TimestampSeconds>")] #[serde_as(as = "Option<TimestampSeconds>")]
pub exp: Option<DateTime<Utc>>, pub exp: Option<DateTime<Utc>>,

View File

@ -23,3 +23,4 @@ url = { version = "2.2.2", features = ["serde"] }
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-iana = { path = "../iana" }

View File

@ -22,7 +22,8 @@ use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Client, Pkce, Session, User, Client, Pkce, Session, User,
}; };
use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode, scope::Scope}; use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope};
use sqlx::PgExecutor; use sqlx::PgExecutor;
use url::Url; use url::Url;
@ -237,12 +238,12 @@ impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
let pkce = match (self.grant_code_challenge, self.grant_code_challenge_method) { let pkce = match (self.grant_code_challenge, self.grant_code_challenge_method) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce { Some(Pkce {
challenge_method: CodeChallengeMethod::Plain, challenge_method: PkceCodeChallengeMethod::Plain,
challenge, challenge,
}) })
} }
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
challenge_method: CodeChallengeMethod::S256, challenge_method: PkceCodeChallengeMethod::S256,
challenge, challenge,
}), }),
(None, None) => None, (None, None) => None,

View File

@ -35,3 +35,4 @@ mas-templates = { path = "../templates" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" }

View File

@ -18,11 +18,11 @@ use std::collections::HashMap;
use headers::{authorization::Basic, Authorization}; use headers::{authorization::Basic, Authorization};
use mas_config::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config}; use mas_config::{OAuth2ClientAuthMethodConfig, OAuth2ClientConfig, OAuth2Config};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{ use mas_jose::{
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB}, claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
DecodedJsonWebToken, JsonWebTokenParts, SharedSecret, DecodedJsonWebToken, JsonWebTokenParts, SharedSecret,
}; };
use oauth2_types::requests::ClientAuthenticationMethod;
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use thiserror::Error; use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection}; use warp::{reject::Reject, Filter, Rejection};
@ -35,7 +35,7 @@ use crate::errors::WrapError;
pub fn client_authentication<T: DeserializeOwned + Send + 'static>( pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
oauth2_config: &OAuth2Config, oauth2_config: &OAuth2Config,
audience: String, audience: String,
) -> impl Filter<Extract = (ClientAuthenticationMethod, OAuth2ClientConfig, T), Error = Rejection> ) -> impl Filter<Extract = (OAuthClientAuthenticationMethod, OAuth2ClientConfig, T), Error = Rejection>
+ Clone + Clone
+ Send + Send
+ Sync + Sync
@ -99,7 +99,7 @@ async fn authenticate_client<T>(
audience: String, audience: String,
credentials: ClientCredentials, credentials: ClientCredentials,
body: T, body: T,
) -> Result<(ClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> { ) -> Result<(OAuthClientAuthenticationMethod, OAuth2ClientConfig, T), Rejection> {
let (auth_method, client) = match credentials { let (auth_method, client) = match credentials {
ClientCredentials::Pair { ClientCredentials::Pair {
client_id, client_id,
@ -114,7 +114,9 @@ async fn authenticate_client<T>(
})?; })?;
let auth_method = match (&client.client_auth_method, client_secret, via) { let auth_method = match (&client.client_auth_method, client_secret, via) {
(OAuth2ClientAuthMethodConfig::None, None, _) => ClientAuthenticationMethod::None, (OAuth2ClientAuthMethodConfig::None, None, _) => {
OAuthClientAuthenticationMethod::None
}
( (
OAuth2ClientAuthMethodConfig::ClientSecretBasic { OAuth2ClientAuthMethodConfig::ClientSecretBasic {
@ -129,7 +131,7 @@ async fn authenticate_client<T>(
); );
} }
ClientAuthenticationMethod::ClientSecretBasic OAuthClientAuthenticationMethod::ClientSecretBasic
} }
( (
@ -145,7 +147,7 @@ async fn authenticate_client<T>(
); );
} }
ClientAuthenticationMethod::ClientSecretPost OAuthClientAuthenticationMethod::ClientSecretPost
} }
_ => { _ => {
@ -204,13 +206,13 @@ async fn authenticate_client<T>(
OAuth2ClientAuthMethodConfig::PrivateKeyJwt(jwks) => { OAuth2ClientAuthMethodConfig::PrivateKeyJwt(jwks) => {
let store = jwks.key_store(); let store = jwks.key_store();
token.verify(&decoded, &store).await.wrap_error()?; token.verify(&decoded, &store).await.wrap_error()?;
ClientAuthenticationMethod::PrivateKeyJwt OAuthClientAuthenticationMethod::PrivateKeyJwt
} }
OAuth2ClientAuthMethodConfig::ClientSecretJwt { client_secret } => { OAuth2ClientAuthMethodConfig::ClientSecretJwt { client_secret } => {
let store = SharedSecret::new(client_secret); let store = SharedSecret::new(client_secret);
token.verify(&decoded, &store).await.wrap_error()?; token.verify(&decoded, &store).await.wrap_error()?;
ClientAuthenticationMethod::ClientSecretJwt OAuthClientAuthenticationMethod::ClientSecretJwt
} }
_ => { _ => {
@ -428,7 +430,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretJwt); assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretJwt);
assert_eq!(client.client_id, "secret-jwt"); assert_eq!(client.client_id, "secret-jwt");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");
@ -515,7 +517,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthenticationMethod::PrivateKeyJwt); assert_eq!(auth, OAuthClientAuthenticationMethod::PrivateKeyJwt);
assert_eq!(client.client_id, "private-key-jwt"); assert_eq!(client.client_id, "private-key-jwt");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");
@ -575,7 +577,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretPost); assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretPost);
assert_eq!(client.client_id, "secret-post"); assert_eq!(client.client_id, "secret-post");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");
@ -607,7 +609,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthenticationMethod::ClientSecretBasic); assert_eq!(auth, OAuthClientAuthenticationMethod::ClientSecretBasic);
assert_eq!(client.client_id, "secret-basic"); assert_eq!(client.client_id, "secret-basic");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");
@ -638,7 +640,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(auth, ClientAuthenticationMethod::None); assert_eq!(auth, OAuthClientAuthenticationMethod::None);
assert_eq!(client.client_id, "public"); assert_eq!(client.client_id, "public");
assert_eq!(body.foo, "baz"); assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");