diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 479ea095..392b47a0 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -86,17 +86,20 @@ pub(crate) async fn get( let redirect_uri = url_builder.upstream_oauth_callback(provider.id); - let data = AuthorizationRequestData::new( + let mut data = AuthorizationRequestData::new( provider.client_id.clone(), provider.scope.clone(), redirect_uri, ); + if let Some(methods) = metadata.code_challenge_methods_supported.clone() { + data = data.with_code_challenge_methods_supported(methods); + } + // Build an authorization request for it let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url( metadata.authorization_endpoint().clone(), data, - metadata.code_challenge_methods_supported.as_deref(), &mut rng, )?; diff --git a/crates/oidc-client/src/requests/authorization_code.rs b/crates/oidc-client/src/requests/authorization_code.rs index f7fc79d0..3567f4c0 100644 --- a/crates/oidc-client/src/requests/authorization_code.rs +++ b/crates/oidc-client/src/requests/authorization_code.rs @@ -75,6 +75,12 @@ pub struct AuthorizationRequestData { /// It must be one of the redirect URIs provided during registration. pub redirect_uri: Url, + /// The PKCE methods supported by the issuer. + /// + /// This field should be cloned from the provider metadata. If it is not + /// set, this security measure will not be used. + pub code_challenge_methods_supported: Option>, + /// How the Authorization Server should display the authentication and /// consent user interface pages to the End-User. pub display: Option, @@ -114,6 +120,7 @@ impl AuthorizationRequestData { client_id, scope, redirect_uri, + code_challenge_methods_supported: None, display: None, prompt: None, max_age: None, @@ -124,6 +131,17 @@ impl AuthorizationRequestData { } } + /// Set the `code_challenge_methods_supported` field of this + /// `AuthorizationRequestData`. + #[must_use] + pub fn with_code_challenge_methods_supported( + mut self, + code_challenge_methods_supported: Vec, + ) -> Self { + self.code_challenge_methods_supported = Some(code_challenge_methods_supported); + self + } + /// Set the `display` field of this `AuthorizationRequestData`. #[must_use] pub fn with_display(mut self, display: Display) -> Self { @@ -203,13 +221,13 @@ struct FullAuthorizationRequest { /// Build the authorization request. fn build_authorization_request( authorization_data: AuthorizationRequestData, - code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>, rng: &mut impl Rng, ) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> { let AuthorizationRequestData { client_id, mut scope, redirect_uri, + code_challenge_methods_supported, display, prompt, max_age, @@ -289,9 +307,6 @@ fn build_authorization_request( /// * `authorization_data` - The data necessary to build the authorization /// request. /// -/// * `code_challenge_methods_supported` - The PKCE methods supported by the -/// issuer, from its metadata. -/// /// * `rng` - A random number generator. /// /// # Returns @@ -317,7 +332,6 @@ fn build_authorization_request( pub fn build_authorization_url( authorization_endpoint: Url, authorization_data: AuthorizationRequestData, - code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>, rng: &mut impl Rng, ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { tracing::debug!( @@ -326,7 +340,7 @@ pub fn build_authorization_url( ); let (authorization_request, validation_data) = - build_authorization_request(authorization_data, code_challenge_methods_supported, rng)?; + build_authorization_request(authorization_data, rng)?; let authorization_query = serde_urlencoded::to_string(authorization_request)?; @@ -365,9 +379,6 @@ pub fn build_authorization_url( /// * `authorization_data` - The data necessary to build the authorization /// request. /// -/// * `code_challenge_methods_supported` - The PKCE methods supported by the -/// issuer, from its metadata. -/// /// * `now` - The current time. /// /// * `rng` - A random number generator. @@ -392,7 +403,6 @@ pub fn build_authorization_url( /// /// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/ /// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode -#[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all, fields(par_endpoint))] pub async fn build_par_authorization_url( http_service: &HttpService, @@ -400,7 +410,6 @@ pub async fn build_par_authorization_url( par_endpoint: &Url, authorization_endpoint: Url, authorization_data: AuthorizationRequestData, - code_challenge_methods_supported: Option<&[PkceCodeChallengeMethod]>, now: DateTime, rng: &mut impl Rng, ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { @@ -412,7 +421,7 @@ pub async fn build_par_authorization_url( let client_id = client_credentials.client_id().to_owned(); let (authorization_request, validation_data) = - build_authorization_request(authorization_data, code_challenge_methods_supported, rng)?; + build_authorization_request(authorization_data, rng)?; let par_request = http::Request::post(par_endpoint.as_str()) .header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref()) diff --git a/crates/oidc-client/tests/it/requests/authorization_code.rs b/crates/oidc-client/tests/it/requests/authorization_code.rs index e8a2d4a7..edd8d214 100644 --- a/crates/oidc-client/tests/it/requests/authorization_code.rs +++ b/crates/oidc-client/tests/it/requests/authorization_code.rs @@ -64,8 +64,8 @@ fn pass_authorization_url() { CLIENT_ID.to_owned(), [ScopeToken::Openid].into_iter().collect(), redirect_uri, - ), - Some(&[PkceCodeChallengeMethod::S256]), + ) + .with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]), &mut rng, ) .unwrap(); @@ -117,8 +117,7 @@ fn pass_full_authorization_url() { .with_acr_values(["custom".to_owned()].into()); let (url, validation_data) = - build_authorization_url(authorization_endpoint, authorization_data, None, &mut rng) - .unwrap(); + build_authorization_url(authorization_endpoint, authorization_data, &mut rng).unwrap(); assert_eq!(validation_data.state, "OrJ8xbWovSpJUTKz"); assert_eq!(validation_data.code_challenge_verifier, None); @@ -190,8 +189,8 @@ async fn pass_pushed_authorization_request() { CLIENT_ID.to_owned(), [ScopeToken::Openid].into_iter().collect(), redirect_uri, - ), - Some(&[PkceCodeChallengeMethod::S256]), + ) + .with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]), now(), &mut rng, ) @@ -241,8 +240,8 @@ async fn fail_pushed_authorization_request_404() { CLIENT_ID.to_owned(), [ScopeToken::Openid].into_iter().collect(), redirect_uri, - ), - Some(&[PkceCodeChallengeMethod::S256]), + ) + .with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]), now(), &mut rng, )