1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Merge data structs and use builder pattern

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2023-08-08 12:25:35 +02:00
committed by Quentin Gliech
parent c67a00ddd6
commit ba4ba75f73
3 changed files with 116 additions and 57 deletions

View File

@ -86,18 +86,17 @@ pub(crate) async fn get(
let redirect_uri = url_builder.upstream_oauth_callback(provider.id); let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
let data = AuthorizationRequestData { let data = AuthorizationRequestData::new(
client_id: &provider.client_id, provider.client_id.clone(),
scope: &provider.scope, provider.scope.clone(),
redirect_uri: &redirect_uri, redirect_uri,
code_challenge_methods_supported: metadata.code_challenge_methods_supported.as_deref(), );
};
// Build an authorization request for it // Build an authorization request for it
let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url( let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
metadata.authorization_endpoint().clone(), metadata.authorization_endpoint().clone(),
data, data,
None, metadata.code_challenge_methods_supported.as_deref(),
&mut rng, &mut rng,
)?; )?;

View File

@ -59,29 +59,22 @@ use crate::{
}; };
/// The data necessary to build an authorization request. /// The data necessary to build an authorization request.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone)]
pub struct AuthorizationRequestData<'a> { pub struct AuthorizationRequestData {
/// The ID obtained when registering the client. /// The ID obtained when registering the client.
pub client_id: &'a str, pub client_id: String,
/// The PKCE methods supported by the issuer, from its metadata.
pub code_challenge_methods_supported: Option<&'a [PkceCodeChallengeMethod]>,
/// The scope to authorize. /// The scope to authorize.
/// ///
/// If the OpenID Connect scope token (`openid`) is not included, it will be /// If the OpenID Connect scope token (`openid`) is not included, it will be
/// added. /// added.
pub scope: &'a Scope, pub scope: Scope,
/// The URI to redirect the end-user to after the authorization. /// The URI to redirect the end-user to after the authorization.
/// ///
/// It must be one of the redirect URIs provided during registration. /// It must be one of the redirect URIs provided during registration.
pub redirect_uri: &'a Url, pub redirect_uri: Url,
}
/// Extra parameters to influence the authorization flow.
#[derive(Debug, Default, Clone)]
pub struct AuthorizationRequestExtraParameters {
/// How the Authorization Server should display the authentication and /// How the Authorization Server should display the authentication and
/// consent user interface pages to the End-User. /// consent user interface pages to the End-User.
pub display: Option<Display>, pub display: Option<Display>,
@ -112,6 +105,74 @@ pub struct AuthorizationRequestExtraParameters {
pub acr_values: Option<HashSet<String>>, pub acr_values: Option<HashSet<String>>,
} }
impl AuthorizationRequestData {
/// Constructs a new `AuthorizationRequestData` with all the required fields.
#[must_use]
pub fn new(client_id: String, scope: Scope, redirect_uri: Url) -> Self {
Self {
client_id,
scope,
redirect_uri,
display: None,
prompt: None,
max_age: None,
ui_locales: None,
id_token_hint: None,
login_hint: None,
acr_values: None,
}
}
/// Set the `display` field of this `AuthorizationRequestData`.
#[must_use]
pub fn with_display(mut self, display: Display) -> Self {
self.display = Some(display);
self
}
/// Set the `prompt` field of this `AuthorizationRequestData`.
#[must_use]
pub fn with_prompt(mut self, prompt: Vec<Prompt>) -> Self {
self.prompt = Some(prompt);
self
}
/// Set the `max_age` field of this `AuthorizationRequestData`.
#[must_use]
pub fn with_max_age(mut self, max_age: NonZeroU32) -> Self {
self.max_age = Some(max_age);
self
}
/// Set the `ui_locales` field of this `AuthorizationRequestData`.
#[must_use]
pub fn with_ui_locales(mut self, ui_locales: Vec<LanguageTag>) -> Self {
self.ui_locales = Some(ui_locales);
self
}
/// Set the `id_token_hint` field of this `AuthorizationRequestData`.
#[must_use]
pub fn with_id_token_hint(mut self, id_token_hint: String) -> Self {
self.id_token_hint = Some(id_token_hint);
self
}
/// Set the `login_hint` field of this `AuthorizationRequestData`.
#[must_use]
pub fn with_login_hint(mut self, login_hint: String) -> Self {
self.login_hint = Some(login_hint);
self
}
/// Set the `acr_values` field of this `AuthorizationRequestData`.
#[must_use]
pub fn with_acr_values(mut self, acr_values: HashSet<String>) -> Self {
self.acr_values = Some(acr_values);
self
}
}
/// The data necessary to validate a response from the Token endpoint in the /// The data necessary to validate a response from the Token endpoint in the
/// Authorization Code flow. /// Authorization Code flow.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -140,17 +201,14 @@ struct FullAuthorizationRequest {
/// Build the authorization request. /// Build the authorization request.
fn build_authorization_request( fn build_authorization_request(
authorization_data: AuthorizationRequestData<'_>, authorization_data: AuthorizationRequestData,
extra_params: Option<AuthorizationRequestExtraParameters>, code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>,
rng: &mut impl Rng, rng: &mut impl Rng,
) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> { ) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> {
let AuthorizationRequestData { let AuthorizationRequestData {
client_id, client_id,
code_challenge_methods_supported, mut scope,
scope,
redirect_uri, redirect_uri,
} = authorization_data;
let AuthorizationRequestExtraParameters {
display, display,
prompt, prompt,
max_age, max_age,
@ -158,8 +216,7 @@ fn build_authorization_request(
id_token_hint, id_token_hint,
login_hint, login_hint,
acr_values, acr_values,
} = extra_params.unwrap_or_default(); } = authorization_data;
let mut scope = scope.clone();
// Generate a random CSRF "state" token and a nonce. // Generate a random CSRF "state" token and a nonce.
let state = Alphanumeric.sample_string(rng, 16); let state = Alphanumeric.sample_string(rng, 16);
@ -192,7 +249,7 @@ fn build_authorization_request(
let auth_request = FullAuthorizationRequest { let auth_request = FullAuthorizationRequest {
inner: AuthorizationRequest { inner: AuthorizationRequest {
response_type: OAuthAuthorizationEndpointResponseType::Code.into(), response_type: OAuthAuthorizationEndpointResponseType::Code.into(),
client_id: client_id.to_owned(), client_id,
redirect_uri: Some(redirect_uri.clone()), redirect_uri: Some(redirect_uri.clone()),
scope, scope,
state: Some(state.clone()), state: Some(state.clone()),
@ -215,7 +272,7 @@ fn build_authorization_request(
let auth_data = AuthorizationValidationData { let auth_data = AuthorizationValidationData {
state, state,
nonce, nonce,
redirect_uri: redirect_uri.clone(), redirect_uri,
code_challenge_verifier, code_challenge_verifier,
}; };
@ -231,6 +288,9 @@ fn build_authorization_request(
/// * `authorization_data` - The data necessary to build the authorization /// * `authorization_data` - The data necessary to build the authorization
/// request. /// request.
/// ///
/// * `code_challenge_methods_supported` - The PKCE methods supported by the
/// issuer, from its metadata.
///
/// * `rng` - A random number generator. /// * `rng` - A random number generator.
/// ///
/// # Returns /// # Returns
@ -255,8 +315,8 @@ fn build_authorization_request(
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub fn build_authorization_url( pub fn build_authorization_url(
authorization_endpoint: Url, authorization_endpoint: Url,
authorization_data: AuthorizationRequestData<'_>, authorization_data: AuthorizationRequestData,
extra_params: Option<AuthorizationRequestExtraParameters>, code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>,
rng: &mut impl Rng, rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!( tracing::debug!(
@ -265,7 +325,7 @@ pub fn build_authorization_url(
); );
let (authorization_request, validation_data) = let (authorization_request, validation_data) =
build_authorization_request(authorization_data, extra_params, rng)?; build_authorization_request(authorization_data, code_challenge_methods_supported, rng)?;
let authorization_query = serde_urlencoded::to_string(authorization_request)?; let authorization_query = serde_urlencoded::to_string(authorization_request)?;
@ -304,6 +364,9 @@ pub fn build_authorization_url(
/// * `authorization_data` - The data necessary to build the authorization /// * `authorization_data` - The data necessary to build the authorization
/// request. /// request.
/// ///
/// * `code_challenge_methods_supported` - The PKCE methods supported by the
/// issuer, from its metadata.
///
/// * `now` - The current time. /// * `now` - The current time.
/// ///
/// * `rng` - A random number generator. /// * `rng` - A random number generator.
@ -328,15 +391,15 @@ pub fn build_authorization_url(
/// ///
/// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/ /// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode /// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
#[allow(clippy::too_many_lines, clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(par_endpoint))] #[tracing::instrument(skip_all, fields(par_endpoint))]
pub async fn build_par_authorization_url( pub async fn build_par_authorization_url(
http_service: &HttpService, http_service: &HttpService,
client_credentials: ClientCredentials, client_credentials: ClientCredentials,
par_endpoint: &Url, par_endpoint: &Url,
authorization_endpoint: Url, authorization_endpoint: Url,
authorization_data: AuthorizationRequestData<'_>, authorization_data: AuthorizationRequestData,
extra_params: Option<AuthorizationRequestExtraParameters>, code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>,
now: DateTime<Utc>, now: DateTime<Utc>,
rng: &mut impl Rng, rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
@ -348,7 +411,7 @@ pub async fn build_par_authorization_url(
let client_id = client_credentials.client_id().to_owned(); let client_id = client_credentials.client_id().to_owned();
let (authorization_request, validation_data) = let (authorization_request, validation_data) =
build_authorization_request(authorization_data, extra_params, rng)?; build_authorization_request(authorization_data, code_challenge_methods_supported, rng)?;
let par_request = http::Request::post(par_endpoint.as_str()) let par_request = http::Request::post(par_endpoint.as_str())
.header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref()) .header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())

View File

@ -59,13 +59,12 @@ fn pass_authorization_url() {
let (url, validation_data) = build_authorization_url( let (url, validation_data) = build_authorization_url(
authorization_endpoint, authorization_endpoint,
AuthorizationRequestData { AuthorizationRequestData::new(
client_id: CLIENT_ID, CLIENT_ID.to_owned(),
code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), [ScopeToken::Openid].into_iter().collect(),
scope: &[ScopeToken::Openid].into_iter().collect(), redirect_uri,
redirect_uri: &redirect_uri, ),
}, Some(&[PkceCodeChallengeMethod::S256]),
None,
&mut rng, &mut rng,
) )
.unwrap(); .unwrap();
@ -130,13 +129,12 @@ async fn pass_pushed_authorization_request() {
client_credentials, client_credentials,
&par_endpoint, &par_endpoint,
authorization_endpoint, authorization_endpoint,
AuthorizationRequestData { AuthorizationRequestData::new(
client_id: CLIENT_ID, CLIENT_ID.to_owned(),
code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), [ScopeToken::Openid].into_iter().collect(),
scope: &[ScopeToken::Openid].into_iter().collect(), redirect_uri,
redirect_uri: &redirect_uri, ),
}, Some(&[PkceCodeChallengeMethod::S256]),
None,
now(), now(),
&mut rng, &mut rng,
) )
@ -182,13 +180,12 @@ async fn fail_pushed_authorization_request_404() {
client_credentials, client_credentials,
&par_endpoint, &par_endpoint,
authorization_endpoint, authorization_endpoint,
AuthorizationRequestData { AuthorizationRequestData::new(
client_id: CLIENT_ID, CLIENT_ID.to_owned(),
code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), [ScopeToken::Openid].into_iter().collect(),
scope: &[ScopeToken::Openid].into_iter().collect(), redirect_uri,
redirect_uri: &redirect_uri, ),
}, Some(&[PkceCodeChallengeMethod::S256]),
None,
now(), now(),
&mut rng, &mut rng,
) )