diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index f8310897..25690f49 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -30,8 +30,8 @@ pub(crate) mod users; pub use self::{ oauth2::{ - AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, JwksOrJwksUri, - Pkce, Session, + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, + InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, }, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, traits::{StorageBackend, StorageBackendMarker}, diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index a0c7237b..ef512260 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -18,6 +18,6 @@ pub(self) mod session; pub use self::{ authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, - client::{Client, JwksOrJwksUri}, + client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, session::Session, }; diff --git a/crates/handlers/src/oauth2/authorization.rs b/crates/handlers/src/oauth2/authorization.rs index 8e94116d..70810d8d 100644 --- a/crates/handlers/src/oauth2/authorization.rs +++ b/crates/handlers/src/oauth2/authorization.rs @@ -12,12 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; - -use anyhow::Context; +use anyhow::{anyhow, Context}; use axum::{ extract::{Extension, Form, Query}, - response::{Html, IntoResponse, Redirect, Response}, + response::{IntoResponse, Redirect, Response}, }; use axum_extra::extract::PrivateCookieJar; use chrono::Duration; @@ -44,7 +42,7 @@ use mas_storage::{ }, PostgresqlBackend, }; -use mas_templates::{FormPostContext, Templates}; +use mas_templates::Templates; use oauth2_types::{ errors::{ INVALID_REQUEST, LOGIN_REQUIRED, REGISTRATION_NOT_SUPPORTED, REQUEST_NOT_SUPPORTED, @@ -62,26 +60,55 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng}; use serde::{Deserialize, Serialize}; use sqlx::{PgConnection, PgPool, Postgres, Transaction}; use thiserror::Error; -use url::Url; +use self::callback::CallbackDestination; use super::consent::ConsentRequest; use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest}; +mod callback; + #[derive(Debug, Error)] pub enum RouteError { #[error(transparent)] Internal(Box), + #[error(transparent)] Anyhow(anyhow::Error), + #[error("could not find client")] ClientNotFound, + #[error("invalid redirect uri")] - InvalidRedirectUri, + InvalidRedirectUri(#[from] self::callback::InvalidRedirectUriError), + + #[error("invalid redirect uri")] + UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError), } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { - StatusCode::INTERNAL_SERVER_ERROR.into_response() + // TODO: better error pages + match self { + RouteError::Internal(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + RouteError::Anyhow(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + RouteError::ClientNotFound => { + (StatusCode::BAD_REQUEST, "could not find client").into_response() + } + RouteError::InvalidRedirectUri(e) => ( + StatusCode::BAD_REQUEST, + format!("Invalid redirect URI ({})", e), + ) + .into_response(), + RouteError::UnknownRedirectUri(e) => ( + StatusCode::BAD_REQUEST, + format!("Invalid redirect URI ({})", e), + ) + .into_response(), + } } } @@ -91,6 +118,12 @@ impl From for RouteError { } } +impl From for RouteError { + fn from(e: self::callback::CallbackDestinationError) -> Self { + Self::Internal(Box::new(e)) + } +} + impl From for RouteError { fn from(e: ClientFetchError) -> Self { if e.not_found() { @@ -107,92 +140,6 @@ impl From for RouteError { } } -async fn back_to_client( - redirect_uri: &Url, - response_mode: ResponseMode, - state: Option, - params: T, - templates: &Templates, -) -> Result -where - T: Serialize, -{ - #[derive(Serialize)] - struct AllParams<'s, T> { - #[serde(flatten, skip_serializing_if = "Option::is_none")] - existing: Option>, - - #[serde(skip_serializing_if = "Option::is_none")] - state: Option, - - #[serde(flatten)] - params: T, - } - - #[derive(Serialize)] - struct ParamsWithState { - #[serde(skip_serializing_if = "Option::is_none")] - state: Option, - - #[serde(flatten)] - params: T, - } - - let mut redirect_uri = redirect_uri.clone(); - - match response_mode { - ResponseMode::Query => { - let existing: Option> = redirect_uri - .query() - .map(serde_urlencoded::from_str) - .transpose() - .map_err(|_e| RouteError::InvalidRedirectUri)?; - - let merged = AllParams { - existing, - state, - params, - }; - - let new_qs = serde_urlencoded::to_string(merged) - .context("could not serialize redirect URI query params")?; - - redirect_uri.set_query(Some(&new_qs)); - - Ok(Redirect::to(redirect_uri.as_str()).into_response()) - } - ResponseMode::Fragment => { - let existing: Option> = redirect_uri - .fragment() - .map(serde_urlencoded::from_str) - .transpose() - .map_err(|_e| RouteError::InvalidRedirectUri)?; - - let merged = AllParams { - existing, - state, - params, - }; - - let new_qs = serde_urlencoded::to_string(merged) - .context("could not serialize redirect URI fragment params")?; - - redirect_uri.set_fragment(Some(&new_qs)); - - Ok(Redirect::to(redirect_uri.as_str()).into_response()) - } - ResponseMode::FormPost => { - let merged = ParamsWithState { state, params }; - let ctx = FormPostContext::new(redirect_uri, merged); - let rendered = templates - .render_form_post(&ctx) - .await - .context("failed to render form_post.html")?; - Ok(Html(rendered).into_response()) - } - } -} - #[derive(Deserialize)] pub(crate) struct Params { #[serde(flatten)] @@ -217,7 +164,7 @@ fn resolve_response_mode( if response_type.has_token() || response_type.has_id_token() { match suggested_response_mode { None => Ok(M::Fragment), - Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")), + Some(M::Query) => Err(anyhow!("invalid response mode")), Some(mode) => Ok(mode), } } else { @@ -248,59 +195,44 @@ pub(crate) async fn get( let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id).await?; let redirect_uri = client - .resolve_redirect_uri(¶ms.auth.redirect_uri) - .map_err(|_e| RouteError::InvalidRedirectUri)? + .resolve_redirect_uri(¶ms.auth.redirect_uri)? .clone(); let response_type = params.auth.response_type; let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?; + let callback_destination = CallbackDestination::try_new( + response_mode, + redirect_uri.clone(), + params.auth.state.clone(), + )?; + // One day, we will have try blocks let res: Result = (async move { // Check if the request/request_uri/registration params are used. If so, reply // with the right error since we don't support them. if params.auth.request.is_some() { - return back_to_client( - &redirect_uri, - response_mode, - params.auth.state, - REQUEST_NOT_SUPPORTED, - &templates, - ) - .await; + return Ok(callback_destination + .go(&templates, REQUEST_NOT_SUPPORTED) + .await?); } if params.auth.request_uri.is_some() { - return back_to_client( - &redirect_uri, - response_mode, - params.auth.state, - REQUEST_URI_NOT_SUPPORTED, - &templates, - ) - .await; + return Ok(callback_destination + .go(&templates, REQUEST_URI_NOT_SUPPORTED) + .await?); } if params.auth.registration.is_some() { - return back_to_client( - &redirect_uri, - response_mode, - params.auth.state, - REGISTRATION_NOT_SUPPORTED, - &templates, - ) - .await; + return Ok(callback_destination + .go(&templates, REGISTRATION_NOT_SUPPORTED) + .await?); } // Check if it is allowed to use this grant type if !client.grant_types.contains(&GrantType::AuthorizationCode) { - return back_to_client( - &redirect_uri, - response_mode, - params.auth.state, - UNAUTHORIZED_CLIENT, - &templates, - ) - .await; + return Ok(callback_destination + .go(&templates, UNAUTHORIZED_CLIENT) + .await?); } let code: Option = if response_type.has_code() { @@ -321,14 +253,7 @@ pub(crate) async fn get( // If the request had PKCE params but no code asked, it should get back with an // error if params.pkce.is_some() { - return back_to_client( - &redirect_uri, - response_mode, - params.auth.state, - INVALID_REQUEST, - &templates, - ) - .await; + return Ok(callback_destination.go(&templates, INVALID_REQUEST).await?); } None @@ -373,14 +298,7 @@ pub(crate) async fn get( (None, Some(Prompt::None)) => { // If there is no session and prompt=none was asked, go back to the client txn.commit().await?; - Ok(back_to_client( - &redirect_uri, - response_mode, - params.auth.state, - LOGIN_REQUIRED, - &templates, - ) - .await?) + Ok(callback_destination.go(&templates, LOGIN_REQUIRED).await?) } (Some(_), Some(Prompt::Consent)) => { // We're already logged in but consent was asked @@ -516,8 +434,10 @@ async fn step( // request using a signed cookie let grant = next.fetch_authorization_grant(&mut txn).await?; + let callback_destination = CallbackDestination::try_from(&grant)?; + if !matches!(grant.stage, AuthorizationGrantStage::Pending) { - return Err(anyhow::anyhow!("authorization grant not pending").into()); + return Err(anyhow!("authorization grant not pending").into()); } let current_consent = @@ -568,19 +488,14 @@ async fn step( // Did they request an ID token? if grant.response_type_id_token { - todo!("id tokens are not implemented yet"); + return Err(RouteError::Anyhow(anyhow!( + "id tokens are not implemented yet" + ))); } let params = serde_json::to_value(¶ms).unwrap(); - back_to_client( - &grant.redirect_uri, - grant.response_mode, - grant.state, - params, - templates, - ) - .await? + callback_destination.go(templates, params).await? } (true, Some(Authentication { created_at, .. })) if created_at > &grant.max_auth_time() => { let next: ConsentRequest = next.into(); diff --git a/crates/handlers/src/oauth2/authorization/callback.rs b/crates/handlers/src/oauth2/authorization/callback.rs new file mode 100644 index 00000000..3af2e094 --- /dev/null +++ b/crates/handlers/src/oauth2/authorization/callback.rs @@ -0,0 +1,166 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(clippy::module_name_repetitions)] + +use std::collections::HashMap; + +use axum::response::{Html, IntoResponse, Redirect, Response}; +use mas_data_model::{AuthorizationGrant, StorageBackend}; +use mas_templates::{FormPostContext, Templates}; +use oauth2_types::requests::ResponseMode; +use serde::Serialize; +use thiserror::Error; +use url::Url; + +enum CallbackDestinationMode { + Query { + existing_params: HashMap, + }, + Fragment, + FormPost, +} + +pub struct CallbackDestination { + mode: CallbackDestinationMode, + safe_redirect_uri: Url, + state: Option, +} + +#[derive(Debug, Error)] +pub enum InvalidRedirectUriError { + #[error("Redirect URI can't have a fragment")] + FragmentNotAllowed, + + #[error("Existing query parameters are not valid")] + InvalidQueryParams(#[from] serde_urlencoded::de::Error), +} + +#[derive(Debug, Error)] +pub enum CallbackDestinationError { + #[error("Failed to render the form_post template")] + FormPostRender(#[from] mas_templates::TemplateError), + + #[error("Failed to serialize parameters query string")] + ParamsSerialization(#[from] serde_urlencoded::ser::Error), +} + +impl TryFrom<&AuthorizationGrant> for CallbackDestination { + type Error = InvalidRedirectUriError; + + fn try_from(value: &AuthorizationGrant) -> Result { + Self::try_new( + value.response_mode, + value.redirect_uri.clone(), + value.state.clone(), + ) + } +} + +impl CallbackDestination { + pub fn try_new( + mode: ResponseMode, + mut redirect_uri: Url, + state: Option, + ) -> Result { + if redirect_uri.fragment().is_some() { + return Err(InvalidRedirectUriError::FragmentNotAllowed); + } + + let mode = match mode { + ResponseMode::Query => { + let existing_params = redirect_uri + .query() + .map(serde_urlencoded::from_str) + .transpose()? + .unwrap_or_default(); + + // Remove the query from the URL + redirect_uri.set_query(None); + + CallbackDestinationMode::Query { existing_params } + } + ResponseMode::Fragment => CallbackDestinationMode::Fragment, + ResponseMode::FormPost => CallbackDestinationMode::FormPost, + }; + + Ok(Self { + mode, + safe_redirect_uri: redirect_uri, + state, + }) + } + + pub async fn go( + self, + templates: &Templates, + params: T, + ) -> Result { + #[derive(Serialize)] + struct AllParams<'s, T> { + #[serde(flatten, skip_serializing_if = "Option::is_none")] + existing: Option<&'s HashMap>, + + #[serde(skip_serializing_if = "Option::is_none")] + state: Option, + + #[serde(flatten)] + params: T, + } + + let mut redirect_uri = self.safe_redirect_uri; + let state = self.state; + + match self.mode { + CallbackDestinationMode::Query { existing_params } => { + let merged = AllParams { + existing: Some(&existing_params), + state, + params, + }; + + let new_qs = serde_urlencoded::to_string(merged)?; + + redirect_uri.set_query(Some(&new_qs)); + + Ok(Redirect::to(redirect_uri.as_str()).into_response()) + } + + CallbackDestinationMode::Fragment => { + let merged = AllParams { + existing: None, + state, + params, + }; + + let new_qs = serde_urlencoded::to_string(merged)?; + + redirect_uri.set_fragment(Some(&new_qs)); + + Ok(Redirect::to(redirect_uri.as_str()).into_response()) + } + + CallbackDestinationMode::FormPost => { + let merged = AllParams { + existing: None, + state, + params, + }; + let ctx = FormPostContext::new(redirect_uri, merged); + let rendered = templates.render_form_post(&ctx).await?; + Ok(Html(rendered).into_response()) + } + } + } +}