From fb7c6f4dd181ac11d7df9196662a5efbc3d65899 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 5 Jan 2023 16:49:19 +0100 Subject: [PATCH] storage: do less joins on authorization grants and refresh tokens --- .../src/oauth2/authorization_grant.rs | 12 +- crates/data-model/src/tokens.rs | 2 +- crates/handlers/src/oauth2/token.rs | 58 ++++---- crates/storage/src/oauth2/access_token.rs | 6 +- .../storage/src/oauth2/authorization_grant.rs | 134 +++++++----------- crates/storage/src/oauth2/refresh_token.rs | 36 +---- crates/storage/src/oauth2/session.rs | 41 +++++- 7 files changed, 140 insertions(+), 149 deletions(-) diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index cb85a265..a7222cda 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -63,11 +63,11 @@ pub enum AuthorizationGrantStage { #[default] Pending, Fulfilled { - session: Session, + session_id: Ulid, fulfilled_at: DateTime, }, Exchanged { - session: Session, + session_id: Ulid, fulfilled_at: DateTime, exchanged_at: DateTime, }, @@ -85,12 +85,12 @@ impl AuthorizationGrantStage { pub fn fulfill( self, fulfilled_at: DateTime, - session: Session, + session: &Session, ) -> Result { match self { Self::Pending => Ok(Self::Fulfilled { fulfilled_at, - session, + session_id: session.id, }), _ => Err(InvalidTransitionError), } @@ -100,11 +100,11 @@ impl AuthorizationGrantStage { match self { Self::Fulfilled { fulfilled_at, - session, + session_id, } => Ok(Self::Exchanged { fulfilled_at, exchanged_at, - session, + session_id, }), _ => Err(InvalidTransitionError), } diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 93b29f6d..7b058820 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -33,7 +33,7 @@ pub struct RefreshToken { pub id: Ulid, pub refresh_token: String, pub created_at: DateTime, - pub access_token: Option, + pub access_token_id: Option, } /// Type of token to generate or validate diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 391bdde2..eb0e20dd 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -109,12 +109,18 @@ pub(crate) enum RouteError { #[error("failed to load browser session")] NoSuchBrowserSession, + + #[error("failed to load oauth session")] + NoSuchOAuthSession, } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_) | Self::InvalidSigningKey | Self::NoSuchBrowserSession => ( + Self::Internal(_) + | Self::InvalidSigningKey + | Self::NoSuchBrowserSession + | Self::NoSuchOAuthSession => ( StatusCode::INTERNAL_SERVER_ERROR, Json(ClientError::from(ClientErrorCode::ServerError)), ), @@ -219,7 +225,7 @@ async fn authorization_code_grant( let now = clock.now(); - let session = match authz_grant.stage { + let session_id = match authz_grant.stage { AuthorizationGrantStage::Cancelled { cancelled_at } => { debug!(%cancelled_at, "Authorization grant was cancelled"); return Err(RouteError::InvalidGrant); @@ -227,13 +233,18 @@ async fn authorization_code_grant( AuthorizationGrantStage::Exchanged { exchanged_at, fulfilled_at, - session, + session_id, } => { debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged"); // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); + let session = txn + .oauth2_session() + .lookup(session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; txn.oauth2_session().finish(&clock, session).await?; txn.commit().await?; } @@ -245,7 +256,7 @@ async fn authorization_code_grant( return Err(RouteError::InvalidGrant); } AuthorizationGrantStage::Fulfilled { - ref session, + session_id, fulfilled_at, } => { if now - fulfilled_at > Duration::minutes(10) { @@ -253,10 +264,16 @@ async fn authorization_code_grant( return Err(RouteError::InvalidGrant); } - session + session_id } }; + let session = txn + .oauth2_session() + .lookup(session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; + // This should never happen, since we looked up in the database using the code let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?; @@ -284,23 +301,16 @@ async fn authorization_code_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let access_token = add_access_token( - &mut txn, - &mut rng, - &clock, - session, - access_token_str.clone(), - ttl, - ) - .await?; + let access_token = + add_access_token(&mut txn, &mut rng, &clock, &session, access_token_str, ttl).await?; - let _refresh_token = add_refresh_token( + let refresh_token = add_refresh_token( &mut txn, &mut rng, &clock, - session, - access_token, - refresh_token_str.clone(), + &session, + &access_token, + refresh_token_str, ) .await?; @@ -328,7 +338,7 @@ async fn authorization_code_grant( .signing_key_for_algorithm(&alg) .ok_or(RouteError::InvalidSigningKey)?; - claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token_str)?)?; + claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?; claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?; let signer = key.params().signing_key_for_alg(&alg)?; @@ -341,9 +351,9 @@ async fn authorization_code_grant( None }; - let mut params = AccessTokenResponse::new(access_token_str) + let mut params = AccessTokenResponse::new(access_token.access_token) .with_expires_in(ttl) - .with_refresh_token(refresh_token_str) + .with_refresh_token(refresh_token.refresh_token) .with_scope(session.scope.clone()); if let Some(id_token) = id_token { @@ -392,15 +402,15 @@ async fn refresh_token_grant( &mut rng, &clock, &session, - new_access_token, + &new_access_token, refresh_token_str, ) .await?; consume_refresh_token(&mut txn, &clock, &refresh_token).await?; - if let Some(access_token) = refresh_token.access_token { - revoke_access_token(&mut txn, &clock, access_token).await?; + if let Some(access_token_id) = refresh_token.access_token_id { + revoke_access_token(&mut txn, &clock, access_token_id).await?; } let params = AccessTokenResponse::new(access_token_str) diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index cadb93e0..cd4cafbf 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -142,13 +142,13 @@ pub async fn lookup_active_access_token( #[tracing::instrument( skip_all, - fields(%access_token.id), + fields(access_token.id = %access_token_id), err, )] pub async fn revoke_access_token( executor: impl PgExecutor<'_>, clock: &Clock, - access_token: AccessToken, + access_token_id: Ulid, ) -> Result<(), DatabaseError> { let revoked_at = clock.now(); let res = sqlx::query!( @@ -157,7 +157,7 @@ pub async fn revoke_access_token( SET revoked_at = $2 WHERE oauth2_access_token_id = $1 "#, - Uuid::from(access_token.id), + Uuid::from(access_token_id), revoked_at, ) .execute(executor) diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 29577d59..33bd8b5d 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -149,7 +149,6 @@ struct GrantLookup { oauth2_authorization_grant_requires_consent: bool, oauth2_client_id: Uuid, oauth2_session_id: Option, - user_session_id: Option, } impl GrantLookup { @@ -176,45 +175,22 @@ impl GrantLookup { .row(id) })?; - let session = match (self.oauth2_session_id, self.user_session_id) { - (Some(session_id), Some(user_session_id)) => { - let scope = scope.clone(); - - let session = Session { - id: session_id.into(), - client_id: client.id, - user_session_id: user_session_id.into(), - scope, - finished_at: None, - }; - - Some(session) - } - (None, None) => None, - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("oauth2_session_id") - .row(id) - .into(), - ) - } - }; - let stage = match ( self.oauth2_authorization_grant_fulfilled_at, self.oauth2_authorization_grant_exchanged_at, self.oauth2_authorization_grant_cancelled_at, - session, + self.oauth2_session_id, ) { (None, None, None, None) => AuthorizationGrantStage::Pending, - (Some(fulfilled_at), None, None, Some(session)) => AuthorizationGrantStage::Fulfilled { - session, - fulfilled_at, - }, - (Some(fulfilled_at), Some(exchanged_at), None, Some(session)) => { + (Some(fulfilled_at), None, None, Some(session_id)) => { + AuthorizationGrantStage::Fulfilled { + session_id: session_id.into(), + fulfilled_at, + } + } + (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => { AuthorizationGrantStage::Exchanged { - session, + session_id: session_id.into(), fulfilled_at, exchanged_at, } @@ -343,32 +319,29 @@ pub async fn get_grant_by_id( let res = sqlx::query_as!( GrantLookup, r#" - SELECT og.oauth2_authorization_grant_id - , og.created_at AS oauth2_authorization_grant_created_at - , og.cancelled_at AS oauth2_authorization_grant_cancelled_at - , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at - , og.exchanged_at AS oauth2_authorization_grant_exchanged_at - , og.scope AS oauth2_authorization_grant_scope - , og.state AS oauth2_authorization_grant_state - , og.redirect_uri AS oauth2_authorization_grant_redirect_uri - , og.response_mode AS oauth2_authorization_grant_response_mode - , og.nonce AS oauth2_authorization_grant_nonce - , og.max_age AS oauth2_authorization_grant_max_age - , og.oauth2_client_id AS oauth2_client_id - , og.authorization_code AS oauth2_authorization_grant_code - , og.response_type_code AS oauth2_authorization_grant_response_type_code - , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token - , og.code_challenge AS oauth2_authorization_grant_code_challenge - , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method - , og.requires_consent AS oauth2_authorization_grant_requires_consent - , os.oauth2_session_id AS "oauth2_session_id?" - , os.user_session_id AS "user_session_id?" + SELECT oauth2_authorization_grant_id + , created_at AS oauth2_authorization_grant_created_at + , cancelled_at AS oauth2_authorization_grant_cancelled_at + , fulfilled_at AS oauth2_authorization_grant_fulfilled_at + , exchanged_at AS oauth2_authorization_grant_exchanged_at + , scope AS oauth2_authorization_grant_scope + , state AS oauth2_authorization_grant_state + , redirect_uri AS oauth2_authorization_grant_redirect_uri + , response_mode AS oauth2_authorization_grant_response_mode + , nonce AS oauth2_authorization_grant_nonce + , max_age AS oauth2_authorization_grant_max_age + , oauth2_client_id AS oauth2_client_id + , authorization_code AS oauth2_authorization_grant_code + , response_type_code AS oauth2_authorization_grant_response_type_code + , response_type_id_token AS oauth2_authorization_grant_response_type_id_token + , code_challenge AS oauth2_authorization_grant_code_challenge + , code_challenge_method AS oauth2_authorization_grant_code_challenge_method + , requires_consent AS oauth2_authorization_grant_requires_consent + , oauth2_session_id AS "oauth2_session_id?" FROM - oauth2_authorization_grants og - LEFT JOIN oauth2_sessions os - USING (oauth2_session_id) + oauth2_authorization_grants - WHERE og.oauth2_authorization_grant_id = $1 + WHERE oauth2_authorization_grant_id = $1 "#, Uuid::from(id), ) @@ -391,32 +364,29 @@ pub async fn lookup_grant_by_code( let res = sqlx::query_as!( GrantLookup, r#" - SELECT og.oauth2_authorization_grant_id - , og.created_at AS oauth2_authorization_grant_created_at - , og.cancelled_at AS oauth2_authorization_grant_cancelled_at - , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at - , og.exchanged_at AS oauth2_authorization_grant_exchanged_at - , og.scope AS oauth2_authorization_grant_scope - , og.state AS oauth2_authorization_grant_state - , og.redirect_uri AS oauth2_authorization_grant_redirect_uri - , og.response_mode AS oauth2_authorization_grant_response_mode - , og.nonce AS oauth2_authorization_grant_nonce - , og.max_age AS oauth2_authorization_grant_max_age - , og.oauth2_client_id AS oauth2_client_id - , og.authorization_code AS oauth2_authorization_grant_code - , og.response_type_code AS oauth2_authorization_grant_response_type_code - , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token - , og.code_challenge AS oauth2_authorization_grant_code_challenge - , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method - , og.requires_consent AS oauth2_authorization_grant_requires_consent - , os.oauth2_session_id AS "oauth2_session_id?" - , os.user_session_id AS "user_session_id?" + SELECT oauth2_authorization_grant_id + , created_at AS oauth2_authorization_grant_created_at + , cancelled_at AS oauth2_authorization_grant_cancelled_at + , fulfilled_at AS oauth2_authorization_grant_fulfilled_at + , exchanged_at AS oauth2_authorization_grant_exchanged_at + , scope AS oauth2_authorization_grant_scope + , state AS oauth2_authorization_grant_state + , redirect_uri AS oauth2_authorization_grant_redirect_uri + , response_mode AS oauth2_authorization_grant_response_mode + , nonce AS oauth2_authorization_grant_nonce + , max_age AS oauth2_authorization_grant_max_age + , oauth2_client_id AS oauth2_client_id + , authorization_code AS oauth2_authorization_grant_code + , response_type_code AS oauth2_authorization_grant_response_type_code + , response_type_id_token AS oauth2_authorization_grant_response_type_id_token + , code_challenge AS oauth2_authorization_grant_code_challenge + , code_challenge_method AS oauth2_authorization_grant_code_challenge_method + , requires_consent AS oauth2_authorization_grant_requires_consent + , oauth2_session_id AS "oauth2_session_id?" FROM - oauth2_authorization_grants og - LEFT JOIN oauth2_sessions os - USING (oauth2_session_id) + oauth2_authorization_grants - WHERE og.authorization_code = $1 + WHERE authorization_code = $1 "#, code, ) @@ -466,7 +436,7 @@ pub async fn fulfill_grant( grant.stage = grant .stage - .fulfill(fulfilled_at, session) + .fulfill(fulfilled_at, &session) .map_err(DatabaseError::to_invalid_operation)?; Ok(grant) diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 61ace6fa..e4c35c71 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -36,7 +36,7 @@ pub async fn add_refresh_token( mut rng: impl Rng + Send, clock: &Clock, session: &Session, - access_token: AccessToken, + access_token: &AccessToken, refresh_token: String, ) -> Result { let created_at = clock.now(); @@ -63,7 +63,7 @@ pub async fn add_refresh_token( Ok(RefreshToken { id, refresh_token, - access_token: Some(access_token), + access_token_id: Some(access_token.id), created_at, }) } @@ -73,9 +73,6 @@ struct OAuth2RefreshTokenLookup { oauth2_refresh_token: String, oauth2_refresh_token_created_at: DateTime, oauth2_access_token_id: Option, - oauth2_access_token: Option, - oauth2_access_token_created_at: Option>, - oauth2_access_token_expires_at: Option>, oauth2_session_id: Uuid, oauth2_client_id: Uuid, oauth2_session_scope: String, @@ -94,10 +91,7 @@ pub async fn lookup_active_refresh_token( SELECT rt.oauth2_refresh_token_id , rt.refresh_token AS oauth2_refresh_token , rt.created_at AS oauth2_refresh_token_created_at - , at.oauth2_access_token_id AS "oauth2_access_token_id?" - , at.access_token AS "oauth2_access_token?" - , at.created_at AS "oauth2_access_token_created_at?" - , at.expires_at AS "oauth2_access_token_expires_at?" + , rt.oauth2_access_token_id AS "oauth2_access_token_id?" , os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_client_id AS "oauth2_client_id!" , os.scope AS "oauth2_session_scope!" @@ -105,8 +99,6 @@ pub async fn lookup_active_refresh_token( FROM oauth2_refresh_tokens rt INNER JOIN oauth2_sessions os USING (oauth2_session_id) - LEFT JOIN oauth2_access_tokens at - USING (oauth2_access_token_id) WHERE rt.refresh_token = $1 AND rt.consumed_at IS NULL @@ -118,31 +110,11 @@ pub async fn lookup_active_refresh_token( .fetch_one(&mut *conn) .await?; - let access_token = match ( - res.oauth2_access_token_id, - res.oauth2_access_token, - res.oauth2_access_token_created_at, - res.oauth2_access_token_expires_at, - ) { - (None, None, None, None) => None, - (Some(id), Some(access_token), Some(created_at), Some(expires_at)) => { - let id = Ulid::from(id); - Some(AccessToken { - id, - jti: id.to_string(), - access_token, - created_at, - expires_at, - }) - } - _ => return Err(DatabaseInconsistencyError::on("oauth2_access_tokens").into()), - }; - let refresh_token = RefreshToken { id: res.oauth2_refresh_token_id.into(), refresh_token: res.oauth2_refresh_token, created_at: res.oauth2_refresh_token_created_at, - access_token, + access_token_id: res.oauth2_access_token_id.map(Ulid::from), }; let session_id = res.oauth2_session_id.into(); diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 5841a1d9..7acaf843 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -23,13 +23,15 @@ use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; #[async_trait] pub trait OAuth2SessionRepository { type Error; + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn create_from_grant( &mut self, rng: &mut (dyn RngCore + Send), @@ -66,6 +68,8 @@ struct OAuthSessionLookup { user_session_id: Uuid, oauth2_client_id: Uuid, scope: String, + #[allow(dead_code)] + created_at: DateTime, finished_at: Option>, } @@ -95,6 +99,41 @@ impl TryFrom for Session { impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { type Error = DatabaseError; + #[tracing::instrument( + name = "db.oauth2_session.lookup", + skip_all, + fields( + db.statement, + session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuthSessionLookup, + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions + + WHERE oauth2_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(session) = res else { return Ok(None) }; + + Ok(Some(session.try_into()?)) + } + #[tracing::instrument( name = "db.oauth2_session.create_from_grant", skip_all,