From 1813984a1ca291fdfe94e7a2001adc4e03373f4a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 17 Sep 2021 18:13:30 +0200 Subject: [PATCH] Include "state" in authorization request errors --- .../core/src/handlers/oauth2/authorization.rs | 35 +++++++++++++------ crates/oauth2-types/src/requests.rs | 1 - 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/crates/core/src/handlers/oauth2/authorization.rs b/crates/core/src/handlers/oauth2/authorization.rs index c4c2c1c4..92c6100e 100644 --- a/crates/core/src/handlers/oauth2/authorization.rs +++ b/crates/core/src/handlers/oauth2/authorization.rs @@ -69,6 +69,7 @@ use crate::{ struct PartialParams { client_id: Option, redirect_uri: Option, + state: Option, /* response_type: Option, response_mode: Option, @@ -81,6 +82,7 @@ enum ReplyOrBackToClient { params: Value, redirect_uri: Url, response_mode: ResponseMode, + state: Option, }, Error(Box), } @@ -88,6 +90,7 @@ enum ReplyOrBackToClient { fn back_to_client( mut redirect_uri: Url, response_mode: ResponseMode, + state: Option, params: T, templates: &Templates, ) -> anyhow::Result> @@ -99,6 +102,9 @@ where #[serde(flatten, skip_serializing_if = "Option::is_none")] existing: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + state: Option, + #[serde(flatten)] params: T, } @@ -110,7 +116,11 @@ where .map(|qs| serde_urlencoded::from_str(qs)) .transpose()?; - let merged = AllParams { existing, params }; + let merged = AllParams { + existing, + state, + params, + }; let new_qs = serde_urlencoded::to_string(merged)?; @@ -128,7 +138,11 @@ where .map(|qs| serde_urlencoded::from_str(qs)) .transpose()?; - let merged = AllParams { existing, params }; + let merged = AllParams { + existing, + state, + params, + }; let new_qs = serde_urlencoded::to_string(merged)?; @@ -230,17 +244,19 @@ async fn actually_reply( clients: Vec, templates: Templates, ) -> Result { - let (redirect_uri, response_mode, params) = match rep { + let (redirect_uri, response_mode, state, params) = match rep { ReplyOrBackToClient::Reply(r) => return Ok(r), ReplyOrBackToClient::BackToClient { redirect_uri, response_mode, params, - } => (redirect_uri, response_mode, params), + state, + } => (redirect_uri, response_mode, state, params), ReplyOrBackToClient::Error(error) => { let PartialParams { client_id, redirect_uri, + state, .. } = q; @@ -271,12 +287,11 @@ async fn actually_reply( let reply: ErrorResponse = error.into(); let reply = serde_json::to_value(&reply).wrap_error()?; // TODO: resolve response mode - (redirect_uri.clone(), ResponseMode::Query, reply) + (redirect_uri.clone(), ResponseMode::Query, state, reply) } }; - // TODO: we should include the state param in errors - back_to_client(redirect_uri, response_mode, params, &templates).wrap_error() + back_to_client(redirect_uri, response_mode, state, params, &templates).wrap_error() } async fn get( @@ -400,10 +415,7 @@ async fn step( && user_session.last_authd_at >= oauth2_session.max_auth_time() { // Yep! Let's complete the auth now - let mut params = AuthorizationResponse { - state: oauth2_session.state.clone(), - ..AuthorizationResponse::default() - }; + let mut params = AuthorizationResponse::default(); // Did they request an auth code? if response_type.contains(&ResponseType::Code) { @@ -446,6 +458,7 @@ async fn step( ReplyOrBackToClient::BackToClient { redirect_uri, response_mode, + state: oauth2_session.state.clone(), params, } } else { diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index 7e318c6a..82c0ebac 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -177,7 +177,6 @@ pub struct AuthorizationRequest { #[derive(Serialize, Deserialize, Default)] pub struct AuthorizationResponse { pub code: Option, - pub state: Option, #[serde(flatten)] pub response: R, }