1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Include "state" in authorization request errors

This commit is contained in:
Quentin Gliech
2021-09-17 18:13:30 +02:00
parent dc0d54aaf5
commit 1813984a1c
2 changed files with 24 additions and 12 deletions

View File

@ -69,6 +69,7 @@ use crate::{
struct PartialParams { struct PartialParams {
client_id: Option<String>, client_id: Option<String>,
redirect_uri: Option<String>, redirect_uri: Option<String>,
state: Option<String>,
/* /*
response_type: Option<String>, response_type: Option<String>,
response_mode: Option<String>, response_mode: Option<String>,
@ -81,6 +82,7 @@ enum ReplyOrBackToClient {
params: Value, params: Value,
redirect_uri: Url, redirect_uri: Url,
response_mode: ResponseMode, response_mode: ResponseMode,
state: Option<String>,
}, },
Error(Box<dyn OAuth2Error>), Error(Box<dyn OAuth2Error>),
} }
@ -88,6 +90,7 @@ enum ReplyOrBackToClient {
fn back_to_client<T>( fn back_to_client<T>(
mut redirect_uri: Url, mut redirect_uri: Url,
response_mode: ResponseMode, response_mode: ResponseMode,
state: Option<String>,
params: T, params: T,
templates: &Templates, templates: &Templates,
) -> anyhow::Result<Box<dyn Reply>> ) -> anyhow::Result<Box<dyn Reply>>
@ -99,6 +102,9 @@ where
#[serde(flatten, skip_serializing_if = "Option::is_none")] #[serde(flatten, skip_serializing_if = "Option::is_none")]
existing: Option<HashMap<&'s str, &'s str>>, existing: Option<HashMap<&'s str, &'s str>>,
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)] #[serde(flatten)]
params: T, params: T,
} }
@ -110,7 +116,11 @@ where
.map(|qs| serde_urlencoded::from_str(qs)) .map(|qs| serde_urlencoded::from_str(qs))
.transpose()?; .transpose()?;
let merged = AllParams { existing, params }; let merged = AllParams {
existing,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?; let new_qs = serde_urlencoded::to_string(merged)?;
@ -128,7 +138,11 @@ where
.map(|qs| serde_urlencoded::from_str(qs)) .map(|qs| serde_urlencoded::from_str(qs))
.transpose()?; .transpose()?;
let merged = AllParams { existing, params }; let merged = AllParams {
existing,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?; let new_qs = serde_urlencoded::to_string(merged)?;
@ -230,17 +244,19 @@ async fn actually_reply(
clients: Vec<OAuth2ClientConfig>, clients: Vec<OAuth2ClientConfig>,
templates: Templates, templates: Templates,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
let (redirect_uri, response_mode, params) = match rep { let (redirect_uri, response_mode, state, params) = match rep {
ReplyOrBackToClient::Reply(r) => return Ok(r), ReplyOrBackToClient::Reply(r) => return Ok(r),
ReplyOrBackToClient::BackToClient { ReplyOrBackToClient::BackToClient {
redirect_uri, redirect_uri,
response_mode, response_mode,
params, params,
} => (redirect_uri, response_mode, params), state,
} => (redirect_uri, response_mode, state, params),
ReplyOrBackToClient::Error(error) => { ReplyOrBackToClient::Error(error) => {
let PartialParams { let PartialParams {
client_id, client_id,
redirect_uri, redirect_uri,
state,
.. ..
} = q; } = q;
@ -271,12 +287,11 @@ async fn actually_reply(
let reply: ErrorResponse = error.into(); let reply: ErrorResponse = error.into();
let reply = serde_json::to_value(&reply).wrap_error()?; let reply = serde_json::to_value(&reply).wrap_error()?;
// TODO: resolve response mode // TODO: resolve response mode
(redirect_uri.clone(), ResponseMode::Query, reply) (redirect_uri.clone(), ResponseMode::Query, state, reply)
} }
}; };
// TODO: we should include the state param in errors back_to_client(redirect_uri, response_mode, state, params, &templates).wrap_error()
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()
} }
async fn get( async fn get(
@ -400,10 +415,7 @@ async fn step(
&& user_session.last_authd_at >= oauth2_session.max_auth_time() && user_session.last_authd_at >= oauth2_session.max_auth_time()
{ {
// Yep! Let's complete the auth now // Yep! Let's complete the auth now
let mut params = AuthorizationResponse { let mut params = AuthorizationResponse::default();
state: oauth2_session.state.clone(),
..AuthorizationResponse::default()
};
// Did they request an auth code? // Did they request an auth code?
if response_type.contains(&ResponseType::Code) { if response_type.contains(&ResponseType::Code) {
@ -446,6 +458,7 @@ async fn step(
ReplyOrBackToClient::BackToClient { ReplyOrBackToClient::BackToClient {
redirect_uri, redirect_uri,
response_mode, response_mode,
state: oauth2_session.state.clone(),
params, params,
} }
} else { } else {

View File

@ -177,7 +177,6 @@ pub struct AuthorizationRequest {
#[derive(Serialize, Deserialize, Default)] #[derive(Serialize, Deserialize, Default)]
pub struct AuthorizationResponse<R> { pub struct AuthorizationResponse<R> {
pub code: Option<String>, pub code: Option<String>,
pub state: Option<String>,
#[serde(flatten)] #[serde(flatten)]
pub response: R, pub response: R,
} }