You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Rewrite authorization code grant callback logic
This commit is contained in:
@ -30,8 +30,8 @@ pub(crate) mod users;
|
|||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
oauth2::{
|
oauth2::{
|
||||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, JwksOrJwksUri,
|
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
|
||||||
Pkce, Session,
|
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session,
|
||||||
},
|
},
|
||||||
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
|
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
|
||||||
traits::{StorageBackend, StorageBackendMarker},
|
traits::{StorageBackend, StorageBackendMarker},
|
||||||
|
@ -18,6 +18,6 @@ pub(self) mod session;
|
|||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
|
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
|
||||||
client::{Client, JwksOrJwksUri},
|
client::{Client, InvalidRedirectUriError, JwksOrJwksUri},
|
||||||
session::Session,
|
session::Session,
|
||||||
};
|
};
|
||||||
|
@ -12,12 +12,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use anyhow::{anyhow, Context};
|
||||||
|
|
||||||
use anyhow::Context;
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form, Query},
|
extract::{Extension, Form, Query},
|
||||||
response::{Html, IntoResponse, Redirect, Response},
|
response::{IntoResponse, Redirect, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
@ -44,7 +42,7 @@ use mas_storage::{
|
|||||||
},
|
},
|
||||||
PostgresqlBackend,
|
PostgresqlBackend,
|
||||||
};
|
};
|
||||||
use mas_templates::{FormPostContext, Templates};
|
use mas_templates::Templates;
|
||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
errors::{
|
errors::{
|
||||||
INVALID_REQUEST, LOGIN_REQUIRED, REGISTRATION_NOT_SUPPORTED, REQUEST_NOT_SUPPORTED,
|
INVALID_REQUEST, LOGIN_REQUIRED, REGISTRATION_NOT_SUPPORTED, REQUEST_NOT_SUPPORTED,
|
||||||
@ -62,26 +60,55 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::{PgConnection, PgPool, Postgres, Transaction};
|
use sqlx::{PgConnection, PgPool, Postgres, Transaction};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use url::Url;
|
|
||||||
|
|
||||||
|
use self::callback::CallbackDestination;
|
||||||
use super::consent::ConsentRequest;
|
use super::consent::ConsentRequest;
|
||||||
use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest};
|
use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest};
|
||||||
|
|
||||||
|
mod callback;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum RouteError {
|
pub enum RouteError {
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Anyhow(anyhow::Error),
|
Anyhow(anyhow::Error),
|
||||||
|
|
||||||
#[error("could not find client")]
|
#[error("could not find client")]
|
||||||
ClientNotFound,
|
ClientNotFound,
|
||||||
|
|
||||||
#[error("invalid redirect uri")]
|
#[error("invalid redirect uri")]
|
||||||
InvalidRedirectUri,
|
InvalidRedirectUri(#[from] self::callback::InvalidRedirectUriError),
|
||||||
|
|
||||||
|
#[error("invalid redirect uri")]
|
||||||
|
UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for RouteError {
|
impl IntoResponse for RouteError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
// TODO: better error pages
|
||||||
|
match self {
|
||||||
|
RouteError::Internal(e) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||||
|
}
|
||||||
|
RouteError::Anyhow(e) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||||
|
}
|
||||||
|
RouteError::ClientNotFound => {
|
||||||
|
(StatusCode::BAD_REQUEST, "could not find client").into_response()
|
||||||
|
}
|
||||||
|
RouteError::InvalidRedirectUri(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Invalid redirect URI ({})", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
RouteError::UnknownRedirectUri(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Invalid redirect URI ({})", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,6 +118,12 @@ impl From<sqlx::Error> for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<self::callback::CallbackDestinationError> for RouteError {
|
||||||
|
fn from(e: self::callback::CallbackDestinationError) -> Self {
|
||||||
|
Self::Internal(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<ClientFetchError> for RouteError {
|
impl From<ClientFetchError> for RouteError {
|
||||||
fn from(e: ClientFetchError) -> Self {
|
fn from(e: ClientFetchError) -> Self {
|
||||||
if e.not_found() {
|
if e.not_found() {
|
||||||
@ -107,92 +140,6 @@ impl From<anyhow::Error> for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn back_to_client<T>(
|
|
||||||
redirect_uri: &Url,
|
|
||||||
response_mode: ResponseMode,
|
|
||||||
state: Option<String>,
|
|
||||||
params: T,
|
|
||||||
templates: &Templates,
|
|
||||||
) -> Result<Response, RouteError>
|
|
||||||
where
|
|
||||||
T: Serialize,
|
|
||||||
{
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct AllParams<'s, T> {
|
|
||||||
#[serde(flatten, skip_serializing_if = "Option::is_none")]
|
|
||||||
existing: Option<HashMap<&'s str, &'s str>>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
state: Option<String>,
|
|
||||||
|
|
||||||
#[serde(flatten)]
|
|
||||||
params: T,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ParamsWithState<T> {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
state: Option<String>,
|
|
||||||
|
|
||||||
#[serde(flatten)]
|
|
||||||
params: T,
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut redirect_uri = redirect_uri.clone();
|
|
||||||
|
|
||||||
match response_mode {
|
|
||||||
ResponseMode::Query => {
|
|
||||||
let existing: Option<HashMap<&str, &str>> = redirect_uri
|
|
||||||
.query()
|
|
||||||
.map(serde_urlencoded::from_str)
|
|
||||||
.transpose()
|
|
||||||
.map_err(|_e| RouteError::InvalidRedirectUri)?;
|
|
||||||
|
|
||||||
let merged = AllParams {
|
|
||||||
existing,
|
|
||||||
state,
|
|
||||||
params,
|
|
||||||
};
|
|
||||||
|
|
||||||
let new_qs = serde_urlencoded::to_string(merged)
|
|
||||||
.context("could not serialize redirect URI query params")?;
|
|
||||||
|
|
||||||
redirect_uri.set_query(Some(&new_qs));
|
|
||||||
|
|
||||||
Ok(Redirect::to(redirect_uri.as_str()).into_response())
|
|
||||||
}
|
|
||||||
ResponseMode::Fragment => {
|
|
||||||
let existing: Option<HashMap<&str, &str>> = redirect_uri
|
|
||||||
.fragment()
|
|
||||||
.map(serde_urlencoded::from_str)
|
|
||||||
.transpose()
|
|
||||||
.map_err(|_e| RouteError::InvalidRedirectUri)?;
|
|
||||||
|
|
||||||
let merged = AllParams {
|
|
||||||
existing,
|
|
||||||
state,
|
|
||||||
params,
|
|
||||||
};
|
|
||||||
|
|
||||||
let new_qs = serde_urlencoded::to_string(merged)
|
|
||||||
.context("could not serialize redirect URI fragment params")?;
|
|
||||||
|
|
||||||
redirect_uri.set_fragment(Some(&new_qs));
|
|
||||||
|
|
||||||
Ok(Redirect::to(redirect_uri.as_str()).into_response())
|
|
||||||
}
|
|
||||||
ResponseMode::FormPost => {
|
|
||||||
let merged = ParamsWithState { state, params };
|
|
||||||
let ctx = FormPostContext::new(redirect_uri, merged);
|
|
||||||
let rendered = templates
|
|
||||||
.render_form_post(&ctx)
|
|
||||||
.await
|
|
||||||
.context("failed to render form_post.html")?;
|
|
||||||
Ok(Html(rendered).into_response())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub(crate) struct Params {
|
pub(crate) struct Params {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
@ -217,7 +164,7 @@ fn resolve_response_mode(
|
|||||||
if response_type.has_token() || response_type.has_id_token() {
|
if response_type.has_token() || response_type.has_id_token() {
|
||||||
match suggested_response_mode {
|
match suggested_response_mode {
|
||||||
None => Ok(M::Fragment),
|
None => Ok(M::Fragment),
|
||||||
Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")),
|
Some(M::Query) => Err(anyhow!("invalid response mode")),
|
||||||
Some(mode) => Ok(mode),
|
Some(mode) => Ok(mode),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -248,59 +195,44 @@ pub(crate) async fn get(
|
|||||||
let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id).await?;
|
let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id).await?;
|
||||||
|
|
||||||
let redirect_uri = client
|
let redirect_uri = client
|
||||||
.resolve_redirect_uri(¶ms.auth.redirect_uri)
|
.resolve_redirect_uri(¶ms.auth.redirect_uri)?
|
||||||
.map_err(|_e| RouteError::InvalidRedirectUri)?
|
|
||||||
.clone();
|
.clone();
|
||||||
let response_type = params.auth.response_type;
|
let response_type = params.auth.response_type;
|
||||||
let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?;
|
let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?;
|
||||||
|
|
||||||
|
let callback_destination = CallbackDestination::try_new(
|
||||||
|
response_mode,
|
||||||
|
redirect_uri.clone(),
|
||||||
|
params.auth.state.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
// One day, we will have try blocks
|
// One day, we will have try blocks
|
||||||
let res: Result<Response, RouteError> = (async move {
|
let res: Result<Response, RouteError> = (async move {
|
||||||
// Check if the request/request_uri/registration params are used. If so, reply
|
// Check if the request/request_uri/registration params are used. If so, reply
|
||||||
// with the right error since we don't support them.
|
// with the right error since we don't support them.
|
||||||
if params.auth.request.is_some() {
|
if params.auth.request.is_some() {
|
||||||
return back_to_client(
|
return Ok(callback_destination
|
||||||
&redirect_uri,
|
.go(&templates, REQUEST_NOT_SUPPORTED)
|
||||||
response_mode,
|
.await?);
|
||||||
params.auth.state,
|
|
||||||
REQUEST_NOT_SUPPORTED,
|
|
||||||
&templates,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.auth.request_uri.is_some() {
|
if params.auth.request_uri.is_some() {
|
||||||
return back_to_client(
|
return Ok(callback_destination
|
||||||
&redirect_uri,
|
.go(&templates, REQUEST_URI_NOT_SUPPORTED)
|
||||||
response_mode,
|
.await?);
|
||||||
params.auth.state,
|
|
||||||
REQUEST_URI_NOT_SUPPORTED,
|
|
||||||
&templates,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.auth.registration.is_some() {
|
if params.auth.registration.is_some() {
|
||||||
return back_to_client(
|
return Ok(callback_destination
|
||||||
&redirect_uri,
|
.go(&templates, REGISTRATION_NOT_SUPPORTED)
|
||||||
response_mode,
|
.await?);
|
||||||
params.auth.state,
|
|
||||||
REGISTRATION_NOT_SUPPORTED,
|
|
||||||
&templates,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if it is allowed to use this grant type
|
// Check if it is allowed to use this grant type
|
||||||
if !client.grant_types.contains(&GrantType::AuthorizationCode) {
|
if !client.grant_types.contains(&GrantType::AuthorizationCode) {
|
||||||
return back_to_client(
|
return Ok(callback_destination
|
||||||
&redirect_uri,
|
.go(&templates, UNAUTHORIZED_CLIENT)
|
||||||
response_mode,
|
.await?);
|
||||||
params.auth.state,
|
|
||||||
UNAUTHORIZED_CLIENT,
|
|
||||||
&templates,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let code: Option<AuthorizationCode> = if response_type.has_code() {
|
let code: Option<AuthorizationCode> = if response_type.has_code() {
|
||||||
@ -321,14 +253,7 @@ pub(crate) async fn get(
|
|||||||
// If the request had PKCE params but no code asked, it should get back with an
|
// If the request had PKCE params but no code asked, it should get back with an
|
||||||
// error
|
// error
|
||||||
if params.pkce.is_some() {
|
if params.pkce.is_some() {
|
||||||
return back_to_client(
|
return Ok(callback_destination.go(&templates, INVALID_REQUEST).await?);
|
||||||
&redirect_uri,
|
|
||||||
response_mode,
|
|
||||||
params.auth.state,
|
|
||||||
INVALID_REQUEST,
|
|
||||||
&templates,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
None
|
None
|
||||||
@ -373,14 +298,7 @@ pub(crate) async fn get(
|
|||||||
(None, Some(Prompt::None)) => {
|
(None, Some(Prompt::None)) => {
|
||||||
// If there is no session and prompt=none was asked, go back to the client
|
// If there is no session and prompt=none was asked, go back to the client
|
||||||
txn.commit().await?;
|
txn.commit().await?;
|
||||||
Ok(back_to_client(
|
Ok(callback_destination.go(&templates, LOGIN_REQUIRED).await?)
|
||||||
&redirect_uri,
|
|
||||||
response_mode,
|
|
||||||
params.auth.state,
|
|
||||||
LOGIN_REQUIRED,
|
|
||||||
&templates,
|
|
||||||
)
|
|
||||||
.await?)
|
|
||||||
}
|
}
|
||||||
(Some(_), Some(Prompt::Consent)) => {
|
(Some(_), Some(Prompt::Consent)) => {
|
||||||
// We're already logged in but consent was asked
|
// We're already logged in but consent was asked
|
||||||
@ -516,8 +434,10 @@ async fn step(
|
|||||||
// request using a signed cookie
|
// request using a signed cookie
|
||||||
let grant = next.fetch_authorization_grant(&mut txn).await?;
|
let grant = next.fetch_authorization_grant(&mut txn).await?;
|
||||||
|
|
||||||
|
let callback_destination = CallbackDestination::try_from(&grant)?;
|
||||||
|
|
||||||
if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
|
if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
|
||||||
return Err(anyhow::anyhow!("authorization grant not pending").into());
|
return Err(anyhow!("authorization grant not pending").into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let current_consent =
|
let current_consent =
|
||||||
@ -568,19 +488,14 @@ async fn step(
|
|||||||
|
|
||||||
// Did they request an ID token?
|
// Did they request an ID token?
|
||||||
if grant.response_type_id_token {
|
if grant.response_type_id_token {
|
||||||
todo!("id tokens are not implemented yet");
|
return Err(RouteError::Anyhow(anyhow!(
|
||||||
|
"id tokens are not implemented yet"
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let params = serde_json::to_value(¶ms).unwrap();
|
let params = serde_json::to_value(¶ms).unwrap();
|
||||||
|
|
||||||
back_to_client(
|
callback_destination.go(templates, params).await?
|
||||||
&grant.redirect_uri,
|
|
||||||
grant.response_mode,
|
|
||||||
grant.state,
|
|
||||||
params,
|
|
||||||
templates,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
}
|
}
|
||||||
(true, Some(Authentication { created_at, .. })) if created_at > &grant.max_auth_time() => {
|
(true, Some(Authentication { created_at, .. })) if created_at > &grant.max_auth_time() => {
|
||||||
let next: ConsentRequest = next.into();
|
let next: ConsentRequest = next.into();
|
||||||
|
166
crates/handlers/src/oauth2/authorization/callback.rs
Normal file
166
crates/handlers/src/oauth2/authorization/callback.rs
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use axum::response::{Html, IntoResponse, Redirect, Response};
|
||||||
|
use mas_data_model::{AuthorizationGrant, StorageBackend};
|
||||||
|
use mas_templates::{FormPostContext, Templates};
|
||||||
|
use oauth2_types::requests::ResponseMode;
|
||||||
|
use serde::Serialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
enum CallbackDestinationMode {
|
||||||
|
Query {
|
||||||
|
existing_params: HashMap<String, String>,
|
||||||
|
},
|
||||||
|
Fragment,
|
||||||
|
FormPost,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CallbackDestination {
|
||||||
|
mode: CallbackDestinationMode,
|
||||||
|
safe_redirect_uri: Url,
|
||||||
|
state: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum InvalidRedirectUriError {
|
||||||
|
#[error("Redirect URI can't have a fragment")]
|
||||||
|
FragmentNotAllowed,
|
||||||
|
|
||||||
|
#[error("Existing query parameters are not valid")]
|
||||||
|
InvalidQueryParams(#[from] serde_urlencoded::de::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum CallbackDestinationError {
|
||||||
|
#[error("Failed to render the form_post template")]
|
||||||
|
FormPostRender(#[from] mas_templates::TemplateError),
|
||||||
|
|
||||||
|
#[error("Failed to serialize parameters query string")]
|
||||||
|
ParamsSerialization(#[from] serde_urlencoded::ser::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: StorageBackend> TryFrom<&AuthorizationGrant<S>> for CallbackDestination {
|
||||||
|
type Error = InvalidRedirectUriError;
|
||||||
|
|
||||||
|
fn try_from(value: &AuthorizationGrant<S>) -> Result<Self, Self::Error> {
|
||||||
|
Self::try_new(
|
||||||
|
value.response_mode,
|
||||||
|
value.redirect_uri.clone(),
|
||||||
|
value.state.clone(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CallbackDestination {
|
||||||
|
pub fn try_new(
|
||||||
|
mode: ResponseMode,
|
||||||
|
mut redirect_uri: Url,
|
||||||
|
state: Option<String>,
|
||||||
|
) -> Result<Self, InvalidRedirectUriError> {
|
||||||
|
if redirect_uri.fragment().is_some() {
|
||||||
|
return Err(InvalidRedirectUriError::FragmentNotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mode = match mode {
|
||||||
|
ResponseMode::Query => {
|
||||||
|
let existing_params = redirect_uri
|
||||||
|
.query()
|
||||||
|
.map(serde_urlencoded::from_str)
|
||||||
|
.transpose()?
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// Remove the query from the URL
|
||||||
|
redirect_uri.set_query(None);
|
||||||
|
|
||||||
|
CallbackDestinationMode::Query { existing_params }
|
||||||
|
}
|
||||||
|
ResponseMode::Fragment => CallbackDestinationMode::Fragment,
|
||||||
|
ResponseMode::FormPost => CallbackDestinationMode::FormPost,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
mode,
|
||||||
|
safe_redirect_uri: redirect_uri,
|
||||||
|
state,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn go<T: Serialize>(
|
||||||
|
self,
|
||||||
|
templates: &Templates,
|
||||||
|
params: T,
|
||||||
|
) -> Result<Response, CallbackDestinationError> {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct AllParams<'s, T> {
|
||||||
|
#[serde(flatten, skip_serializing_if = "Option::is_none")]
|
||||||
|
existing: Option<&'s HashMap<String, String>>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
state: Option<String>,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
params: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut redirect_uri = self.safe_redirect_uri;
|
||||||
|
let state = self.state;
|
||||||
|
|
||||||
|
match self.mode {
|
||||||
|
CallbackDestinationMode::Query { existing_params } => {
|
||||||
|
let merged = AllParams {
|
||||||
|
existing: Some(&existing_params),
|
||||||
|
state,
|
||||||
|
params,
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_qs = serde_urlencoded::to_string(merged)?;
|
||||||
|
|
||||||
|
redirect_uri.set_query(Some(&new_qs));
|
||||||
|
|
||||||
|
Ok(Redirect::to(redirect_uri.as_str()).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
CallbackDestinationMode::Fragment => {
|
||||||
|
let merged = AllParams {
|
||||||
|
existing: None,
|
||||||
|
state,
|
||||||
|
params,
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_qs = serde_urlencoded::to_string(merged)?;
|
||||||
|
|
||||||
|
redirect_uri.set_fragment(Some(&new_qs));
|
||||||
|
|
||||||
|
Ok(Redirect::to(redirect_uri.as_str()).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
CallbackDestinationMode::FormPost => {
|
||||||
|
let merged = AllParams {
|
||||||
|
existing: None,
|
||||||
|
state,
|
||||||
|
params,
|
||||||
|
};
|
||||||
|
let ctx = FormPostContext::new(redirect_uri, merged);
|
||||||
|
let rendered = templates.render_form_post(&ctx).await?;
|
||||||
|
Ok(Html(rendered).into_response())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user