diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index c8aeb4cf..6d4e592e 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -56,10 +56,7 @@ impl AccessToken { &self, conn: &mut PgConnection, ) -> Result< - ( - mas_data_model::AccessToken, - Session, - ), + (mas_data_model::AccessToken, Session), AuthorizationVerificationError, > { let token = match self { diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index 85c68837..cf56a46b 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -17,47 +17,23 @@ use crc::{Crc, CRC_32_ISO_HDLC}; use mas_iana::oauth::OAuthTokenTypeHint; use rand::{distributions::Alphanumeric, Rng}; use thiserror::Error; - -use crate::traits::{StorageBackend, StorageBackendMarker}; +use ulid::Ulid; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct AccessToken { - pub data: T::AccessTokenData, +pub struct AccessToken { + pub id: Ulid, pub jti: String, pub access_token: String, pub created_at: DateTime, pub expires_at: DateTime, } -impl From> for AccessToken<()> { - fn from(t: AccessToken) -> Self { - AccessToken { - data: (), - jti: t.jti, - access_token: t.access_token, - expires_at: t.expires_at, - created_at: t.created_at, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct RefreshToken { - pub data: T::RefreshTokenData, +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RefreshToken { + pub id: Ulid, pub refresh_token: String, pub created_at: DateTime, - pub access_token: Option>, -} - -impl From> for RefreshToken<()> { - fn from(t: RefreshToken) -> Self { - RefreshToken { - data: (), - refresh_token: t.refresh_token, - created_at: t.created_at, - access_token: t.access_token.map(Into::into), - } - } + pub access_token: Option, } /// Type of token to generate or validate diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index b26bc1ad..3cdb0732 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -41,7 +41,7 @@ pub async fn add_access_token( session: &Session, access_token: String, expires_after: Duration, -) -> Result, anyhow::Error> { +) -> Result { let created_at = clock.now(); let expires_at = created_at + expires_after; let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); @@ -66,7 +66,7 @@ pub async fn add_access_token( .context("could not insert oauth2 access token")?; Ok(AccessToken { - data: id, + id, access_token, jti: id.to_string(), created_at, @@ -113,7 +113,7 @@ impl LookupError for AccessTokenLookupError { pub async fn lookup_active_access_token( conn: &mut PgConnection, token: &str, -) -> Result<(AccessToken, Session), AccessTokenLookupError> { +) -> Result<(AccessToken, Session), AccessTokenLookupError> { let res = sqlx::query_as!( OAuth2AccessTokenLookup, r#" @@ -162,7 +162,7 @@ pub async fn lookup_active_access_token( let id = Ulid::from(res.oauth2_access_token_id); let access_token = AccessToken { - data: id, + id, jti: id.to_string(), access_token: res.oauth2_access_token, created_at: res.oauth2_access_token_created_at, @@ -228,13 +228,13 @@ pub async fn lookup_active_access_token( #[tracing::instrument( skip_all, - fields(access_token.id = %access_token.data), + fields(%access_token.id), err(Debug), )] pub async fn revoke_access_token( executor: impl PgExecutor<'_>, clock: &Clock, - access_token: AccessToken, + access_token: AccessToken, ) -> anyhow::Result<()> { let revoked_at = clock.now(); let res = sqlx::query!( @@ -243,7 +243,7 @@ pub async fn revoke_access_token( SET revoked_at = $2 WHERE oauth2_access_token_id = $1 "#, - Uuid::from(access_token.data), + Uuid::from(access_token.id), revoked_at, ) .execute(executor) diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 0bedf5b8..a91285ca 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -42,9 +42,9 @@ pub async fn add_refresh_token( mut rng: impl Rng + Send, clock: &Clock, session: &Session, - access_token: AccessToken, + access_token: AccessToken, refresh_token: String, -) -> anyhow::Result> { +) -> anyhow::Result { let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); @@ -59,7 +59,7 @@ pub async fn add_refresh_token( "#, Uuid::from(id), Uuid::from(session.data), - Uuid::from(access_token.data), + Uuid::from(access_token.id), refresh_token, created_at, ) @@ -68,7 +68,7 @@ pub async fn add_refresh_token( .context("could not insert oauth2 refresh token")?; Ok(RefreshToken { - data: id, + id, refresh_token, access_token: Some(access_token), created_at, @@ -117,8 +117,7 @@ impl LookupError for RefreshTokenLookupError { pub async fn lookup_active_refresh_token( conn: &mut PgConnection, token: &str, -) -> Result<(RefreshToken, Session), RefreshTokenLookupError> -{ +) -> Result<(RefreshToken, Session), RefreshTokenLookupError> { let res = sqlx::query_as!( OAuth2RefreshTokenLookup, r#" @@ -181,7 +180,7 @@ pub async fn lookup_active_refresh_token( (Some(id), Some(access_token), Some(created_at), Some(expires_at)) => { let id = Ulid::from(id); Some(AccessToken { - data: id, + id, jti: id.to_string(), access_token, created_at, @@ -192,7 +191,7 @@ pub async fn lookup_active_refresh_token( }; let refresh_token = RefreshToken { - data: res.oauth2_refresh_token_id.into(), + id: res.oauth2_refresh_token_id.into(), refresh_token: res.oauth2_refresh_token, created_at: res.oauth2_refresh_token_created_at, access_token, @@ -261,14 +260,14 @@ pub async fn lookup_active_refresh_token( #[tracing::instrument( skip_all, fields( - refresh_token.id = %refresh_token.data, + %refresh_token.id, ), err(Debug), )] pub async fn consume_refresh_token( executor: impl PgExecutor<'_>, clock: &Clock, - refresh_token: &RefreshToken, + refresh_token: &RefreshToken, ) -> Result<(), anyhow::Error> { let consumed_at = clock.now(); let res = sqlx::query!( @@ -277,7 +276,7 @@ pub async fn consume_refresh_token( SET consumed_at = $2 WHERE oauth2_refresh_token_id = $1 "#, - Uuid::from(refresh_token.data), + Uuid::from(refresh_token.id), consumed_at, ) .execute(executor)