1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2026-01-03 17:02:28 +03:00

more error handling in token endpoint

Also adds some OP metadatas to help with conformance
This commit is contained in:
Quentin Gliech
2021-09-11 00:53:21 +02:00
parent f8c51f67e8
commit bb11ab7af8
6 changed files with 103 additions and 14 deletions

View File

@@ -139,12 +139,12 @@ mod tests {
config.clients.push(OAuth2ClientConfig {
client_id: "public".to_string(),
client_secret: None,
redirect_uris: None,
redirect_uris: Vec::new(),
});
config.clients.push(OAuth2ClientConfig {
client_id: "confidential".to_string(),
client_secret: Some("secret".to_string()),
redirect_uris: None,
redirect_uris: Vec::new(),
});
config
}

View File

@@ -14,7 +14,10 @@
use std::collections::HashSet;
use oauth2_types::{oidc::Metadata, requests::ResponseMode};
use oauth2_types::{
oidc::Metadata,
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
};
use warp::{Filter, Rejection, Reply};
use crate::config::OAuth2Config;
@@ -44,6 +47,21 @@ pub(super) fn filter(
s
});
let grant_types_supported = Some({
let mut s = HashSet::new();
s.insert(GrantType::AuthorizationCode);
s.insert(GrantType::RefreshToken);
s
});
let token_endpoint_auth_methods_supported = Some({
let mut s = HashSet::new();
s.insert(ClientAuthenticationMethod::ClientSecretBasic);
s.insert(ClientAuthenticationMethod::ClientSecretPost);
s.insert(ClientAuthenticationMethod::None);
s
});
let metadata = Metadata {
authorization_endpoint: base.join("oauth2/authorize").ok(),
token_endpoint: base.join("oauth2/token").ok(),
@@ -55,7 +73,8 @@ pub(super) fn filter(
scopes_supported: None,
response_types_supported,
response_modes_supported,
grant_types_supported: None,
grant_types_supported,
token_endpoint_auth_methods_supported,
code_challenge_methods_supported: None,
};

View File

@@ -16,9 +16,10 @@ use anyhow::Context;
use chrono::Duration;
use data_encoding::BASE64URL_NOPAD;
use headers::{CacheControl, Pragma};
use hyper::StatusCode;
use jwt_compact::{Claims, Header, TimeOptions};
use oauth2_types::{
errors::{InvalidGrant, OAuth2Error},
errors::{InvalidGrant, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
},
@@ -29,7 +30,11 @@ use serde_with::skip_serializing_none;
use sha2::{Digest, Sha256};
use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres};
use url::Url;
use warp::{Filter, Rejection, Reply};
use warp::{
reject::Reject,
reply::{json, with_status},
Filter, Rejection, Reply,
};
use crate::{
config::{KeySet, OAuth2ClientConfig, OAuth2Config},
@@ -62,6 +67,23 @@ struct CustomClaims {
c_hash: String,
}
#[derive(Debug)]
struct Error {
json: serde_json::Value,
status: StatusCode,
}
impl Reject for Error {}
fn error<T, E>(e: E) -> Result<T, Rejection>
where
E: OAuth2ErrorCode + 'static,
{
let status = e.status();
let json = serde_json::to_value(e.into_response()).wrap_error()?;
Err(Error { json, status }.into())
}
pub fn filter(
pool: &PgPool,
oauth2_config: &OAuth2Config,
@@ -74,6 +96,15 @@ pub fn filter(
.and(warp::any().map(move || issuer.clone()))
.and(with_connection(pool))
.and_then(token)
.recover(recover)
}
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
if let Some(Error { json, status }) = rejection.find::<Error>() {
Ok(with_status(warp::reply::json(json), *status))
} else {
Err(rejection)
}
}
async fn token(
@@ -87,15 +118,15 @@ async fn token(
let reply = match req {
AccessTokenRequest::AuthorizationCode(grant) => {
let reply = authorization_code_grant(&grant, &client, &keys, issuer, &mut conn).await?;
warp::reply::json(&reply)
json(&reply)
}
AccessTokenRequest::RefreshToken(grant) => {
let reply = refresh_token_grant(&grant, &client, &mut conn).await?;
warp::reply::json(&reply)
json(&reply)
}
_ => {
let reply = InvalidGrant.into_response();
warp::reply::json(&reply)
json(&reply)
}
};
@@ -125,7 +156,7 @@ async fn authorization_code_grant(
let mut txn = conn.begin().await.wrap_error()?;
let code = lookup_code(&mut txn, &grant.code).await.wrap_error()?;
if client.client_id != code.client_id {
return Err(anyhow::anyhow!("invalid client")).wrap_error();
return error(UnauthorizedClient);
}
// TODO: verify PKCE
@@ -194,7 +225,8 @@ async fn refresh_token_grant(
.wrap_error()?;
if client.client_id != refresh_token_lookup.client_id {
return Err(anyhow::anyhow!("invalid client")).wrap_error();
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
return error(InvalidGrant);
}
let ttl = Duration::minutes(5);

View File

@@ -49,11 +49,24 @@ pub trait OAuth2Error: std::fmt::Debug + Send + Sync {
}
}
trait OAuth2ErrorCode: OAuth2Error {
pub trait OAuth2ErrorCode: OAuth2Error + 'static {
/// The HTTP status code that must be returned by this error
fn status(&self) -> StatusCode;
}
impl OAuth2Error for &Box<dyn OAuth2ErrorCode> {
fn error(&self) -> &'static str {
self.as_ref().error()
}
fn description(&self) -> Option<String> {
self.as_ref().description()
}
fn uri(&self) -> Option<Url> {
self.as_ref().uri()
}
}
#[derive(Debug)]
pub struct ErrorResponse(Box<dyn OAuth2Error>);
@@ -112,7 +125,7 @@ impl Serialize for ErrorResponse {
macro_rules! oauth2_error_def {
($name:ident) => {
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct $name;
};
}

View File

@@ -20,7 +20,7 @@ use url::Url;
use crate::{
pkce::CodeChallengeMethod,
requests::{GrantType, ResponseMode},
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
};
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
@@ -61,6 +61,10 @@ pub struct Metadata {
/// this authorization server supports.
pub grant_types_supported: Option<HashSet<GrantType>>,
/// JSON array containing a list of client authentication methods supported
/// by this token endpoint.
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
/// PKCE code challenge methods supported by this authorization server
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,

View File

@@ -70,6 +70,27 @@ pub enum ResponseMode {
FormPost,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthenticationMethod {
None,
ClientSecretPost,
ClientSecretBasic,
}
#[derive(
Debug,
Hash,