1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Handle auth errors on the userinfo endpoint

This commit is contained in:
Quentin Gliech
2021-09-17 16:20:10 +02:00
parent 463184bbb1
commit 59df55c2f9
5 changed files with 140 additions and 23 deletions

View File

@@ -22,6 +22,10 @@ pub struct WrappedError(anyhow::Error);
impl warp::reject::Reject for WrappedError {} impl warp::reject::Reject for WrappedError {}
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
WrappedError(e.into())
}
pub trait WrapError<T> { pub trait WrapError<T> {
fn wrap_error(self) -> Result<T, Rejection>; fn wrap_error(self) -> Result<T, Rejection>;
} }

View File

@@ -14,16 +14,60 @@
use chrono::Utc; use chrono::Utc;
use headers::{authorization::Bearer, Authorization}; use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use sqlx::{pool::PoolConnection, PgPool, Postgres}; use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{Filter, Rejection}; use thiserror::Error;
use warp::{
use super::{database::with_connection, headers::with_typed_header}; reject::{MissingHeader, Reject},
use crate::{ reply::{with_header, with_status},
errors::WrapError, Filter, Rejection, Reply,
storage::oauth2::access_token::{lookup_access_token, OAuth2AccessTokenLookup},
tokens,
}; };
use super::{
database::with_connection,
headers::{with_typed_header, InvalidTypedHeader},
};
use crate::{
errors::wrapped_error,
storage::oauth2::access_token::{
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
},
tokens::{self, TokenFormatError, TokenType},
};
#[derive(Debug, Error)]
enum AuthenticationError {
#[error("invalid token format")]
TokenFormat(#[from] TokenFormatError),
#[error("invalid token type {0:?}, expected an access token")]
WrongTokenType(TokenType),
#[error("unknown token")]
TokenNotFound(#[source] AccessTokenLookupError),
#[error("token is not active")]
TokenInactive,
#[error("token expired")]
TokenExpired,
#[error("missing authorization header")]
MissingAuthorizationHeader,
#[error("invalid authorization header")]
InvalidAuthorizationHeader,
}
impl Reject for AuthenticationError {}
/// Authenticate a request using an access token as a bearer authorization
///
/// # Rejections
///
/// This can reject with either a [`AuthenticationError`] or with a generic
/// wrapped sqlx error.
#[must_use]
pub fn with_authentication( pub fn with_authentication(
pool: &PgPool, pool: &PgPool,
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static ) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static
@@ -31,6 +75,8 @@ pub fn with_authentication(
with_connection(pool) with_connection(pool)
.and(with_typed_header()) .and(with_typed_header())
.and_then(authenticate) .and_then(authenticate)
.recover(recover)
.unify()
} }
async fn authenticate( async fn authenticate(
@@ -38,18 +84,58 @@ async fn authenticate(
auth: Authorization<Bearer>, auth: Authorization<Bearer>,
) -> Result<OAuth2AccessTokenLookup, Rejection> { ) -> Result<OAuth2AccessTokenLookup, Rejection> {
let token = auth.0.token(); let token = auth.0.token();
let token_type = tokens::check(token).wrap_error()?; let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?;
if token_type != tokens::TokenType::AccessToken { if token_type != tokens::TokenType::AccessToken {
return Err(anyhow::anyhow!("wrong token type")).wrap_error(); return Err(AuthenticationError::WrongTokenType(token_type).into());
} }
let token = lookup_access_token(&mut conn, token).await.wrap_error()?; let token = lookup_access_token(&mut conn, token).await.map_err(|e| {
let exp = token.exp(); if e.not_found() {
// This error happens if the token was not found and should be recovered
warp::reject::custom(AuthenticationError::TokenNotFound(e))
} else {
// This is a generic database error that we want to propagate
warp::reject::custom(wrapped_error(e))
}
})?;
// Check it is active and did not expire if !token.active {
if !token.active || exp < Utc::now() { return Err(AuthenticationError::TokenInactive.into());
return Err(anyhow::anyhow!("token expired")).wrap_error(); }
if token.exp() < Utc::now() {
return Err(AuthenticationError::TokenExpired.into());
} }
Ok(token) Ok(token)
} }
/// Transform the rejections from the [`with_typed_header`] filter
async fn recover(rejection: Rejection) -> Result<OAuth2AccessTokenLookup, Rejection> {
if rejection.find::<MissingHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::MissingAuthorizationHeader,
));
}
if rejection.find::<InvalidTypedHeader>().is_some() {
return Err(warp::reject::custom(
AuthenticationError::InvalidAuthorizationHeader,
));
}
Err(rejection)
}
pub async fn recover_unauthorized(rejection: Rejection) -> Result<impl Reply, Rejection> {
if rejection.find::<AuthenticationError>().is_some() {
// TODO: have the issuer/realm here
let reply = "invalid token";
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
return Ok(reply);
}
Err(rejection)
}

View File

@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
// //
use headers::{Header, HeaderMapExt, HeaderValue}; use headers::{Header, HeaderMapExt, HeaderValue};
use warp::{Filter, Rejection, Reply}; use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection, Reply};
use crate::errors::WrapError;
/// Add a typed header to a reply
pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> { pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
WithTypedHeader { reply, header } WithTypedHeader { reply, header }
} }
@@ -38,6 +38,18 @@ where
} }
} }
#[derive(Debug, Error)]
#[error("could not decode header {1}")]
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
impl Reject for InvalidTypedHeader {}
/// Extract a typed header from the request
///
/// # Rejections
///
/// This can reject with either a [`warp::reject::MissingHeader`] or a
/// [`InvalidTypedHeader`].
pub fn with_typed_header<T: Header + Send + 'static>( pub fn with_typed_header<T: Header + Send + 'static>(
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static { ) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::header::value(T::name().as_str()).and_then(decode_typed_header) warp::header::value(T::name().as_str()).and_then(decode_typed_header)
@@ -45,6 +57,6 @@ pub fn with_typed_header<T: Header + Send + 'static>(
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> { async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
let mut it = std::iter::once(&header); let mut it = std::iter::once(&header);
let decoded = T::decode(&mut it).wrap_error()?; let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
Ok(decoded) Ok(decoded)
} }

View File

@@ -17,7 +17,8 @@ use sqlx::PgPool;
use warp::{Filter, Rejection, Reply}; use warp::{Filter, Rejection, Reply};
use crate::{ use crate::{
config::OAuth2Config, filters::authenticate::with_authentication, config::OAuth2Config,
filters::authenticate::{recover_unauthorized, with_authentication},
storage::oauth2::access_token::OAuth2AccessTokenLookup, storage::oauth2::access_token::OAuth2AccessTokenLookup,
}; };
@@ -34,6 +35,7 @@ pub(super) fn filter(
.and(warp::get().or(warp::post()).unify()) .and(warp::get().or(warp::post()).unify())
.and(with_authentication(pool)) .and(with_authentication(pool))
.and_then(userinfo) .and_then(userinfo)
.recover(recover_unauthorized)
} }
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> { async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {

View File

@@ -18,6 +18,7 @@ use anyhow::Context;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use serde::Serialize; use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres}; use sqlx::{Executor, FromRow, Postgres};
use thiserror::Error;
#[derive(FromRow, Serialize)] #[derive(FromRow, Serialize)]
pub struct OAuth2AccessToken { pub struct OAuth2AccessToken {
@@ -73,11 +74,22 @@ impl OAuth2AccessTokenLookup {
} }
} }
#[derive(Debug, Error)]
#[error("failed to lookup access token")]
pub struct AccessTokenLookupError(#[from] sqlx::Error);
impl AccessTokenLookupError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(self.0, sqlx::Error::RowNotFound)
}
}
pub async fn lookup_access_token( pub async fn lookup_access_token(
executor: impl Executor<'_, Database = Postgres>, executor: impl Executor<'_, Database = Postgres>,
token: &str, token: &str,
) -> anyhow::Result<OAuth2AccessTokenLookup> { ) -> Result<OAuth2AccessTokenLookup, AccessTokenLookupError> {
sqlx::query_as!( let res = sqlx::query_as!(
OAuth2AccessTokenLookup, OAuth2AccessTokenLookup,
r#" r#"
SELECT SELECT
@@ -99,8 +111,9 @@ pub async fn lookup_access_token(
token, token,
) )
.fetch_one(executor) .fetch_one(executor)
.await .await?;
.context("could not introspect oauth2 access token")
Ok(res)
} }
pub async fn revoke_access_token( pub async fn revoke_access_token(