You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
storage: cleanup access/refresh token lookups
This commit is contained in:
@ -13,13 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{AccessToken, Session, SessionState};
|
||||
use mas_data_model::{AccessToken, AccessTokenState, Session};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
|
||||
use crate::{Clock, DatabaseError, LookupResultExt};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -63,8 +63,9 @@ pub async fn add_access_token(
|
||||
|
||||
Ok(AccessToken {
|
||||
id,
|
||||
state: AccessTokenState::default(),
|
||||
access_token,
|
||||
jti: id.to_string(),
|
||||
session_id: session.id,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
@ -73,74 +74,59 @@ pub async fn add_access_token(
|
||||
#[derive(Debug)]
|
||||
pub struct OAuth2AccessTokenLookup {
|
||||
oauth2_access_token_id: Uuid,
|
||||
oauth2_access_token: String,
|
||||
oauth2_access_token_created_at: DateTime<Utc>,
|
||||
oauth2_access_token_expires_at: DateTime<Utc>,
|
||||
oauth2_session_created_at: DateTime<Utc>,
|
||||
oauth2_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
scope: String,
|
||||
user_session_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: DateTime<Utc>,
|
||||
revoked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_access_token(
|
||||
impl From<OAuth2AccessTokenLookup> for AccessToken {
|
||||
fn from(value: OAuth2AccessTokenLookup) -> Self {
|
||||
let state = match value.revoked_at {
|
||||
None => AccessTokenState::Valid,
|
||||
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
|
||||
};
|
||||
|
||||
Self {
|
||||
id: value.oauth2_access_token_id.into(),
|
||||
state,
|
||||
session_id: value.oauth2_session_id.into(),
|
||||
access_token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn find_access_token(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<(AccessToken, Session)>, DatabaseError> {
|
||||
) -> Result<Option<AccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT at.oauth2_access_token_id
|
||||
, at.access_token AS "oauth2_access_token"
|
||||
, at.created_at AS "oauth2_access_token_created_at"
|
||||
, at.expires_at AS "oauth2_access_token_expires_at"
|
||||
, os.created_at AS "oauth2_session_created_at"
|
||||
, os.oauth2_session_id AS "oauth2_session_id!"
|
||||
, os.oauth2_client_id AS "oauth2_client_id!"
|
||||
, os.scope AS "scope!"
|
||||
, os.user_session_id AS "user_session_id!"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens at
|
||||
INNER JOIN oauth2_sessions os
|
||||
USING (oauth2_session_id)
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE at.access_token = $1
|
||||
AND at.revoked_at IS NULL
|
||||
AND os.finished_at IS NULL
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let access_token_id = Ulid::from(res.oauth2_access_token_id);
|
||||
let access_token = AccessToken {
|
||||
id: access_token_id,
|
||||
jti: access_token_id.to_string(),
|
||||
access_token: res.oauth2_access_token,
|
||||
created_at: res.oauth2_access_token_created_at,
|
||||
expires_at: res.oauth2_access_token_expires_at,
|
||||
};
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let scope = res.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(session_id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = Session {
|
||||
id: session_id,
|
||||
state: SessionState::Valid,
|
||||
created_at: res.oauth2_session_created_at,
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
};
|
||||
|
||||
Ok(Some((access_token, session)))
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -148,11 +134,48 @@ pub async fn lookup_active_access_token(
|
||||
fields(access_token.id = %access_token_id),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_access_token(
|
||||
conn: &mut PgConnection,
|
||||
access_token_id: Ulid,
|
||||
) -> Result<Option<AccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(access_token_id),
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%access_token.id,
|
||||
session.id = %access_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn revoke_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
access_token_id: Ulid,
|
||||
) -> Result<(), DatabaseError> {
|
||||
access_token: AccessToken,
|
||||
) -> Result<AccessToken, DatabaseError> {
|
||||
let revoked_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@ -160,13 +183,17 @@ pub async fn revoke_access_token(
|
||||
SET revoked_at = $2
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(access_token_id),
|
||||
Uuid::from(access_token.id),
|
||||
revoked_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
access_token
|
||||
.revoke(revoked_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(
|
||||
|
@ -13,13 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AccessToken, RefreshToken, Session, SessionState};
|
||||
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
|
||||
use crate::{Clock, DatabaseError};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -62,6 +62,8 @@ pub async fn add_refresh_token(
|
||||
|
||||
Ok(RefreshToken {
|
||||
id,
|
||||
state: RefreshTokenState::default(),
|
||||
session_id: session.id,
|
||||
refresh_token,
|
||||
access_token_id: Some(access_token.id),
|
||||
created_at,
|
||||
@ -70,73 +72,52 @@ pub async fn add_refresh_token(
|
||||
|
||||
struct OAuth2RefreshTokenLookup {
|
||||
oauth2_refresh_token_id: Uuid,
|
||||
oauth2_refresh_token: String,
|
||||
oauth2_refresh_token_created_at: DateTime<Utc>,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
oauth2_access_token_id: Option<Uuid>,
|
||||
oauth2_session_created_at: DateTime<Utc>,
|
||||
oauth2_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
oauth2_session_scope: String,
|
||||
user_session_id: Uuid,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_refresh_token(
|
||||
pub async fn lookup_refresh_token(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<(RefreshToken, Session)>, DatabaseError> {
|
||||
) -> Result<Option<RefreshToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
SELECT rt.oauth2_refresh_token_id
|
||||
, rt.refresh_token AS oauth2_refresh_token
|
||||
, rt.created_at AS oauth2_refresh_token_created_at
|
||||
, rt.oauth2_access_token_id AS "oauth2_access_token_id?"
|
||||
, os.created_at AS "oauth2_session_created_at"
|
||||
, os.oauth2_session_id AS "oauth2_session_id!"
|
||||
, os.oauth2_client_id AS "oauth2_client_id!"
|
||||
, os.scope AS "oauth2_session_scope!"
|
||||
, os.user_session_id AS "user_session_id!"
|
||||
FROM oauth2_refresh_tokens rt
|
||||
INNER JOIN oauth2_sessions os
|
||||
USING (oauth2_session_id)
|
||||
SELECT oauth2_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE rt.refresh_token = $1
|
||||
AND rt.consumed_at IS NULL
|
||||
AND rt.revoked_at IS NULL
|
||||
AND os.finished_at IS NULL
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let state = match res.consumed_at {
|
||||
None => RefreshTokenState::Valid,
|
||||
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
|
||||
};
|
||||
|
||||
let refresh_token = RefreshToken {
|
||||
id: res.oauth2_refresh_token_id.into(),
|
||||
refresh_token: res.oauth2_refresh_token,
|
||||
created_at: res.oauth2_refresh_token_created_at,
|
||||
state,
|
||||
session_id: res.oauth2_session_id.into(),
|
||||
refresh_token: res.refresh_token,
|
||||
created_at: res.created_at,
|
||||
access_token_id: res.oauth2_access_token_id.map(Ulid::from),
|
||||
};
|
||||
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let scope = res.oauth2_session_scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(session_id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = Session {
|
||||
id: session_id,
|
||||
state: SessionState::Valid,
|
||||
created_at: res.oauth2_session_created_at,
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
};
|
||||
|
||||
Ok(Some((refresh_token, session)))
|
||||
Ok(Some(refresh_token))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -149,8 +130,8 @@ pub async fn lookup_active_refresh_token(
|
||||
pub async fn consume_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
refresh_token: &RefreshToken,
|
||||
) -> Result<(), DatabaseError> {
|
||||
refresh_token: RefreshToken,
|
||||
) -> Result<RefreshToken, DatabaseError> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@ -164,5 +145,9 @@ pub async fn consume_refresh_token(
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
refresh_token
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
Reference in New Issue
Block a user