You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
storage: cleanup access/refresh token lookups
This commit is contained in:
@@ -24,10 +24,14 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
BoxError,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
|
||||
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
||||
use mas_data_model::Session;
|
||||
use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError};
|
||||
use mas_storage::{
|
||||
oauth2::{access_token::find_access_token, OAuth2SessionRepository},
|
||||
DatabaseError, Repository,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use sqlx::PgConnection;
|
||||
use thiserror::Error;
|
||||
@@ -49,7 +53,7 @@ enum AccessToken {
|
||||
}
|
||||
|
||||
impl AccessToken {
|
||||
pub async fn fetch(
|
||||
async fn fetch(
|
||||
&self,
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> {
|
||||
@@ -58,7 +62,13 @@ impl AccessToken {
|
||||
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
|
||||
};
|
||||
|
||||
let (token, session) = lookup_active_access_token(conn, token.as_str())
|
||||
let token = find_access_token(conn, token.as_str())
|
||||
.await?
|
||||
.ok_or(AuthorizationVerificationError::InvalidToken)?;
|
||||
|
||||
let session = conn
|
||||
.oauth2_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.ok_or(AuthorizationVerificationError::InvalidToken)?;
|
||||
|
||||
@@ -77,13 +87,18 @@ impl<F: Send> UserAuthorization<F> {
|
||||
pub async fn protected_form(
|
||||
self,
|
||||
conn: &mut PgConnection,
|
||||
now: DateTime<Utc>,
|
||||
) -> Result<(Session, F), AuthorizationVerificationError> {
|
||||
let form = match self.form {
|
||||
Some(f) => f,
|
||||
None => return Err(AuthorizationVerificationError::MissingForm),
|
||||
};
|
||||
|
||||
let (_token, session) = self.access_token.fetch(conn).await?;
|
||||
let (token, session) = self.access_token.fetch(conn).await?;
|
||||
|
||||
if !token.is_valid(now) || !session.is_valid() {
|
||||
return Err(AuthorizationVerificationError::InvalidToken);
|
||||
}
|
||||
|
||||
Ok((session, form))
|
||||
}
|
||||
@@ -92,8 +107,13 @@ impl<F: Send> UserAuthorization<F> {
|
||||
pub async fn protected(
|
||||
self,
|
||||
conn: &mut PgConnection,
|
||||
now: DateTime<Utc>,
|
||||
) -> Result<Session, AuthorizationVerificationError> {
|
||||
let (_token, session) = self.access_token.fetch(conn).await?;
|
||||
let (token, session) = self.access_token.fetch(conn).await?;
|
||||
|
||||
if !token.is_valid(now) || !session.is_valid() {
|
||||
return Err(AuthorizationVerificationError::InvalidToken);
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
@@ -44,7 +44,9 @@ pub use self::{
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
|
||||
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
|
||||
},
|
||||
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
|
||||
tokens::{
|
||||
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
|
||||
},
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
|
||||
UpstreamOAuthLink, UpstreamOAuthProvider,
|
||||
|
@@ -19,23 +19,133 @@ use rand::{distributions::Alphanumeric, Rng};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::InvalidTransitionError;
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub enum AccessTokenState {
|
||||
#[default]
|
||||
Valid,
|
||||
Revoked {
|
||||
revoked_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AccessTokenState {
|
||||
fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Revoked { revoked_at }),
|
||||
Self::Revoked { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Valid`].
|
||||
///
|
||||
/// [`Valid`]: RefreshTokenState::Valid
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
matches!(self, Self::Valid)
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Revoked`].
|
||||
///
|
||||
/// [`Revoked`]: RefreshTokenState::Revoked
|
||||
#[must_use]
|
||||
pub fn is_revoked(&self) -> bool {
|
||||
matches!(self, Self::Revoked { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AccessToken {
|
||||
pub id: Ulid,
|
||||
pub jti: String,
|
||||
pub state: AccessTokenState,
|
||||
pub session_id: Ulid,
|
||||
pub access_token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl AccessToken {
|
||||
#[must_use]
|
||||
pub fn jti(&self) -> String {
|
||||
self.id.to_string()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
|
||||
self.state.is_valid() && self.expires_at > now
|
||||
}
|
||||
|
||||
pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.revoke(revoked_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub enum RefreshTokenState {
|
||||
#[default]
|
||||
Valid,
|
||||
Consumed {
|
||||
consumed_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RefreshTokenState {
|
||||
fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Valid => Ok(Self::Consumed { consumed_at }),
|
||||
Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Valid`].
|
||||
///
|
||||
/// [`Valid`]: RefreshTokenState::Valid
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
matches!(self, Self::Valid)
|
||||
}
|
||||
|
||||
/// Returns `true` if the refresh token state is [`Consumed`].
|
||||
///
|
||||
/// [`Consumed`]: RefreshTokenState::Consumed
|
||||
#[must_use]
|
||||
pub fn is_consumed(&self) -> bool {
|
||||
matches!(self, Self::Consumed { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RefreshToken {
|
||||
pub id: Ulid,
|
||||
pub state: RefreshTokenState,
|
||||
pub refresh_token: String,
|
||||
pub session_id: Ulid,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub access_token_id: Option<Ulid>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for RefreshToken {
|
||||
type Target = RefreshTokenState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl RefreshToken {
|
||||
#[must_use]
|
||||
pub fn jti(&self) -> String {
|
||||
self.id.to_string()
|
||||
}
|
||||
|
||||
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.consume(consumed_at)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of token to generate or validate
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TokenType {
|
||||
|
@@ -24,7 +24,8 @@ use mas_keystore::Encrypter;
|
||||
use mas_storage::{
|
||||
compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session},
|
||||
oauth2::{
|
||||
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
|
||||
access_token::find_access_token, refresh_token::lookup_refresh_token,
|
||||
OAuth2SessionRepository,
|
||||
},
|
||||
user::{BrowserSessionRepository, UserRepository},
|
||||
Clock, Repository,
|
||||
@@ -168,8 +169,17 @@ pub(crate) async fn post(
|
||||
|
||||
let reply = match token_type {
|
||||
TokenType::AccessToken => {
|
||||
let (token, session) = lookup_active_access_token(&mut conn, token)
|
||||
let token = find_access_token(&mut conn, token)
|
||||
.await?
|
||||
.filter(|t| t.is_valid(clock.now()))
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let session = conn
|
||||
.oauth2_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.filter(|s| s.is_valid())
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let browser_session = conn
|
||||
@@ -191,13 +201,22 @@ pub(crate) async fn post(
|
||||
sub: Some(browser_session.user.sub),
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
jti: Some(token.jti()),
|
||||
}
|
||||
}
|
||||
|
||||
TokenType::RefreshToken => {
|
||||
let (token, session) = lookup_active_refresh_token(&mut conn, token)
|
||||
let token = lookup_refresh_token(&mut conn, token)
|
||||
.await?
|
||||
.filter(|t| t.is_valid())
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let session = conn
|
||||
.oauth2_session()
|
||||
.lookup(token.session_id)
|
||||
.await?
|
||||
.filter(|s| s.is_valid())
|
||||
// XXX: is that the right error to bubble up?
|
||||
.ok_or(RouteError::UnknownToken)?;
|
||||
|
||||
let browser_session = conn
|
||||
@@ -219,7 +238,7 @@ pub(crate) async fn post(
|
||||
sub: Some(browser_session.user.sub),
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
jti: Some(token.jti()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -33,9 +33,9 @@ use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{
|
||||
oauth2::{
|
||||
access_token::{add_access_token, revoke_access_token},
|
||||
access_token::{add_access_token, lookup_access_token, revoke_access_token},
|
||||
authorization_grant::{exchange_grant, lookup_grant_by_code},
|
||||
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token},
|
||||
refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token},
|
||||
OAuth2SessionRepository,
|
||||
},
|
||||
user::BrowserSessionRepository,
|
||||
@@ -374,10 +374,20 @@ async fn refresh_token_grant(
|
||||
) -> Result<AccessTokenResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
|
||||
let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token)
|
||||
let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token)
|
||||
.await?
|
||||
.ok_or(RouteError::InvalidGrant)?;
|
||||
|
||||
let session = txn
|
||||
.oauth2_session()
|
||||
.lookup(refresh_token.session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::NoSuchOAuthSession)?;
|
||||
|
||||
if !refresh_token.is_valid() || !session.is_valid() {
|
||||
return Err(RouteError::InvalidGrant);
|
||||
}
|
||||
|
||||
if client.id != session.client_id {
|
||||
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
return Err(RouteError::InvalidGrant);
|
||||
@@ -407,10 +417,12 @@ async fn refresh_token_grant(
|
||||
)
|
||||
.await?;
|
||||
|
||||
consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
|
||||
let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?;
|
||||
|
||||
if let Some(access_token_id) = refresh_token.access_token_id {
|
||||
revoke_access_token(&mut txn, &clock, access_token_id).await?;
|
||||
if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? {
|
||||
revoke_access_token(&mut txn, &clock, access_token).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let params = AccessTokenResponse::new(access_token_str)
|
||||
|
@@ -101,10 +101,10 @@ pub async fn get(
|
||||
State(key_store): State<Keystore>,
|
||||
user_authorization: UserAuthorization,
|
||||
) -> Result<Response, RouteError> {
|
||||
let (_clock, mut rng) = crate::clock_and_rng();
|
||||
let (clock, mut rng) = crate::clock_and_rng();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let session = user_authorization.protected(&mut conn).await?;
|
||||
let session = user_authorization.protected(&mut conn, clock.now()).await?;
|
||||
|
||||
let browser_session = conn
|
||||
.browser_session()
|
||||
|
@@ -583,74 +583,6 @@
|
||||
},
|
||||
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n "
|
||||
},
|
||||
"5f0e2aec0d7766d3674af3e68417921fec7068e83845e218a4a00d86487557f9": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "oauth2_access_token_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_access_token",
|
||||
"ordinal": 1,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_access_token_created_at",
|
||||
"ordinal": 2,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_access_token_expires_at",
|
||||
"ordinal": 3,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_created_at",
|
||||
"ordinal": 4,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_id!",
|
||||
"ordinal": 5,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_client_id!",
|
||||
"ordinal": 6,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "scope!",
|
||||
"ordinal": 7,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "user_session_id!",
|
||||
"ordinal": 8,
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n "
|
||||
},
|
||||
"5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": {
|
||||
"describe": {
|
||||
"columns": [],
|
||||
@@ -1612,6 +1544,56 @@
|
||||
},
|
||||
"query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n "
|
||||
},
|
||||
"b20e846843cf88810fbc0f4b0fa3159117f035841758d682d90c614c374f6059": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "oauth2_access_token_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "access_token",
|
||||
"ordinal": 1,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "created_at",
|
||||
"ordinal": 2,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "expires_at",
|
||||
"ordinal": 3,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "revoked_at",
|
||||
"ordinal": 4,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_id",
|
||||
"ordinal": 5,
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n "
|
||||
},
|
||||
"b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": {
|
||||
"describe": {
|
||||
"columns": [],
|
||||
@@ -1984,6 +1966,106 @@
|
||||
},
|
||||
"query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n "
|
||||
},
|
||||
"d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "oauth2_refresh_token_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "refresh_token",
|
||||
"ordinal": 1,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "created_at",
|
||||
"ordinal": 2,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "consumed_at",
|
||||
"ordinal": 3,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_access_token_id",
|
||||
"ordinal": 4,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_id",
|
||||
"ordinal": 5,
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n "
|
||||
},
|
||||
"d2b1af24f88b2f05eb219f7cbdcfa9680bafe9f77fa1772097875b3718bd1aff": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "oauth2_access_token_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "access_token",
|
||||
"ordinal": 1,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "created_at",
|
||||
"ordinal": 2,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "expires_at",
|
||||
"ordinal": 3,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "revoked_at",
|
||||
"ordinal": 4,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_id",
|
||||
"ordinal": 5,
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n "
|
||||
},
|
||||
"d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": {
|
||||
"describe": {
|
||||
"columns": [],
|
||||
@@ -2129,74 +2211,6 @@
|
||||
},
|
||||
"query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n "
|
||||
},
|
||||
"e25b8071b59075c4be9fac283410ec4acf771fdf06076ef7bbb11bf086c4bc03": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "oauth2_refresh_token_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_refresh_token",
|
||||
"ordinal": 1,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_refresh_token_created_at",
|
||||
"ordinal": 2,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_access_token_id?",
|
||||
"ordinal": 3,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_created_at",
|
||||
"ordinal": 4,
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_id!",
|
||||
"ordinal": 5,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_client_id!",
|
||||
"ordinal": 6,
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"name": "oauth2_session_scope!",
|
||||
"ordinal": 7,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "user_session_id!",
|
||||
"ordinal": 8,
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , rt.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n "
|
||||
},
|
||||
"e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
|
@@ -13,13 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{AccessToken, Session, SessionState};
|
||||
use mas_data_model::{AccessToken, AccessTokenState, Session};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
|
||||
use crate::{Clock, DatabaseError, LookupResultExt};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@@ -63,8 +63,9 @@ pub async fn add_access_token(
|
||||
|
||||
Ok(AccessToken {
|
||||
id,
|
||||
state: AccessTokenState::default(),
|
||||
access_token,
|
||||
jti: id.to_string(),
|
||||
session_id: session.id,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
@@ -73,74 +74,59 @@ pub async fn add_access_token(
|
||||
#[derive(Debug)]
|
||||
pub struct OAuth2AccessTokenLookup {
|
||||
oauth2_access_token_id: Uuid,
|
||||
oauth2_access_token: String,
|
||||
oauth2_access_token_created_at: DateTime<Utc>,
|
||||
oauth2_access_token_expires_at: DateTime<Utc>,
|
||||
oauth2_session_created_at: DateTime<Utc>,
|
||||
oauth2_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
scope: String,
|
||||
user_session_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: DateTime<Utc>,
|
||||
revoked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_access_token(
|
||||
impl From<OAuth2AccessTokenLookup> for AccessToken {
|
||||
fn from(value: OAuth2AccessTokenLookup) -> Self {
|
||||
let state = match value.revoked_at {
|
||||
None => AccessTokenState::Valid,
|
||||
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
|
||||
};
|
||||
|
||||
Self {
|
||||
id: value.oauth2_access_token_id.into(),
|
||||
state,
|
||||
session_id: value.oauth2_session_id.into(),
|
||||
access_token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn find_access_token(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<(AccessToken, Session)>, DatabaseError> {
|
||||
) -> Result<Option<AccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT at.oauth2_access_token_id
|
||||
, at.access_token AS "oauth2_access_token"
|
||||
, at.created_at AS "oauth2_access_token_created_at"
|
||||
, at.expires_at AS "oauth2_access_token_expires_at"
|
||||
, os.created_at AS "oauth2_session_created_at"
|
||||
, os.oauth2_session_id AS "oauth2_session_id!"
|
||||
, os.oauth2_client_id AS "oauth2_client_id!"
|
||||
, os.scope AS "scope!"
|
||||
, os.user_session_id AS "user_session_id!"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens at
|
||||
INNER JOIN oauth2_sessions os
|
||||
USING (oauth2_session_id)
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE at.access_token = $1
|
||||
AND at.revoked_at IS NULL
|
||||
AND os.finished_at IS NULL
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let access_token_id = Ulid::from(res.oauth2_access_token_id);
|
||||
let access_token = AccessToken {
|
||||
id: access_token_id,
|
||||
jti: access_token_id.to_string(),
|
||||
access_token: res.oauth2_access_token,
|
||||
created_at: res.oauth2_access_token_created_at,
|
||||
expires_at: res.oauth2_access_token_expires_at,
|
||||
};
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let scope = res.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(session_id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = Session {
|
||||
id: session_id,
|
||||
state: SessionState::Valid,
|
||||
created_at: res.oauth2_session_created_at,
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
};
|
||||
|
||||
Ok(Some((access_token, session)))
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@@ -148,11 +134,48 @@ pub async fn lookup_active_access_token(
|
||||
fields(access_token.id = %access_token_id),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_access_token(
|
||||
conn: &mut PgConnection,
|
||||
access_token_id: Ulid,
|
||||
) -> Result<Option<AccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(access_token_id),
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%access_token.id,
|
||||
session.id = %access_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn revoke_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
access_token_id: Ulid,
|
||||
) -> Result<(), DatabaseError> {
|
||||
access_token: AccessToken,
|
||||
) -> Result<AccessToken, DatabaseError> {
|
||||
let revoked_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@@ -160,13 +183,17 @@ pub async fn revoke_access_token(
|
||||
SET revoked_at = $2
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(access_token_id),
|
||||
Uuid::from(access_token.id),
|
||||
revoked_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
access_token
|
||||
.revoke(revoked_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(
|
||||
|
@@ -13,13 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AccessToken, RefreshToken, Session, SessionState};
|
||||
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
|
||||
use crate::{Clock, DatabaseError};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@@ -62,6 +62,8 @@ pub async fn add_refresh_token(
|
||||
|
||||
Ok(RefreshToken {
|
||||
id,
|
||||
state: RefreshTokenState::default(),
|
||||
session_id: session.id,
|
||||
refresh_token,
|
||||
access_token_id: Some(access_token.id),
|
||||
created_at,
|
||||
@@ -70,73 +72,52 @@ pub async fn add_refresh_token(
|
||||
|
||||
struct OAuth2RefreshTokenLookup {
|
||||
oauth2_refresh_token_id: Uuid,
|
||||
oauth2_refresh_token: String,
|
||||
oauth2_refresh_token_created_at: DateTime<Utc>,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
oauth2_access_token_id: Option<Uuid>,
|
||||
oauth2_session_created_at: DateTime<Utc>,
|
||||
oauth2_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
oauth2_session_scope: String,
|
||||
user_session_id: Uuid,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_refresh_token(
|
||||
pub async fn lookup_refresh_token(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<(RefreshToken, Session)>, DatabaseError> {
|
||||
) -> Result<Option<RefreshToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
SELECT rt.oauth2_refresh_token_id
|
||||
, rt.refresh_token AS oauth2_refresh_token
|
||||
, rt.created_at AS oauth2_refresh_token_created_at
|
||||
, rt.oauth2_access_token_id AS "oauth2_access_token_id?"
|
||||
, os.created_at AS "oauth2_session_created_at"
|
||||
, os.oauth2_session_id AS "oauth2_session_id!"
|
||||
, os.oauth2_client_id AS "oauth2_client_id!"
|
||||
, os.scope AS "oauth2_session_scope!"
|
||||
, os.user_session_id AS "user_session_id!"
|
||||
FROM oauth2_refresh_tokens rt
|
||||
INNER JOIN oauth2_sessions os
|
||||
USING (oauth2_session_id)
|
||||
SELECT oauth2_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE rt.refresh_token = $1
|
||||
AND rt.consumed_at IS NULL
|
||||
AND rt.revoked_at IS NULL
|
||||
AND os.finished_at IS NULL
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let state = match res.consumed_at {
|
||||
None => RefreshTokenState::Valid,
|
||||
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
|
||||
};
|
||||
|
||||
let refresh_token = RefreshToken {
|
||||
id: res.oauth2_refresh_token_id.into(),
|
||||
refresh_token: res.oauth2_refresh_token,
|
||||
created_at: res.oauth2_refresh_token_created_at,
|
||||
state,
|
||||
session_id: res.oauth2_session_id.into(),
|
||||
refresh_token: res.refresh_token,
|
||||
created_at: res.created_at,
|
||||
access_token_id: res.oauth2_access_token_id.map(Ulid::from),
|
||||
};
|
||||
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let scope = res.oauth2_session_scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(session_id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = Session {
|
||||
id: session_id,
|
||||
state: SessionState::Valid,
|
||||
created_at: res.oauth2_session_created_at,
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
};
|
||||
|
||||
Ok(Some((refresh_token, session)))
|
||||
Ok(Some(refresh_token))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@@ -149,8 +130,8 @@ pub async fn lookup_active_refresh_token(
|
||||
pub async fn consume_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
refresh_token: &RefreshToken,
|
||||
) -> Result<(), DatabaseError> {
|
||||
refresh_token: RefreshToken,
|
||||
) -> Result<RefreshToken, DatabaseError> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@@ -164,5 +145,9 @@ pub async fn consume_refresh_token(
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
refresh_token
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
Reference in New Issue
Block a user