You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2026-01-03 17:02:28 +03:00
more error handling in token endpoint
Also adds some OP metadatas to help with conformance
This commit is contained in:
@@ -139,12 +139,12 @@ mod tests {
|
||||
config.clients.push(OAuth2ClientConfig {
|
||||
client_id: "public".to_string(),
|
||||
client_secret: None,
|
||||
redirect_uris: None,
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.clients.push(OAuth2ClientConfig {
|
||||
client_id: "confidential".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
redirect_uris: None,
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config
|
||||
}
|
||||
|
||||
@@ -14,7 +14,10 @@
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use oauth2_types::{oidc::Metadata, requests::ResponseMode};
|
||||
use oauth2_types::{
|
||||
oidc::Metadata,
|
||||
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
||||
};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::config::OAuth2Config;
|
||||
@@ -44,6 +47,21 @@ pub(super) fn filter(
|
||||
s
|
||||
});
|
||||
|
||||
let grant_types_supported = Some({
|
||||
let mut s = HashSet::new();
|
||||
s.insert(GrantType::AuthorizationCode);
|
||||
s.insert(GrantType::RefreshToken);
|
||||
s
|
||||
});
|
||||
|
||||
let token_endpoint_auth_methods_supported = Some({
|
||||
let mut s = HashSet::new();
|
||||
s.insert(ClientAuthenticationMethod::ClientSecretBasic);
|
||||
s.insert(ClientAuthenticationMethod::ClientSecretPost);
|
||||
s.insert(ClientAuthenticationMethod::None);
|
||||
s
|
||||
});
|
||||
|
||||
let metadata = Metadata {
|
||||
authorization_endpoint: base.join("oauth2/authorize").ok(),
|
||||
token_endpoint: base.join("oauth2/token").ok(),
|
||||
@@ -55,7 +73,8 @@ pub(super) fn filter(
|
||||
scopes_supported: None,
|
||||
response_types_supported,
|
||||
response_modes_supported,
|
||||
grant_types_supported: None,
|
||||
grant_types_supported,
|
||||
token_endpoint_auth_methods_supported,
|
||||
code_challenge_methods_supported: None,
|
||||
};
|
||||
|
||||
|
||||
@@ -16,9 +16,10 @@ use anyhow::Context;
|
||||
use chrono::Duration;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{CacheControl, Pragma};
|
||||
use hyper::StatusCode;
|
||||
use jwt_compact::{Claims, Header, TimeOptions};
|
||||
use oauth2_types::{
|
||||
errors::{InvalidGrant, OAuth2Error},
|
||||
errors::{InvalidGrant, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
|
||||
requests::{
|
||||
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
|
||||
},
|
||||
@@ -29,7 +30,11 @@ use serde_with::skip_serializing_none;
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres};
|
||||
use url::Url;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
use warp::{
|
||||
reject::Reject,
|
||||
reply::{json, with_status},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::{KeySet, OAuth2ClientConfig, OAuth2Config},
|
||||
@@ -62,6 +67,23 @@ struct CustomClaims {
|
||||
c_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Error {
|
||||
json: serde_json::Value,
|
||||
status: StatusCode,
|
||||
}
|
||||
|
||||
impl Reject for Error {}
|
||||
|
||||
fn error<T, E>(e: E) -> Result<T, Rejection>
|
||||
where
|
||||
E: OAuth2ErrorCode + 'static,
|
||||
{
|
||||
let status = e.status();
|
||||
let json = serde_json::to_value(e.into_response()).wrap_error()?;
|
||||
Err(Error { json, status }.into())
|
||||
}
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
oauth2_config: &OAuth2Config,
|
||||
@@ -74,6 +96,15 @@ pub fn filter(
|
||||
.and(warp::any().map(move || issuer.clone()))
|
||||
.and(with_connection(pool))
|
||||
.and_then(token)
|
||||
.recover(recover)
|
||||
}
|
||||
|
||||
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
||||
if let Some(Error { json, status }) = rejection.find::<Error>() {
|
||||
Ok(with_status(warp::reply::json(json), *status))
|
||||
} else {
|
||||
Err(rejection)
|
||||
}
|
||||
}
|
||||
|
||||
async fn token(
|
||||
@@ -87,15 +118,15 @@ async fn token(
|
||||
let reply = match req {
|
||||
AccessTokenRequest::AuthorizationCode(grant) => {
|
||||
let reply = authorization_code_grant(&grant, &client, &keys, issuer, &mut conn).await?;
|
||||
warp::reply::json(&reply)
|
||||
json(&reply)
|
||||
}
|
||||
AccessTokenRequest::RefreshToken(grant) => {
|
||||
let reply = refresh_token_grant(&grant, &client, &mut conn).await?;
|
||||
warp::reply::json(&reply)
|
||||
json(&reply)
|
||||
}
|
||||
_ => {
|
||||
let reply = InvalidGrant.into_response();
|
||||
warp::reply::json(&reply)
|
||||
json(&reply)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -125,7 +156,7 @@ async fn authorization_code_grant(
|
||||
let mut txn = conn.begin().await.wrap_error()?;
|
||||
let code = lookup_code(&mut txn, &grant.code).await.wrap_error()?;
|
||||
if client.client_id != code.client_id {
|
||||
return Err(anyhow::anyhow!("invalid client")).wrap_error();
|
||||
return error(UnauthorizedClient);
|
||||
}
|
||||
|
||||
// TODO: verify PKCE
|
||||
@@ -194,7 +225,8 @@ async fn refresh_token_grant(
|
||||
.wrap_error()?;
|
||||
|
||||
if client.client_id != refresh_token_lookup.client_id {
|
||||
return Err(anyhow::anyhow!("invalid client")).wrap_error();
|
||||
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
return error(InvalidGrant);
|
||||
}
|
||||
|
||||
let ttl = Duration::minutes(5);
|
||||
|
||||
@@ -49,11 +49,24 @@ pub trait OAuth2Error: std::fmt::Debug + Send + Sync {
|
||||
}
|
||||
}
|
||||
|
||||
trait OAuth2ErrorCode: OAuth2Error {
|
||||
pub trait OAuth2ErrorCode: OAuth2Error + 'static {
|
||||
/// The HTTP status code that must be returned by this error
|
||||
fn status(&self) -> StatusCode;
|
||||
}
|
||||
|
||||
impl OAuth2Error for &Box<dyn OAuth2ErrorCode> {
|
||||
fn error(&self) -> &'static str {
|
||||
self.as_ref().error()
|
||||
}
|
||||
fn description(&self) -> Option<String> {
|
||||
self.as_ref().description()
|
||||
}
|
||||
|
||||
fn uri(&self) -> Option<Url> {
|
||||
self.as_ref().uri()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ErrorResponse(Box<dyn OAuth2Error>);
|
||||
|
||||
@@ -112,7 +125,7 @@ impl Serialize for ErrorResponse {
|
||||
|
||||
macro_rules! oauth2_error_def {
|
||||
($name:ident) => {
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct $name;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ use url::Url;
|
||||
|
||||
use crate::{
|
||||
pkce::CodeChallengeMethod,
|
||||
requests::{GrantType, ResponseMode},
|
||||
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
||||
};
|
||||
|
||||
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
|
||||
@@ -61,6 +61,10 @@ pub struct Metadata {
|
||||
/// this authorization server supports.
|
||||
pub grant_types_supported: Option<HashSet<GrantType>>,
|
||||
|
||||
/// JSON array containing a list of client authentication methods supported
|
||||
/// by this token endpoint.
|
||||
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
|
||||
|
||||
/// PKCE code challenge methods supported by this authorization server
|
||||
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
|
||||
|
||||
|
||||
@@ -70,6 +70,27 @@ pub enum ResponseMode {
|
||||
FormPost,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ClientAuthenticationMethod {
|
||||
None,
|
||||
ClientSecretPost,
|
||||
ClientSecretBasic,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
|
||||
Reference in New Issue
Block a user