1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

storage: cleanup access/refresh token lookups

This commit is contained in:
Quentin Gliech
2023-01-11 12:14:52 +01:00
parent 920869b583
commit 9f0c9f1466
9 changed files with 452 additions and 263 deletions

View File

@@ -24,10 +24,14 @@ use axum::{
response::{IntoResponse, Response},
BoxError,
};
use chrono::{DateTime, Utc};
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
use mas_data_model::Session;
use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError};
use mas_storage::{
oauth2::{access_token::find_access_token, OAuth2SessionRepository},
DatabaseError, Repository,
};
use serde::{de::DeserializeOwned, Deserialize};
use sqlx::PgConnection;
use thiserror::Error;
@@ -49,7 +53,7 @@ enum AccessToken {
}
impl AccessToken {
pub async fn fetch(
async fn fetch(
&self,
conn: &mut PgConnection,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> {
@@ -58,7 +62,13 @@ impl AccessToken {
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
};
let (token, session) = lookup_active_access_token(conn, token.as_str())
let token = find_access_token(conn, token.as_str())
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
let session = conn
.oauth2_session()
.lookup(token.session_id)
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
@@ -77,13 +87,18 @@ impl<F: Send> UserAuthorization<F> {
pub async fn protected_form(
self,
conn: &mut PgConnection,
now: DateTime<Utc>,
) -> Result<(Session, F), AuthorizationVerificationError> {
let form = match self.form {
Some(f) => f,
None => return Err(AuthorizationVerificationError::MissingForm),
};
let (_token, session) = self.access_token.fetch(conn).await?;
let (token, session) = self.access_token.fetch(conn).await?;
if !token.is_valid(now) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok((session, form))
}
@@ -92,8 +107,13 @@ impl<F: Send> UserAuthorization<F> {
pub async fn protected(
self,
conn: &mut PgConnection,
now: DateTime<Utc>,
) -> Result<Session, AuthorizationVerificationError> {
let (_token, session) = self.access_token.fetch(conn).await?;
let (token, session) = self.access_token.fetch(conn).await?;
if !token.is_valid(now) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok(session)
}

View File

@@ -44,7 +44,9 @@ pub use self::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState,
},
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
tokens::{
AccessToken, AccessTokenState, RefreshToken, RefreshTokenState, TokenFormatError, TokenType,
},
upstream_oauth2::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
UpstreamOAuthLink, UpstreamOAuthProvider,

View File

@@ -19,23 +19,133 @@ use rand::{distributions::Alphanumeric, Rng};
use thiserror::Error;
use ulid::Ulid;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum AccessTokenState {
#[default]
Valid,
Revoked {
revoked_at: DateTime<Utc>,
},
}
impl AccessTokenState {
fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Revoked { revoked_at }),
Self::Revoked { .. } => Err(InvalidTransitionError),
}
}
/// Returns `true` if the refresh token state is [`Valid`].
///
/// [`Valid`]: RefreshTokenState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the refresh token state is [`Revoked`].
///
/// [`Revoked`]: RefreshTokenState::Revoked
#[must_use]
pub fn is_revoked(&self) -> bool {
matches!(self, Self::Revoked { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AccessToken {
pub id: Ulid,
pub jti: String,
pub state: AccessTokenState,
pub session_id: Ulid,
pub access_token: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl AccessToken {
#[must_use]
pub fn jti(&self) -> String {
self.id.to_string()
}
#[must_use]
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
self.state.is_valid() && self.expires_at > now
}
pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.revoke(revoked_at)?;
Ok(self)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum RefreshTokenState {
#[default]
Valid,
Consumed {
consumed_at: DateTime<Utc>,
},
}
impl RefreshTokenState {
fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Consumed { consumed_at }),
Self::Consumed { .. } => Err(InvalidTransitionError),
}
}
/// Returns `true` if the refresh token state is [`Valid`].
///
/// [`Valid`]: RefreshTokenState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the refresh token state is [`Consumed`].
///
/// [`Consumed`]: RefreshTokenState::Consumed
#[must_use]
pub fn is_consumed(&self) -> bool {
matches!(self, Self::Consumed { .. })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RefreshToken {
pub id: Ulid,
pub state: RefreshTokenState,
pub refresh_token: String,
pub session_id: Ulid,
pub created_at: DateTime<Utc>,
pub access_token_id: Option<Ulid>,
}
impl std::ops::Deref for RefreshToken {
type Target = RefreshTokenState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl RefreshToken {
#[must_use]
pub fn jti(&self) -> String {
self.id.to_string()
}
pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.consume(consumed_at)?;
Ok(self)
}
}
/// Type of token to generate or validate
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenType {

View File

@@ -24,7 +24,8 @@ use mas_keystore::Encrypter;
use mas_storage::{
compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session},
oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
access_token::find_access_token, refresh_token::lookup_refresh_token,
OAuth2SessionRepository,
},
user::{BrowserSessionRepository, UserRepository},
Clock, Repository,
@@ -168,8 +169,17 @@ pub(crate) async fn post(
let reply = match token_type {
TokenType::AccessToken => {
let (token, session) = lookup_active_access_token(&mut conn, token)
let token = find_access_token(&mut conn, token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = conn
.oauth2_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let browser_session = conn
@@ -191,13 +201,22 @@ pub(crate) async fn post(
sub: Some(browser_session.user.sub),
aud: None,
iss: None,
jti: None,
jti: Some(token.jti()),
}
}
TokenType::RefreshToken => {
let (token, session) = lookup_active_refresh_token(&mut conn, token)
let token = lookup_refresh_token(&mut conn, token)
.await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = conn
.oauth2_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let browser_session = conn
@@ -219,7 +238,7 @@ pub(crate) async fn post(
sub: Some(browser_session.user.sub),
aud: None,
iss: None,
jti: None,
jti: Some(token.jti()),
}
}

View File

@@ -33,9 +33,9 @@ use mas_keystore::{Encrypter, Keystore};
use mas_router::UrlBuilder;
use mas_storage::{
oauth2::{
access_token::{add_access_token, revoke_access_token},
access_token::{add_access_token, lookup_access_token, revoke_access_token},
authorization_grant::{exchange_grant, lookup_grant_by_code},
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token},
refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token},
OAuth2SessionRepository,
},
user::BrowserSessionRepository,
@@ -374,10 +374,20 @@ async fn refresh_token_grant(
) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token)
let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token)
.await?
.ok_or(RouteError::InvalidGrant)?;
let session = txn
.oauth2_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
if !refresh_token.is_valid() || !session.is_valid() {
return Err(RouteError::InvalidGrant);
}
if client.id != session.client_id {
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
return Err(RouteError::InvalidGrant);
@@ -407,10 +417,12 @@ async fn refresh_token_grant(
)
.await?;
consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?;
if let Some(access_token_id) = refresh_token.access_token_id {
revoke_access_token(&mut txn, &clock, access_token_id).await?;
if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? {
revoke_access_token(&mut txn, &clock, access_token).await?;
}
}
let params = AccessTokenResponse::new(access_token_str)

View File

@@ -101,10 +101,10 @@ pub async fn get(
State(key_store): State<Keystore>,
user_authorization: UserAuthorization,
) -> Result<Response, RouteError> {
let (_clock, mut rng) = crate::clock_and_rng();
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let session = user_authorization.protected(&mut conn).await?;
let session = user_authorization.protected(&mut conn, clock.now()).await?;
let browser_session = conn
.browser_session()

View File

@@ -583,74 +583,6 @@
},
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n "
},
"5f0e2aec0d7766d3674af3e68417921fec7068e83845e218a4a00d86487557f9": {
"describe": {
"columns": [
{
"name": "oauth2_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "oauth2_access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "oauth2_access_token_created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "oauth2_access_token_expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_created_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id!",
"ordinal": 5,
"type_info": "Uuid"
},
{
"name": "oauth2_client_id!",
"ordinal": 6,
"type_info": "Uuid"
},
{
"name": "scope!",
"ordinal": 7,
"type_info": "Text"
},
{
"name": "user_session_id!",
"ordinal": 8,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
false,
false,
false,
false,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n "
},
"5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": {
"describe": {
"columns": [],
@@ -1612,6 +1544,56 @@
},
"query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n "
},
"b20e846843cf88810fbc0f4b0fa3159117f035841758d682d90c614c374f6059": {
"describe": {
"columns": [
{
"name": "oauth2_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "revoked_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE oauth2_access_token_id = $1\n "
},
"b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": {
"describe": {
"columns": [],
@@ -1984,6 +1966,106 @@
},
"query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n "
},
"d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": {
"describe": {
"columns": [
{
"name": "oauth2_refresh_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "refresh_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "consumed_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "oauth2_access_token_id",
"ordinal": 4,
"type_info": "Uuid"
},
{
"name": "oauth2_session_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
true,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n "
},
"d2b1af24f88b2f05eb219f7cbdcfa9680bafe9f77fa1772097875b3718bd1aff": {
"describe": {
"columns": [
{
"name": "oauth2_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "revoked_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT oauth2_access_token_id\n , access_token\n , created_at\n , expires_at\n , revoked_at\n , oauth2_session_id\n\n FROM oauth2_access_tokens\n\n WHERE access_token = $1\n "
},
"d8677b3b6ee594c230fad98c1aa1c6e3d983375bf5b701c7b52468e7f906abf9": {
"describe": {
"columns": [],
@@ -2129,74 +2211,6 @@
},
"query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n "
},
"e25b8071b59075c4be9fac283410ec4acf771fdf06076ef7bbb11bf086c4bc03": {
"describe": {
"columns": [
{
"name": "oauth2_refresh_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "oauth2_refresh_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "oauth2_refresh_token_created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "oauth2_access_token_id?",
"ordinal": 3,
"type_info": "Uuid"
},
{
"name": "oauth2_session_created_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "oauth2_session_id!",
"ordinal": 5,
"type_info": "Uuid"
},
{
"name": "oauth2_client_id!",
"ordinal": 6,
"type_info": "Uuid"
},
{
"name": "oauth2_session_scope!",
"ordinal": 7,
"type_info": "Text"
},
{
"name": "user_session_id!",
"ordinal": 8,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false,
false,
false,
false,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , rt.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n "
},
"e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": {
"describe": {
"columns": [

View File

@@ -13,13 +13,13 @@
// limitations under the License.
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Session, SessionState};
use mas_data_model::{AccessToken, AccessTokenState, Session};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
use crate::{Clock, DatabaseError, LookupResultExt};
#[tracing::instrument(
skip_all,
@@ -63,8 +63,9 @@ pub async fn add_access_token(
Ok(AccessToken {
id,
state: AccessTokenState::default(),
access_token,
jti: id.to_string(),
session_id: session.id,
created_at,
expires_at,
})
@@ -73,74 +74,59 @@ pub async fn add_access_token(
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid,
oauth2_access_token: String,
oauth2_access_token_created_at: DateTime<Utc>,
oauth2_access_token_expires_at: DateTime<Utc>,
oauth2_session_created_at: DateTime<Utc>,
oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
user_session_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
}
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_access_token(
impl From<OAuth2AccessTokenLookup> for AccessToken {
fn from(value: OAuth2AccessTokenLookup) -> Self {
let state = match value.revoked_at {
None => AccessTokenState::Valid,
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
};
Self {
id: value.oauth2_access_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
access_token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[tracing::instrument(skip_all, err)]
pub async fn find_access_token(
conn: &mut PgConnection,
token: &str,
) -> Result<Option<(AccessToken, Session)>, DatabaseError> {
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT at.oauth2_access_token_id
, at.access_token AS "oauth2_access_token"
, at.created_at AS "oauth2_access_token_created_at"
, at.expires_at AS "oauth2_access_token_expires_at"
, os.created_at AS "oauth2_session_created_at"
, os.oauth2_session_id AS "oauth2_session_id!"
, os.oauth2_client_id AS "oauth2_client_id!"
, os.scope AS "scope!"
, os.user_session_id AS "user_session_id!"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
USING (oauth2_session_id)
FROM oauth2_access_tokens
WHERE at.access_token = $1
AND at.revoked_at IS NULL
AND os.finished_at IS NULL
WHERE access_token = $1
"#,
token,
)
.fetch_one(&mut *conn)
.await?;
.await
.to_option()?;
let access_token_id = Ulid::from(res.oauth2_access_token_id);
let access_token = AccessToken {
id: access_token_id,
jti: access_token_id.to_string(),
access_token: res.oauth2_access_token,
created_at: res.oauth2_access_token_created_at,
expires_at: res.oauth2_access_token_expires_at,
};
let Some(res) = res else { return Ok(None) };
let session_id = res.oauth2_session_id.into();
let scope = res.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(session_id)
.source(e)
})?;
let session = Session {
id: session_id,
state: SessionState::Valid,
created_at: res.oauth2_session_created_at,
client_id: res.oauth2_client_id.into(),
user_session_id: res.user_session_id.into(),
scope,
};
Ok(Some((access_token, session)))
Ok(Some(res.into()))
}
#[tracing::instrument(
@@ -148,11 +134,48 @@ pub async fn lookup_active_access_token(
fields(access_token.id = %access_token_id),
err,
)]
pub async fn lookup_access_token(
conn: &mut PgConnection,
access_token_id: Ulid,
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token_id),
)
.fetch_one(&mut *conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(
%access_token.id,
session.id = %access_token.session_id,
),
err,
)]
pub async fn revoke_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token_id: Ulid,
) -> Result<(), DatabaseError> {
access_token: AccessToken,
) -> Result<AccessToken, DatabaseError> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
@@ -160,13 +183,17 @@ pub async fn revoke_access_token(
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token_id),
Uuid::from(access_token.id),
revoked_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
pub async fn cleanup_expired(

View File

@@ -13,13 +13,13 @@
// limitations under the License.
use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, Session, SessionState};
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
use crate::{Clock, DatabaseError};
#[tracing::instrument(
skip_all,
@@ -62,6 +62,8 @@ pub async fn add_refresh_token(
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
@@ -70,73 +72,52 @@ pub async fn add_refresh_token(
struct OAuth2RefreshTokenLookup {
oauth2_refresh_token_id: Uuid,
oauth2_refresh_token: String,
oauth2_refresh_token_created_at: DateTime<Utc>,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
oauth2_access_token_id: Option<Uuid>,
oauth2_session_created_at: DateTime<Utc>,
oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
oauth2_session_scope: String,
user_session_id: Uuid,
}
#[tracing::instrument(skip_all, err)]
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_refresh_token(
pub async fn lookup_refresh_token(
conn: &mut PgConnection,
token: &str,
) -> Result<Option<(RefreshToken, Session)>, DatabaseError> {
) -> Result<Option<RefreshToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT rt.oauth2_refresh_token_id
, rt.refresh_token AS oauth2_refresh_token
, rt.created_at AS oauth2_refresh_token_created_at
, rt.oauth2_access_token_id AS "oauth2_access_token_id?"
, os.created_at AS "oauth2_session_created_at"
, os.oauth2_session_id AS "oauth2_session_id!"
, os.oauth2_client_id AS "oauth2_client_id!"
, os.scope AS "oauth2_session_scope!"
, os.user_session_id AS "user_session_id!"
FROM oauth2_refresh_tokens rt
INNER JOIN oauth2_sessions os
USING (oauth2_session_id)
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE rt.refresh_token = $1
AND rt.consumed_at IS NULL
AND rt.revoked_at IS NULL
AND os.finished_at IS NULL
WHERE refresh_token = $1
"#,
token,
)
.fetch_one(&mut *conn)
.await?;
let state = match res.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
let refresh_token = RefreshToken {
id: res.oauth2_refresh_token_id.into(),
refresh_token: res.oauth2_refresh_token,
created_at: res.oauth2_refresh_token_created_at,
state,
session_id: res.oauth2_session_id.into(),
refresh_token: res.refresh_token,
created_at: res.created_at,
access_token_id: res.oauth2_access_token_id.map(Ulid::from),
};
let session_id = res.oauth2_session_id.into();
let scope = res.oauth2_session_scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(session_id)
.source(e)
})?;
let session = Session {
id: session_id,
state: SessionState::Valid,
created_at: res.oauth2_session_created_at,
client_id: res.oauth2_client_id.into(),
user_session_id: res.user_session_id.into(),
scope,
};
Ok(Some((refresh_token, session)))
Ok(Some(refresh_token))
}
#[tracing::instrument(
@@ -149,8 +130,8 @@ pub async fn lookup_active_refresh_token(
pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: &RefreshToken,
) -> Result<(), DatabaseError> {
refresh_token: RefreshToken,
) -> Result<RefreshToken, DatabaseError> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
@@ -164,5 +145,9 @@ pub async fn consume_refresh_token(
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
}