You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Handle auth errors on the userinfo endpoint
This commit is contained in:
@ -22,6 +22,10 @@ pub struct WrappedError(anyhow::Error);
|
||||
|
||||
impl warp::reject::Reject for WrappedError {}
|
||||
|
||||
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
|
||||
WrappedError(e.into())
|
||||
}
|
||||
|
||||
pub trait WrapError<T> {
|
||||
fn wrap_error(self) -> Result<T, Rejection>;
|
||||
}
|
||||
|
@ -14,16 +14,60 @@
|
||||
|
||||
use chrono::Utc;
|
||||
use headers::{authorization::Bearer, Authorization};
|
||||
use hyper::StatusCode;
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
use super::{database::with_connection, headers::with_typed_header};
|
||||
use crate::{
|
||||
errors::WrapError,
|
||||
storage::oauth2::access_token::{lookup_access_token, OAuth2AccessTokenLookup},
|
||||
tokens,
|
||||
use thiserror::Error;
|
||||
use warp::{
|
||||
reject::{MissingHeader, Reject},
|
||||
reply::{with_header, with_status},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use super::{
|
||||
database::with_connection,
|
||||
headers::{with_typed_header, InvalidTypedHeader},
|
||||
};
|
||||
use crate::{
|
||||
errors::wrapped_error,
|
||||
storage::oauth2::access_token::{
|
||||
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
|
||||
},
|
||||
tokens::{self, TokenFormatError, TokenType},
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum AuthenticationError {
|
||||
#[error("invalid token format")]
|
||||
TokenFormat(#[from] TokenFormatError),
|
||||
|
||||
#[error("invalid token type {0:?}, expected an access token")]
|
||||
WrongTokenType(TokenType),
|
||||
|
||||
#[error("unknown token")]
|
||||
TokenNotFound(#[source] AccessTokenLookupError),
|
||||
|
||||
#[error("token is not active")]
|
||||
TokenInactive,
|
||||
|
||||
#[error("token expired")]
|
||||
TokenExpired,
|
||||
|
||||
#[error("missing authorization header")]
|
||||
MissingAuthorizationHeader,
|
||||
|
||||
#[error("invalid authorization header")]
|
||||
InvalidAuthorizationHeader,
|
||||
}
|
||||
|
||||
impl Reject for AuthenticationError {}
|
||||
|
||||
/// Authenticate a request using an access token as a bearer authorization
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`AuthenticationError`] or with a generic
|
||||
/// wrapped sqlx error.
|
||||
#[must_use]
|
||||
pub fn with_authentication(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
@ -31,6 +75,8 @@ pub fn with_authentication(
|
||||
with_connection(pool)
|
||||
.and(with_typed_header())
|
||||
.and_then(authenticate)
|
||||
.recover(recover)
|
||||
.unify()
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
@ -38,18 +84,58 @@ async fn authenticate(
|
||||
auth: Authorization<Bearer>,
|
||||
) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
||||
let token = auth.0.token();
|
||||
let token_type = tokens::check(token).wrap_error()?;
|
||||
let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?;
|
||||
|
||||
if token_type != tokens::TokenType::AccessToken {
|
||||
return Err(anyhow::anyhow!("wrong token type")).wrap_error();
|
||||
return Err(AuthenticationError::WrongTokenType(token_type).into());
|
||||
}
|
||||
|
||||
let token = lookup_access_token(&mut conn, token).await.wrap_error()?;
|
||||
let exp = token.exp();
|
||||
let token = lookup_access_token(&mut conn, token).await.map_err(|e| {
|
||||
if e.not_found() {
|
||||
// This error happens if the token was not found and should be recovered
|
||||
warp::reject::custom(AuthenticationError::TokenNotFound(e))
|
||||
} else {
|
||||
// This is a generic database error that we want to propagate
|
||||
warp::reject::custom(wrapped_error(e))
|
||||
}
|
||||
})?;
|
||||
|
||||
// Check it is active and did not expire
|
||||
if !token.active || exp < Utc::now() {
|
||||
return Err(anyhow::anyhow!("token expired")).wrap_error();
|
||||
if !token.active {
|
||||
return Err(AuthenticationError::TokenInactive.into());
|
||||
}
|
||||
|
||||
if token.exp() < Utc::now() {
|
||||
return Err(AuthenticationError::TokenExpired.into());
|
||||
}
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Transform the rejections from the [`with_typed_header`] filter
|
||||
async fn recover(rejection: Rejection) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
||||
if rejection.find::<MissingHeader>().is_some() {
|
||||
return Err(warp::reject::custom(
|
||||
AuthenticationError::MissingAuthorizationHeader,
|
||||
));
|
||||
}
|
||||
|
||||
if rejection.find::<InvalidTypedHeader>().is_some() {
|
||||
return Err(warp::reject::custom(
|
||||
AuthenticationError::InvalidAuthorizationHeader,
|
||||
));
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
||||
|
||||
pub async fn recover_unauthorized(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
||||
if rejection.find::<AuthenticationError>().is_some() {
|
||||
// TODO: have the issuer/realm here
|
||||
let reply = "invalid token";
|
||||
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
|
||||
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
|
||||
return Ok(reply);
|
||||
}
|
||||
|
||||
Err(rejection)
|
||||
}
|
||||
|
@ -13,10 +13,10 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
use headers::{Header, HeaderMapExt, HeaderValue};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::errors::WrapError;
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection, Reply};
|
||||
|
||||
/// Add a typed header to a reply
|
||||
pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
|
||||
WithTypedHeader { reply, header }
|
||||
}
|
||||
@ -38,6 +38,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("could not decode header {1}")]
|
||||
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
|
||||
|
||||
impl Reject for InvalidTypedHeader {}
|
||||
|
||||
/// Extract a typed header from the request
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`warp::reject::MissingHeader`] or a
|
||||
/// [`InvalidTypedHeader`].
|
||||
pub fn with_typed_header<T: Header + Send + 'static>(
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
|
||||
@ -45,6 +57,6 @@ pub fn with_typed_header<T: Header + Send + 'static>(
|
||||
|
||||
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
|
||||
let mut it = std::iter::once(&header);
|
||||
let decoded = T::decode(&mut it).wrap_error()?;
|
||||
let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
|
||||
Ok(decoded)
|
||||
}
|
||||
|
@ -17,7 +17,8 @@ use sqlx::PgPool;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::OAuth2Config, filters::authenticate::with_authentication,
|
||||
config::OAuth2Config,
|
||||
filters::authenticate::{recover_unauthorized, with_authentication},
|
||||
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
||||
};
|
||||
|
||||
@ -34,6 +35,7 @@ pub(super) fn filter(
|
||||
.and(warp::get().or(warp::post()).unify())
|
||||
.and(with_authentication(pool))
|
||||
.and_then(userinfo)
|
||||
.recover(recover_unauthorized)
|
||||
}
|
||||
|
||||
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
|
||||
|
@ -18,6 +18,7 @@ use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::Serialize;
|
||||
use sqlx::{Executor, FromRow, Postgres};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2AccessToken {
|
||||
@ -73,11 +74,22 @@ impl OAuth2AccessTokenLookup {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("failed to lookup access token")]
|
||||
pub struct AccessTokenLookupError(#[from] sqlx::Error);
|
||||
|
||||
impl AccessTokenLookupError {
|
||||
#[must_use]
|
||||
pub fn not_found(&self) -> bool {
|
||||
matches!(self.0, sqlx::Error::RowNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn lookup_access_token(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
token: &str,
|
||||
) -> anyhow::Result<OAuth2AccessTokenLookup> {
|
||||
sqlx::query_as!(
|
||||
) -> Result<OAuth2AccessTokenLookup, AccessTokenLookupError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT
|
||||
@ -99,8 +111,9 @@ pub async fn lookup_access_token(
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not introspect oauth2 access token")
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub async fn revoke_access_token(
|
||||
|
Reference in New Issue
Block a user