You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-06 06:02:40 +03:00
Handle auth errors on the userinfo endpoint
This commit is contained in:
@@ -22,6 +22,10 @@ pub struct WrappedError(anyhow::Error);
|
|||||||
|
|
||||||
impl warp::reject::Reject for WrappedError {}
|
impl warp::reject::Reject for WrappedError {}
|
||||||
|
|
||||||
|
pub fn wrapped_error<T: Into<anyhow::Error>>(e: T) -> impl Reject {
|
||||||
|
WrappedError(e.into())
|
||||||
|
}
|
||||||
|
|
||||||
pub trait WrapError<T> {
|
pub trait WrapError<T> {
|
||||||
fn wrap_error(self) -> Result<T, Rejection>;
|
fn wrap_error(self) -> Result<T, Rejection>;
|
||||||
}
|
}
|
||||||
|
@@ -14,16 +14,60 @@
|
|||||||
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use headers::{authorization::Bearer, Authorization};
|
use headers::{authorization::Bearer, Authorization};
|
||||||
|
use hyper::StatusCode;
|
||||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||||
use warp::{Filter, Rejection};
|
use thiserror::Error;
|
||||||
|
use warp::{
|
||||||
use super::{database::with_connection, headers::with_typed_header};
|
reject::{MissingHeader, Reject},
|
||||||
use crate::{
|
reply::{with_header, with_status},
|
||||||
errors::WrapError,
|
Filter, Rejection, Reply,
|
||||||
storage::oauth2::access_token::{lookup_access_token, OAuth2AccessTokenLookup},
|
|
||||||
tokens,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
database::with_connection,
|
||||||
|
headers::{with_typed_header, InvalidTypedHeader},
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
errors::wrapped_error,
|
||||||
|
storage::oauth2::access_token::{
|
||||||
|
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
|
||||||
|
},
|
||||||
|
tokens::{self, TokenFormatError, TokenType},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum AuthenticationError {
|
||||||
|
#[error("invalid token format")]
|
||||||
|
TokenFormat(#[from] TokenFormatError),
|
||||||
|
|
||||||
|
#[error("invalid token type {0:?}, expected an access token")]
|
||||||
|
WrongTokenType(TokenType),
|
||||||
|
|
||||||
|
#[error("unknown token")]
|
||||||
|
TokenNotFound(#[source] AccessTokenLookupError),
|
||||||
|
|
||||||
|
#[error("token is not active")]
|
||||||
|
TokenInactive,
|
||||||
|
|
||||||
|
#[error("token expired")]
|
||||||
|
TokenExpired,
|
||||||
|
|
||||||
|
#[error("missing authorization header")]
|
||||||
|
MissingAuthorizationHeader,
|
||||||
|
|
||||||
|
#[error("invalid authorization header")]
|
||||||
|
InvalidAuthorizationHeader,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Reject for AuthenticationError {}
|
||||||
|
|
||||||
|
/// Authenticate a request using an access token as a bearer authorization
|
||||||
|
///
|
||||||
|
/// # Rejections
|
||||||
|
///
|
||||||
|
/// This can reject with either a [`AuthenticationError`] or with a generic
|
||||||
|
/// wrapped sqlx error.
|
||||||
|
#[must_use]
|
||||||
pub fn with_authentication(
|
pub fn with_authentication(
|
||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static
|
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||||
@@ -31,6 +75,8 @@ pub fn with_authentication(
|
|||||||
with_connection(pool)
|
with_connection(pool)
|
||||||
.and(with_typed_header())
|
.and(with_typed_header())
|
||||||
.and_then(authenticate)
|
.and_then(authenticate)
|
||||||
|
.recover(recover)
|
||||||
|
.unify()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authenticate(
|
async fn authenticate(
|
||||||
@@ -38,18 +84,58 @@ async fn authenticate(
|
|||||||
auth: Authorization<Bearer>,
|
auth: Authorization<Bearer>,
|
||||||
) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
||||||
let token = auth.0.token();
|
let token = auth.0.token();
|
||||||
let token_type = tokens::check(token).wrap_error()?;
|
let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?;
|
||||||
|
|
||||||
if token_type != tokens::TokenType::AccessToken {
|
if token_type != tokens::TokenType::AccessToken {
|
||||||
return Err(anyhow::anyhow!("wrong token type")).wrap_error();
|
return Err(AuthenticationError::WrongTokenType(token_type).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let token = lookup_access_token(&mut conn, token).await.wrap_error()?;
|
let token = lookup_access_token(&mut conn, token).await.map_err(|e| {
|
||||||
let exp = token.exp();
|
if e.not_found() {
|
||||||
|
// This error happens if the token was not found and should be recovered
|
||||||
|
warp::reject::custom(AuthenticationError::TokenNotFound(e))
|
||||||
|
} else {
|
||||||
|
// This is a generic database error that we want to propagate
|
||||||
|
warp::reject::custom(wrapped_error(e))
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
// Check it is active and did not expire
|
if !token.active {
|
||||||
if !token.active || exp < Utc::now() {
|
return Err(AuthenticationError::TokenInactive.into());
|
||||||
return Err(anyhow::anyhow!("token expired")).wrap_error();
|
}
|
||||||
|
|
||||||
|
if token.exp() < Utc::now() {
|
||||||
|
return Err(AuthenticationError::TokenExpired.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(token)
|
Ok(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Transform the rejections from the [`with_typed_header`] filter
|
||||||
|
async fn recover(rejection: Rejection) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
||||||
|
if rejection.find::<MissingHeader>().is_some() {
|
||||||
|
return Err(warp::reject::custom(
|
||||||
|
AuthenticationError::MissingAuthorizationHeader,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if rejection.find::<InvalidTypedHeader>().is_some() {
|
||||||
|
return Err(warp::reject::custom(
|
||||||
|
AuthenticationError::InvalidAuthorizationHeader,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(rejection)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn recover_unauthorized(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
||||||
|
if rejection.find::<AuthenticationError>().is_some() {
|
||||||
|
// TODO: have the issuer/realm here
|
||||||
|
let reply = "invalid token";
|
||||||
|
let reply = with_status(reply, StatusCode::UNAUTHORIZED);
|
||||||
|
let reply = with_header(reply, "WWW-Authenticate", r#"Bearer error="invalid_token""#);
|
||||||
|
return Ok(reply);
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(rejection)
|
||||||
|
}
|
||||||
|
@@ -13,10 +13,10 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
//
|
//
|
||||||
use headers::{Header, HeaderMapExt, HeaderValue};
|
use headers::{Header, HeaderMapExt, HeaderValue};
|
||||||
use warp::{Filter, Rejection, Reply};
|
use thiserror::Error;
|
||||||
|
use warp::{reject::Reject, Filter, Rejection, Reply};
|
||||||
use crate::errors::WrapError;
|
|
||||||
|
|
||||||
|
/// Add a typed header to a reply
|
||||||
pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
|
pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
|
||||||
WithTypedHeader { reply, header }
|
WithTypedHeader { reply, header }
|
||||||
}
|
}
|
||||||
@@ -38,6 +38,18 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error("could not decode header {1}")]
|
||||||
|
pub struct InvalidTypedHeader(#[source] headers::Error, &'static str);
|
||||||
|
|
||||||
|
impl Reject for InvalidTypedHeader {}
|
||||||
|
|
||||||
|
/// Extract a typed header from the request
|
||||||
|
///
|
||||||
|
/// # Rejections
|
||||||
|
///
|
||||||
|
/// This can reject with either a [`warp::reject::MissingHeader`] or a
|
||||||
|
/// [`InvalidTypedHeader`].
|
||||||
pub fn with_typed_header<T: Header + Send + 'static>(
|
pub fn with_typed_header<T: Header + Send + 'static>(
|
||||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
|
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||||
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
|
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
|
||||||
@@ -45,6 +57,6 @@ pub fn with_typed_header<T: Header + Send + 'static>(
|
|||||||
|
|
||||||
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
|
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
|
||||||
let mut it = std::iter::once(&header);
|
let mut it = std::iter::once(&header);
|
||||||
let decoded = T::decode(&mut it).wrap_error()?;
|
let decoded = T::decode(&mut it).map_err(|e| InvalidTypedHeader(e, T::name().as_str()))?;
|
||||||
Ok(decoded)
|
Ok(decoded)
|
||||||
}
|
}
|
||||||
|
@@ -17,7 +17,8 @@ use sqlx::PgPool;
|
|||||||
use warp::{Filter, Rejection, Reply};
|
use warp::{Filter, Rejection, Reply};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::OAuth2Config, filters::authenticate::with_authentication,
|
config::OAuth2Config,
|
||||||
|
filters::authenticate::{recover_unauthorized, with_authentication},
|
||||||
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -34,6 +35,7 @@ pub(super) fn filter(
|
|||||||
.and(warp::get().or(warp::post()).unify())
|
.and(warp::get().or(warp::post()).unify())
|
||||||
.and(with_authentication(pool))
|
.and(with_authentication(pool))
|
||||||
.and_then(userinfo)
|
.and_then(userinfo)
|
||||||
|
.recover(recover_unauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
|
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
|
||||||
|
@@ -18,6 +18,7 @@ use anyhow::Context;
|
|||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use sqlx::{Executor, FromRow, Postgres};
|
use sqlx::{Executor, FromRow, Postgres};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(FromRow, Serialize)]
|
#[derive(FromRow, Serialize)]
|
||||||
pub struct OAuth2AccessToken {
|
pub struct OAuth2AccessToken {
|
||||||
@@ -73,11 +74,22 @@ impl OAuth2AccessTokenLookup {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error("failed to lookup access token")]
|
||||||
|
pub struct AccessTokenLookupError(#[from] sqlx::Error);
|
||||||
|
|
||||||
|
impl AccessTokenLookupError {
|
||||||
|
#[must_use]
|
||||||
|
pub fn not_found(&self) -> bool {
|
||||||
|
matches!(self.0, sqlx::Error::RowNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn lookup_access_token(
|
pub async fn lookup_access_token(
|
||||||
executor: impl Executor<'_, Database = Postgres>,
|
executor: impl Executor<'_, Database = Postgres>,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> anyhow::Result<OAuth2AccessTokenLookup> {
|
) -> Result<OAuth2AccessTokenLookup, AccessTokenLookupError> {
|
||||||
sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
OAuth2AccessTokenLookup,
|
OAuth2AccessTokenLookup,
|
||||||
r#"
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
@@ -99,8 +111,9 @@ pub async fn lookup_access_token(
|
|||||||
token,
|
token,
|
||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.await
|
.await?;
|
||||||
.context("could not introspect oauth2 access token")
|
|
||||||
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn revoke_access_token(
|
pub async fn revoke_access_token(
|
||||||
|
Reference in New Issue
Block a user