1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-07 22:41:18 +03:00

handlers: add tests for the token endpoint

This also simplifies the way we issue tokens in tests
This commit is contained in:
Quentin Gliech
2023-02-22 18:46:15 +01:00
parent 03583d2936
commit 17471c651e
10 changed files with 670 additions and 229 deletions

View File

@ -22,20 +22,20 @@ use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt;
use mas_data_model::{AuthorizationGrant, BrowserSession, Client};
use mas_keystore::Encrypter;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
use oauth2_types::requests::AuthorizationResponse;
use thiserror::Error;
use ulid::Ulid;
use super::callback::CallbackDestination;
use crate::impl_from_error_for_route;
use crate::{impl_from_error_for_route, oauth2::generate_id_token};
#[derive(Debug, Error)]
pub enum RouteError {
@ -90,6 +90,8 @@ pub(crate) async fn get(
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(key_store): State<Keystore>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
@ -119,7 +121,19 @@ pub(crate) async fn get(
.await?
.ok_or(RouteError::NoSuchClient)?;
match complete(rng, clock, grant, client, session, &policy_factory, repo).await {
match complete(
rng,
clock,
repo,
key_store,
&policy_factory,
url_builder,
grant,
client,
session,
)
.await
{
Ok(params) => {
let res = callback_destination.go(&templates, params).await?;
Ok((cookie_jar, res).into_response())
@ -161,16 +175,19 @@ impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDe
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError);
pub(crate) async fn complete(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
key_store: Keystore,
policy_factory: &PolicyFactory,
url_builder: UrlBuilder,
grant: AuthorizationGrant,
client: Client,
browser_session: BrowserSession,
policy_factory: &PolicyFactory,
mut repo: BoxRepository,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
) -> Result<AuthorizationResponse, GrantCompletionError> {
// Verify that the grant is in a pending stage
if !grant.stage.is_pending() {
return Err(GrantCompletionError::NotPending);
@ -211,7 +228,13 @@ pub(crate) async fn complete(
// All good, let's start the session
let session = repo
.oauth2_session()
.create_from_grant(&mut rng, &clock, &grant, &browser_session)
.add(
&mut rng,
&clock,
&client,
&browser_session,
grant.scope.clone(),
)
.await?;
let grant = repo
@ -222,19 +245,25 @@ pub(crate) async fn complete(
// Yep! Let's complete the auth now
let mut params = AuthorizationResponse::default();
// Did they request an ID token?
if grant.response_type_id_token {
params.id_token = Some(generate_id_token(
&mut rng,
&clock,
&url_builder,
&key_store,
&client,
&grant,
&browser_session,
None,
)?);
}
// Did they request an auth code?
if let Some(code) = grant.code {
params.code = Some(code.code);
}
// Did they request an ID token?
if grant.response_type_id_token {
// TODO
return Err(GrantCompletionError::Internal(
"ID tokens are not implemented yet".into(),
));
}
repo.save().await?;
Ok(params)
}

View File

@ -22,9 +22,9 @@ use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt;
use mas_data_model::{AuthorizationCode, Pkce};
use mas_keystore::Encrypter;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRepository, BoxRng,
@ -141,6 +141,8 @@ pub(crate) async fn get(
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(params): Form<Params>,
@ -340,11 +342,13 @@ pub(crate) async fn get(
match self::complete::complete(
rng,
clock,
repo,
key_store,
&policy_factory,
url_builder,
grant,
client,
user_session,
&policy_factory,
repo,
)
.await
{
@ -385,11 +389,13 @@ pub(crate) async fn get(
match self::complete::complete(
rng,
clock,
repo,
key_store,
&policy_factory,
url_builder,
grant,
client,
user_session,
&policy_factory,
repo,
)
.await
{

View File

@ -12,6 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use chrono::Duration;
use mas_data_model::{
AccessToken, AuthorizationGrant, BrowserSession, Client, RefreshToken, Session, TokenType,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims::{self, hash_token},
constraints::Constrainable,
jwt::{JsonWebSignatureHeader, Jwt},
};
use mas_keystore::Keystore;
use mas_router::UrlBuilder;
use mas_storage::{Clock, RepositoryAccess};
use thiserror::Error;
pub mod authorization;
pub mod consent;
pub mod discovery;
@ -22,3 +39,87 @@ pub mod revoke;
pub mod token;
pub mod userinfo;
pub mod webfinger;
#[derive(Debug, Error)]
#[error(transparent)]
pub(crate) enum IdTokenSignatureError {
#[error("The signing key is invalid")]
InvalidSigningKey,
Claim(#[from] mas_jose::claims::ClaimError),
JwtSignature(#[from] mas_jose::jwt::JwtSignatureError),
WrongAlgorithm(#[from] mas_keystore::WrongAlgorithmError),
TokenHash(#[from] mas_jose::claims::TokenHashError),
}
pub(crate) fn generate_id_token(
rng: &mut (impl rand::RngCore + rand::CryptoRng),
clock: &impl Clock,
url_builder: &UrlBuilder,
key_store: &Keystore,
client: &Client,
grant: &AuthorizationGrant,
browser_session: &BrowserSession,
access_token: Option<&AccessToken>,
) -> Result<String, IdTokenSignatureError> {
let mut claims = HashMap::new();
let now = clock.now();
claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?;
claims::SUB.insert(&mut claims, &browser_session.user.sub)?;
claims::AUD.insert(&mut claims, client.client_id.clone())?;
claims::IAT.insert(&mut claims, now)?;
claims::EXP.insert(&mut claims, now + Duration::hours(1))?;
if let Some(ref nonce) = grant.nonce {
claims::NONCE.insert(&mut claims, nonce.clone())?;
}
if let Some(ref last_authentication) = browser_session.last_authentication {
claims::AUTH_TIME.insert(&mut claims, last_authentication.created_at)?;
}
let alg = client
.id_token_signed_response_alg
.clone()
.unwrap_or(JsonWebSignatureAlg::Rs256);
let key = key_store
.signing_key_for_algorithm(&alg)
.ok_or(IdTokenSignatureError::InvalidSigningKey)?;
if let Some(access_token) = access_token {
claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?;
}
if let Some(ref code) = grant.code {
claims::C_HASH.insert(&mut claims, hash_token(&alg, &code.code)?)?;
}
let signer = key.params().signing_key_for_alg(&alg)?;
let header = JsonWebSignatureHeader::new(alg)
.with_kid(key.kid().ok_or(IdTokenSignatureError::InvalidSigningKey)?);
let id_token = Jwt::sign_with_rng(rng, header, claims, &signer)?;
Ok(id_token.into_string())
}
pub(crate) async fn generate_token_pair<R: RepositoryAccess>(
rng: &mut (impl rand::RngCore + Send),
clock: &impl Clock,
repo: &mut R,
session: &Session,
ttl: Duration,
) -> Result<(AccessToken, RefreshToken), R::Error> {
let access_token_str = TokenType::AccessToken.generate(rng);
let refresh_token_str = TokenType::RefreshToken.generate(rng);
let access_token = repo
.oauth2_access_token()
.add(rng, clock, session, access_token_str.clone(), ttl)
.await?;
let refresh_token = repo
.oauth2_refresh_token()
.add(rng, clock, session, &access_token, refresh_token_str)
.await?;
Ok((access_token, refresh_token))
}

View File

@ -201,19 +201,23 @@ pub(crate) async fn post(
#[cfg(test)]
mod tests {
use chrono::Duration;
use hyper::Request;
use mas_data_model::AuthorizationCode;
use mas_data_model::{AccessToken, RefreshToken};
use mas_router::SimpleRoute;
use mas_storage::RepositoryAccess;
use oauth2_types::{
registration::ClientRegistrationResponse,
requests::{AccessTokenResponse, ResponseMode},
requests::AccessTokenResponse,
scope::{Scope, OPENID},
};
use sqlx::PgPool;
use super::*;
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
use crate::{
oauth2::generate_token_pair,
test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState},
};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_revoke_access_token(pool: PgPool) {
@ -262,64 +266,37 @@ mod tests {
.unwrap()
.unwrap();
// Start a grant
let grant = repo
.oauth2_authorization_grant()
let session = repo
.oauth2_session()
.add(
&mut state.rng(),
&state.clock,
&client,
"https://example.com/redirect".parse().unwrap(),
&browser_session,
Scope::from_iter([OPENID]),
Some(AuthorizationCode {
code: "thisisaverysecurecode".to_owned(),
pkce: None,
}),
Some("state".to_owned()),
Some("nonce".to_owned()),
None,
ResponseMode::Query,
false,
false,
)
.await
.unwrap();
let session = repo
.oauth2_session()
.create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session)
.await
.unwrap();
let grant = repo
.oauth2_authorization_grant()
.fulfill(&state.clock, &session, grant)
let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
generate_token_pair(
&mut state.rng(),
&state.clock,
&mut repo,
&session,
Duration::minutes(5),
)
.await
.unwrap();
repo.save().await.unwrap();
// Now call the token endpoint to get an access token.
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "authorization_code",
"code": grant.code.unwrap().code,
"redirect_uri": grant.redirect_uri,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let token: AccessTokenResponse = response.json();
// Check that the token is valid
assert!(state.is_access_token_valid(&token.access_token).await);
assert!(state.is_access_token_valid(&access_token).await);
// Now let's revoke the access token.
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": token.access_token,
"token": access_token,
"token_type_hint": "access_token",
"client_id": client_id,
"client_secret": client_secret,
@ -329,13 +306,13 @@ mod tests {
response.assert_status(StatusCode::OK);
// Check that the token is no longer valid
assert!(!state.is_access_token_valid(&token.access_token).await);
assert!(!state.is_access_token_valid(&access_token).await);
// Try using the refresh token to get a new access token, it should fail.
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": token.refresh_token,
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
}));
@ -345,62 +322,36 @@ mod tests {
// Now try with a new grant, and by revoking the refresh token instead
let mut repo = state.repository().await.unwrap();
let grant = repo
.oauth2_authorization_grant()
let session = repo
.oauth2_session()
.add(
&mut state.rng(),
&state.clock,
&client,
"https://example.com/redirect".parse().unwrap(),
&browser_session,
Scope::from_iter([OPENID]),
Some(AuthorizationCode {
code: "anotherverysecretcode".to_owned(),
pkce: None,
}),
Some("state".to_owned()),
Some("nonce".to_owned()),
None,
ResponseMode::Query,
false,
false,
)
.await
.unwrap();
let session = repo
.oauth2_session()
.create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session)
.await
.unwrap();
let grant = repo
.oauth2_authorization_grant()
.fulfill(&state.clock, &session, grant)
let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
generate_token_pair(
&mut state.rng(),
&state.clock,
&mut repo,
&session,
Duration::minutes(5),
)
.await
.unwrap();
repo.save().await.unwrap();
// Now call the token endpoint to get an access token.
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "authorization_code",
"code": grant.code.unwrap().code,
"redirect_uri": grant.redirect_uri,
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let token: AccessTokenResponse = response.json();
// Use the refresh token to get a new access token.
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": token.refresh_token,
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
}));
@ -408,14 +359,19 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let old_token = token;
let token: AccessTokenResponse = response.json();
assert!(state.is_access_token_valid(&token.access_token).await);
assert!(!state.is_access_token_valid(&old_token.access_token).await);
let old_access_token = access_token;
let old_refresh_token = refresh_token;
let AccessTokenResponse {
access_token,
refresh_token,
..
} = response.json();
assert!(state.is_access_token_valid(&access_token).await);
assert!(!state.is_access_token_valid(&old_access_token).await);
// Revoking the old access token shouldn't do anything.
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": old_token.access_token,
"token": old_access_token,
"token_type_hint": "access_token",
"client_id": client_id,
"client_secret": client_secret,
@ -424,11 +380,11 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
assert!(state.is_access_token_valid(&token.access_token).await);
assert!(state.is_access_token_valid(&access_token).await);
// Revoking the old refresh token shouldn't do anything.
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": old_token.refresh_token,
"token": old_refresh_token,
"token_type_hint": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
@ -437,11 +393,11 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
assert!(state.is_access_token_valid(&token.access_token).await);
assert!(state.is_access_token_valid(&access_token).await);
// Revoking the new refresh token should invalidate the session
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
"token": token.refresh_token,
"token": refresh_token,
"token_type_hint": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
@ -450,6 +406,6 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
assert!(!state.is_access_token_valid(&token.access_token).await);
assert!(!state.is_access_token_valid(&access_token).await);
}
}

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use axum::{extract::State, response::IntoResponse, Json};
use chrono::{DateTime, Duration, Utc};
use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
@ -22,13 +20,7 @@ use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
http_client_factory::HttpClientFactory,
};
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims::{self, hash_token},
constraints::Constrainable,
jwt::{JsonWebSignatureHeader, Jwt},
};
use mas_data_model::{AuthorizationGrantStage, Client};
use mas_keystore::{Encrypter, Keystore};
use mas_router::UrlBuilder;
use mas_storage::{
@ -53,6 +45,7 @@ use thiserror::Error;
use tracing::debug;
use url::Url;
use super::{generate_id_token, generate_token_pair};
use crate::impl_from_error_for_route;
#[serde_as]
@ -98,12 +91,12 @@ pub(crate) enum RouteError {
#[error("invalid grant")]
InvalidGrant,
#[error("unsupported grant type")]
UnsupportedGrantType,
#[error("unauthorized client")]
UnauthorizedClient,
#[error("no suitable key found for signing")]
InvalidSigningKey,
#[error("failed to load browser session")]
NoSuchBrowserSession,
@ -115,10 +108,7 @@ impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
sentry::capture_error(&self);
match self {
Self::Internal(_)
| Self::InvalidSigningKey
| Self::NoSuchBrowserSession
| Self::NoSuchOAuthSession => (
Self::Internal(_) | Self::NoSuchBrowserSession | Self::NoSuchOAuthSession => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)),
),
@ -145,16 +135,17 @@ impl IntoResponse for RouteError {
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::InvalidGrant)),
),
Self::UnsupportedGrantType => (
StatusCode::BAD_REQUEST,
Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
),
}
.into_response()
}
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::claims::ClaimError);
impl_from_error_for_route!(mas_jose::claims::TokenHashError);
impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
impl_from_error_for_route!(super::IdTokenSignatureError);
#[tracing::instrument(
name = "handlers.oauth2.token.post",
@ -207,7 +198,7 @@ pub(crate) async fn post(
refresh_token_grant(&mut rng, &clock, &grant, &client, repo).await?
}
_ => {
return Err(RouteError::InvalidGrant);
return Err(RouteError::UnsupportedGrantType);
}
};
@ -220,7 +211,6 @@ pub(crate) async fn post(
Ok((headers, Json(reply)))
}
#[allow(clippy::too_many_lines)]
async fn authorization_code_grant(
mut rng: &mut BoxRng,
clock: &impl Clock,
@ -311,52 +301,20 @@ async fn authorization_code_grant(
.ok_or(RouteError::NoSuchBrowserSession)?;
let ttl = Duration::minutes(5);
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let access_token = repo
.oauth2_access_token()
.add(&mut rng, clock, &session, access_token_str, ttl)
.await?;
let refresh_token = repo
.oauth2_refresh_token()
.add(&mut rng, clock, &session, &access_token, refresh_token_str)
.await?;
let (access_token, refresh_token) =
generate_token_pair(&mut rng, clock, &mut repo, &session, ttl).await?;
let id_token = if session.scope.contains(&scope::OPENID) {
let mut claims = HashMap::new();
let now = clock.now();
claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?;
claims::SUB.insert(&mut claims, &browser_session.user.sub)?;
claims::AUD.insert(&mut claims, client.client_id.clone())?;
claims::IAT.insert(&mut claims, now)?;
claims::EXP.insert(&mut claims, now + Duration::hours(1))?;
if let Some(ref nonce) = authz_grant.nonce {
claims::NONCE.insert(&mut claims, nonce.clone())?;
}
if let Some(ref last_authentication) = browser_session.last_authentication {
claims::AUTH_TIME.insert(&mut claims, last_authentication.created_at)?;
}
let alg = client
.id_token_signed_response_alg
.clone()
.unwrap_or(JsonWebSignatureAlg::Rs256);
let key = key_store
.signing_key_for_algorithm(&alg)
.ok_or(RouteError::InvalidSigningKey)?;
claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?;
claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?;
let signer = key.params().signing_key_for_alg(&alg)?;
let header = JsonWebSignatureHeader::new(alg)
.with_kid(key.kid().ok_or(RouteError::InvalidSigningKey)?);
let id_token = Jwt::sign_with_rng(&mut rng, header, claims, &signer)?;
Some(id_token.as_str().to_owned())
Some(generate_id_token(
&mut rng,
clock,
url_builder,
key_store,
client,
&authz_grant,
&browser_session,
Some(&access_token),
)?)
} else {
None
};
@ -378,7 +336,7 @@ async fn authorization_code_grant(
}
async fn refresh_token_grant(
mut rng: &mut BoxRng,
rng: &mut BoxRng,
clock: &impl Clock,
grant: &RefreshTokenGrant,
client: &Client,
@ -406,24 +364,8 @@ async fn refresh_token_grant(
}
let ttl = Duration::minutes(5);
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let new_access_token = repo
.oauth2_access_token()
.add(&mut rng, clock, &session, access_token_str.clone(), ttl)
.await?;
let new_refresh_token = repo
.oauth2_refresh_token()
.add(
&mut rng,
clock,
&session,
&new_access_token,
refresh_token_str,
)
.await?;
let (new_access_token, new_refresh_token) =
generate_token_pair(rng, clock, &mut repo, &session, ttl).await?;
let refresh_token = repo
.oauth2_refresh_token()
@ -439,10 +381,394 @@ async fn refresh_token_grant(
}
}
let params = AccessTokenResponse::new(access_token_str)
let params = AccessTokenResponse::new(new_access_token.access_token)
.with_expires_in(ttl)
.with_refresh_token(new_refresh_token.refresh_token)
.with_scope(session.scope);
Ok((params, repo))
}
#[cfg(test)]
mod tests {
use hyper::Request;
use mas_data_model::{AccessToken, AuthorizationCode, RefreshToken};
use mas_router::SimpleRoute;
use oauth2_types::{
registration::ClientRegistrationResponse,
requests::ResponseMode,
scope::{Scope, OPENID},
};
use sqlx::PgPool;
use super::*;
use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_auth_code_grant(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
// Provision a client
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/callback"],
"contacts": ["contact@example.com"],
"token_endpoint_auth_method": "none",
"response_types": ["code"],
"grant_types": ["authorization_code"],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let ClientRegistrationResponse { client_id, .. } = response.json();
// Let's provision a user and create a session for them. This part is hard to
// test with just HTTP requests, so we'll use the repository directly.
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();
let browser_session = repo
.browser_session()
.add(&mut state.rng(), &state.clock, &user)
.await
.unwrap();
// Lookup the client in the database.
let client = repo
.oauth2_client()
.find_by_client_id(&client_id)
.await
.unwrap()
.unwrap();
// Start a grant
let code = "thisisaverysecurecode";
let grant = repo
.oauth2_authorization_grant()
.add(
&mut state.rng(),
&state.clock,
&client,
"https://example.com/redirect".parse().unwrap(),
Scope::from_iter([OPENID]),
Some(AuthorizationCode {
code: code.to_owned(),
pkce: None,
}),
Some("state".to_owned()),
Some("nonce".to_owned()),
None,
ResponseMode::Query,
false,
false,
)
.await
.unwrap();
let session = repo
.oauth2_session()
.add(
&mut state.rng(),
&state.clock,
&client,
&browser_session,
grant.scope.clone(),
)
.await
.unwrap();
// And fulfill it
let grant = repo
.oauth2_authorization_grant()
.fulfill(&state.clock, &session, grant)
.await
.unwrap();
repo.save().await.unwrap();
// Now call the token endpoint to get an access token.
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "authorization_code",
"code": code,
"redirect_uri": grant.redirect_uri,
"client_id": client.client_id,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let AccessTokenResponse { access_token, .. } = response.json();
// Check that the token is valid
assert!(state.is_access_token_valid(&access_token).await);
// Exchange it again, this it should fail
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "authorization_code",
"code": code,
"redirect_uri": grant.redirect_uri,
"client_id": client.client_id,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let error: ClientError = response.json();
assert_eq!(error.error, ClientErrorCode::InvalidGrant);
// The token should still be valid
assert!(state.is_access_token_valid(&access_token).await);
// Now wait a bit
state.clock.advance(Duration::minutes(1));
// Exchange it again, this it should fail
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "authorization_code",
"code": code,
"redirect_uri": grant.redirect_uri,
"client_id": client.client_id,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let error: ClientError = response.json();
assert_eq!(error.error, ClientErrorCode::InvalidGrant);
// And it should have revoked the token we got
assert!(!state.is_access_token_valid(&access_token).await);
// Try another one and wait for too long before exchanging it
let mut repo = state.repository().await.unwrap();
let code = "thisisanothercode";
let grant = repo
.oauth2_authorization_grant()
.add(
&mut state.rng(),
&state.clock,
&client,
"https://example.com/redirect".parse().unwrap(),
Scope::from_iter([OPENID]),
Some(AuthorizationCode {
code: code.to_owned(),
pkce: None,
}),
Some("state".to_owned()),
Some("nonce".to_owned()),
None,
ResponseMode::Query,
false,
false,
)
.await
.unwrap();
let session = repo
.oauth2_session()
.add(
&mut state.rng(),
&state.clock,
&client,
&browser_session,
grant.scope.clone(),
)
.await
.unwrap();
// And fulfill it
let grant = repo
.oauth2_authorization_grant()
.fulfill(&state.clock, &session, grant)
.await
.unwrap();
repo.save().await.unwrap();
// Now wait a bit
state.clock.advance(Duration::minutes(15));
// Exchange it, it should fail
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "authorization_code",
"code": code,
"redirect_uri": grant.redirect_uri,
"client_id": client.client_id,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let ClientError { error, .. } = response.json();
assert_eq!(error, ClientErrorCode::InvalidGrant);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_refresh_token_grant(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
// Provision a client
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/callback"],
"contacts": ["contact@example.com"],
"token_endpoint_auth_method": "none",
"response_types": ["code"],
"grant_types": ["authorization_code"],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let ClientRegistrationResponse { client_id, .. } = response.json();
// Let's provision a user and create a session for them. This part is hard to
// test with just HTTP requests, so we'll use the repository directly.
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();
let browser_session = repo
.browser_session()
.add(&mut state.rng(), &state.clock, &user)
.await
.unwrap();
// Lookup the client in the database.
let client = repo
.oauth2_client()
.find_by_client_id(&client_id)
.await
.unwrap()
.unwrap();
// Get a token pair
let session = repo
.oauth2_session()
.add(
&mut state.rng(),
&state.clock,
&client,
&browser_session,
Scope::from_iter([OPENID]),
)
.await
.unwrap();
let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
generate_token_pair(
&mut state.rng(),
&state.clock,
&mut repo,
&session,
Duration::minutes(5),
)
.await
.unwrap();
repo.save().await.unwrap();
// First check that the token is valid
assert!(state.is_access_token_valid(&access_token).await);
// Now call the token endpoint to get an access token.
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client.client_id,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let old_access_token = access_token;
let old_refresh_token = refresh_token;
let response: AccessTokenResponse = response.json();
let access_token = response.access_token;
let refresh_token = response.refresh_token.expect("to have a refresh token");
// Check that the new token is valid
assert!(state.is_access_token_valid(&access_token).await);
// Check that the old token is no longer valid
assert!(!state.is_access_token_valid(&old_access_token).await);
// Call it again with the old token, it should fail
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": old_refresh_token,
"client_id": client.client_id,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let ClientError { error, .. } = response.json();
assert_eq!(error, ClientErrorCode::InvalidGrant);
// Call it again with the new token, it should work
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client.client_id,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let _: AccessTokenResponse = response.json();
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_unsupported_grant(pool: PgPool) {
init_tracing();
let state = TestState::from_pool(pool).await.unwrap();
// Provision a client
let request =
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
"client_uri": "https://example.com/",
"redirect_uris": ["https://example.com/callback"],
"contacts": ["contact@example.com"],
"token_endpoint_auth_method": "client_secret_post",
"grant_types": ["client_credentials"],
"response_types": [],
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let response: ClientRegistrationResponse = response.json();
let client_id = response.client_id;
let client_secret = response.client_secret.expect("to have a client secret");
// Call the token endpoint with an unsupported grant type
let request =
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
}));
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let ClientError { error, .. } = response.json();
assert_eq!(error, ClientErrorCode::UnsupportedGrantType);
}
}

View File

@ -19,7 +19,7 @@ use axum::{
body::HttpBody,
extract::{FromRef, FromRequestParts},
};
use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue};
use headers::{Authorization, ContentType, HeaderMapExt, HeaderName};
use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode};
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_email::{MailTransport, Mailer};

View File

@ -325,20 +325,32 @@ impl fmt::Debug for AuthorizationRequest {
///
/// [Authorization Endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
#[skip_serializing_none]
#[serde_as]
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct AuthorizationResponse<R> {
pub struct AuthorizationResponse {
/// The authorization code generated by the authorization server.
pub code: Option<String>,
/// Other fields of the response.
#[serde(flatten)]
pub response: R,
/// The access token to access the requested scope.
pub access_token: Option<String>,
/// The type of the access token.
pub token_type: Option<OAuthAccessTokenType>,
/// ID Token value associated with the authenticated session.
pub id_token: Option<String>,
/// The duration for which the access token is valid.
#[serde_as(as = "Option<DurationSeconds<i64>>")]
pub expires_in: Option<Duration>,
}
impl<R: fmt::Debug> fmt::Debug for AuthorizationResponse<R> {
impl fmt::Debug for AuthorizationResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthorizationResponse")
.field("response", &self.response)
.field("token_type", &self.token_type)
.field("id_token", &self.id_token)
.field("expires_in", &self.expires_in)
.finish_non_exhaustive()
}
}

View File

@ -203,10 +203,16 @@ mod tests {
let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap();
assert_eq!(session, None);
// Create a session out of the grant
// Create an OAuth session
let session = repo
.oauth2_session()
.create_from_grant(&mut rng, &clock, &grant, &user_session)
.add(
&mut rng,
&clock,
&client,
&user_session,
grant.scope.clone(),
)
.await
.unwrap();

View File

@ -14,8 +14,9 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User};
use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination};
use oauth2_types::scope::Scope;
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
@ -118,25 +119,25 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
}
#[tracing::instrument(
name = "db.oauth2_session.create_from_grant",
name = "db.oauth2_session.add",
skip_all,
fields(
db.statement,
%user_session.id,
user.id = %user_session.user.id,
%grant.id,
client.id = %grant.client_id,
%client.id,
session.id,
session.scope = %grant.scope,
session.scope = %scope,
),
err,
)]
async fn create_from_grant(
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
grant: &AuthorizationGrant,
client: &Client,
user_session: &BrowserSession,
scope: Scope,
) -> Result<Session, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
@ -155,8 +156,8 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
"#,
Uuid::from(id),
Uuid::from(user_session.id),
Uuid::from(grant.client_id),
grant.scope.to_string(),
Uuid::from(client.id),
scope.to_string(),
created_at,
)
.traced()
@ -168,8 +169,8 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
state: SessionState::Valid,
created_at,
user_session_id: user_session.id,
client_id: grant.client_id,
scope: grant.scope.clone(),
client_id: client.id,
scope,
})
}

View File

@ -13,7 +13,8 @@
// limitations under the License.
use async_trait::async_trait;
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User};
use mas_data_model::{BrowserSession, Client, Session, User};
use oauth2_types::scope::Scope;
use rand_core::RngCore;
use ulid::Ulid;
@ -39,7 +40,7 @@ pub trait OAuth2SessionRepository: Send + Sync {
/// Returns [`Self::Error`] if the underlying repository fails
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
/// Create a new [`Session`] from an [`AuthorizationGrant`]
/// Create a new [`Session`]
///
/// Returns the newly created [`Session`]
///
@ -47,19 +48,21 @@ pub trait OAuth2SessionRepository: Send + Sync {
///
/// * `rng`: The random number generator to use
/// * `clock`: The clock used to generate timestamps
/// * `grant`: The [`AuthorizationGrant`] to create the [`Session`] from
/// * `client`: The [`Client`] which created the [`Session`]
/// * `user_session`: The [`BrowserSession`] of the user which completed the
/// authorization
/// * `scope`: The [`Scope`] of the [`Session`]
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn create_from_grant(
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
grant: &AuthorizationGrant,
client: &Client,
user_session: &BrowserSession,
scope: Scope,
) -> Result<Session, Self::Error>;
/// Mark a [`Session`] as finished
@ -97,12 +100,13 @@ pub trait OAuth2SessionRepository: Send + Sync {
repository_impl!(OAuth2SessionRepository:
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
async fn create_from_grant(
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
grant: &AuthorizationGrant,
client: &Client,
user_session: &BrowserSession,
scope: Scope,
) -> Result<Session, Self::Error>;
async fn finish(&mut self, clock: &dyn Clock, session: Session)