diff --git a/Cargo.lock b/Cargo.lock index bc0b37ba..b5e1e65b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3104,6 +3104,7 @@ dependencies = [ "mas-jose", "oauth2-types", "rand 0.8.5", + "rand_chacha 0.3.1", "serde", "serde_json", "sqlx", diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 5e69f416..178eba1a 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -22,7 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::Encrypter; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; -use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; +use mas_storage::{ + upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, + Repository, +}; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -97,16 +100,17 @@ pub(crate) async fn get( &mut rng, )?; - let session = mas_storage::upstream_oauth2::add_session( - &mut txn, - &mut rng, - &clock, - &provider, - data.state.clone(), - data.code_challenge_verifier, - data.nonce, - ) - .await?; + let session = txn + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + data.state.clone(), + data.code_challenge_verifier, + data.nonce, + ) + .await?; let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar) .add(session.id, provider.id, data.state, query.post_auth_action) diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 6158f941..295f7307 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -26,7 +26,7 @@ use mas_oidc_client::requests::{ }; use mas_router::{Route, UrlBuilder}; use mas_storage::{ - upstream_oauth2::{complete_session, lookup_session}, + upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, Repository, UpstreamOAuthLinkRepository, }; use oauth2_types::errors::ClientErrorCode; @@ -65,6 +65,9 @@ pub(crate) enum RouteError { #[error("Session not found")] SessionNotFound, + #[error("Provider not found")] + ProviderNotFound, + #[error("Provider mismatch")] ProviderMismatch, @@ -105,6 +108,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { + Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(), Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(), Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(), @@ -127,16 +131,24 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; + let provider = txn + .upstream_oauth_provider() + .lookup(provider_id) + .await? + .ok_or(RouteError::ProviderNotFound)?; + let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, _post_auth_action) = sessions_cookie .find_session(provider_id, ¶ms.state) .map_err(|_| RouteError::MissingCookie)?; - let (provider, session) = lookup_session(&mut txn, session_id) + let session = txn + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; - if provider.id != provider_id { + if provider.id != session.provider_id { // The provider in the session cookie should match the one from the URL return Err(RouteError::ProviderMismatch); } @@ -245,7 +257,11 @@ pub(crate) async fn get( .await? }; - let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?; + let session = txn + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, response.id_token) + .await?; + let cookie_jar = sessions_cookie .add_link_to_session(session.id, link.id)? .save(cookie_jar, clock.now()); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 4a109ba6..c01d9799 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,7 +25,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ - upstream_oauth2::{consume_session, lookup_session_on_link}, + upstream_oauth2::UpstreamOAuthSessionRepository, user::{add_user, authenticate_session_with_upstream, lookup_user, start_session}, Repository, UpstreamOAuthLinkRepository, }; @@ -109,12 +109,18 @@ pub(crate) async fn get( .await? .ok_or(RouteError::LinkNotFound)?; - // This checks that we're in a browser session which is allowed to consume this - // link: the upstream auth session should have been started in this browser. - let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + let upstream_session = txn + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + if upstream_session.link_id != Some(link.id) { + return Err(RouteError::SessionNotFound); + } + if upstream_session.consumed() { return Err(RouteError::SessionConsumed); } @@ -127,7 +133,10 @@ pub(crate) async fn get( (Some(mut session), Some(user_id)) if session.user.id == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. - consume_session(&mut txn, &clock, upstream_session).await?; + txn.upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link) .await?; @@ -212,12 +221,18 @@ pub(crate) async fn post( .await? .ok_or(RouteError::LinkNotFound)?; - // This checks that we're in a browser session which is allowed to consume this - // link: the upstream auth session should have been started in this browser. - let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + let upstream_session = txn + .upstream_oauth_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + if upstream_session.link_id != Some(link.id) { + return Err(RouteError::SessionNotFound); + } + if upstream_session.consumed() { return Err(RouteError::SessionConsumed); } @@ -251,7 +266,10 @@ pub(crate) async fn post( _ => return Err(RouteError::InvalidFormAction), }; - consume_session(&mut txn, &clock, upstream_session).await?; + txn.upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?; let cookie_jar = sessions_cookie diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 71240129..fb6c0fdc 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -14,8 +14,8 @@ serde_json = "1.0.91" thiserror = "1.0.38" tracing = "0.1.37" -# Password hashing rand = "0.8.5" +rand_chacha = "0.3.1" url = { version = "2.3.1", features = ["serde"] } uuid = "1.2.2" ulid = { version = "1.0.0", features = ["uuid", "serde"] } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 52b0118f..8167fef3 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -521,81 +521,6 @@ }, "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " }, - "2ca7b990c11e84db62fb7887a2bc3410ec1eee2f6a0ec124db36575111970ca9": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_authorization_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_link_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "state", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "code_challenge_verifier", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "nonce", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "id_token", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "completed_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 9, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - true, - false, - true, - false, - true, - true - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n " - }, "2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": { "describe": { "columns": [ @@ -708,21 +633,6 @@ }, "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " }, - "2fb8f1aef96571a6f3f6260d7836de99ff24ba1947747e08b0e8d64097507442": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz", - "Text", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " - }, "360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": { "describe": { "columns": [], @@ -1388,7 +1298,24 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, - "65c7600f1af07cb6ea49d89ae6fbca5374a57c5a866c8aadd7b75ed1d2d1d0cd": { + "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n " + }, + "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c": { "describe": { "columns": [ { @@ -1440,41 +1367,6 @@ "name": "consumed_at", "ordinal": 9, "type_info": "Timestamptz" - }, - { - "name": "provider_issuer", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "provider_scope", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "provider_client_id", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "provider_encrypted_client_secret", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_auth_method", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_signing_alg", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "provider_created_at", - "ordinal": 16, - "type_info": "Timestamptz" } ], "nullable": [ @@ -1487,14 +1379,7 @@ true, false, true, - true, - false, - false, - false, - true, - false, - true, - false + true ], "parameters": { "Left": [ @@ -1502,7 +1387,20 @@ ] } }, - "query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.upstream_oauth_link_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.id_token,\n ua.created_at,\n ua.completed_at,\n ua.consumed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n " + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n " + }, + "689ffbfc5137ec788e89062ad679bbe6b23a8861c09a7246dc1659c28f12bf8d": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " }, "6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": { "describe": { @@ -2420,6 +2318,21 @@ }, "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n " }, + "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Text", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " + }, "bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": { "describe": { "columns": [], @@ -2702,19 +2615,6 @@ }, "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " }, - "e30562e9637d3a723a91adca6336a8d083657ce6d7fe9551fcd6a9d672453d3c": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " - }, "e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": { "describe": { "columns": [], @@ -2773,22 +2673,5 @@ } }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " - }, - "fb71ac6539039313fd90b29ac943330e54c7b62b2778727726e2f60a554f9c5a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n " } } \ No newline at end of file diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index c1d259fc..9e6ca807 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -14,7 +14,10 @@ use sqlx::{PgConnection, Postgres, Transaction}; -use crate::upstream_oauth2::{PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository}; +use crate::upstream_oauth2::{ + PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, + PgUpstreamOAuthSessionRepository, +}; pub trait Repository { type UpstreamOAuthLinkRepository<'c> @@ -25,13 +28,19 @@ pub trait Repository { where Self: 'c; + type UpstreamOAuthSessionRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; } impl Repository for PgConnection { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; + type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -40,11 +49,16 @@ impl Repository for PgConnection { fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { PgUpstreamOAuthProviderRepository::new(self) } + + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { + PgUpstreamOAuthSessionRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; + type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -53,4 +67,8 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { PgUpstreamOAuthProviderRepository::new(self) } + + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { + PgUpstreamOAuthSessionRepository::new(self) + } } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d29b5e71..1abcd1d0 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -19,7 +19,111 @@ mod session; pub use self::{ link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository}, - session::{ - add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, - }, + session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository}, }; + +#[cfg(test)] +mod tests { + use oauth2_types::scope::{Scope, OPENID}; + use rand::SeedableRng; + use sqlx::PgPool; + + use super::*; + use crate::{Clock, Repository}; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_repository(pool: PgPool) -> Result<(), Box> { + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::default(); + let mut conn = pool.acquire().await?; + + // The provider list should be empty at the start + let all_providers = conn.upstream_oauth_provider().all().await?; + assert!(all_providers.is_empty()); + + // Let's add a provider + let provider = conn + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + "https://example.com/".to_owned(), + Scope::from_iter([OPENID]), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + "client-id".to_owned(), + None, + ) + .await?; + + // Look it up in the database + let provider = conn + .upstream_oauth_provider() + .lookup(provider.id) + .await? + .expect("provider to be found in the database"); + assert_eq!(provider.issuer, "https://example.com/"); + assert_eq!(provider.client_id, "client-id"); + + // Start a session + let session = conn + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + "some-state".to_owned(), + None, + "some-nonce".to_owned(), + ) + .await?; + + // Look it up in the database + let session = conn + .upstream_oauth_session() + .lookup(session.id) + .await? + .expect("session to be found in the database"); + assert_eq!(session.provider_id, provider.id); + assert_eq!(session.link_id, None); + assert!(!session.completed()); + assert!(!session.consumed()); + + // Create a link + let link = conn + .upstream_oauth_link() + .add(&mut rng, &clock, &provider, "a-subject".to_owned()) + .await?; + + // We can look it up by its ID + conn.upstream_oauth_link() + .lookup(link.id) + .await? + .expect("link to be found in database"); + + // or by its subject + let link = conn + .upstream_oauth_link() + .find_by_subject(&provider, "a-subject") + .await? + .expect("link to be found in database"); + assert_eq!(link.subject, "a-subject"); + assert_eq!(link.provider_id, provider.id); + + let session = conn + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, None) + .await?; + assert!(session.completed()); + assert!(!session.consumed()); + assert_eq!(session.link_id, Some(link.id)); + + let session = conn + .upstream_oauth_session() + .consume(&clock, session) + .await?; + assert!(session.consumed()); + + Ok(()) + } +} diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 5e013f24..f8dffcf3 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -12,261 +12,62 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; -use rand::Rng; -use sqlx::PgExecutor; +use rand::RngCore; +use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{Clock, DatabaseError, LookupResultExt}; -struct SessionAndProviderLookup { - upstream_oauth_authorization_session_id: Uuid, - upstream_oauth_provider_id: Uuid, - upstream_oauth_link_id: Option, - state: String, - code_challenge_verifier: Option, - nonce: String, - id_token: Option, - created_at: DateTime, - completed_at: Option>, - consumed_at: Option>, - provider_issuer: String, - provider_scope: String, - provider_client_id: String, - provider_encrypted_client_secret: Option, - provider_token_endpoint_auth_method: String, - provider_token_endpoint_signing_alg: Option, - provider_created_at: DateTime, +#[async_trait] +pub trait UpstreamOAuthSessionRepository: Send + Sync { + type Error; + + /// Lookup a session by its ID + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error>; + + /// Add a session to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result; + + /// Mark a session as completed and associate the given link + async fn complete_with_link( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result; + + /// Mark a session as consumed + async fn consume( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result; } -/// Lookup a session and its provider by its ID -#[tracing::instrument( - skip_all, - fields(upstream_oauth_authorization_session.id = %id), - err, -)] -pub async fn lookup_session( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - SessionAndProviderLookup, - r#" - SELECT - ua.upstream_oauth_authorization_session_id, - ua.upstream_oauth_provider_id, - ua.upstream_oauth_link_id, - ua.state, - ua.code_challenge_verifier, - ua.nonce, - ua.id_token, - ua.created_at, - ua.completed_at, - ua.consumed_at, - up.issuer AS "provider_issuer", - up.scope AS "provider_scope", - up.client_id AS "provider_client_id", - up.encrypted_client_secret AS "provider_encrypted_client_secret", - up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method", - up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg", - up.created_at AS "provider_created_at" - FROM upstream_oauth_authorization_sessions ua - INNER JOIN upstream_oauth_providers up - USING (upstream_oauth_provider_id) - WHERE upstream_oauth_authorization_session_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = res.upstream_oauth_provider_id.into(); - let provider = UpstreamOAuthProvider { - id, - issuer: res.provider_issuer, - scope: res.provider_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("scope") - .row(id) - .source(e) - })?, - client_id: res.provider_client_id, - encrypted_client_secret: res.provider_encrypted_client_secret, - token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err( - |e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - }, - )?, - token_endpoint_signing_alg: res - .provider_token_endpoint_signing_alg - .map(|x| x.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_signing_alg") - .row(id) - .source(e) - })?, - created_at: res.provider_created_at, - }; - - let session = UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: provider.id, - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - }; - - Ok(Some((provider, session))) +pub struct PgUpstreamOAuthSessionRepository<'c> { + conn: &'c mut PgConnection, } -/// Add a session to the database -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - upstream_oauth_authorization_session.id, - ), - err, -)] -pub async fn add_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - state: String, - code_challenge_verifier: Option, - nonce: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record( - "upstream_oauth_authorization_session.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_authorization_sessions ( - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - state, - code_challenge_verifier, - nonce, - created_at, - completed_at, - consumed_at, - id_token - ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &state, - code_challenge_verifier.as_deref(), - nonce, - created_at, - ) - .execute(executor) - .await?; - - Ok(UpstreamOAuthAuthorizationSession { - id, - provider_id: upstream_oauth_provider.id, - link_id: None, - state, - code_challenge_verifier, - nonce, - id_token: None, - created_at, - completed_at: None, - consumed_at: None, - }) -} - -/// Mark a session as completed and associate the given link -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_authorization_session.id, - %upstream_oauth_link.id, - ), - err, -)] -pub async fn complete_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, - upstream_oauth_link: &UpstreamOAuthLink, - id_token: Option, -) -> Result { - let completed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET upstream_oauth_link_id = $1, - completed_at = $2, - id_token = $3 - WHERE upstream_oauth_authorization_session_id = $4 - "#, - Uuid::from(upstream_oauth_link.id), - completed_at, - id_token, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .execute(executor) - .await?; - - upstream_oauth_authorization_session.completed_at = Some(completed_at); - upstream_oauth_authorization_session.id_token = id_token; - - Ok(upstream_oauth_authorization_session) -} - -/// Mark a session as consumed -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_authorization_session.id, - ), - err, -)] -pub async fn consume_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, -) -> Result { - let consumed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET consumed_at = $1 - WHERE upstream_oauth_authorization_session_id = $2 - "#, - consumed_at, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .execute(executor) - .await?; - - upstream_oauth_authorization_session.consumed_at = Some(consumed_at); - - Ok(upstream_oauth_authorization_session) +impl<'c> PgUpstreamOAuthSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } } struct SessionLookup { @@ -282,57 +83,191 @@ struct SessionLookup { consumed_at: Option>, } -/// Lookup a session, which belongs to a link, by its ID -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_authorization_session.id = %id, - %upstream_oauth_link.id, - ), - err, -)] -pub async fn lookup_session_on_link( - executor: impl PgExecutor<'_>, - upstream_oauth_link: &UpstreamOAuthLink, - id: Ulid, -) -> Result, sqlx::Error> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - upstream_oauth_link_id, - state, - code_challenge_verifier, - nonce, - id_token, - created_at, - completed_at, - consumed_at - FROM upstream_oauth_authorization_sessions - WHERE upstream_oauth_authorization_session_id = $1 - AND upstream_oauth_link_id = $2 - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_link.id), - ) - .fetch_one(executor) - .await - .to_option()?; +#[async_trait] +impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { + type Error = DatabaseError; - let Some(res) = res else { return Ok(None) }; + #[tracing::instrument( + skip_all, + fields(upstream_oauth_provider.id = %id), + err, + )] + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + upstream_oauth_link_id, + state, + code_challenge_verifier, + nonce, + id_token, + created_at, + completed_at, + consumed_at + FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_authorization_session_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; - Ok(Some(UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: res.upstream_oauth_provider_id.into(), - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - })) + let Some(res) = res else { return Ok(None) }; + + let session = UpstreamOAuthAuthorizationSession { + id: res.upstream_oauth_authorization_session_id.into(), + provider_id: res.upstream_oauth_provider_id.into(), + link_id: res.upstream_oauth_link_id.map(Ulid::from), + state: res.state, + code_challenge_verifier: res.code_challenge_verifier, + nonce: res.nonce, + id_token: res.id_token, + created_at: res.created_at, + completed_at: res.completed_at, + consumed_at: res.consumed_at, + }; + + Ok(Some(session)) + } + + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + upstream_oauth_authorization_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "upstream_oauth_authorization_session.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_authorization_sessions ( + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at, + consumed_at, + id_token + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &state, + code_challenge_verifier.as_deref(), + nonce, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthAuthorizationSession { + id, + provider_id: upstream_oauth_provider.id, + link_id: None, + state, + code_challenge_verifier, + nonce, + id_token: None, + created_at, + completed_at: None, + consumed_at: None, + }) + } + + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_authorization_session.id, + %upstream_oauth_link.id, + ), + err, + )] + async fn complete_with_link( + &mut self, + clock: &Clock, + mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + let completed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET upstream_oauth_link_id = $1, + completed_at = $2, + id_token = $3 + WHERE upstream_oauth_authorization_session_id = $4 + "#, + Uuid::from(upstream_oauth_link.id), + completed_at, + id_token, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .execute(&mut *self.conn) + .await?; + + upstream_oauth_authorization_session.completed_at = Some(completed_at); + upstream_oauth_authorization_session.id_token = id_token; + upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id); + + Ok(upstream_oauth_authorization_session) + } + + /// Mark a session as consumed + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result { + let consumed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET consumed_at = $1 + WHERE upstream_oauth_authorization_session_id = $2 + "#, + consumed_at, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .execute(&mut *self.conn) + .await?; + + upstream_oauth_authorization_session.consumed_at = Some(consumed_at); + + Ok(upstream_oauth_authorization_session) + } }