1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Rewrite authorization code grant callback logic

This commit is contained in:
Quentin Gliech
2022-05-04 16:36:14 +02:00
parent b26be37f5a
commit 7a4dbd2910
4 changed files with 239 additions and 158 deletions

View File

@ -30,8 +30,8 @@ pub(crate) mod users;
pub use self::{
oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, JwksOrJwksUri,
Pkce, Session,
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session,
},
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
traits::{StorageBackend, StorageBackendMarker},

View File

@ -18,6 +18,6 @@ pub(self) mod session;
pub use self::{
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
client::{Client, JwksOrJwksUri},
client::{Client, InvalidRedirectUriError, JwksOrJwksUri},
session::Session,
};

View File

@ -12,12 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use anyhow::Context;
use anyhow::{anyhow, Context};
use axum::{
extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response},
response::{IntoResponse, Redirect, Response},
};
use axum_extra::extract::PrivateCookieJar;
use chrono::Duration;
@ -44,7 +42,7 @@ use mas_storage::{
},
PostgresqlBackend,
};
use mas_templates::{FormPostContext, Templates};
use mas_templates::Templates;
use oauth2_types::{
errors::{
INVALID_REQUEST, LOGIN_REQUIRED, REGISTRATION_NOT_SUPPORTED, REQUEST_NOT_SUPPORTED,
@ -62,26 +60,55 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool, Postgres, Transaction};
use thiserror::Error;
use url::Url;
use self::callback::CallbackDestination;
use super::consent::ConsentRequest;
use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest};
mod callback;
#[derive(Debug, Error)]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)]
Anyhow(anyhow::Error),
#[error("could not find client")]
ClientNotFound,
#[error("invalid redirect uri")]
InvalidRedirectUri,
InvalidRedirectUri(#[from] self::callback::InvalidRedirectUriError),
#[error("invalid redirect uri")]
UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError),
}
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
StatusCode::INTERNAL_SERVER_ERROR.into_response()
// TODO: better error pages
match self {
RouteError::Internal(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
RouteError::Anyhow(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
RouteError::ClientNotFound => {
(StatusCode::BAD_REQUEST, "could not find client").into_response()
}
RouteError::InvalidRedirectUri(e) => (
StatusCode::BAD_REQUEST,
format!("Invalid redirect URI ({})", e),
)
.into_response(),
RouteError::UnknownRedirectUri(e) => (
StatusCode::BAD_REQUEST,
format!("Invalid redirect URI ({})", e),
)
.into_response(),
}
}
}
@ -91,6 +118,12 @@ impl From<sqlx::Error> for RouteError {
}
}
impl From<self::callback::CallbackDestinationError> for RouteError {
fn from(e: self::callback::CallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<ClientFetchError> for RouteError {
fn from(e: ClientFetchError) -> Self {
if e.not_found() {
@ -107,92 +140,6 @@ impl From<anyhow::Error> for RouteError {
}
}
async fn back_to_client<T>(
redirect_uri: &Url,
response_mode: ResponseMode,
state: Option<String>,
params: T,
templates: &Templates,
) -> Result<Response, RouteError>
where
T: Serialize,
{
#[derive(Serialize)]
struct AllParams<'s, T> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
existing: Option<HashMap<&'s str, &'s str>>,
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
#[derive(Serialize)]
struct ParamsWithState<T> {
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
let mut redirect_uri = redirect_uri.clone();
match response_mode {
ResponseMode::Query => {
let existing: Option<HashMap<&str, &str>> = redirect_uri
.query()
.map(serde_urlencoded::from_str)
.transpose()
.map_err(|_e| RouteError::InvalidRedirectUri)?;
let merged = AllParams {
existing,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)
.context("could not serialize redirect URI query params")?;
redirect_uri.set_query(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
ResponseMode::Fragment => {
let existing: Option<HashMap<&str, &str>> = redirect_uri
.fragment()
.map(serde_urlencoded::from_str)
.transpose()
.map_err(|_e| RouteError::InvalidRedirectUri)?;
let merged = AllParams {
existing,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)
.context("could not serialize redirect URI fragment params")?;
redirect_uri.set_fragment(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
ResponseMode::FormPost => {
let merged = ParamsWithState { state, params };
let ctx = FormPostContext::new(redirect_uri, merged);
let rendered = templates
.render_form_post(&ctx)
.await
.context("failed to render form_post.html")?;
Ok(Html(rendered).into_response())
}
}
}
#[derive(Deserialize)]
pub(crate) struct Params {
#[serde(flatten)]
@ -217,7 +164,7 @@ fn resolve_response_mode(
if response_type.has_token() || response_type.has_id_token() {
match suggested_response_mode {
None => Ok(M::Fragment),
Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")),
Some(M::Query) => Err(anyhow!("invalid response mode")),
Some(mode) => Ok(mode),
}
} else {
@ -248,59 +195,44 @@ pub(crate) async fn get(
let client = lookup_client_by_client_id(&mut txn, &params.auth.client_id).await?;
let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri)
.map_err(|_e| RouteError::InvalidRedirectUri)?
.resolve_redirect_uri(&params.auth.redirect_uri)?
.clone();
let response_type = params.auth.response_type;
let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?;
let callback_destination = CallbackDestination::try_new(
response_mode,
redirect_uri.clone(),
params.auth.state.clone(),
)?;
// One day, we will have try blocks
let res: Result<Response, RouteError> = (async move {
// Check if the request/request_uri/registration params are used. If so, reply
// with the right error since we don't support them.
if params.auth.request.is_some() {
return back_to_client(
&redirect_uri,
response_mode,
params.auth.state,
REQUEST_NOT_SUPPORTED,
&templates,
)
.await;
return Ok(callback_destination
.go(&templates, REQUEST_NOT_SUPPORTED)
.await?);
}
if params.auth.request_uri.is_some() {
return back_to_client(
&redirect_uri,
response_mode,
params.auth.state,
REQUEST_URI_NOT_SUPPORTED,
&templates,
)
.await;
return Ok(callback_destination
.go(&templates, REQUEST_URI_NOT_SUPPORTED)
.await?);
}
if params.auth.registration.is_some() {
return back_to_client(
&redirect_uri,
response_mode,
params.auth.state,
REGISTRATION_NOT_SUPPORTED,
&templates,
)
.await;
return Ok(callback_destination
.go(&templates, REGISTRATION_NOT_SUPPORTED)
.await?);
}
// Check if it is allowed to use this grant type
if !client.grant_types.contains(&GrantType::AuthorizationCode) {
return back_to_client(
&redirect_uri,
response_mode,
params.auth.state,
UNAUTHORIZED_CLIENT,
&templates,
)
.await;
return Ok(callback_destination
.go(&templates, UNAUTHORIZED_CLIENT)
.await?);
}
let code: Option<AuthorizationCode> = if response_type.has_code() {
@ -321,14 +253,7 @@ pub(crate) async fn get(
// If the request had PKCE params but no code asked, it should get back with an
// error
if params.pkce.is_some() {
return back_to_client(
&redirect_uri,
response_mode,
params.auth.state,
INVALID_REQUEST,
&templates,
)
.await;
return Ok(callback_destination.go(&templates, INVALID_REQUEST).await?);
}
None
@ -373,14 +298,7 @@ pub(crate) async fn get(
(None, Some(Prompt::None)) => {
// If there is no session and prompt=none was asked, go back to the client
txn.commit().await?;
Ok(back_to_client(
&redirect_uri,
response_mode,
params.auth.state,
LOGIN_REQUIRED,
&templates,
)
.await?)
Ok(callback_destination.go(&templates, LOGIN_REQUIRED).await?)
}
(Some(_), Some(Prompt::Consent)) => {
// We're already logged in but consent was asked
@ -516,8 +434,10 @@ async fn step(
// request using a signed cookie
let grant = next.fetch_authorization_grant(&mut txn).await?;
let callback_destination = CallbackDestination::try_from(&grant)?;
if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
return Err(anyhow::anyhow!("authorization grant not pending").into());
return Err(anyhow!("authorization grant not pending").into());
}
let current_consent =
@ -568,19 +488,14 @@ async fn step(
// Did they request an ID token?
if grant.response_type_id_token {
todo!("id tokens are not implemented yet");
return Err(RouteError::Anyhow(anyhow!(
"id tokens are not implemented yet"
)));
}
let params = serde_json::to_value(&params).unwrap();
back_to_client(
&grant.redirect_uri,
grant.response_mode,
grant.state,
params,
templates,
)
.await?
callback_destination.go(templates, params).await?
}
(true, Some(Authentication { created_at, .. })) if created_at > &grant.max_auth_time() => {
let next: ConsentRequest = next.into();

View File

@ -0,0 +1,166 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::module_name_repetitions)]
use std::collections::HashMap;
use axum::response::{Html, IntoResponse, Redirect, Response};
use mas_data_model::{AuthorizationGrant, StorageBackend};
use mas_templates::{FormPostContext, Templates};
use oauth2_types::requests::ResponseMode;
use serde::Serialize;
use thiserror::Error;
use url::Url;
enum CallbackDestinationMode {
Query {
existing_params: HashMap<String, String>,
},
Fragment,
FormPost,
}
pub struct CallbackDestination {
mode: CallbackDestinationMode,
safe_redirect_uri: Url,
state: Option<String>,
}
#[derive(Debug, Error)]
pub enum InvalidRedirectUriError {
#[error("Redirect URI can't have a fragment")]
FragmentNotAllowed,
#[error("Existing query parameters are not valid")]
InvalidQueryParams(#[from] serde_urlencoded::de::Error),
}
#[derive(Debug, Error)]
pub enum CallbackDestinationError {
#[error("Failed to render the form_post template")]
FormPostRender(#[from] mas_templates::TemplateError),
#[error("Failed to serialize parameters query string")]
ParamsSerialization(#[from] serde_urlencoded::ser::Error),
}
impl<S: StorageBackend> TryFrom<&AuthorizationGrant<S>> for CallbackDestination {
type Error = InvalidRedirectUriError;
fn try_from(value: &AuthorizationGrant<S>) -> Result<Self, Self::Error> {
Self::try_new(
value.response_mode,
value.redirect_uri.clone(),
value.state.clone(),
)
}
}
impl CallbackDestination {
pub fn try_new(
mode: ResponseMode,
mut redirect_uri: Url,
state: Option<String>,
) -> Result<Self, InvalidRedirectUriError> {
if redirect_uri.fragment().is_some() {
return Err(InvalidRedirectUriError::FragmentNotAllowed);
}
let mode = match mode {
ResponseMode::Query => {
let existing_params = redirect_uri
.query()
.map(serde_urlencoded::from_str)
.transpose()?
.unwrap_or_default();
// Remove the query from the URL
redirect_uri.set_query(None);
CallbackDestinationMode::Query { existing_params }
}
ResponseMode::Fragment => CallbackDestinationMode::Fragment,
ResponseMode::FormPost => CallbackDestinationMode::FormPost,
};
Ok(Self {
mode,
safe_redirect_uri: redirect_uri,
state,
})
}
pub async fn go<T: Serialize>(
self,
templates: &Templates,
params: T,
) -> Result<Response, CallbackDestinationError> {
#[derive(Serialize)]
struct AllParams<'s, T> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
existing: Option<&'s HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
let mut redirect_uri = self.safe_redirect_uri;
let state = self.state;
match self.mode {
CallbackDestinationMode::Query { existing_params } => {
let merged = AllParams {
existing: Some(&existing_params),
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_query(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
CallbackDestinationMode::Fragment => {
let merged = AllParams {
existing: None,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_fragment(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
CallbackDestinationMode::FormPost => {
let merged = AllParams {
existing: None,
state,
params,
};
let ctx = FormPostContext::new(redirect_uri, merged);
let rendered = templates.render_form_post(&ctx).await?;
Ok(Html(rendered).into_response())
}
}
}
}