1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Rewrite authorization code grant callback logic

This commit is contained in:
Quentin Gliech
2022-05-04 16:36:14 +02:00
parent b26be37f5a
commit 7a4dbd2910
4 changed files with 239 additions and 158 deletions

View File

@ -30,8 +30,8 @@ pub(crate) mod users;
pub use self::{ pub use self::{
oauth2::{ oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, JwksOrJwksUri, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
Pkce, Session, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session,
}, },
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
traits::{StorageBackend, StorageBackendMarker}, traits::{StorageBackend, StorageBackendMarker},

View File

@ -18,6 +18,6 @@ pub(self) mod session;
pub use self::{ pub use self::{
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
client::{Client, JwksOrJwksUri}, client::{Client, InvalidRedirectUriError, JwksOrJwksUri},
session::Session, session::Session,
}; };

View File

@ -12,12 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::HashMap; use anyhow::{anyhow, Context};
use anyhow::Context;
use axum::{ use axum::{
extract::{Extension, Form, Query}, extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response}, response::{IntoResponse, Redirect, Response},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use chrono::Duration; use chrono::Duration;
@ -44,7 +42,7 @@ use mas_storage::{
}, },
PostgresqlBackend, PostgresqlBackend,
}; };
use mas_templates::{FormPostContext, Templates}; use mas_templates::Templates;
use oauth2_types::{ use oauth2_types::{
errors::{ errors::{
INVALID_REQUEST, LOGIN_REQUIRED, REGISTRATION_NOT_SUPPORTED, REQUEST_NOT_SUPPORTED, INVALID_REQUEST, LOGIN_REQUIRED, REGISTRATION_NOT_SUPPORTED, REQUEST_NOT_SUPPORTED,
@ -62,26 +60,55 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool, Postgres, Transaction}; use sqlx::{PgConnection, PgPool, Postgres, Transaction};
use thiserror::Error; use thiserror::Error;
use url::Url;
use self::callback::CallbackDestination;
use super::consent::ConsentRequest; use super::consent::ConsentRequest;
use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest}; use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest};
mod callback;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum RouteError { pub enum RouteError {
#[error(transparent)] #[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>), Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)] #[error(transparent)]
Anyhow(anyhow::Error), Anyhow(anyhow::Error),
#[error("could not find client")] #[error("could not find client")]
ClientNotFound, ClientNotFound,
#[error("invalid redirect uri")] #[error("invalid redirect uri")]
InvalidRedirectUri, InvalidRedirectUri(#[from] self::callback::InvalidRedirectUriError),
#[error("invalid redirect uri")]
UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError),
} }
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
StatusCode::INTERNAL_SERVER_ERROR.into_response() // TODO: better error pages
match self {
RouteError::Internal(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
RouteError::Anyhow(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
}
RouteError::ClientNotFound => {
(StatusCode::BAD_REQUEST, "could not find client").into_response()
}
RouteError::InvalidRedirectUri(e) => (
StatusCode::BAD_REQUEST,
format!("Invalid redirect URI ({})", e),
)
.into_response(),
RouteError::UnknownRedirectUri(e) => (
StatusCode::BAD_REQUEST,
format!("Invalid redirect URI ({})", e),
)
.into_response(),
}
} }
} }
@ -91,6 +118,12 @@ impl From<sqlx::Error> for RouteError {
} }
} }
impl From<self::callback::CallbackDestinationError> for RouteError {
fn from(e: self::callback::CallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<ClientFetchError> for RouteError { impl From<ClientFetchError> for RouteError {
fn from(e: ClientFetchError) -> Self { fn from(e: ClientFetchError) -> Self {
if e.not_found() { if e.not_found() {
@ -107,92 +140,6 @@ impl From<anyhow::Error> for RouteError {
} }
} }
async fn back_to_client<T>(
redirect_uri: &Url,
response_mode: ResponseMode,
state: Option<String>,
params: T,
templates: &Templates,
) -> Result<Response, RouteError>
where
T: Serialize,
{
#[derive(Serialize)]
struct AllParams<'s, T> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
existing: Option<HashMap<&'s str, &'s str>>,
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
#[derive(Serialize)]
struct ParamsWithState<T> {
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
let mut redirect_uri = redirect_uri.clone();
match response_mode {
ResponseMode::Query => {
let existing: Option<HashMap<&str, &str>> = redirect_uri
.query()
.map(serde_urlencoded::from_str)
.transpose()
.map_err(|_e| RouteError::InvalidRedirectUri)?;
let merged = AllParams {
existing,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)
.context("could not serialize redirect URI query params")?;
redirect_uri.set_query(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
ResponseMode::Fragment => {
let existing: Option<HashMap<&str, &str>> = redirect_uri
.fragment()
.map(serde_urlencoded::from_str)
.transpose()
.map_err(|_e| RouteError::InvalidRedirectUri)?;
let merged = AllParams {
existing,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)
.context("could not serialize redirect URI fragment params")?;
redirect_uri.set_fragment(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
ResponseMode::FormPost => {
let merged = ParamsWithState { state, params };
let ctx = FormPostContext::new(redirect_uri, merged);
let rendered = templates
.render_form_post(&ctx)
.await
.context("failed to render form_post.html")?;
Ok(Html(rendered).into_response())
}
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub(crate) struct Params { pub(crate) struct Params {
#[serde(flatten)] #[serde(flatten)]
@ -217,7 +164,7 @@ fn resolve_response_mode(
if response_type.has_token() || response_type.has_id_token() { if response_type.has_token() || response_type.has_id_token() {
match suggested_response_mode { match suggested_response_mode {
None => Ok(M::Fragment), None => Ok(M::Fragment),
Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")), Some(M::Query) => Err(anyhow!("invalid response mode")),
Some(mode) => Ok(mode), Some(mode) => Ok(mode),
} }
} else { } else {
@ -248,59 +195,44 @@ pub(crate) async fn get(
let client = lookup_client_by_client_id(&mut txn, &params.auth.client_id).await?; let client = lookup_client_by_client_id(&mut txn, &params.auth.client_id).await?;
let redirect_uri = client let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri) .resolve_redirect_uri(&params.auth.redirect_uri)?
.map_err(|_e| RouteError::InvalidRedirectUri)?
.clone(); .clone();
let response_type = params.auth.response_type; let response_type = params.auth.response_type;
let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?; let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?;
let callback_destination = CallbackDestination::try_new(
response_mode,
redirect_uri.clone(),
params.auth.state.clone(),
)?;
// One day, we will have try blocks // One day, we will have try blocks
let res: Result<Response, RouteError> = (async move { let res: Result<Response, RouteError> = (async move {
// Check if the request/request_uri/registration params are used. If so, reply // Check if the request/request_uri/registration params are used. If so, reply
// with the right error since we don't support them. // with the right error since we don't support them.
if params.auth.request.is_some() { if params.auth.request.is_some() {
return back_to_client( return Ok(callback_destination
&redirect_uri, .go(&templates, REQUEST_NOT_SUPPORTED)
response_mode, .await?);
params.auth.state,
REQUEST_NOT_SUPPORTED,
&templates,
)
.await;
} }
if params.auth.request_uri.is_some() { if params.auth.request_uri.is_some() {
return back_to_client( return Ok(callback_destination
&redirect_uri, .go(&templates, REQUEST_URI_NOT_SUPPORTED)
response_mode, .await?);
params.auth.state,
REQUEST_URI_NOT_SUPPORTED,
&templates,
)
.await;
} }
if params.auth.registration.is_some() { if params.auth.registration.is_some() {
return back_to_client( return Ok(callback_destination
&redirect_uri, .go(&templates, REGISTRATION_NOT_SUPPORTED)
response_mode, .await?);
params.auth.state,
REGISTRATION_NOT_SUPPORTED,
&templates,
)
.await;
} }
// Check if it is allowed to use this grant type // Check if it is allowed to use this grant type
if !client.grant_types.contains(&GrantType::AuthorizationCode) { if !client.grant_types.contains(&GrantType::AuthorizationCode) {
return back_to_client( return Ok(callback_destination
&redirect_uri, .go(&templates, UNAUTHORIZED_CLIENT)
response_mode, .await?);
params.auth.state,
UNAUTHORIZED_CLIENT,
&templates,
)
.await;
} }
let code: Option<AuthorizationCode> = if response_type.has_code() { let code: Option<AuthorizationCode> = if response_type.has_code() {
@ -321,14 +253,7 @@ pub(crate) async fn get(
// If the request had PKCE params but no code asked, it should get back with an // If the request had PKCE params but no code asked, it should get back with an
// error // error
if params.pkce.is_some() { if params.pkce.is_some() {
return back_to_client( return Ok(callback_destination.go(&templates, INVALID_REQUEST).await?);
&redirect_uri,
response_mode,
params.auth.state,
INVALID_REQUEST,
&templates,
)
.await;
} }
None None
@ -373,14 +298,7 @@ pub(crate) async fn get(
(None, Some(Prompt::None)) => { (None, Some(Prompt::None)) => {
// If there is no session and prompt=none was asked, go back to the client // If there is no session and prompt=none was asked, go back to the client
txn.commit().await?; txn.commit().await?;
Ok(back_to_client( Ok(callback_destination.go(&templates, LOGIN_REQUIRED).await?)
&redirect_uri,
response_mode,
params.auth.state,
LOGIN_REQUIRED,
&templates,
)
.await?)
} }
(Some(_), Some(Prompt::Consent)) => { (Some(_), Some(Prompt::Consent)) => {
// We're already logged in but consent was asked // We're already logged in but consent was asked
@ -516,8 +434,10 @@ async fn step(
// request using a signed cookie // request using a signed cookie
let grant = next.fetch_authorization_grant(&mut txn).await?; let grant = next.fetch_authorization_grant(&mut txn).await?;
let callback_destination = CallbackDestination::try_from(&grant)?;
if !matches!(grant.stage, AuthorizationGrantStage::Pending) { if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
return Err(anyhow::anyhow!("authorization grant not pending").into()); return Err(anyhow!("authorization grant not pending").into());
} }
let current_consent = let current_consent =
@ -568,19 +488,14 @@ async fn step(
// Did they request an ID token? // Did they request an ID token?
if grant.response_type_id_token { if grant.response_type_id_token {
todo!("id tokens are not implemented yet"); return Err(RouteError::Anyhow(anyhow!(
"id tokens are not implemented yet"
)));
} }
let params = serde_json::to_value(&params).unwrap(); let params = serde_json::to_value(&params).unwrap();
back_to_client( callback_destination.go(templates, params).await?
&grant.redirect_uri,
grant.response_mode,
grant.state,
params,
templates,
)
.await?
} }
(true, Some(Authentication { created_at, .. })) if created_at > &grant.max_auth_time() => { (true, Some(Authentication { created_at, .. })) if created_at > &grant.max_auth_time() => {
let next: ConsentRequest = next.into(); let next: ConsentRequest = next.into();

View File

@ -0,0 +1,166 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::module_name_repetitions)]
use std::collections::HashMap;
use axum::response::{Html, IntoResponse, Redirect, Response};
use mas_data_model::{AuthorizationGrant, StorageBackend};
use mas_templates::{FormPostContext, Templates};
use oauth2_types::requests::ResponseMode;
use serde::Serialize;
use thiserror::Error;
use url::Url;
enum CallbackDestinationMode {
Query {
existing_params: HashMap<String, String>,
},
Fragment,
FormPost,
}
pub struct CallbackDestination {
mode: CallbackDestinationMode,
safe_redirect_uri: Url,
state: Option<String>,
}
#[derive(Debug, Error)]
pub enum InvalidRedirectUriError {
#[error("Redirect URI can't have a fragment")]
FragmentNotAllowed,
#[error("Existing query parameters are not valid")]
InvalidQueryParams(#[from] serde_urlencoded::de::Error),
}
#[derive(Debug, Error)]
pub enum CallbackDestinationError {
#[error("Failed to render the form_post template")]
FormPostRender(#[from] mas_templates::TemplateError),
#[error("Failed to serialize parameters query string")]
ParamsSerialization(#[from] serde_urlencoded::ser::Error),
}
impl<S: StorageBackend> TryFrom<&AuthorizationGrant<S>> for CallbackDestination {
type Error = InvalidRedirectUriError;
fn try_from(value: &AuthorizationGrant<S>) -> Result<Self, Self::Error> {
Self::try_new(
value.response_mode,
value.redirect_uri.clone(),
value.state.clone(),
)
}
}
impl CallbackDestination {
pub fn try_new(
mode: ResponseMode,
mut redirect_uri: Url,
state: Option<String>,
) -> Result<Self, InvalidRedirectUriError> {
if redirect_uri.fragment().is_some() {
return Err(InvalidRedirectUriError::FragmentNotAllowed);
}
let mode = match mode {
ResponseMode::Query => {
let existing_params = redirect_uri
.query()
.map(serde_urlencoded::from_str)
.transpose()?
.unwrap_or_default();
// Remove the query from the URL
redirect_uri.set_query(None);
CallbackDestinationMode::Query { existing_params }
}
ResponseMode::Fragment => CallbackDestinationMode::Fragment,
ResponseMode::FormPost => CallbackDestinationMode::FormPost,
};
Ok(Self {
mode,
safe_redirect_uri: redirect_uri,
state,
})
}
pub async fn go<T: Serialize>(
self,
templates: &Templates,
params: T,
) -> Result<Response, CallbackDestinationError> {
#[derive(Serialize)]
struct AllParams<'s, T> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
existing: Option<&'s HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<String>,
#[serde(flatten)]
params: T,
}
let mut redirect_uri = self.safe_redirect_uri;
let state = self.state;
match self.mode {
CallbackDestinationMode::Query { existing_params } => {
let merged = AllParams {
existing: Some(&existing_params),
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_query(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
CallbackDestinationMode::Fragment => {
let merged = AllParams {
existing: None,
state,
params,
};
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_fragment(Some(&new_qs));
Ok(Redirect::to(redirect_uri.as_str()).into_response())
}
CallbackDestinationMode::FormPost => {
let merged = AllParams {
existing: None,
state,
params,
};
let ctx = FormPostContext::new(redirect_uri, merged);
let rendered = templates.render_form_post(&ctx).await?;
Ok(Html(rendered).into_response())
}
}
}
}