// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #![allow(clippy::module_name_repetitions)] use std::collections::HashMap; use axum::response::{Html, IntoResponse, Redirect, Response}; use mas_data_model::{AuthorizationGrant, StorageBackend}; use mas_templates::{FormPostContext, Templates}; use oauth2_types::requests::ResponseMode; use serde::Serialize; use thiserror::Error; use url::Url; #[derive(Debug, Clone)] enum CallbackDestinationMode { Query { existing_params: HashMap, }, Fragment, FormPost, } #[derive(Debug, Clone)] pub struct CallbackDestination { mode: CallbackDestinationMode, safe_redirect_uri: Url, state: Option, } #[derive(Debug, Error)] pub enum IntoCallbackDestinationError { #[error("Redirect URI can't have a fragment")] RedirectUriFragmentNotAllowed, #[error("Existing query parameters are not valid")] RedirectUriInvalidQueryParams(#[from] serde_urlencoded::de::Error), #[error("Requested response_mode is not supported")] UnsupportedResponseMode, } #[derive(Debug, Error)] pub enum CallbackDestinationError { #[error("Failed to render the form_post template")] FormPostRender(#[from] mas_templates::TemplateError), #[error("Failed to serialize parameters query string")] ParamsSerialization(#[from] serde_urlencoded::ser::Error), } impl TryFrom<&AuthorizationGrant> for CallbackDestination { type Error = IntoCallbackDestinationError; fn try_from(value: &AuthorizationGrant) -> Result { Self::try_new( &value.response_mode, value.redirect_uri.clone(), value.state.clone(), ) } } impl CallbackDestination { pub fn try_new( mode: &ResponseMode, mut redirect_uri: Url, state: Option, ) -> Result { if redirect_uri.fragment().is_some() { return Err(IntoCallbackDestinationError::RedirectUriFragmentNotAllowed); } let mode = match mode { ResponseMode::Query => { let existing_params = redirect_uri .query() .map(serde_urlencoded::from_str) .transpose()? .unwrap_or_default(); // Remove the query from the URL redirect_uri.set_query(None); CallbackDestinationMode::Query { existing_params } } ResponseMode::Fragment => CallbackDestinationMode::Fragment, ResponseMode::FormPost => CallbackDestinationMode::FormPost, _ => return Err(IntoCallbackDestinationError::UnsupportedResponseMode), }; Ok(Self { mode, safe_redirect_uri: redirect_uri, state, }) } pub async fn go( self, templates: &Templates, params: T, ) -> Result { #[derive(Serialize)] struct AllParams<'s, T> { #[serde(flatten, skip_serializing_if = "Option::is_none")] existing: Option<&'s HashMap>, #[serde(skip_serializing_if = "Option::is_none")] state: Option, #[serde(flatten)] params: T, } let mut redirect_uri = self.safe_redirect_uri; let state = self.state; match self.mode { CallbackDestinationMode::Query { existing_params } => { let merged = AllParams { existing: Some(&existing_params), state, params, }; let new_qs = serde_urlencoded::to_string(merged)?; redirect_uri.set_query(Some(&new_qs)); Ok(Redirect::to(redirect_uri.as_str()).into_response()) } CallbackDestinationMode::Fragment => { let merged = AllParams { existing: None, state, params, }; let new_qs = serde_urlencoded::to_string(merged)?; redirect_uri.set_fragment(Some(&new_qs)); Ok(Redirect::to(redirect_uri.as_str()).into_response()) } CallbackDestinationMode::FormPost => { let merged = AllParams { existing: None, state, params, }; let ctx = FormPostContext::new(redirect_uri, merged); let rendered = templates.render_form_post(&ctx).await?; Ok(Html(rendered).into_response()) } } } }