1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Handle auth errors on the userinfo endpoint

This commit is contained in:
Quentin Gliech
2021-09-17 16:20:10 +02:00
parent 463184bbb1
commit 59df55c2f9
5 changed files with 140 additions and 23 deletions

View File

@ -22,6 +22,10 @@ pub struct WrappedError(anyhow::Error);
impl warp::reject::Reject for WrappedError {}
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
WrappedError(e.into())
}
pub trait WrapError<T> {
fn wrap_error(self) -> Result<T, Rejection>;
}

View File

@ -14,16 +14,60 @@
use chrono::Utc;
use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{Filter, Rejection};
use super::{database::with_connection, headers::with_typed_header};
use crate::{
errors::WrapError,
storage::oauth2::access_token::{lookup_access_token, OAuth2AccessTokenLookup},
tokens,
use thiserror::Error;
use warp::{
reject::{MissingHeader, Reject},
reply::{with_header, with_status},
Filter, Rejection, Reply,
};
use super::{
database::with_connection,
headers::{with_typed_header, InvalidTypedHeader},
};
use crate::{
errors::wrapped_error,
storage::oauth2::access_token::{
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
},
tokens::{self, TokenFormatError, TokenType},
};
#[derive(Debug, Error)]
enum AuthenticationError {
#[error("invalid token format")]
TokenFormat(#[from] TokenFormatError),
#[error("invalid token type {0:?}, expected an access token")]
WrongTokenType(TokenType),
#[error("unknown token")]
TokenNotFound(#[source] AccessTokenLookupError),
#[error("token is not active")]
TokenInactive,
#[error("token expired")]
TokenExpired,
#[error("missing authorization header")]
MissingAuthorizationHeader,
#[error("invalid authorization header")]
InvalidAuthorizationHeader,
}
impl Reject for AuthenticationError {}
/// Authenticate a request using an access token as a bearer authorization
///
/// # Rejections
///
/// This can reject with either a [`AuthenticationError`] or with a generic
/// wrapped sqlx error.
#[must_use]
pub fn with_authentication(
pool: &PgPool,
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static
@ -31,6 +75,8 @@ pub fn with_authentication(
with_connection(pool)
.and(with_typed_header())
.and_then(authenticate)
.recover(recover)
.unify()
}
async fn authenticate(
@ -38,18 +84,58 @@ async fn authenticate(
auth: Authorization<Bearer>,
) -> Result<OAuth2AccessTokenLookup, Rejection> {
let token = auth.0.token();
let token_type = tokens::check(token).wrap_error()?;
let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?;
if token_type != tokens::TokenType::AccessToken {
return Err(anyhow::anyhow!("wrong token type")).wrap_error();
return Err(AuthenticationError::WrongTokenType(token_type).into());
}
let token = lookup_access_token(&mut conn, token).await.wrap_error()?;
let exp = token.exp();
let token = lookup_access_token(&mut conn, token).await.map_err(|e| {
if e.not_found() {
// This error happens if the token was not found and should be recovered
warp::reject::custom(AuthenticationError::TokenNotFound(e))
} else {
// This is a generic database error that we want to propagate
warp::reject::custom(wrapped_error(e))
}
})?;
// Check it is active and did not expire
if !token.active || exp < Utc::now() {
return Err(anyhow::anyhow!("token expired")).wrap_error();
if !token.active {
return Err(AuthenticationError::TokenInactive.into());
}
if token.exp() < Utc::now() {
return Err(AuthenticationError::TokenExpired.into());
}
Ok(token)
}
/// Transform the rejections from the [`with_typed_header`] filter
async fn recover(rejection: Rejection) -> Result<OAuth2AccessTokenLookup, Rejection> {
if rejection.find::<MissingHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::MissingAuthorizationHeader,
));
}
if rejection.find::<InvalidTypedHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::InvalidAuthorizationHeader,
));
}
Err(rejection)
}
pub async fn recover_unauthorized(rejection: Rejection) -> Result<impl Reply, Rejection> {
if rejection.find::<AuthenticationError>().is_some() {
// TODO: have the issuer/realm here
let reply = "invalid token";
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
return Ok(reply);
}
Err(rejection)
}

View File

@ -13,10 +13,10 @@
// limitations under the License.
//
use headers::{Header, HeaderMapExt, HeaderValue};
use warp::{Filter, Rejection, Reply};
use crate::errors::WrapError;
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection, Reply};
/// Add a typed header to a reply
pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
WithTypedHeader { reply, header }
}
@ -38,6 +38,18 @@ where
}
}
#[derive(Debug, Error)]
#[error("could not decode header {1}")]
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
impl Reject for InvalidTypedHeader {}
/// Extract a typed header from the request
///
/// # Rejections
///
/// This can reject with either a [`warp::reject::MissingHeader`] or a
/// [`InvalidTypedHeader`].
pub fn with_typed_header<T: Header + Send + 'static>(
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
@ -45,6 +57,6 @@ pub fn with_typed_header<T: Header + Send + 'static>(
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
let mut it = std::iter::once(&header);
let decoded = T::decode(&mut it).wrap_error()?;
let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
Ok(decoded)
}

View File

@ -17,7 +17,8 @@ use sqlx::PgPool;
use warp::{Filter, Rejection, Reply};
use crate::{
config::OAuth2Config, filters::authenticate::with_authentication,
config::OAuth2Config,
filters::authenticate::{recover_unauthorized, with_authentication},
storage::oauth2::access_token::OAuth2AccessTokenLookup,
};
@ -34,6 +35,7 @@ pub(super) fn filter(
.and(warp::get().or(warp::post()).unify())
.and(with_authentication(pool))
.and_then(userinfo)
.recover(recover_unauthorized)
}
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {

View File

@ -18,6 +18,7 @@ use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
use thiserror::Error;
#[derive(FromRow, Serialize)]
pub struct OAuth2AccessToken {
@ -73,11 +74,22 @@ impl OAuth2AccessTokenLookup {
}
}
#[derive(Debug, Error)]
#[error("failed to lookup access token")]
pub struct AccessTokenLookupError(#[from] sqlx::Error);
impl AccessTokenLookupError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(self.0, sqlx::Error::RowNotFound)
}
}
pub async fn lookup_access_token(
executor: impl Executor<'_, Database = Postgres>,
token: &str,
) -> anyhow::Result<OAuth2AccessTokenLookup> {
sqlx::query_as!(
) -> Result<OAuth2AccessTokenLookup, AccessTokenLookupError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT
@ -99,8 +111,9 @@ pub async fn lookup_access_token(
token,
)
.fetch_one(executor)
.await
.context("could not introspect oauth2 access token")
.await?;
Ok(res)
}
pub async fn revoke_access_token(