diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 923ef34d..a76e1e9a 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -24,10 +24,14 @@ use axum::{ response::{IntoResponse, Response}, BoxError, }; +use chrono::{DateTime, Utc}; use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use mas_data_model::Session; -use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError}; +use mas_storage::{ + oauth2::{access_token::find_access_token, OAuth2SessionRepository}, + DatabaseError, Repository, +}; use serde::{de::DeserializeOwned, Deserialize}; use sqlx::PgConnection; use thiserror::Error; @@ -49,7 +53,7 @@ enum AccessToken { } impl AccessToken { - pub async fn fetch( + async fn fetch( &self, conn: &mut PgConnection, ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { @@ -58,7 +62,13 @@ impl AccessToken { AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), }; - let (token, session) = lookup_active_access_token(conn, token.as_str()) + let token = find_access_token(conn, token.as_str()) + .await? + .ok_or(AuthorizationVerificationError::InvalidToken)?; + + let session = conn + .oauth2_session() + .lookup(token.session_id) .await? .ok_or(AuthorizationVerificationError::InvalidToken)?; @@ -77,13 +87,18 @@ impl UserAuthorization { pub async fn protected_form( self, conn: &mut PgConnection, + now: DateTime, ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, None => return Err(AuthorizationVerificationError::MissingForm), }; - let (_token, session) = self.access_token.fetch(conn).await?; + let (token, session) = self.access_token.fetch(conn).await?; + + if !token.is_valid(now) || !session.is_valid() { + return Err(AuthorizationVerificationError::InvalidToken); + } Ok((session, form)) } @@ -92,8 +107,13 @@ impl UserAuthorization { pub async fn protected( self, conn: &mut PgConnection, + now: DateTime, ) -> Result { - let (_token, session) = self.access_token.fetch(conn).await?; + let (token, session) = self.access_token.fetch(conn).await?; + + if !token.is_valid(now) || !session.is_valid() { + return Err(AuthorizationVerificationError::InvalidToken); + } Ok(session) } diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 8454f05d..bde11fbe 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -44,7 +44,9 @@ pub use self::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, }, - tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, + tokens::{ + AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType, + }, upstream_oauth2::{ UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider, diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 7b058820..120f293e 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -19,23 +19,133 @@ use rand::{distributions::Alphanumeric, Rng}; use thiserror::Error; use ulid::Ulid; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum AccessTokenState { + #[default] + Valid, + Revoked { + revoked_at: DateTime, + }, +} + +impl AccessTokenState { + fn revoke(self, revoked_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Revoked { revoked_at }), + Self::Revoked { .. } => Err(InvalidTransitionError), + } + } + + /// Returns `true` if the refresh token state is [`Valid`]. + /// + /// [`Valid`]: RefreshTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the refresh token state is [`Revoked`]. + /// + /// [`Revoked`]: RefreshTokenState::Revoked + #[must_use] + pub fn is_revoked(&self) -> bool { + matches!(self, Self::Revoked { .. }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct AccessToken { pub id: Ulid, - pub jti: String, + pub state: AccessTokenState, + pub session_id: Ulid, pub access_token: String, pub created_at: DateTime, pub expires_at: DateTime, } +impl AccessToken { + #[must_use] + pub fn jti(&self) -> String { + self.id.to_string() + } + + #[must_use] + pub fn is_valid(&self, now: DateTime) -> bool { + self.state.is_valid() && self.expires_at > now + } + + pub fn revoke(mut self, revoked_at: DateTime) -> Result { + self.state = self.state.revoke(revoked_at)?; + Ok(self) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum RefreshTokenState { + #[default] + Valid, + Consumed { + consumed_at: DateTime, + }, +} + +impl RefreshTokenState { + fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Consumed { consumed_at }), + Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + /// Returns `true` if the refresh token state is [`Valid`]. + /// + /// [`Valid`]: RefreshTokenState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the refresh token state is [`Consumed`]. + /// + /// [`Consumed`]: RefreshTokenState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct RefreshToken { pub id: Ulid, + pub state: RefreshTokenState, pub refresh_token: String, + pub session_id: Ulid, pub created_at: DateTime, pub access_token_id: Option, } +impl std::ops::Deref for RefreshToken { + type Target = RefreshTokenState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl RefreshToken { + #[must_use] + pub fn jti(&self) -> String { + self.id.to_string() + } + + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } +} + /// Type of token to generate or validate #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TokenType { diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 2cf34c97..3dec02db 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -24,7 +24,8 @@ use mas_keystore::Encrypter; use mas_storage::{ compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session}, oauth2::{ - access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, + access_token::find_access_token, refresh_token::lookup_refresh_token, + OAuth2SessionRepository, }, user::{BrowserSessionRepository, UserRepository}, Clock, Repository, @@ -168,8 +169,17 @@ pub(crate) async fn post( let reply = match token_type { TokenType::AccessToken => { - let (token, session) = lookup_active_access_token(&mut conn, token) + let token = find_access_token(&mut conn, token) .await? + .filter(|t| t.is_valid(clock.now())) + .ok_or(RouteError::UnknownToken)?; + + let session = conn + .oauth2_session() + .lookup(token.session_id) + .await? + .filter(|s| s.is_valid()) + // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; let browser_session = conn @@ -191,13 +201,22 @@ pub(crate) async fn post( sub: Some(browser_session.user.sub), aud: None, iss: None, - jti: None, + jti: Some(token.jti()), } } TokenType::RefreshToken => { - let (token, session) = lookup_active_refresh_token(&mut conn, token) + let token = lookup_refresh_token(&mut conn, token) .await? + .filter(|t| t.is_valid()) + .ok_or(RouteError::UnknownToken)?; + + let session = conn + .oauth2_session() + .lookup(token.session_id) + .await? + .filter(|s| s.is_valid()) + // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; let browser_session = conn @@ -219,7 +238,7 @@ pub(crate) async fn post( sub: Some(browser_session.user.sub), aud: None, iss: None, - jti: None, + jti: Some(token.jti()), } } diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index eb0e20dd..75ddb4a6 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -33,9 +33,9 @@ use mas_keystore::{Encrypter, Keystore}; use mas_router::UrlBuilder; use mas_storage::{ oauth2::{ - access_token::{add_access_token, revoke_access_token}, + access_token::{add_access_token, lookup_access_token, revoke_access_token}, authorization_grant::{exchange_grant, lookup_grant_by_code}, - refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, + refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token}, OAuth2SessionRepository, }, user::BrowserSessionRepository, @@ -374,10 +374,20 @@ async fn refresh_token_grant( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token) + let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token) .await? .ok_or(RouteError::InvalidGrant)?; + let session = txn + .oauth2_session() + .lookup(refresh_token.session_id) + .await? + .ok_or(RouteError::NoSuchOAuthSession)?; + + if !refresh_token.is_valid() || !session.is_valid() { + return Err(RouteError::InvalidGrant); + } + if client.id != session.client_id { // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 return Err(RouteError::InvalidGrant); @@ -407,10 +417,12 @@ async fn refresh_token_grant( ) .await?; - consume_refresh_token(&mut txn, &clock, &refresh_token).await?; + let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?; if let Some(access_token_id) = refresh_token.access_token_id { - revoke_access_token(&mut txn, &clock, access_token_id).await?; + if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? { + revoke_access_token(&mut txn, &clock, access_token).await?; + } } let params = AccessTokenResponse::new(access_token_str) diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index d2b2b615..49b6c5f1 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -101,10 +101,10 @@ pub async fn get( State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let (_clock, mut rng) = crate::clock_and_rng(); + let (clock, mut rng) = crate::clock_and_rng(); let mut conn = pool.acquire().await?; - let session = user_authorization.protected(&mut conn).await?; + let session = user_authorization.protected(&mut conn, clock.now()).await?; let browser_session = conn .browser_session() diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index d31c59a3..5324fa2a 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -583,74 +583,6 @@ }, "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " }, - "5f0e2aec0d7766d3674af3e68417921fec7068e83845e218a4a00d86487557f9": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_created_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "scope!", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 8, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n " - }, "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "describe": { "columns": [], @@ -1612,6 +1544,56 @@ }, "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " }, + "b20e846843cf88810fbc0f4b0fa3159117f035841758d682d90c614c374f6059": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n " + }, "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { "describe": { "columns": [], @@ -1984,6 +1966,106 @@ }, "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, + "d2b1af24f88b2f05eb219f7cbdcfa9680bafe9f77fa1772097875b3718bd1aff": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "revoked_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n " + }, "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { "describe": { "columns": [], @@ -2129,74 +2211,6 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, - "e25b8071b59075c4be9fac283410ec4acf771fdf06076ef7bbb11bf086c4bc03": { - "describe": { - "columns": [ - { - "name": "oauth2_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_id?", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_created_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_scope!", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 8, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , rt.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n " - }, "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 58c13b19..8389dff4 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,13 +13,13 @@ // limitations under the License. use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, Session, SessionState}; +use mas_data_model::{AccessToken, AccessTokenState, Session}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use crate::{Clock, DatabaseError, LookupResultExt}; #[tracing::instrument( skip_all, @@ -63,8 +63,9 @@ pub async fn add_access_token( Ok(AccessToken { id, + state: AccessTokenState::default(), access_token, - jti: id.to_string(), + session_id: session.id, created_at, expires_at, }) @@ -73,74 +74,59 @@ pub async fn add_access_token( #[derive(Debug)] pub struct OAuth2AccessTokenLookup { oauth2_access_token_id: Uuid, - oauth2_access_token: String, - oauth2_access_token_created_at: DateTime, - oauth2_access_token_expires_at: DateTime, - oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, - user_session_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: DateTime, + revoked_at: Option>, } -#[allow(clippy::too_many_lines)] -pub async fn lookup_active_access_token( +impl From for AccessToken { + fn from(value: OAuth2AccessTokenLookup) -> Self { + let state = match value.revoked_at { + None => AccessTokenState::Valid, + Some(revoked_at) => AccessTokenState::Revoked { revoked_at }, + }; + + Self { + id: value.oauth2_access_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + access_token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[tracing::instrument(skip_all, err)] +pub async fn find_access_token( conn: &mut PgConnection, token: &str, -) -> Result, DatabaseError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( OAuth2AccessTokenLookup, r#" - SELECT at.oauth2_access_token_id - , at.access_token AS "oauth2_access_token" - , at.created_at AS "oauth2_access_token_created_at" - , at.expires_at AS "oauth2_access_token_expires_at" - , os.created_at AS "oauth2_session_created_at" - , os.oauth2_session_id AS "oauth2_session_id!" - , os.oauth2_client_id AS "oauth2_client_id!" - , os.scope AS "scope!" - , os.user_session_id AS "user_session_id!" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id - FROM oauth2_access_tokens at - INNER JOIN oauth2_sessions os - USING (oauth2_session_id) + FROM oauth2_access_tokens - WHERE at.access_token = $1 - AND at.revoked_at IS NULL - AND os.finished_at IS NULL + WHERE access_token = $1 "#, token, ) .fetch_one(&mut *conn) - .await?; + .await + .to_option()?; - let access_token_id = Ulid::from(res.oauth2_access_token_id); - let access_token = AccessToken { - id: access_token_id, - jti: access_token_id.to_string(), - access_token: res.oauth2_access_token, - created_at: res.oauth2_access_token_created_at, - expires_at: res.oauth2_access_token_expires_at, - }; + let Some(res) = res else { return Ok(None) }; - let session_id = res.oauth2_session_id.into(); - let scope = res.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(session_id) - .source(e) - })?; - - let session = Session { - id: session_id, - state: SessionState::Valid, - created_at: res.oauth2_session_created_at, - client_id: res.oauth2_client_id.into(), - user_session_id: res.user_session_id.into(), - scope, - }; - - Ok(Some((access_token, session))) + Ok(Some(res.into())) } #[tracing::instrument( @@ -148,11 +134,48 @@ pub async fn lookup_active_access_token( fields(access_token.id = %access_token_id), err, )] +pub async fn lookup_access_token( + conn: &mut PgConnection, + access_token_id: Ulid, +) -> Result, DatabaseError> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(access_token_id), + ) + .fetch_one(&mut *conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) +} + +#[tracing::instrument( + skip_all, + fields( + %access_token.id, + session.id = %access_token.session_id, + ), + err, +)] pub async fn revoke_access_token( executor: impl PgExecutor<'_>, clock: &Clock, - access_token_id: Ulid, -) -> Result<(), DatabaseError> { + access_token: AccessToken, +) -> Result { let revoked_at = clock.now(); let res = sqlx::query!( r#" @@ -160,13 +183,17 @@ pub async fn revoke_access_token( SET revoked_at = $2 WHERE oauth2_access_token_id = $1 "#, - Uuid::from(access_token_id), + Uuid::from(access_token.id), revoked_at, ) .execute(executor) .await?; - DatabaseError::ensure_affected_rows(&res, 1) + DatabaseError::ensure_affected_rows(&res, 1)?; + + access_token + .revoke(revoked_at) + .map_err(DatabaseError::to_invalid_operation) } pub async fn cleanup_expired( diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index f49b38e8..29f6ab34 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -13,13 +13,13 @@ // limitations under the License. use chrono::{DateTime, Utc}; -use mas_data_model::{AccessToken, RefreshToken, Session, SessionState}; +use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use crate::{Clock, DatabaseError}; #[tracing::instrument( skip_all, @@ -62,6 +62,8 @@ pub async fn add_refresh_token( Ok(RefreshToken { id, + state: RefreshTokenState::default(), + session_id: session.id, refresh_token, access_token_id: Some(access_token.id), created_at, @@ -70,73 +72,52 @@ pub async fn add_refresh_token( struct OAuth2RefreshTokenLookup { oauth2_refresh_token_id: Uuid, - oauth2_refresh_token: String, - oauth2_refresh_token_created_at: DateTime, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, oauth2_access_token_id: Option, - oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, - oauth2_client_id: Uuid, - oauth2_session_scope: String, - user_session_id: Uuid, } #[tracing::instrument(skip_all, err)] #[allow(clippy::too_many_lines)] -pub async fn lookup_active_refresh_token( +pub async fn lookup_refresh_token( conn: &mut PgConnection, token: &str, -) -> Result, DatabaseError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( OAuth2RefreshTokenLookup, r#" - SELECT rt.oauth2_refresh_token_id - , rt.refresh_token AS oauth2_refresh_token - , rt.created_at AS oauth2_refresh_token_created_at - , rt.oauth2_access_token_id AS "oauth2_access_token_id?" - , os.created_at AS "oauth2_session_created_at" - , os.oauth2_session_id AS "oauth2_session_id!" - , os.oauth2_client_id AS "oauth2_client_id!" - , os.scope AS "oauth2_session_scope!" - , os.user_session_id AS "user_session_id!" - FROM oauth2_refresh_tokens rt - INNER JOIN oauth2_sessions os - USING (oauth2_session_id) + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens - WHERE rt.refresh_token = $1 - AND rt.consumed_at IS NULL - AND rt.revoked_at IS NULL - AND os.finished_at IS NULL + WHERE refresh_token = $1 "#, token, ) .fetch_one(&mut *conn) .await?; + let state = match res.consumed_at { + None => RefreshTokenState::Valid, + Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, + }; + let refresh_token = RefreshToken { id: res.oauth2_refresh_token_id.into(), - refresh_token: res.oauth2_refresh_token, - created_at: res.oauth2_refresh_token_created_at, + state, + session_id: res.oauth2_session_id.into(), + refresh_token: res.refresh_token, + created_at: res.created_at, access_token_id: res.oauth2_access_token_id.map(Ulid::from), }; - let session_id = res.oauth2_session_id.into(); - let scope = res.oauth2_session_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(session_id) - .source(e) - })?; - - let session = Session { - id: session_id, - state: SessionState::Valid, - created_at: res.oauth2_session_created_at, - client_id: res.oauth2_client_id.into(), - user_session_id: res.user_session_id.into(), - scope, - }; - - Ok(Some((refresh_token, session))) + Ok(Some(refresh_token)) } #[tracing::instrument( @@ -149,8 +130,8 @@ pub async fn lookup_active_refresh_token( pub async fn consume_refresh_token( executor: impl PgExecutor<'_>, clock: &Clock, - refresh_token: &RefreshToken, -) -> Result<(), DatabaseError> { + refresh_token: RefreshToken, +) -> Result { let consumed_at = clock.now(); let res = sqlx::query!( r#" @@ -164,5 +145,9 @@ pub async fn consume_refresh_token( .execute(executor) .await?; - DatabaseError::ensure_affected_rows(&res, 1) + DatabaseError::ensure_affected_rows(&res, 1)?; + + refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation) }