From ba4ba75f73a07d1fe9c199e258a1549f87005f56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Tue, 8 Aug 2023 12:25:35 +0200 Subject: [PATCH] Merge data structs and use builder pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Commaille --- .../handlers/src/upstream_oauth2/authorize.rs | 13 +- .../src/requests/authorization_code.rs | 121 +++++++++++++----- .../tests/it/requests/authorization_code.rs | 39 +++--- 3 files changed, 116 insertions(+), 57 deletions(-) diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 96eee236..479ea095 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -86,18 +86,17 @@ pub(crate) async fn get( let redirect_uri = url_builder.upstream_oauth_callback(provider.id); - let data = AuthorizationRequestData { - client_id: &provider.client_id, - scope: &provider.scope, - redirect_uri: &redirect_uri, - code_challenge_methods_supported: metadata.code_challenge_methods_supported.as_deref(), - }; + let data = AuthorizationRequestData::new( + provider.client_id.clone(), + provider.scope.clone(), + redirect_uri, + ); // Build an authorization request for it let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url( metadata.authorization_endpoint().clone(), data, - None, + metadata.code_challenge_methods_supported.as_deref(), &mut rng, )?; diff --git a/crates/oidc-client/src/requests/authorization_code.rs b/crates/oidc-client/src/requests/authorization_code.rs index 6ed51a74..136b4b6f 100644 --- a/crates/oidc-client/src/requests/authorization_code.rs +++ b/crates/oidc-client/src/requests/authorization_code.rs @@ -59,29 +59,22 @@ use crate::{ }; /// The data necessary to build an authorization request. -#[derive(Debug, Clone, Copy)] -pub struct AuthorizationRequestData<'a> { +#[derive(Debug, Clone)] +pub struct AuthorizationRequestData { /// The ID obtained when registering the client. - pub client_id: &'a str, - - /// The PKCE methods supported by the issuer, from its metadata. - pub code_challenge_methods_supported: Option<&'a [PkceCodeChallengeMethod]>, + pub client_id: String, /// The scope to authorize. /// /// If the OpenID Connect scope token (`openid`) is not included, it will be /// added. - pub scope: &'a Scope, + pub scope: Scope, /// The URI to redirect the end-user to after the authorization. /// /// It must be one of the redirect URIs provided during registration. - pub redirect_uri: &'a Url, -} + pub redirect_uri: Url, -/// Extra parameters to influence the authorization flow. -#[derive(Debug, Default, Clone)] -pub struct AuthorizationRequestExtraParameters { /// How the Authorization Server should display the authentication and /// consent user interface pages to the End-User. pub display: Option, @@ -112,6 +105,74 @@ pub struct AuthorizationRequestExtraParameters { pub acr_values: Option>, } +impl AuthorizationRequestData { + /// Constructs a new `AuthorizationRequestData` with all the required fields. + #[must_use] + pub fn new(client_id: String, scope: Scope, redirect_uri: Url) -> Self { + Self { + client_id, + scope, + redirect_uri, + display: None, + prompt: None, + max_age: None, + ui_locales: None, + id_token_hint: None, + login_hint: None, + acr_values: None, + } + } + + /// Set the `display` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_display(mut self, display: Display) -> Self { + self.display = Some(display); + self + } + + /// Set the `prompt` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_prompt(mut self, prompt: Vec) -> Self { + self.prompt = Some(prompt); + self + } + + /// Set the `max_age` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_max_age(mut self, max_age: NonZeroU32) -> Self { + self.max_age = Some(max_age); + self + } + + /// Set the `ui_locales` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_ui_locales(mut self, ui_locales: Vec) -> Self { + self.ui_locales = Some(ui_locales); + self + } + + /// Set the `id_token_hint` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_id_token_hint(mut self, id_token_hint: String) -> Self { + self.id_token_hint = Some(id_token_hint); + self + } + + /// Set the `login_hint` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_login_hint(mut self, login_hint: String) -> Self { + self.login_hint = Some(login_hint); + self + } + + /// Set the `acr_values` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_acr_values(mut self, acr_values: HashSet) -> Self { + self.acr_values = Some(acr_values); + self + } +} + /// The data necessary to validate a response from the Token endpoint in the /// Authorization Code flow. #[derive(Debug, Clone, PartialEq, Eq)] @@ -140,17 +201,14 @@ struct FullAuthorizationRequest { /// Build the authorization request. fn build_authorization_request( - authorization_data: AuthorizationRequestData<'_>, - extra_params: Option, + authorization_data: AuthorizationRequestData, + code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>, rng: &mut impl Rng, ) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> { let AuthorizationRequestData { client_id, - code_challenge_methods_supported, - scope, + mut scope, redirect_uri, - } = authorization_data; - let AuthorizationRequestExtraParameters { display, prompt, max_age, @@ -158,8 +216,7 @@ fn build_authorization_request( id_token_hint, login_hint, acr_values, - } = extra_params.unwrap_or_default(); - let mut scope = scope.clone(); + } = authorization_data; // Generate a random CSRF "state" token and a nonce. let state = Alphanumeric.sample_string(rng, 16); @@ -192,7 +249,7 @@ fn build_authorization_request( let auth_request = FullAuthorizationRequest { inner: AuthorizationRequest { response_type: OAuthAuthorizationEndpointResponseType::Code.into(), - client_id: client_id.to_owned(), + client_id, redirect_uri: Some(redirect_uri.clone()), scope, state: Some(state.clone()), @@ -215,7 +272,7 @@ fn build_authorization_request( let auth_data = AuthorizationValidationData { state, nonce, - redirect_uri: redirect_uri.clone(), + redirect_uri, code_challenge_verifier, }; @@ -231,6 +288,9 @@ fn build_authorization_request( /// * `authorization_data` - The data necessary to build the authorization /// request. /// +/// * `code_challenge_methods_supported` - The PKCE methods supported by the +/// issuer, from its metadata. +/// /// * `rng` - A random number generator. /// /// # Returns @@ -255,8 +315,8 @@ fn build_authorization_request( #[allow(clippy::too_many_lines)] pub fn build_authorization_url( authorization_endpoint: Url, - authorization_data: AuthorizationRequestData<'_>, - extra_params: Option, + authorization_data: AuthorizationRequestData, + code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>, rng: &mut impl Rng, ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { tracing::debug!( @@ -265,7 +325,7 @@ pub fn build_authorization_url( ); let (authorization_request, validation_data) = - build_authorization_request(authorization_data, extra_params, rng)?; + build_authorization_request(authorization_data, code_challenge_methods_supported, rng)?; let authorization_query = serde_urlencoded::to_string(authorization_request)?; @@ -304,6 +364,9 @@ pub fn build_authorization_url( /// * `authorization_data` - The data necessary to build the authorization /// request. /// +/// * `code_challenge_methods_supported` - The PKCE methods supported by the +/// issuer, from its metadata. +/// /// * `now` - The current time. /// /// * `rng` - A random number generator. @@ -328,15 +391,15 @@ pub fn build_authorization_url( /// /// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/ /// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode -#[allow(clippy::too_many_lines, clippy::too_many_arguments)] +#[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all, fields(par_endpoint))] pub async fn build_par_authorization_url( http_service: &HttpService, client_credentials: ClientCredentials, par_endpoint: &Url, authorization_endpoint: Url, - authorization_data: AuthorizationRequestData<'_>, - extra_params: Option, + authorization_data: AuthorizationRequestData, + code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>, now: DateTime, rng: &mut impl Rng, ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { @@ -348,7 +411,7 @@ pub async fn build_par_authorization_url( let client_id = client_credentials.client_id().to_owned(); let (authorization_request, validation_data) = - build_authorization_request(authorization_data, extra_params, rng)?; + build_authorization_request(authorization_data, code_challenge_methods_supported, rng)?; let par_request = http::Request::post(par_endpoint.as_str()) .header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref()) diff --git a/crates/oidc-client/tests/it/requests/authorization_code.rs b/crates/oidc-client/tests/it/requests/authorization_code.rs index 1db1916f..2e206d1c 100644 --- a/crates/oidc-client/tests/it/requests/authorization_code.rs +++ b/crates/oidc-client/tests/it/requests/authorization_code.rs @@ -59,13 +59,12 @@ fn pass_authorization_url() { let (url, validation_data) = build_authorization_url( authorization_endpoint, - AuthorizationRequestData { - client_id: CLIENT_ID, - code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), - scope: &[ScopeToken::Openid].into_iter().collect(), - redirect_uri: &redirect_uri, - }, - None, + AuthorizationRequestData::new( + CLIENT_ID.to_owned(), + [ScopeToken::Openid].into_iter().collect(), + redirect_uri, + ), + Some(&[PkceCodeChallengeMethod::S256]), &mut rng, ) .unwrap(); @@ -130,13 +129,12 @@ async fn pass_pushed_authorization_request() { client_credentials, &par_endpoint, authorization_endpoint, - AuthorizationRequestData { - client_id: CLIENT_ID, - code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), - scope: &[ScopeToken::Openid].into_iter().collect(), - redirect_uri: &redirect_uri, - }, - None, + AuthorizationRequestData::new( + CLIENT_ID.to_owned(), + [ScopeToken::Openid].into_iter().collect(), + redirect_uri, + ), + Some(&[PkceCodeChallengeMethod::S256]), now(), &mut rng, ) @@ -182,13 +180,12 @@ async fn fail_pushed_authorization_request_404() { client_credentials, &par_endpoint, authorization_endpoint, - AuthorizationRequestData { - client_id: CLIENT_ID, - code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), - scope: &[ScopeToken::Openid].into_iter().collect(), - redirect_uri: &redirect_uri, - }, - None, + AuthorizationRequestData::new( + CLIENT_ID.to_owned(), + [ScopeToken::Openid].into_iter().collect(), + redirect_uri, + ), + Some(&[PkceCodeChallengeMethod::S256]), now(), &mut rng, )