1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

storage: cleanup access/refresh token lookups

This commit is contained in:
Quentin Gliech
2023-01-11 12:14:52 +01:00
parent 920869b583
commit 9f0c9f1466
9 changed files with 452 additions and 263 deletions

View File

@ -13,13 +13,13 @@
// limitations under the License.
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Session, SessionState};
use mas_data_model::{AccessToken, AccessTokenState, Session};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
use crate::{Clock, DatabaseError, LookupResultExt};
#[tracing::instrument(
skip_all,
@ -63,8 +63,9 @@ pub async fn add_access_token(
Ok(AccessToken {
id,
state: AccessTokenState::default(),
access_token,
jti: id.to_string(),
session_id: session.id,
created_at,
expires_at,
})
@ -73,74 +74,59 @@ pub async fn add_access_token(
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid,
oauth2_access_token: String,
oauth2_access_token_created_at: DateTime<Utc>,
oauth2_access_token_expires_at: DateTime<Utc>,
oauth2_session_created_at: DateTime<Utc>,
oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
user_session_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
}
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_access_token(
impl From<OAuth2AccessTokenLookup> for AccessToken {
fn from(value: OAuth2AccessTokenLookup) -> Self {
let state = match value.revoked_at {
None => AccessTokenState::Valid,
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
};
Self {
id: value.oauth2_access_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
access_token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[tracing::instrument(skip_all, err)]
pub async fn find_access_token(
conn: &mut PgConnection,
token: &str,
) -> Result<Option<(AccessToken, Session)>, DatabaseError> {
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT at.oauth2_access_token_id
, at.access_token AS "oauth2_access_token"
, at.created_at AS "oauth2_access_token_created_at"
, at.expires_at AS "oauth2_access_token_expires_at"
, os.created_at AS "oauth2_session_created_at"
, os.oauth2_session_id AS "oauth2_session_id!"
, os.oauth2_client_id AS "oauth2_client_id!"
, os.scope AS "scope!"
, os.user_session_id AS "user_session_id!"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
USING (oauth2_session_id)
FROM oauth2_access_tokens
WHERE at.access_token = $1
AND at.revoked_at IS NULL
AND os.finished_at IS NULL
WHERE access_token = $1
"#,
token,
)
.fetch_one(&mut *conn)
.await?;
.await
.to_option()?;
let access_token_id = Ulid::from(res.oauth2_access_token_id);
let access_token = AccessToken {
id: access_token_id,
jti: access_token_id.to_string(),
access_token: res.oauth2_access_token,
created_at: res.oauth2_access_token_created_at,
expires_at: res.oauth2_access_token_expires_at,
};
let Some(res) = res else { return Ok(None) };
let session_id = res.oauth2_session_id.into();
let scope = res.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(session_id)
.source(e)
})?;
let session = Session {
id: session_id,
state: SessionState::Valid,
created_at: res.oauth2_session_created_at,
client_id: res.oauth2_client_id.into(),
user_session_id: res.user_session_id.into(),
scope,
};
Ok(Some((access_token, session)))
Ok(Some(res.into()))
}
#[tracing::instrument(
@ -148,11 +134,48 @@ pub async fn lookup_active_access_token(
fields(access_token.id = %access_token_id),
err,
)]
pub async fn lookup_access_token(
conn: &mut PgConnection,
access_token_id: Ulid,
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token_id),
)
.fetch_one(&mut *conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(
%access_token.id,
session.id = %access_token.session_id,
),
err,
)]
pub async fn revoke_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token_id: Ulid,
) -> Result<(), DatabaseError> {
access_token: AccessToken,
) -> Result<AccessToken, DatabaseError> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
@ -160,13 +183,17 @@ pub async fn revoke_access_token(
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token_id),
Uuid::from(access_token.id),
revoked_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
pub async fn cleanup_expired(

View File

@ -13,13 +13,13 @@
// limitations under the License.
use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, Session, SessionState};
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
use crate::{Clock, DatabaseError};
#[tracing::instrument(
skip_all,
@ -62,6 +62,8 @@ pub async fn add_refresh_token(
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
@ -70,73 +72,52 @@ pub async fn add_refresh_token(
struct OAuth2RefreshTokenLookup {
oauth2_refresh_token_id: Uuid,
oauth2_refresh_token: String,
oauth2_refresh_token_created_at: DateTime<Utc>,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
oauth2_access_token_id: Option<Uuid>,
oauth2_session_created_at: DateTime<Utc>,
oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
oauth2_session_scope: String,
user_session_id: Uuid,
}
#[tracing::instrument(skip_all, err)]
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_refresh_token(
pub async fn lookup_refresh_token(
conn: &mut PgConnection,
token: &str,
) -> Result<Option<(RefreshToken, Session)>, DatabaseError> {
) -> Result<Option<RefreshToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT rt.oauth2_refresh_token_id
, rt.refresh_token AS oauth2_refresh_token
, rt.created_at AS oauth2_refresh_token_created_at
, rt.oauth2_access_token_id AS "oauth2_access_token_id?"
, os.created_at AS "oauth2_session_created_at"
, os.oauth2_session_id AS "oauth2_session_id!"
, os.oauth2_client_id AS "oauth2_client_id!"
, os.scope AS "oauth2_session_scope!"
, os.user_session_id AS "user_session_id!"
FROM oauth2_refresh_tokens rt
INNER JOIN oauth2_sessions os
USING (oauth2_session_id)
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE rt.refresh_token = $1
AND rt.consumed_at IS NULL
AND rt.revoked_at IS NULL
AND os.finished_at IS NULL
WHERE refresh_token = $1
"#,
token,
)
.fetch_one(&mut *conn)
.await?;
let state = match res.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
let refresh_token = RefreshToken {
id: res.oauth2_refresh_token_id.into(),
refresh_token: res.oauth2_refresh_token,
created_at: res.oauth2_refresh_token_created_at,
state,
session_id: res.oauth2_session_id.into(),
refresh_token: res.refresh_token,
created_at: res.created_at,
access_token_id: res.oauth2_access_token_id.map(Ulid::from),
};
let session_id = res.oauth2_session_id.into();
let scope = res.oauth2_session_scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(session_id)
.source(e)
})?;
let session = Session {
id: session_id,
state: SessionState::Valid,
created_at: res.oauth2_session_created_at,
client_id: res.oauth2_client_id.into(),
user_session_id: res.user_session_id.into(),
scope,
};
Ok(Some((refresh_token, session)))
Ok(Some(refresh_token))
}
#[tracing::instrument(
@ -149,8 +130,8 @@ pub async fn lookup_active_refresh_token(
pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: &RefreshToken,
) -> Result<(), DatabaseError> {
refresh_token: RefreshToken,
) -> Result<RefreshToken, DatabaseError> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
@ -164,5 +145,9 @@ pub async fn consume_refresh_token(
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
}