diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index fffb044c..96eee236 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -89,7 +89,6 @@ pub(crate) async fn get( let data = AuthorizationRequestData { client_id: &provider.client_id, scope: &provider.scope, - prompt: None, redirect_uri: &redirect_uri, code_challenge_methods_supported: metadata.code_challenge_methods_supported.as_deref(), }; @@ -98,6 +97,7 @@ pub(crate) async fn get( let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url( metadata.authorization_endpoint().clone(), data, + None, &mut rng, )?; diff --git a/crates/oidc-client/src/requests/authorization_code.rs b/crates/oidc-client/src/requests/authorization_code.rs index 066a61b9..6ed51a74 100644 --- a/crates/oidc-client/src/requests/authorization_code.rs +++ b/crates/oidc-client/src/requests/authorization_code.rs @@ -16,9 +16,12 @@ //! //! [Authorization Code flow]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth +use std::{collections::HashSet, num::NonZeroU32}; + use base64ct::{Base64UrlUnpadded, Encoding}; use chrono::{DateTime, Utc}; use http::header::CONTENT_TYPE; +use language_tags::LanguageTag; use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer}; use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod}; use mas_jose::claims::{self, TokenHash}; @@ -27,7 +30,7 @@ use oauth2_types::{ prelude::CodeChallengeMethodExt, requests::{ AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest, - Prompt, PushedAuthorizationResponse, + Display, Prompt, PushedAuthorizationResponse, }, scope::Scope, }; @@ -74,9 +77,39 @@ pub struct AuthorizationRequestData<'a> { /// /// It must be one of the redirect URIs provided during registration. pub redirect_uri: &'a Url, +} - /// Optional hints for the action to be performed. - pub prompt: Option<&'a [Prompt]>, +/// Extra parameters to influence the authorization flow. +#[derive(Debug, Default, Clone)] +pub struct AuthorizationRequestExtraParameters { + /// How the Authorization Server should display the authentication and + /// consent user interface pages to the End-User. + pub display: Option, + + /// Whether the Authorization Server should prompt the End-User for + /// reauthentication and consent. + /// + /// If [`Prompt::None`] is used, it must be the only value. + pub prompt: Option>, + + /// The allowable elapsed time in seconds since the last time the End-User + /// was actively authenticated by the OpenID Provider. + pub max_age: Option, + + /// End-User's preferred languages and scripts for the user interface. + pub ui_locales: Option>, + + /// ID Token previously issued by the Authorization Server being passed as a + /// hint about the End-User's current or past authenticated session with the + /// Client. + pub id_token_hint: Option, + + /// Hint to the Authorization Server about the login identifier the End-User + /// might use to log in. + pub login_hint: Option, + + /// Requested Authentication Context Class Reference values. + pub acr_values: Option>, } /// The data necessary to validate a response from the Token endpoint in the @@ -108,6 +141,7 @@ struct FullAuthorizationRequest { /// Build the authorization request. fn build_authorization_request( authorization_data: AuthorizationRequestData<'_>, + extra_params: Option, rng: &mut impl Rng, ) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> { let AuthorizationRequestData { @@ -115,8 +149,16 @@ fn build_authorization_request( code_challenge_methods_supported, scope, redirect_uri, - prompt, } = authorization_data; + let AuthorizationRequestExtraParameters { + display, + prompt, + max_age, + ui_locales, + id_token_hint, + login_hint, + acr_values, + } = extra_params.unwrap_or_default(); let mut scope = scope.clone(); // Generate a random CSRF "state" token and a nonce. @@ -156,13 +198,13 @@ fn build_authorization_request( state: Some(state.clone()), response_mode: None, nonce: Some(nonce.clone()), - display: None, - prompt: prompt.map(ToOwned::to_owned), - max_age: None, - ui_locales: None, - id_token_hint: None, - login_hint: None, - acr_values: None, + display, + prompt, + max_age, + ui_locales, + id_token_hint, + login_hint, + acr_values, request: None, request_uri: None, registration: None, @@ -214,6 +256,7 @@ fn build_authorization_request( pub fn build_authorization_url( authorization_endpoint: Url, authorization_data: AuthorizationRequestData<'_>, + extra_params: Option, rng: &mut impl Rng, ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { tracing::debug!( @@ -222,7 +265,7 @@ pub fn build_authorization_url( ); let (authorization_request, validation_data) = - build_authorization_request(authorization_data, rng)?; + build_authorization_request(authorization_data, extra_params, rng)?; let authorization_query = serde_urlencoded::to_string(authorization_request)?; @@ -285,7 +328,7 @@ pub fn build_authorization_url( /// /// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/ /// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode -#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_lines, clippy::too_many_arguments)] #[tracing::instrument(skip_all, fields(par_endpoint))] pub async fn build_par_authorization_url( http_service: &HttpService, @@ -293,6 +336,7 @@ pub async fn build_par_authorization_url( par_endpoint: &Url, authorization_endpoint: Url, authorization_data: AuthorizationRequestData<'_>, + extra_params: Option, now: DateTime, rng: &mut impl Rng, ) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { @@ -304,7 +348,7 @@ pub async fn build_par_authorization_url( let client_id = client_credentials.client_id().to_owned(); let (authorization_request, validation_data) = - build_authorization_request(authorization_data, rng)?; + build_authorization_request(authorization_data, extra_params, rng)?; let par_request = http::Request::post(par_endpoint.as_str()) .header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref()) diff --git a/crates/oidc-client/tests/it/requests/authorization_code.rs b/crates/oidc-client/tests/it/requests/authorization_code.rs index 52dea447..1db1916f 100644 --- a/crates/oidc-client/tests/it/requests/authorization_code.rs +++ b/crates/oidc-client/tests/it/requests/authorization_code.rs @@ -64,8 +64,8 @@ fn pass_authorization_url() { code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), scope: &[ScopeToken::Openid].into_iter().collect(), redirect_uri: &redirect_uri, - prompt: None, }, + None, &mut rng, ) .unwrap(); @@ -135,8 +135,8 @@ async fn pass_pushed_authorization_request() { code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), scope: &[ScopeToken::Openid].into_iter().collect(), redirect_uri: &redirect_uri, - prompt: None, }, + None, now(), &mut rng, ) @@ -187,8 +187,8 @@ async fn fail_pushed_authorization_request_404() { code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]), scope: &[ScopeToken::Openid].into_iter().collect(), redirect_uri: &redirect_uri, - prompt: None, }, + None, now(), &mut rng, )