diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 0e6771b4..6f212369 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,7 +31,7 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::client::OAuth2ClientRepository, DatabaseError, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, DatabaseError, Repository}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use sqlx::PgConnection; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 62d62db8..d159f3e3 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,7 +18,7 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::OAuth2ClientRepository, + oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, Repository, diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index 29454feb..aec48ac1 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use chrono::{DateTime, Utc}; use oauth2_types::scope::Scope; use serde::Serialize; use ulid::Ulid; @@ -22,4 +23,5 @@ pub struct Session { pub user_session_id: Ulid, pub client_id: Ulid, pub scope: Scope, + pub finished_at: Option>, } diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index ffa63396..d01be16c 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -31,7 +31,7 @@ use async_graphql::{ Context, Description, EmptyMutation, EmptySubscription, ID, }; use mas_storage::{ - oauth2::client::OAuth2ClientRepository, + oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 8e418e6c..0ab2bc68 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,9 +14,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{ - oauth2::client::OAuth2ClientRepository, user::BrowserSessionRepository, Repository, -}; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index dc40d6cd..2f241ced 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -18,6 +18,7 @@ use async_graphql::{ }; use chrono::{DateTime, Utc}; use mas_storage::{ + oauth2::OAuth2SessionRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, }; @@ -241,14 +242,13 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::oauth2::get_paginated_user_oauth_sessions( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .oauth2_session() + .list_paginated(&self.0, before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|s| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)), OAuth2Session(s), diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index eb4b8889..01c89ff3 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -25,9 +25,13 @@ use mas_data_model::{AuthorizationGrant, BrowserSession}; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::{derive_session, fulfill_grant, get_grant_by_id}, - consent::fetch_client_consent, +use mas_storage::{ + oauth2::{ + authorization_grant::{fulfill_grant, get_grant_by_id}, + consent::fetch_client_consent, + OAuth2SessionRepository, + }, + Repository, }; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; @@ -193,7 +197,10 @@ pub(crate) async fn complete( } // All good, let's start the session - let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?; + let session = txn + .oauth2_session() + .create_from_grant(&mut rng, &clock, &grant, &browser_session) + .await?; let grant = fulfill_grant(&mut txn, grant, session.clone()).await?; diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 36d15d2b..cfcd936e 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - oauth2::{authorization_grant::new_authorization_grant, client::OAuth2ClientRepository}, + oauth2::{authorization_grant::new_authorization_grant, OAuth2ClientRepository}, Repository, }; use mas_templates::Templates; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index b12194eb..a6ff6158 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 473dcab8..391bdde2 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -35,8 +35,8 @@ use mas_storage::{ oauth2::{ access_token::{add_access_token, revoke_access_token}, authorization_grant::{exchange_grant, lookup_grant_by_code}, - end_oauth_session, refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, + OAuth2SessionRepository, }, user::BrowserSessionRepository, Repository, @@ -234,7 +234,7 @@ async fn authorization_code_grant( // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); - end_oauth_session(&mut txn, &clock, session).await?; + txn.oauth2_session().finish(&clock, session).await?; txn.commit().await?; } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 699b049a..d2b2b615 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -29,7 +29,7 @@ use mas_jose::{ use mas_keystore::Keystore; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::OAuth2ClientRepository, + oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, }; diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 4784c030..740e9a3a 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -629,6 +629,22 @@ }, "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n FROM compat_access_tokens ca\n WHERE ca.access_token = $1\n AND ca.compat_session_id = cs.compat_session_id\n AND cs.finished_at IS NULL\n RETURNING cs.compat_session_id\n " }, + "583ae9a0db9cd55fa57a179339550f3dab1bfc76f35ad488e1560ea37f7ed029": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth2_sessions\n ( oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n )\n VALUES ($1, $2, $3, $4, $5)\n " + }, "5b5d5c82da37c6f2d8affacfb02119965c04d1f2a9cc53dbf5bd4c12584969a0": { "describe": { "columns": [], @@ -1325,19 +1341,6 @@ }, "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " }, - "9c1ef3114bfe22884d893bb11dc6054421c28cce4bd828cfe6a4ad46c062481a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n " - }, "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { "describe": { "columns": [ @@ -1469,6 +1472,19 @@ }, "query": "\n INSERT INTO user_email_confirmation_codes\n (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, + "b700dc3f7d0f86f4904725d8357e34b7e457f857ed37c467c314142877fd5367": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n " + }, "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": { "describe": { "columns": [], @@ -1484,21 +1500,6 @@ }, "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " }, - "bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz", - "Uuid" - ] - } - }, - "query": "\n INSERT INTO oauth2_sessions\n (oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at)\n SELECT\n $1,\n $2,\n og.oauth2_client_id,\n og.scope,\n $3\n FROM\n oauth2_authorization_grants og\n WHERE\n og.oauth2_authorization_grant_id = $4\n " - }, "bd1f6daa5fa1b10250c01f8b3fbe451646a9ceeefa6f72b9c4e29b6d05f17641": { "describe": { "columns": [], diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index c85d1b48..cadb93e0 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -134,6 +134,7 @@ pub async fn lookup_active_access_token( client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, + finished_at: None, }; Ok(Some((access_token, session))) diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 3a18ef41..29577d59 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -16,8 +16,7 @@ use std::num::NonZeroU32; use chrono::{DateTime, Utc}; use mas_data_model::{ - AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Client, Pkce, - Session, + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, }; use mas_iana::oauth::PkceCodeChallengeMethod; use oauth2_types::{requests::ResponseMode, scope::Scope}; @@ -27,7 +26,7 @@ use ulid::Ulid; use url::Url; use uuid::Uuid; -use super::client::OAuth2ClientRepository; +use super::OAuth2ClientRepository; use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository}; #[tracing::instrument( @@ -186,6 +185,7 @@ impl GrantLookup { client_id: client.id, user_session_id: user_session_id.into(), scope, + finished_at: None, }; Some(session) @@ -431,59 +431,6 @@ pub async fn lookup_grant_by_code( Ok(Some(grant)) } -#[tracing::instrument( - skip_all, - fields( - %grant.id, - client.id = %grant.client.id, - session.id, - user_session.id = %browser_session.id, - user.id = %browser_session.user.id, - ), - err, -)] -pub async fn derive_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - grant: &AuthorizationGrant, - browser_session: BrowserSession, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_sessions - (oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at) - SELECT - $1, - $2, - og.oauth2_client_id, - og.scope, - $3 - FROM - oauth2_authorization_grants og - WHERE - og.oauth2_authorization_grant_id = $4 - "#, - Uuid::from(id), - Uuid::from(browser_session.id), - created_at, - Uuid::from(grant.id), - ) - .execute(executor) - .await?; - - Ok(Session { - id, - user_session_id: browser_session.id, - client_id: grant.client.id, - scope: grant.scope.clone(), - }) -} - #[tracing::instrument( skip_all, fields( diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 66313139..b02216a6 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -12,129 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use mas_data_model::{Session, User}; -use sqlx::{PgConnection, PgExecutor, QueryBuilder}; -use tracing::{info_span, Instrument}; -use ulid::Ulid; -use uuid::Uuid; - -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, DatabaseInconsistencyError, -}; - pub mod access_token; pub mod authorization_grant; -pub mod client; +mod client; pub mod consent; pub mod refresh_token; -pub mod session; +mod session; -#[tracing::instrument( - skip_all, - fields( - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - ), - err, -)] -pub async fn end_oauth_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - session: Session, -) -> Result<(), DatabaseError> { - let finished_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_sessions - SET finished_at = $2 - WHERE oauth2_session_id = $1 - "#, - Uuid::from(session.id), - finished_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[derive(sqlx::FromRow)] -struct OAuthSessionLookup { - oauth2_session_id: Uuid, - user_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_oauth_sessions( - conn: &mut PgConnection, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - os.oauth2_session_id, - os.user_session_id, - os.oauth2_client_id, - os.scope, - os.created_at, - os.finished_at - FROM oauth2_sessions os - LEFT JOIN user_sessions us - USING (user_session_id) - "#, - ); - - query - .push(" WHERE us.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("oauth2_session_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated user oauth sessions", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(&mut *conn) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, DatabaseInconsistencyError> = page - .into_iter() - .map(|item| { - let id = Ulid::from(item.oauth2_session_id); - let scope = item.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(id) - .source(e) - })?; - - Ok(Session { - id: Ulid::from(item.oauth2_session_id), - client_id: item.oauth2_client_id.into(), - user_session_id: item.user_session_id.into(), - scope, - }) - }) - .collect(); - - Ok((has_previous_page, has_next_page, page?)) -} +pub use self::{ + client::{OAuth2ClientRepository, PgOAuth2ClientRepository}, + session::{OAuth2SessionRepository, PgOAuth2SessionRepository}, +}; diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index bf223a79..61ace6fa 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -158,6 +158,7 @@ pub async fn lookup_active_refresh_token( client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, + finished_at: None, }; Ok(Some((refresh_token, session))) diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 71efb5bd..5841a1d9 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -13,8 +13,231 @@ // limitations under the License. use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, + Clock, DatabaseError, DatabaseInconsistencyError, +}; #[async_trait] pub trait OAuth2SessionRepository { type Error; + + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result; + + async fn finish(&mut self, clock: &Clock, session: Session) -> Result; + + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; +} + +pub struct PgOAuth2SessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2SessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct OAuthSessionLookup { + oauth2_session_id: Uuid, + user_session_id: Uuid, + oauth2_client_id: Uuid, + scope: String, + finished_at: Option>, +} + +impl TryFrom for Session { + type Error = DatabaseInconsistencyError; + + fn try_from(value: OAuthSessionLookup) -> Result { + let id = Ulid::from(value.oauth2_session_id); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_sessions") + .column("scope") + .row(id) + .source(e) + })?; + + Ok(Session { + id, + client_id: value.oauth2_client_id.into(), + user_session_id: value.user_session_id.into(), + scope, + finished_at: value.finished_at, + }) + } +} + +#[async_trait] +impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_session.create_from_grant", + skip_all, + fields( + db.statement, + %user_session.id, + user.id = %user_session.user.id, + %grant.id, + client.id = %grant.client.id, + session.id, + session.scope = %grant.scope, + ), + err, + )] + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_sessions + ( oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + ) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + Uuid::from(grant.client.id), + grant.scope.to_string(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Session { + id, + user_session_id: user_session.id, + client_id: grant.client.id, + scope: grant.scope.clone(), + finished_at: None, + }) + } + + #[tracing::instrument( + name = "db.oauth2_session.finish", + skip_all, + fields( + db.statement, + %session.id, + %session.scope, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + mut session: Session, + ) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_sessions + SET finished_at = $2 + WHERE oauth2_session_id = $1 + "#, + Uuid::from(session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + session.finished_at = Some(finished_at); + + Ok(session) + } + + #[tracing::instrument( + name = "db.oauth2_session.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions os + "#, + ); + + query + .push(" WHERE us.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("oauth2_session_id", before, after, first, last)?; + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?; + + let edges: Result, DatabaseInconsistencyError> = + edges.into_iter().map(Session::try_from).collect(); + + Ok(Page { + has_next_page, + has_previous_page, + edges: edges?, + }) + } } diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 4bca2253..8eda5701 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -15,7 +15,7 @@ use sqlx::{PgConnection, Postgres, Transaction}; use crate::{ - oauth2::client::PgOAuth2ClientRepository, + oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository}, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -59,6 +59,10 @@ pub trait Repository { where Self: 'c; + type OAuth2SessionRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; @@ -67,6 +71,7 @@ pub trait Repository { fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; } impl Repository for PgConnection { @@ -78,6 +83,7 @@ impl Repository for PgConnection { type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -110,6 +116,10 @@ impl Repository for PgConnection { fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { PgOAuth2ClientRepository::new(self) } + + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { + PgOAuth2SessionRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { @@ -121,6 +131,7 @@ impl<'t> Repository for Transaction<'t, Postgres> { type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -153,4 +164,8 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { PgOAuth2ClientRepository::new(self) } + + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { + PgOAuth2SessionRepository::new(self) + } }