1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Implement the client credentials grant

This commit is contained in:
Quentin Gliech
2023-09-04 19:45:53 +02:00
parent 00fe5f902b
commit 542d0a6073
17 changed files with 498 additions and 127 deletions

View File

@ -218,7 +218,7 @@ async fn get_requester(
};
// If there is a user for this session, check that it is not locked
let user_valid = user.as_ref().map_or(false, User::is_valid);
let user_valid = user.as_ref().map_or(true, User::is_valid);
if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
return Err(RouteError::InvalidToken);

View File

@ -16,8 +16,13 @@ use axum::http::Request;
use chrono::Duration;
use hyper::StatusCode;
use mas_data_model::{AccessToken, Client, TokenType, User};
use mas_router::SimpleRoute;
use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess};
use oauth2_types::scope::{Scope, ScopeToken, OPENID};
use oauth2_types::{
registration::ClientRegistrationResponse,
requests::AccessTokenResponse,
scope::{Scope, ScopeToken, OPENID},
};
use sqlx::PgPool;
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
@ -349,3 +354,106 @@ async fn test_oauth2_admin(pool: PgPool) {
})
);
}
/// Test that we can query the GraphQL endpoint with a token from a
/// client_credentials grant.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_oauth2_client_credentials(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
// Provision a client
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
// XXX: we shouldn't have to specify the redirect URI here, but the policy denies it for now
"redirect_uris": ["https://example.com/callback"],
"contacts": ["contact@example.com"],
"token_endpoint_auth_method": "client_secret_post",
"grant_types": ["client_credentials"],
"response_types": [],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
let client_id = response.client_id;
let client_secret = response.client_secret.expect("to have a client secret");
// Call the token endpoint with an empty scope
let request = Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
"scope": "urn:mas:graphql:*",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let AccessTokenResponse { access_token, .. } = response.json();
let request = Request::post("/graphql")
.bearer(&access_token)
.json(serde_json::json!({
"query": r#"
query {
viewer {
__typename
}
viewerSession {
__typename
}
}
"#,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: GraphQLResponse = response.json();
assert!(response.errors.is_empty());
assert_eq!(
response.data,
serde_json::json!({
"viewer": {
// There is no user associated with the client credentials grant
"__typename": "Anonymous",
},
"viewerSession": {
// But there is a session
"__typename": "Oauth2Session",
},
})
);
// Check that we can't do a query once the token is revoked
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": access_token,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
// Do the same request again
let request = Request::post("/graphql")
.bearer(&access_token)
.json(serde_json::json!({
"query": r#"
query {
viewer {
__typename
}
viewerSession {
__typename
}
}
"#,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::UNAUTHORIZED);
}

View File

@ -20,8 +20,10 @@ use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
http_client_factory::HttpClientFactory,
};
use mas_data_model::{AuthorizationGrantStage, Client, Device};
use mas_data_model::{AuthorizationGrantStage, Client, Device, TokenType};
use mas_keystore::{Encrypter, Keystore};
use mas_oidc_client::types::scope::ScopeToken;
use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{
job::{JobRepositoryExt, ProvisionDeviceJob},
@ -36,7 +38,8 @@ use oauth2_types::{
errors::{ClientError, ClientErrorCode},
pkce::CodeChallengeError,
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, ClientCredentialsGrant,
GrantType, RefreshTokenGrant,
},
scope,
};
@ -92,6 +95,9 @@ pub(crate) enum RouteError {
#[error("invalid grant")]
InvalidGrant,
#[error("policy denied the request")]
DeniedByPolicy(Vec<mas_policy::Violation>),
#[error("unsupported grant type")]
UnsupportedGrantType,
@ -132,6 +138,18 @@ impl IntoResponse for RouteError {
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
),
Self::DeniedByPolicy(violations) => (
StatusCode::FORBIDDEN,
Json(
ClientError::from(ClientErrorCode::InvalidScope).with_description(
violations
.into_iter()
.map(|violation| violation.msg)
.collect::<Vec<_>>()
.join(", "),
),
),
),
Self::InvalidGrant | Self::GrantNotFound => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidGrant)),
@ -146,6 +164,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(super::IdTokenSignatureError);
#[tracing::instrument(
@ -163,6 +182,7 @@ pub(crate) async fn post(
mut repo: BoxRepository,
State(site_config): State<SiteConfig>,
State(encrypter): State<Encrypter>,
policy: Policy,
client_authorization: ClientAuthorization<AccessTokenRequest>,
) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization
@ -200,6 +220,18 @@ pub(crate) async fn post(
AccessTokenRequest::RefreshToken(grant) => {
refresh_token_grant(&mut rng, &clock, &grant, &client, &site_config, repo).await?
}
AccessTokenRequest::ClientCredentials(grant) => {
client_credentials_grant(
&mut rng,
&clock,
&grant,
&client,
&site_config,
repo,
policy,
)
.await?
}
_ => {
return Err(RouteError::UnsupportedGrantType);
}
@ -420,6 +452,58 @@ async fn refresh_token_grant(
Ok((params, repo))
}
async fn client_credentials_grant(
rng: &mut BoxRng,
clock: &impl Clock,
grant: &ClientCredentialsGrant,
client: &Client,
site_config: &SiteConfig,
mut repo: BoxRepository,
mut policy: Policy,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::ClientCredentials) {
return Err(RouteError::UnauthorizedClient);
}
// Default to an empty scope if none is provided
let scope = grant
.scope
.clone()
.unwrap_or_else(|| std::iter::empty::<ScopeToken>().collect());
// Make the request go through the policy engine
let res = policy
.evaluate_client_credentials_grant(&scope, client)
.await?;
if !res.valid() {
return Err(RouteError::DeniedByPolicy(res.violations));
}
// Start the session
let session = repo
.oauth2_session()
.add_from_client_credentials(rng, clock, client, scope)
.await?;
let ttl = site_config.access_token_ttl;
let access_token_str = TokenType::AccessToken.generate(rng);
let access_token = repo
.oauth2_access_token()
.add(rng, clock, &session, access_token_str, ttl)
.await?;
let mut params = AccessTokenResponse::new(access_token.access_token).with_expires_in(ttl);
if !session.scope.is_empty() {
// We only return the scope if it's not empty
params = params.with_scope(session.scope);
}
Ok((params, repo))
}
#[cfg(test)]
mod tests {
use hyper::Request;
@ -767,7 +851,7 @@ mod tests {
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_unsupported_grant(pool: PgPool) {
async fn test_client_credentials(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
@ -775,6 +859,7 @@ mod tests {
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
// XXX: we shouldn't have to specify the redirect URI here, but the policy denies it for now
"redirect_uris": ["https://example.com/callback"],
"contacts": ["contact@example.com"],
"token_endpoint_auth_method": "client_secret_post",
@ -789,7 +874,7 @@ mod tests {
let client_id = response.client_id;
let client_secret = response.client_secret.expect("to have a client secret");
// Call the token endpoint with an unsupported grant type
// Call the token endpoint with an empty scope
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "client_credentials",
@ -797,6 +882,137 @@ mod tests {
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: AccessTokenResponse = response.json();
assert!(response.refresh_token.is_none());
assert!(response.expires_in.is_some());
assert!(response.scope.is_none());
// Revoke the token
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": response.access_token,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
// We should be allowed to ask for the GraphQL API scope
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
"scope": "urn:mas:graphql:*"
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: AccessTokenResponse = response.json();
assert!(response.refresh_token.is_none());
assert!(response.expires_in.is_some());
assert_eq!(response.scope, Some("urn:mas:graphql:*".parse().unwrap()));
// Revoke the token
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": response.access_token,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
// We should be NOT allowed to ask for the MAS admin scope
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
"scope": "urn:mas:admin"
}));
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let ClientError { error, .. } = response.json();
assert_eq!(error, ClientErrorCode::InvalidScope);
// Now, if we add the client to the admin list in the policy, it should work
let state = {
let mut state = state;
state.policy_factory = crate::test_utils::policy_factory(serde_json::json!({
"admin_clients": [client_id]
}))
.await
.unwrap();
state
};
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
"scope": "urn:mas:admin"
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: AccessTokenResponse = response.json();
assert!(response.refresh_token.is_none());
assert!(response.expires_in.is_some());
assert_eq!(response.scope, Some("urn:mas:admin".parse().unwrap()));
// Revoke the token
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": response.access_token,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_unsupported_grant(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
// Provision a client
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/callback"],
"contacts": ["contact@example.com"],
"token_endpoint_auth_method": "client_secret_post",
"grant_types": ["password"],
"response_types": [],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
let client_id = response.client_id;
let client_secret = response.client_secret.expect("to have a client secret");
// Call the token endpoint with an unsupported grant type
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "password",
"client_id": client_id,
"client_secret": client_secret,
"username": "john",
"password": "hunter2",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let ClientError { error, .. } = response.json();

View File

@ -63,6 +63,28 @@ pub(crate) fn init_tracing() {
.try_init();
}
pub(crate) async fn policy_factory(
data: serde_json::Value,
) -> Result<Arc<PolicyFactory>, anyhow::Error> {
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..");
let file = tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?;
let entrypoints = mas_policy::Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
password: "password/violation".to_owned(),
};
let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
let policy_factory = Arc::new(policy_factory);
Ok(policy_factory)
}
#[derive(Clone)]
pub(crate) struct TestState {
pub pool: PgPool,
@ -116,23 +138,10 @@ impl TestState {
let homeserver = MatrixHomeserver::new("example.com".to_owned());
let file =
tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?;
let entrypoints = mas_policy::Entrypoints {
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
email: "email/violation".to_owned(),
password: "password/violation".to_owned(),
};
let policy_factory = PolicyFactory::load(file, serde_json::json!({}), entrypoints).await?;
let policy_factory = policy_factory(serde_json::json!({})).await?;
let homeserver_connection = MockHomeserverConnection::new("example.com");
let policy_factory = Arc::new(policy_factory);
let http_client_factory = HttpClientFactory::new(10);
let site_config = SiteConfig::default();