1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

storage: cleanup access/refresh token lookups

This commit is contained in:
Quentin Gliech
2023-01-11 12:14:52 +01:00
parent 920869b583
commit 9f0c9f1466
9 changed files with 452 additions and 263 deletions

View File

@@ -24,10 +24,14 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
BoxError, BoxError,
}; };
use chrono::{DateTime, Utc};
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
use mas_data_model::Session; use mas_data_model::Session;
use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError}; use mas_storage::{
oauth2::{access_token::find_access_token, OAuth2SessionRepository},
DatabaseError, Repository,
};
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use sqlx::PgConnection; use sqlx::PgConnection;
use thiserror::Error; use thiserror::Error;
@@ -49,7 +53,7 @@ enum AccessToken {
} }
impl AccessToken { impl AccessToken {
pub async fn fetch( async fn fetch(
&self, &self,
conn: &mut PgConnection, conn: &mut PgConnection,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> {
@@ -58,7 +62,13 @@ impl AccessToken {
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
}; };
let (token, session) = lookup_active_access_token(conn, token.as_str()) let token = find_access_token(conn, token.as_str())
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
let session = conn
.oauth2_session()
.lookup(token.session_id)
.await? .await?
.ok_or(AuthorizationVerificationError::InvalidToken)?; .ok_or(AuthorizationVerificationError::InvalidToken)?;
@@ -77,13 +87,18 @@ impl<F: Send> UserAuthorization<F> {
pub async fn protected_form( pub async fn protected_form(
self, self,
conn: &mut PgConnection, conn: &mut PgConnection,
now: DateTime<Utc>,
) -> Result<(Session, F), AuthorizationVerificationError> { ) -> Result<(Session, F), AuthorizationVerificationError> {
let form = match self.form { let form = match self.form {
Some(f) => f, Some(f) => f,
None => return Err(AuthorizationVerificationError::MissingForm), None => return Err(AuthorizationVerificationError::MissingForm),
}; };
let (_token, session) = self.access_token.fetch(conn).await?; let (token, session) = self.access_token.fetch(conn).await?;
if !token.is_valid(now) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok((session, form)) Ok((session, form))
} }
@@ -92,8 +107,13 @@ impl<F: Send> UserAuthorization<F> {
pub async fn protected( pub async fn protected(
self, self,
conn: &mut PgConnection, conn: &mut PgConnection,
now: DateTime<Utc>,
) -> Result<Session, AuthorizationVerificationError> { ) -> Result<Session, AuthorizationVerificationError> {
let (_token, session) = self.access_token.fetch(conn).await?; let (token, session) = self.access_token.fetch(conn).await?;
if !token.is_valid(now) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok(session) Ok(session)
} }

View File

@@ -44,7 +44,9 @@ pub use self::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
}, },
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, tokens::{
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
},
upstream_oauth2::{ upstream_oauth2::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthLink, UpstreamOAuthProvider,

View File

@@ -19,23 +19,133 @@ use rand::{distributions::Alphanumeric, Rng};
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum AccessTokenState {
#[default]
Valid,
Revoked {
revoked_at: DateTime<Utc>,
},
}
impl AccessTokenState {
fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Revoked { revoked_at }),
Self::Revoked { .. } => Err(InvalidTransitionError),
}
}
/// Returns `true` if the refresh token state is [`Valid`].
///
/// [`Valid`]: RefreshTokenState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the refresh token state is [`Revoked`].
///
/// [`Revoked`]: RefreshTokenState::Revoked
#[must_use]
pub fn is_revoked(&self) -> bool {
matches!(self, Self::Revoked { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct AccessToken { pub struct AccessToken {
pub id: Ulid, pub id: Ulid,
pub jti: String, pub state: AccessTokenState,
pub session_id: Ulid,
pub access_token: String, pub access_token: String,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>, pub expires_at: DateTime<Utc>,
} }
impl AccessToken {
#[must_use]
pub fn jti(&self) -> String {
self.id.to_string()
}
#[must_use]
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
self.state.is_valid() && self.expires_at > now
}
pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.revoke(revoked_at)?;
Ok(self)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum RefreshTokenState {
#[default]
Valid,
Consumed {
consumed_at: DateTime<Utc>,
},
}
impl RefreshTokenState {
fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Consumed { consumed_at }),
Self::Consumed { .. } => Err(InvalidTransitionError),
}
}
/// Returns `true` if the refresh token state is [`Valid`].
///
/// [`Valid`]: RefreshTokenState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the refresh token state is [`Consumed`].
///
/// [`Consumed`]: RefreshTokenState::Consumed
#[must_use]
pub fn is_consumed(&self) -> bool {
matches!(self, Self::Consumed { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct RefreshToken { pub struct RefreshToken {
pub id: Ulid, pub id: Ulid,
pub state: RefreshTokenState,
pub refresh_token: String, pub refresh_token: String,
pub session_id: Ulid,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub access_token_id: Option<Ulid>, pub access_token_id: Option<Ulid>,
} }
impl std::ops::Deref for RefreshToken {
type Target = RefreshTokenState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl RefreshToken {
#[must_use]
pub fn jti(&self) -> String {
self.id.to_string()
}
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.consume(consumed_at)?;
Ok(self)
}
}
/// Type of token to generate or validate /// Type of token to generate or validate
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenType { pub enum TokenType {

View File

@@ -24,7 +24,8 @@ use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session}, compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session},
oauth2::{ oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, access_token::find_access_token, refresh_token::lookup_refresh_token,
OAuth2SessionRepository,
}, },
user::{BrowserSessionRepository, UserRepository}, user::{BrowserSessionRepository, UserRepository},
Clock, Repository, Clock, Repository,
@@ -168,8 +169,17 @@ pub(crate) async fn post(
let reply = match token_type { let reply = match token_type {
TokenType::AccessToken => { TokenType::AccessToken => {
let (token, session) = lookup_active_access_token(&mut conn, token) let token = find_access_token(&mut conn, token)
.await? .await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = conn
.oauth2_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let browser_session = conn let browser_session = conn
@@ -191,13 +201,22 @@ pub(crate) async fn post(
sub: Some(browser_session.user.sub), sub: Some(browser_session.user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: Some(token.jti()),
} }
} }
TokenType::RefreshToken => { TokenType::RefreshToken => {
let (token, session) = lookup_active_refresh_token(&mut conn, token) let token = lookup_refresh_token(&mut conn, token)
.await? .await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = conn
.oauth2_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let browser_session = conn let browser_session = conn
@@ -219,7 +238,7 @@ pub(crate) async fn post(
sub: Some(browser_session.user.sub), sub: Some(browser_session.user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: Some(token.jti()),
} }
} }

View File

@@ -33,9 +33,9 @@ use mas_keystore::{Encrypter, Keystore};
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
oauth2::{ oauth2::{
access_token::{add_access_token, revoke_access_token}, access_token::{add_access_token, lookup_access_token, revoke_access_token},
authorization_grant::{exchange_grant, lookup_grant_by_code}, authorization_grant::{exchange_grant, lookup_grant_by_code},
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token},
OAuth2SessionRepository, OAuth2SessionRepository,
}, },
user::BrowserSessionRepository, user::BrowserSessionRepository,
@@ -374,10 +374,20 @@ async fn refresh_token_grant(
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token) let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token)
.await? .await?
.ok_or(RouteError::InvalidGrant)?; .ok_or(RouteError::InvalidGrant)?;
let session = txn
.oauth2_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
if !refresh_token.is_valid() || !session.is_valid() {
return Err(RouteError::InvalidGrant);
}
if client.id != session.client_id { if client.id != session.client_id {
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
return Err(RouteError::InvalidGrant); return Err(RouteError::InvalidGrant);
@@ -407,10 +417,12 @@ async fn refresh_token_grant(
) )
.await?; .await?;
consume_refresh_token(&mut txn, &clock, &refresh_token).await?; let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?;
if let Some(access_token_id) = refresh_token.access_token_id { if let Some(access_token_id) = refresh_token.access_token_id {
revoke_access_token(&mut txn, &clock, access_token_id).await?; if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? {
revoke_access_token(&mut txn, &clock, access_token).await?;
}
} }
let params = AccessTokenResponse::new(access_token_str) let params = AccessTokenResponse::new(access_token_str)

View File

@@ -101,10 +101,10 @@ pub async fn get(
State(key_store): State<Keystore>, State(key_store): State<Keystore>,
user_authorization: UserAuthorization, user_authorization: UserAuthorization,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (_clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let session = user_authorization.protected(&mut conn).await?; let session = user_authorization.protected(&mut conn, clock.now()).await?;
let browser_session = conn let browser_session = conn
.browser_session() .browser_session()

View File

@@ -583,74 +583,6 @@
}, },
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n "
}, },
"5f0e2aec0d7766d3674af3e68417921fec7068e83845e218a4a00d86487557f9": {
"describe": {
"columns": [
{
"name": "oauth2_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "oauth2_access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "oauth2_access_token_created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "oauth2_access_token_expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_created_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id!",
"ordinal": 5,
"type_info": "Uuid"
},
{
"name": "oauth2_client_id!",
"ordinal": 6,
"type_info": "Uuid"
},
{
"name": "scope!",
"ordinal": 7,
"type_info": "Text"
},
{
"name": "user_session_id!",
"ordinal": 8,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
false,
false,
false,
false,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n "
},
"5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": {
"describe": { "describe": {
"columns": [], "columns": [],
@@ -1612,6 +1544,56 @@
}, },
"query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n "
}, },
"b20e846843cf88810fbc0f4b0fa3159117f035841758d682d90c614c374f6059": {
"describe": {
"columns": [
{
"name": "oauth2_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "revoked_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n "
},
"b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": {
"describe": { "describe": {
"columns": [], "columns": [],
@@ -1984,6 +1966,106 @@
}, },
"query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n "
}, },
"d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": {
"describe": {
"columns": [
{
"name": "oauth2_refresh_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "refresh_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "consumed_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "oauth2_access_token_id",
"ordinal": 4,
"type_info": "Uuid"
},
{
"name": "oauth2_session_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
true,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n "
},
"d2b1af24f88b2f05eb219f7cbdcfa9680bafe9f77fa1772097875b3718bd1aff": {
"describe": {
"columns": [
{
"name": "oauth2_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "revoked_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n "
},
"d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": { "d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": {
"describe": { "describe": {
"columns": [], "columns": [],
@@ -2129,74 +2211,6 @@
}, },
"query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n "
}, },
"e25b8071b59075c4be9fac283410ec4acf771fdf06076ef7bbb11bf086c4bc03": {
"describe": {
"columns": [
{
"name": "oauth2_refresh_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "oauth2_refresh_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "oauth2_refresh_token_created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "oauth2_access_token_id?",
"ordinal": 3,
"type_info": "Uuid"
},
{
"name": "oauth2_session_created_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id!",
"ordinal": 5,
"type_info": "Uuid"
},
{
"name": "oauth2_client_id!",
"ordinal": 6,
"type_info": "Uuid"
},
{
"name": "oauth2_session_scope!",
"ordinal": 7,
"type_info": "Text"
},
{
"name": "user_session_id!",
"ordinal": 8,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false,
false,
false,
false,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , rt.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n "
},
"e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": {
"describe": { "describe": {
"columns": [ "columns": [

View File

@@ -13,13 +13,13 @@
// limitations under the License. // limitations under the License.
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Session, SessionState}; use mas_data_model::{AccessToken, AccessTokenState, Session};
use rand::Rng; use rand::Rng;
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; use crate::{Clock, DatabaseError, LookupResultExt};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@@ -63,8 +63,9 @@ pub async fn add_access_token(
Ok(AccessToken { Ok(AccessToken {
id, id,
state: AccessTokenState::default(),
access_token, access_token,
jti: id.to_string(), session_id: session.id,
created_at, created_at,
expires_at, expires_at,
}) })
@@ -73,74 +74,59 @@ pub async fn add_access_token(
#[derive(Debug)] #[derive(Debug)]
pub struct OAuth2AccessTokenLookup { pub struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid, oauth2_access_token_id: Uuid,
oauth2_access_token: String,
oauth2_access_token_created_at: DateTime<Utc>,
oauth2_access_token_expires_at: DateTime<Utc>,
oauth2_session_created_at: DateTime<Utc>,
oauth2_session_id: Uuid, oauth2_session_id: Uuid,
oauth2_client_id: Uuid, access_token: String,
scope: String, created_at: DateTime<Utc>,
user_session_id: Uuid, expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
} }
#[allow(clippy::too_many_lines)] impl From<OAuth2AccessTokenLookup> for AccessToken {
pub async fn lookup_active_access_token( fn from(value: OAuth2AccessTokenLookup) -> Self {
let state = match value.revoked_at {
None => AccessTokenState::Valid,
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
};
Self {
id: value.oauth2_access_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
access_token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[tracing::instrument(skip_all, err)]
pub async fn find_access_token(
conn: &mut PgConnection, conn: &mut PgConnection,
token: &str, token: &str,
) -> Result<Option<(AccessToken, Session)>, DatabaseError> { ) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2AccessTokenLookup, OAuth2AccessTokenLookup,
r#" r#"
SELECT at.oauth2_access_token_id SELECT oauth2_access_token_id
, at.access_token AS "oauth2_access_token" , access_token
, at.created_at AS "oauth2_access_token_created_at" , created_at
, at.expires_at AS "oauth2_access_token_expires_at" , expires_at
, os.created_at AS "oauth2_session_created_at" , revoked_at
, os.oauth2_session_id AS "oauth2_session_id!" , oauth2_session_id
, os.oauth2_client_id AS "oauth2_client_id!"
, os.scope AS "scope!"
, os.user_session_id AS "user_session_id!"
FROM oauth2_access_tokens at FROM oauth2_access_tokens
INNER JOIN oauth2_sessions os
USING (oauth2_session_id)
WHERE at.access_token = $1 WHERE access_token = $1
AND at.revoked_at IS NULL
AND os.finished_at IS NULL
"#, "#,
token, token,
) )
.fetch_one(&mut *conn) .fetch_one(&mut *conn)
.await?; .await
.to_option()?;
let access_token_id = Ulid::from(res.oauth2_access_token_id); let Some(res) = res else { return Ok(None) };
let access_token = AccessToken {
id: access_token_id,
jti: access_token_id.to_string(),
access_token: res.oauth2_access_token,
created_at: res.oauth2_access_token_created_at,
expires_at: res.oauth2_access_token_expires_at,
};
let session_id = res.oauth2_session_id.into(); Ok(Some(res.into()))
let scope = res.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(session_id)
.source(e)
})?;
let session = Session {
id: session_id,
state: SessionState::Valid,
created_at: res.oauth2_session_created_at,
client_id: res.oauth2_client_id.into(),
user_session_id: res.user_session_id.into(),
scope,
};
Ok(Some((access_token, session)))
} }
#[tracing::instrument( #[tracing::instrument(
@@ -148,11 +134,48 @@ pub async fn lookup_active_access_token(
fields(access_token.id = %access_token_id), fields(access_token.id = %access_token_id),
err, err,
)] )]
pub async fn lookup_access_token(
conn: &mut PgConnection,
access_token_id: Ulid,
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token_id),
)
.fetch_one(&mut *conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(
%access_token.id,
session.id = %access_token.session_id,
),
err,
)]
pub async fn revoke_access_token( pub async fn revoke_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
access_token_id: Ulid, access_token: AccessToken,
) -> Result<(), DatabaseError> { ) -> Result<AccessToken, DatabaseError> {
let revoked_at = clock.now(); let revoked_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
@@ -160,13 +183,17 @@ pub async fn revoke_access_token(
SET revoked_at = $2 SET revoked_at = $2
WHERE oauth2_access_token_id = $1 WHERE oauth2_access_token_id = $1
"#, "#,
Uuid::from(access_token_id), Uuid::from(access_token.id),
revoked_at, revoked_at,
) )
.execute(executor) .execute(executor)
.await?; .await?;
DatabaseError::ensure_affected_rows(&res, 1) DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
} }
pub async fn cleanup_expired( pub async fn cleanup_expired(

View File

@@ -13,13 +13,13 @@
// limitations under the License. // limitations under the License.
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, Session, SessionState}; use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use rand::Rng; use rand::Rng;
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; use crate::{Clock, DatabaseError};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@@ -62,6 +62,8 @@ pub async fn add_refresh_token(
Ok(RefreshToken { Ok(RefreshToken {
id, id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token, refresh_token,
access_token_id: Some(access_token.id), access_token_id: Some(access_token.id),
created_at, created_at,
@@ -70,73 +72,52 @@ pub async fn add_refresh_token(
struct OAuth2RefreshTokenLookup { struct OAuth2RefreshTokenLookup {
oauth2_refresh_token_id: Uuid, oauth2_refresh_token_id: Uuid,
oauth2_refresh_token: String, refresh_token: String,
oauth2_refresh_token_created_at: DateTime<Utc>, created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
oauth2_access_token_id: Option<Uuid>, oauth2_access_token_id: Option<Uuid>,
oauth2_session_created_at: DateTime<Utc>,
oauth2_session_id: Uuid, oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
oauth2_session_scope: String,
user_session_id: Uuid,
} }
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub async fn lookup_active_refresh_token( pub async fn lookup_refresh_token(
conn: &mut PgConnection, conn: &mut PgConnection,
token: &str, token: &str,
) -> Result<Option<(RefreshToken, Session)>, DatabaseError> { ) -> Result<Option<RefreshToken>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2RefreshTokenLookup, OAuth2RefreshTokenLookup,
r#" r#"
SELECT rt.oauth2_refresh_token_id SELECT oauth2_refresh_token_id
, rt.refresh_token AS oauth2_refresh_token , refresh_token
, rt.created_at AS oauth2_refresh_token_created_at , created_at
, rt.oauth2_access_token_id AS "oauth2_access_token_id?" , consumed_at
, os.created_at AS "oauth2_session_created_at" , oauth2_access_token_id
, os.oauth2_session_id AS "oauth2_session_id!" , oauth2_session_id
, os.oauth2_client_id AS "oauth2_client_id!" FROM oauth2_refresh_tokens
, os.scope AS "oauth2_session_scope!"
, os.user_session_id AS "user_session_id!"
FROM oauth2_refresh_tokens rt
INNER JOIN oauth2_sessions os
USING (oauth2_session_id)
WHERE rt.refresh_token = $1 WHERE refresh_token = $1
AND rt.consumed_at IS NULL
AND rt.revoked_at IS NULL
AND os.finished_at IS NULL
"#, "#,
token, token,
) )
.fetch_one(&mut *conn) .fetch_one(&mut *conn)
.await?; .await?;
let state = match res.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
let refresh_token = RefreshToken { let refresh_token = RefreshToken {
id: res.oauth2_refresh_token_id.into(), id: res.oauth2_refresh_token_id.into(),
refresh_token: res.oauth2_refresh_token, state,
created_at: res.oauth2_refresh_token_created_at, session_id: res.oauth2_session_id.into(),
refresh_token: res.refresh_token,
created_at: res.created_at,
access_token_id: res.oauth2_access_token_id.map(Ulid::from), access_token_id: res.oauth2_access_token_id.map(Ulid::from),
}; };
let session_id = res.oauth2_session_id.into(); Ok(Some(refresh_token))
let scope = res.oauth2_session_scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(session_id)
.source(e)
})?;
let session = Session {
id: session_id,
state: SessionState::Valid,
created_at: res.oauth2_session_created_at,
client_id: res.oauth2_client_id.into(),
user_session_id: res.user_session_id.into(),
scope,
};
Ok(Some((refresh_token, session)))
} }
#[tracing::instrument( #[tracing::instrument(
@@ -149,8 +130,8 @@ pub async fn lookup_active_refresh_token(
pub async fn consume_refresh_token( pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
refresh_token: &RefreshToken, refresh_token: RefreshToken,
) -> Result<(), DatabaseError> { ) -> Result<RefreshToken, DatabaseError> {
let consumed_at = clock.now(); let consumed_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
@@ -164,5 +145,9 @@ pub async fn consume_refresh_token(
.execute(executor) .execute(executor)
.await?; .await?;
DatabaseError::ensure_affected_rows(&res, 1) DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
} }