diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index b978a701..41912dbe 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -22,20 +22,20 @@ use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; use mas_axum_utils::SessionInfoExt; use mas_data_model::{AuthorizationGrant, BrowserSession, Client}; -use mas_keystore::Encrypter; +use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; -use mas_router::{PostAuthAction, Route}; +use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, BoxClock, BoxRepository, BoxRng, }; use mas_templates::Templates; -use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; +use oauth2_types::requests::AuthorizationResponse; use thiserror::Error; use ulid::Ulid; use super::callback::CallbackDestination; -use crate::impl_from_error_for_route; +use crate::{impl_from_error_for_route, oauth2::generate_id_token}; #[derive(Debug, Error)] pub enum RouteError { @@ -90,6 +90,8 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, + State(url_builder): State, + State(key_store): State, mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, @@ -119,7 +121,19 @@ pub(crate) async fn get( .await? .ok_or(RouteError::NoSuchClient)?; - match complete(rng, clock, grant, client, session, &policy_factory, repo).await { + match complete( + rng, + clock, + repo, + key_store, + &policy_factory, + url_builder, + grant, + client, + session, + ) + .await + { Ok(params) => { let res = callback_destination.go(&templates, params).await?; Ok((cookie_jar, res).into_response()) @@ -161,16 +175,19 @@ impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDe impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); +impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError); pub(crate) async fn complete( mut rng: BoxRng, clock: BoxClock, + mut repo: BoxRepository, + key_store: Keystore, + policy_factory: &PolicyFactory, + url_builder: UrlBuilder, grant: AuthorizationGrant, client: Client, browser_session: BrowserSession, - policy_factory: &PolicyFactory, - mut repo: BoxRepository, -) -> Result>, GrantCompletionError> { +) -> Result { // Verify that the grant is in a pending stage if !grant.stage.is_pending() { return Err(GrantCompletionError::NotPending); @@ -211,7 +228,13 @@ pub(crate) async fn complete( // All good, let's start the session let session = repo .oauth2_session() - .create_from_grant(&mut rng, &clock, &grant, &browser_session) + .add( + &mut rng, + &clock, + &client, + &browser_session, + grant.scope.clone(), + ) .await?; let grant = repo @@ -222,19 +245,25 @@ pub(crate) async fn complete( // Yep! Let's complete the auth now let mut params = AuthorizationResponse::default(); + // Did they request an ID token? + if grant.response_type_id_token { + params.id_token = Some(generate_id_token( + &mut rng, + &clock, + &url_builder, + &key_store, + &client, + &grant, + &browser_session, + None, + )?); + } + // Did they request an auth code? if let Some(code) = grant.code { params.code = Some(code.code); } - // Did they request an ID token? - if grant.response_type_id_token { - // TODO - return Err(GrantCompletionError::Internal( - "ID tokens are not implemented yet".into(), - )); - } - repo.save().await?; Ok(params) } diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 657fcaf4..1230362d 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -22,9 +22,9 @@ use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; use mas_axum_utils::SessionInfoExt; use mas_data_model::{AuthorizationCode, Pkce}; -use mas_keystore::Encrypter; +use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; -use mas_router::{PostAuthAction, Route}; +use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, BoxClock, BoxRepository, BoxRng, @@ -141,6 +141,8 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, + State(key_store): State, + State(url_builder): State, mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(params): Form, @@ -340,11 +342,13 @@ pub(crate) async fn get( match self::complete::complete( rng, clock, + repo, + key_store, + &policy_factory, + url_builder, grant, client, user_session, - &policy_factory, - repo, ) .await { @@ -385,11 +389,13 @@ pub(crate) async fn get( match self::complete::complete( rng, clock, + repo, + key_store, + &policy_factory, + url_builder, grant, client, user_session, - &policy_factory, - repo, ) .await { diff --git a/crates/handlers/src/oauth2/mod.rs b/crates/handlers/src/oauth2/mod.rs index 98b79383..19edd0da 100644 --- a/crates/handlers/src/oauth2/mod.rs +++ b/crates/handlers/src/oauth2/mod.rs @@ -12,6 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + +use chrono::Duration; +use mas_data_model::{ + AccessToken, AuthorizationGrant, BrowserSession, Client, RefreshToken, Session, TokenType, +}; +use mas_iana::jose::JsonWebSignatureAlg; +use mas_jose::{ + claims::{self, hash_token}, + constraints::Constrainable, + jwt::{JsonWebSignatureHeader, Jwt}, +}; +use mas_keystore::Keystore; +use mas_router::UrlBuilder; +use mas_storage::{Clock, RepositoryAccess}; +use thiserror::Error; + pub mod authorization; pub mod consent; pub mod discovery; @@ -22,3 +39,87 @@ pub mod revoke; pub mod token; pub mod userinfo; pub mod webfinger; + +#[derive(Debug, Error)] +#[error(transparent)] +pub(crate) enum IdTokenSignatureError { + #[error("The signing key is invalid")] + InvalidSigningKey, + Claim(#[from] mas_jose::claims::ClaimError), + JwtSignature(#[from] mas_jose::jwt::JwtSignatureError), + WrongAlgorithm(#[from] mas_keystore::WrongAlgorithmError), + TokenHash(#[from] mas_jose::claims::TokenHashError), +} + +pub(crate) fn generate_id_token( + rng: &mut (impl rand::RngCore + rand::CryptoRng), + clock: &impl Clock, + url_builder: &UrlBuilder, + key_store: &Keystore, + client: &Client, + grant: &AuthorizationGrant, + browser_session: &BrowserSession, + access_token: Option<&AccessToken>, +) -> Result { + let mut claims = HashMap::new(); + let now = clock.now(); + claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?; + claims::SUB.insert(&mut claims, &browser_session.user.sub)?; + claims::AUD.insert(&mut claims, client.client_id.clone())?; + claims::IAT.insert(&mut claims, now)?; + claims::EXP.insert(&mut claims, now + Duration::hours(1))?; + + if let Some(ref nonce) = grant.nonce { + claims::NONCE.insert(&mut claims, nonce.clone())?; + } + + if let Some(ref last_authentication) = browser_session.last_authentication { + claims::AUTH_TIME.insert(&mut claims, last_authentication.created_at)?; + } + + let alg = client + .id_token_signed_response_alg + .clone() + .unwrap_or(JsonWebSignatureAlg::Rs256); + let key = key_store + .signing_key_for_algorithm(&alg) + .ok_or(IdTokenSignatureError::InvalidSigningKey)?; + + if let Some(access_token) = access_token { + claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?; + } + + if let Some(ref code) = grant.code { + claims::C_HASH.insert(&mut claims, hash_token(&alg, &code.code)?)?; + } + + let signer = key.params().signing_key_for_alg(&alg)?; + let header = JsonWebSignatureHeader::new(alg) + .with_kid(key.kid().ok_or(IdTokenSignatureError::InvalidSigningKey)?); + let id_token = Jwt::sign_with_rng(rng, header, claims, &signer)?; + + Ok(id_token.into_string()) +} + +pub(crate) async fn generate_token_pair( + rng: &mut (impl rand::RngCore + Send), + clock: &impl Clock, + repo: &mut R, + session: &Session, + ttl: Duration, +) -> Result<(AccessToken, RefreshToken), R::Error> { + let access_token_str = TokenType::AccessToken.generate(rng); + let refresh_token_str = TokenType::RefreshToken.generate(rng); + + let access_token = repo + .oauth2_access_token() + .add(rng, clock, session, access_token_str.clone(), ttl) + .await?; + + let refresh_token = repo + .oauth2_refresh_token() + .add(rng, clock, session, &access_token, refresh_token_str) + .await?; + + Ok((access_token, refresh_token)) +} diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index c80febc2..d9a6748e 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -201,19 +201,23 @@ pub(crate) async fn post( #[cfg(test)] mod tests { + use chrono::Duration; use hyper::Request; - use mas_data_model::AuthorizationCode; + use mas_data_model::{AccessToken, RefreshToken}; use mas_router::SimpleRoute; use mas_storage::RepositoryAccess; use oauth2_types::{ registration::ClientRegistrationResponse, - requests::{AccessTokenResponse, ResponseMode}, + requests::AccessTokenResponse, scope::{Scope, OPENID}, }; use sqlx::PgPool; use super::*; - use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}; + use crate::{ + oauth2::generate_token_pair, + test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}, + }; #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_revoke_access_token(pool: PgPool) { @@ -262,64 +266,37 @@ mod tests { .unwrap() .unwrap(); - // Start a grant - let grant = repo - .oauth2_authorization_grant() + let session = repo + .oauth2_session() .add( &mut state.rng(), &state.clock, &client, - "https://example.com/redirect".parse().unwrap(), + &browser_session, Scope::from_iter([OPENID]), - Some(AuthorizationCode { - code: "thisisaverysecurecode".to_owned(), - pkce: None, - }), - Some("state".to_owned()), - Some("nonce".to_owned()), - None, - ResponseMode::Query, - false, - false, ) .await .unwrap(); - let session = repo - .oauth2_session() - .create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session) - .await - .unwrap(); - - let grant = repo - .oauth2_authorization_grant() - .fulfill(&state.clock, &session, grant) + let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) = + generate_token_pair( + &mut state.rng(), + &state.clock, + &mut repo, + &session, + Duration::minutes(5), + ) .await .unwrap(); repo.save().await.unwrap(); - // Now call the token endpoint to get an access token. - let request = - Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ - "grant_type": "authorization_code", - "code": grant.code.unwrap().code, - "redirect_uri": grant.redirect_uri, - "client_id": client_id, - "client_secret": client_secret, - })); - - let response = state.request(request).await; - response.assert_status(StatusCode::OK); - - let token: AccessTokenResponse = response.json(); - // Check that the token is valid - assert!(state.is_access_token_valid(&token.access_token).await); + assert!(state.is_access_token_valid(&access_token).await); // Now let's revoke the access token. let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ - "token": token.access_token, + "token": access_token, "token_type_hint": "access_token", "client_id": client_id, "client_secret": client_secret, @@ -329,13 +306,13 @@ mod tests { response.assert_status(StatusCode::OK); // Check that the token is no longer valid - assert!(!state.is_access_token_valid(&token.access_token).await); + assert!(!state.is_access_token_valid(&access_token).await); // Try using the refresh token to get a new access token, it should fail. let request = Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ "grant_type": "refresh_token", - "refresh_token": token.refresh_token, + "refresh_token": refresh_token, "client_id": client_id, "client_secret": client_secret, })); @@ -345,62 +322,36 @@ mod tests { // Now try with a new grant, and by revoking the refresh token instead let mut repo = state.repository().await.unwrap(); - let grant = repo - .oauth2_authorization_grant() + let session = repo + .oauth2_session() .add( &mut state.rng(), &state.clock, &client, - "https://example.com/redirect".parse().unwrap(), + &browser_session, Scope::from_iter([OPENID]), - Some(AuthorizationCode { - code: "anotherverysecretcode".to_owned(), - pkce: None, - }), - Some("state".to_owned()), - Some("nonce".to_owned()), - None, - ResponseMode::Query, - false, - false, ) .await .unwrap(); - let session = repo - .oauth2_session() - .create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session) - .await - .unwrap(); - - let grant = repo - .oauth2_authorization_grant() - .fulfill(&state.clock, &session, grant) + let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) = + generate_token_pair( + &mut state.rng(), + &state.clock, + &mut repo, + &session, + Duration::minutes(5), + ) .await .unwrap(); repo.save().await.unwrap(); - // Now call the token endpoint to get an access token. - let request = - Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ - "grant_type": "authorization_code", - "code": grant.code.unwrap().code, - "redirect_uri": grant.redirect_uri, - "client_id": client_id, - "client_secret": client_secret, - })); - - let response = state.request(request).await; - response.assert_status(StatusCode::OK); - - let token: AccessTokenResponse = response.json(); - // Use the refresh token to get a new access token. let request = Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ "grant_type": "refresh_token", - "refresh_token": token.refresh_token, + "refresh_token": refresh_token, "client_id": client_id, "client_secret": client_secret, })); @@ -408,14 +359,19 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::OK); - let old_token = token; - let token: AccessTokenResponse = response.json(); - assert!(state.is_access_token_valid(&token.access_token).await); - assert!(!state.is_access_token_valid(&old_token.access_token).await); + let old_access_token = access_token; + let old_refresh_token = refresh_token; + let AccessTokenResponse { + access_token, + refresh_token, + .. + } = response.json(); + assert!(state.is_access_token_valid(&access_token).await); + assert!(!state.is_access_token_valid(&old_access_token).await); // Revoking the old access token shouldn't do anything. let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ - "token": old_token.access_token, + "token": old_access_token, "token_type_hint": "access_token", "client_id": client_id, "client_secret": client_secret, @@ -424,11 +380,11 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::OK); - assert!(state.is_access_token_valid(&token.access_token).await); + assert!(state.is_access_token_valid(&access_token).await); // Revoking the old refresh token shouldn't do anything. let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ - "token": old_token.refresh_token, + "token": old_refresh_token, "token_type_hint": "refresh_token", "client_id": client_id, "client_secret": client_secret, @@ -437,11 +393,11 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::OK); - assert!(state.is_access_token_valid(&token.access_token).await); + assert!(state.is_access_token_valid(&access_token).await); // Revoking the new refresh token should invalidate the session let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ - "token": token.refresh_token, + "token": refresh_token, "token_type_hint": "refresh_token", "client_id": client_id, "client_secret": client_secret, @@ -450,6 +406,6 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::OK); - assert!(!state.is_access_token_valid(&token.access_token).await); + assert!(!state.is_access_token_valid(&access_token).await); } } diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 29ebd1ac..235f7500 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; - use axum::{extract::State, response::IntoResponse, Json}; use chrono::{DateTime, Duration, Utc}; use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; @@ -22,13 +20,7 @@ use mas_axum_utils::{ client_authorization::{ClientAuthorization, CredentialsVerificationError}, http_client_factory::HttpClientFactory, }; -use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; -use mas_iana::jose::JsonWebSignatureAlg; -use mas_jose::{ - claims::{self, hash_token}, - constraints::Constrainable, - jwt::{JsonWebSignatureHeader, Jwt}, -}; +use mas_data_model::{AuthorizationGrantStage, Client}; use mas_keystore::{Encrypter, Keystore}; use mas_router::UrlBuilder; use mas_storage::{ @@ -53,6 +45,7 @@ use thiserror::Error; use tracing::debug; use url::Url; +use super::{generate_id_token, generate_token_pair}; use crate::impl_from_error_for_route; #[serde_as] @@ -98,12 +91,12 @@ pub(crate) enum RouteError { #[error("invalid grant")] InvalidGrant, + #[error("unsupported grant type")] + UnsupportedGrantType, + #[error("unauthorized client")] UnauthorizedClient, - #[error("no suitable key found for signing")] - InvalidSigningKey, - #[error("failed to load browser session")] NoSuchBrowserSession, @@ -115,10 +108,7 @@ impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { sentry::capture_error(&self); match self { - Self::Internal(_) - | Self::InvalidSigningKey - | Self::NoSuchBrowserSession - | Self::NoSuchOAuthSession => ( + Self::Internal(_) | Self::NoSuchBrowserSession | Self::NoSuchOAuthSession => ( StatusCode::INTERNAL_SERVER_ERROR, Json(ClientError::from(ClientErrorCode::ServerError)), ), @@ -145,16 +135,17 @@ impl IntoResponse for RouteError { StatusCode::BAD_REQUEST, Json(ClientError::from(ClientErrorCode::InvalidGrant)), ), + Self::UnsupportedGrantType => ( + StatusCode::BAD_REQUEST, + Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)), + ), } .into_response() } } impl_from_error_for_route!(mas_storage::RepositoryError); -impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); -impl_from_error_for_route!(mas_jose::claims::ClaimError); -impl_from_error_for_route!(mas_jose::claims::TokenHashError); -impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); +impl_from_error_for_route!(super::IdTokenSignatureError); #[tracing::instrument( name = "handlers.oauth2.token.post", @@ -207,7 +198,7 @@ pub(crate) async fn post( refresh_token_grant(&mut rng, &clock, &grant, &client, repo).await? } _ => { - return Err(RouteError::InvalidGrant); + return Err(RouteError::UnsupportedGrantType); } }; @@ -220,7 +211,6 @@ pub(crate) async fn post( Ok((headers, Json(reply))) } -#[allow(clippy::too_many_lines)] async fn authorization_code_grant( mut rng: &mut BoxRng, clock: &impl Clock, @@ -311,52 +301,20 @@ async fn authorization_code_grant( .ok_or(RouteError::NoSuchBrowserSession)?; let ttl = Duration::minutes(5); - let access_token_str = TokenType::AccessToken.generate(&mut rng); - let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - - let access_token = repo - .oauth2_access_token() - .add(&mut rng, clock, &session, access_token_str, ttl) - .await?; - - let refresh_token = repo - .oauth2_refresh_token() - .add(&mut rng, clock, &session, &access_token, refresh_token_str) - .await?; + let (access_token, refresh_token) = + generate_token_pair(&mut rng, clock, &mut repo, &session, ttl).await?; let id_token = if session.scope.contains(&scope::OPENID) { - let mut claims = HashMap::new(); - let now = clock.now(); - claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?; - claims::SUB.insert(&mut claims, &browser_session.user.sub)?; - claims::AUD.insert(&mut claims, client.client_id.clone())?; - claims::IAT.insert(&mut claims, now)?; - claims::EXP.insert(&mut claims, now + Duration::hours(1))?; - - if let Some(ref nonce) = authz_grant.nonce { - claims::NONCE.insert(&mut claims, nonce.clone())?; - } - if let Some(ref last_authentication) = browser_session.last_authentication { - claims::AUTH_TIME.insert(&mut claims, last_authentication.created_at)?; - } - - let alg = client - .id_token_signed_response_alg - .clone() - .unwrap_or(JsonWebSignatureAlg::Rs256); - let key = key_store - .signing_key_for_algorithm(&alg) - .ok_or(RouteError::InvalidSigningKey)?; - - claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?; - claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?; - - let signer = key.params().signing_key_for_alg(&alg)?; - let header = JsonWebSignatureHeader::new(alg) - .with_kid(key.kid().ok_or(RouteError::InvalidSigningKey)?); - let id_token = Jwt::sign_with_rng(&mut rng, header, claims, &signer)?; - - Some(id_token.as_str().to_owned()) + Some(generate_id_token( + &mut rng, + clock, + url_builder, + key_store, + client, + &authz_grant, + &browser_session, + Some(&access_token), + )?) } else { None }; @@ -378,7 +336,7 @@ async fn authorization_code_grant( } async fn refresh_token_grant( - mut rng: &mut BoxRng, + rng: &mut BoxRng, clock: &impl Clock, grant: &RefreshTokenGrant, client: &Client, @@ -406,24 +364,8 @@ async fn refresh_token_grant( } let ttl = Duration::minutes(5); - let access_token_str = TokenType::AccessToken.generate(&mut rng); - let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - - let new_access_token = repo - .oauth2_access_token() - .add(&mut rng, clock, &session, access_token_str.clone(), ttl) - .await?; - - let new_refresh_token = repo - .oauth2_refresh_token() - .add( - &mut rng, - clock, - &session, - &new_access_token, - refresh_token_str, - ) - .await?; + let (new_access_token, new_refresh_token) = + generate_token_pair(rng, clock, &mut repo, &session, ttl).await?; let refresh_token = repo .oauth2_refresh_token() @@ -439,10 +381,394 @@ async fn refresh_token_grant( } } - let params = AccessTokenResponse::new(access_token_str) + let params = AccessTokenResponse::new(new_access_token.access_token) .with_expires_in(ttl) .with_refresh_token(new_refresh_token.refresh_token) .with_scope(session.scope); Ok((params, repo)) } + +#[cfg(test)] +mod tests { + use hyper::Request; + use mas_data_model::{AccessToken, AuthorizationCode, RefreshToken}; + use mas_router::SimpleRoute; + use oauth2_types::{ + registration::ClientRegistrationResponse, + requests::ResponseMode, + scope::{Scope, OPENID}, + }; + use sqlx::PgPool; + + use super::*; + use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_auth_code_grant(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Provision a client + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/callback"], + "contacts": ["contact@example.com"], + "token_endpoint_auth_method": "none", + "response_types": ["code"], + "grant_types": ["authorization_code"], + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + + let ClientRegistrationResponse { client_id, .. } = response.json(); + + // Let's provision a user and create a session for them. This part is hard to + // test with just HTTP requests, so we'll use the repository directly. + let mut repo = state.repository().await.unwrap(); + + let user = repo + .user() + .add(&mut state.rng(), &state.clock, "alice".to_owned()) + .await + .unwrap(); + + let browser_session = repo + .browser_session() + .add(&mut state.rng(), &state.clock, &user) + .await + .unwrap(); + + // Lookup the client in the database. + let client = repo + .oauth2_client() + .find_by_client_id(&client_id) + .await + .unwrap() + .unwrap(); + + // Start a grant + let code = "thisisaverysecurecode"; + let grant = repo + .oauth2_authorization_grant() + .add( + &mut state.rng(), + &state.clock, + &client, + "https://example.com/redirect".parse().unwrap(), + Scope::from_iter([OPENID]), + Some(AuthorizationCode { + code: code.to_owned(), + pkce: None, + }), + Some("state".to_owned()), + Some("nonce".to_owned()), + None, + ResponseMode::Query, + false, + false, + ) + .await + .unwrap(); + + let session = repo + .oauth2_session() + .add( + &mut state.rng(), + &state.clock, + &client, + &browser_session, + grant.scope.clone(), + ) + .await + .unwrap(); + + // And fulfill it + let grant = repo + .oauth2_authorization_grant() + .fulfill(&state.clock, &session, grant) + .await + .unwrap(); + + repo.save().await.unwrap(); + + // Now call the token endpoint to get an access token. + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": grant.redirect_uri, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + let AccessTokenResponse { access_token, .. } = response.json(); + + // Check that the token is valid + assert!(state.is_access_token_valid(&access_token).await); + + // Exchange it again, this it should fail + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": grant.redirect_uri, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let error: ClientError = response.json(); + assert_eq!(error.error, ClientErrorCode::InvalidGrant); + + // The token should still be valid + assert!(state.is_access_token_valid(&access_token).await); + + // Now wait a bit + state.clock.advance(Duration::minutes(1)); + + // Exchange it again, this it should fail + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": grant.redirect_uri, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let error: ClientError = response.json(); + assert_eq!(error.error, ClientErrorCode::InvalidGrant); + + // And it should have revoked the token we got + assert!(!state.is_access_token_valid(&access_token).await); + + // Try another one and wait for too long before exchanging it + let mut repo = state.repository().await.unwrap(); + let code = "thisisanothercode"; + let grant = repo + .oauth2_authorization_grant() + .add( + &mut state.rng(), + &state.clock, + &client, + "https://example.com/redirect".parse().unwrap(), + Scope::from_iter([OPENID]), + Some(AuthorizationCode { + code: code.to_owned(), + pkce: None, + }), + Some("state".to_owned()), + Some("nonce".to_owned()), + None, + ResponseMode::Query, + false, + false, + ) + .await + .unwrap(); + + let session = repo + .oauth2_session() + .add( + &mut state.rng(), + &state.clock, + &client, + &browser_session, + grant.scope.clone(), + ) + .await + .unwrap(); + + // And fulfill it + let grant = repo + .oauth2_authorization_grant() + .fulfill(&state.clock, &session, grant) + .await + .unwrap(); + + repo.save().await.unwrap(); + + // Now wait a bit + state.clock.advance(Duration::minutes(15)); + + // Exchange it, it should fail + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": grant.redirect_uri, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let ClientError { error, .. } = response.json(); + assert_eq!(error, ClientErrorCode::InvalidGrant); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_refresh_token_grant(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Provision a client + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/callback"], + "contacts": ["contact@example.com"], + "token_endpoint_auth_method": "none", + "response_types": ["code"], + "grant_types": ["authorization_code"], + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + + let ClientRegistrationResponse { client_id, .. } = response.json(); + + // Let's provision a user and create a session for them. This part is hard to + // test with just HTTP requests, so we'll use the repository directly. + let mut repo = state.repository().await.unwrap(); + + let user = repo + .user() + .add(&mut state.rng(), &state.clock, "alice".to_owned()) + .await + .unwrap(); + + let browser_session = repo + .browser_session() + .add(&mut state.rng(), &state.clock, &user) + .await + .unwrap(); + + // Lookup the client in the database. + let client = repo + .oauth2_client() + .find_by_client_id(&client_id) + .await + .unwrap() + .unwrap(); + + // Get a token pair + let session = repo + .oauth2_session() + .add( + &mut state.rng(), + &state.clock, + &client, + &browser_session, + Scope::from_iter([OPENID]), + ) + .await + .unwrap(); + + let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) = + generate_token_pair( + &mut state.rng(), + &state.clock, + &mut repo, + &session, + Duration::minutes(5), + ) + .await + .unwrap(); + + repo.save().await.unwrap(); + + // First check that the token is valid + assert!(state.is_access_token_valid(&access_token).await); + + // Now call the token endpoint to get an access token. + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + let old_access_token = access_token; + let old_refresh_token = refresh_token; + let response: AccessTokenResponse = response.json(); + let access_token = response.access_token; + let refresh_token = response.refresh_token.expect("to have a refresh token"); + + // Check that the new token is valid + assert!(state.is_access_token_valid(&access_token).await); + + // Check that the old token is no longer valid + assert!(!state.is_access_token_valid(&old_access_token).await); + + // Call it again with the old token, it should fail + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "refresh_token", + "refresh_token": old_refresh_token, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let ClientError { error, .. } = response.json(); + assert_eq!(error, ClientErrorCode::InvalidGrant); + + // Call it again with the new token, it should work + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client.client_id, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let _: AccessTokenResponse = response.json(); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_unsupported_grant(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Provision a client + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/callback"], + "contacts": ["contact@example.com"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["client_credentials"], + "response_types": [], + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + + let response: ClientRegistrationResponse = response.json(); + let client_id = response.client_id; + let client_secret = response.client_secret.expect("to have a client secret"); + + // Call the token endpoint with an unsupported grant type + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + let ClientError { error, .. } = response.json(); + assert_eq!(error, ClientErrorCode::UnsupportedGrantType); + } +} diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 2a2ef0b4..1c46aeaa 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -19,7 +19,7 @@ use axum::{ body::HttpBody, extract::{FromRef, FromRequestParts}, }; -use headers::{Authorization, ContentType, HeaderMapExt, HeaderName, HeaderValue}; +use headers::{Authorization, ContentType, HeaderMapExt, HeaderName}; use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::{MailTransport, Mailer}; diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index cc977a72..fcac3a6f 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -325,20 +325,32 @@ impl fmt::Debug for AuthorizationRequest { /// /// [Authorization Endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1 #[skip_serializing_none] +#[serde_as] #[derive(Serialize, Deserialize, Default, Clone)] -pub struct AuthorizationResponse { +pub struct AuthorizationResponse { /// The authorization code generated by the authorization server. pub code: Option, - /// Other fields of the response. - #[serde(flatten)] - pub response: R, + /// The access token to access the requested scope. + pub access_token: Option, + + /// The type of the access token. + pub token_type: Option, + + /// ID Token value associated with the authenticated session. + pub id_token: Option, + + /// The duration for which the access token is valid. + #[serde_as(as = "Option>")] + pub expires_in: Option, } -impl fmt::Debug for AuthorizationResponse { +impl fmt::Debug for AuthorizationResponse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AuthorizationResponse") - .field("response", &self.response) + .field("token_type", &self.token_type) + .field("id_token", &self.id_token) + .field("expires_in", &self.expires_in) .finish_non_exhaustive() } } diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index 120fca6c..cf4ae5b9 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -203,10 +203,16 @@ mod tests { let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap(); assert_eq!(session, None); - // Create a session out of the grant + // Create an OAuth session let session = repo .oauth2_session() - .create_from_grant(&mut rng, &clock, &grant, &user_session) + .add( + &mut rng, + &clock, + &client, + &user_session, + grant.scope.clone(), + ) .await .unwrap(); diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index e6168310..df66b2c7 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -14,8 +14,9 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; +use mas_data_model::{BrowserSession, Client, Session, SessionState, User}; use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination}; +use oauth2_types::scope::Scope; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; @@ -118,25 +119,25 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { } #[tracing::instrument( - name = "db.oauth2_session.create_from_grant", + name = "db.oauth2_session.add", skip_all, fields( db.statement, %user_session.id, user.id = %user_session.user.id, - %grant.id, - client.id = %grant.client_id, + %client.id, session.id, - session.scope = %grant.scope, + session.scope = %scope, ), err, )] - async fn create_from_grant( + async fn add( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, - grant: &AuthorizationGrant, + client: &Client, user_session: &BrowserSession, + scope: Scope, ) -> Result { let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), rng); @@ -155,8 +156,8 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { "#, Uuid::from(id), Uuid::from(user_session.id), - Uuid::from(grant.client_id), - grant.scope.to_string(), + Uuid::from(client.id), + scope.to_string(), created_at, ) .traced() @@ -168,8 +169,8 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { state: SessionState::Valid, created_at, user_session_id: user_session.id, - client_id: grant.client_id, - scope: grant.scope.clone(), + client_id: client.id, + scope, }) } diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 880992a6..f2b64a38 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -13,7 +13,8 @@ // limitations under the License. use async_trait::async_trait; -use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; +use mas_data_model::{BrowserSession, Client, Session, User}; +use oauth2_types::scope::Scope; use rand_core::RngCore; use ulid::Ulid; @@ -39,7 +40,7 @@ pub trait OAuth2SessionRepository: Send + Sync { /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - /// Create a new [`Session`] from an [`AuthorizationGrant`] + /// Create a new [`Session`] /// /// Returns the newly created [`Session`] /// @@ -47,19 +48,21 @@ pub trait OAuth2SessionRepository: Send + Sync { /// /// * `rng`: The random number generator to use /// * `clock`: The clock used to generate timestamps - /// * `grant`: The [`AuthorizationGrant`] to create the [`Session`] from + /// * `client`: The [`Client`] which created the [`Session`] /// * `user_session`: The [`BrowserSession`] of the user which completed the /// authorization + /// * `scope`: The [`Scope`] of the [`Session`] /// /// # Errors /// /// Returns [`Self::Error`] if the underlying repository fails - async fn create_from_grant( + async fn add( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, - grant: &AuthorizationGrant, + client: &Client, user_session: &BrowserSession, + scope: Scope, ) -> Result; /// Mark a [`Session`] as finished @@ -97,12 +100,13 @@ pub trait OAuth2SessionRepository: Send + Sync { repository_impl!(OAuth2SessionRepository: async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - async fn create_from_grant( + async fn add( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, - grant: &AuthorizationGrant, + client: &Client, user_session: &BrowserSession, + scope: Scope, ) -> Result; async fn finish(&mut self, clock: &dyn Clock, session: Session)