1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

storage: remaining oauth2 repositories

- authorization grants
 - access tokens
 - refresh tokens
This commit is contained in:
Quentin Gliech
2023-01-12 18:26:04 +01:00
parent 36396c0b45
commit 488a666a8d
17 changed files with 1700 additions and 1366 deletions

View File

@ -29,7 +29,7 @@ use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, Header
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
use mas_data_model::Session; use mas_data_model::Session;
use mas_storage::{ use mas_storage::{
oauth2::{access_token::find_access_token, OAuth2SessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
DatabaseError, Repository, DatabaseError, Repository,
}; };
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
@ -62,7 +62,9 @@ impl AccessToken {
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
}; };
let token = find_access_token(conn, token.as_str()) let token = conn
.oauth2_access_token()
.find_by_token(token.as_str())
.await? .await?
.ok_or(AuthorizationVerificationError::InvalidToken)?; .ok_or(AuthorizationVerificationError::InvalidToken)?;

View File

@ -78,7 +78,7 @@ impl AuthorizationGrantStage {
Self::Pending Self::Pending
} }
pub fn fulfill( fn fulfill(
self, self,
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
session: &Session, session: &Session,
@ -92,7 +92,7 @@ impl AuthorizationGrantStage {
} }
} }
pub fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> { fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self { match self {
Self::Fulfilled { Self::Fulfilled {
fulfilled_at, fulfilled_at,
@ -106,7 +106,7 @@ impl AuthorizationGrantStage {
} }
} }
pub fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> { fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self { match self {
Self::Pending => Ok(Self::Cancelled { cancelled_at }), Self::Pending => Ok(Self::Cancelled { cancelled_at }),
_ => Err(InvalidTransitionError), _ => Err(InvalidTransitionError),
@ -146,4 +146,24 @@ impl AuthorizationGrant {
let max_age: Option<i64> = self.max_age.map(|x| x.get().into()); let max_age: Option<i64> = self.max_age.map(|x| x.get().into());
self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365)) self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365))
} }
pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.stage = self.stage.exchange(exchanged_at)?;
Ok(self)
}
pub fn fulfill(
mut self,
fulfilled_at: DateTime<Utc>,
session: &Session,
) -> Result<Self, InvalidTransitionError> {
self.stage = self.stage.fulfill(fulfilled_at, session)?;
Ok(self)
}
// TODO: this is not used?
pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.stage = self.stage.cancel(canceld_at)?;
Ok(self)
}
} }

View File

@ -26,11 +26,7 @@ use mas_keystore::Encrypter;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
authorization_grant::{fulfill_grant, get_grant_by_id},
consent::fetch_client_consent,
OAuth2ClientRepository, OAuth2SessionRepository,
},
Repository, Repository,
}; };
use mas_templates::Templates; use mas_templates::Templates;
@ -94,7 +90,9 @@ pub(crate) async fn get(
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut txn).await?;
let grant = get_grant_by_id(&mut txn, grant_id) let grant = txn
.oauth2_authorization_grant()
.lookup(grant_id)
.await? .await?
.ok_or(RouteError::NotFound)?; .ok_or(RouteError::NotFound)?;
@ -192,7 +190,10 @@ pub(crate) async fn complete(
.await? .await?
.ok_or(GrantCompletionError::NoSuchClient)?; .ok_or(GrantCompletionError::NoSuchClient)?;
let current_consent = fetch_client_consent(&mut txn, &browser_session.user, &client).await?; let current_consent = txn
.oauth2_client()
.get_consent_for_user(&client, &browser_session.user)
.await?;
let lacks_consent = grant let lacks_consent = grant
.scope .scope
@ -211,7 +212,10 @@ pub(crate) async fn complete(
.create_from_grant(&mut rng, &clock, &grant, &browser_session) .create_from_grant(&mut rng, &clock, &grant, &browser_session)
.await?; .await?;
let grant = fulfill_grant(&mut txn, grant, session.clone()).await?; let grant = txn
.oauth2_authorization_grant()
.fulfill(&clock, &session, grant)
.await?;
// Yep! Let's complete the auth now // Yep! Let's complete the auth now
let mut params = AuthorizationResponse::default(); let mut params = AuthorizationResponse::default();

View File

@ -26,7 +26,7 @@ use mas_keystore::Encrypter;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{authorization_grant::new_authorization_grant, OAuth2ClientRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
Repository, Repository,
}; };
use mas_templates::Templates; use mas_templates::Templates;
@ -275,23 +275,23 @@ pub(crate) async fn get(
let requires_consent = prompt.contains(&Prompt::Consent); let requires_consent = prompt.contains(&Prompt::Consent);
let grant = new_authorization_grant( let grant = txn
&mut txn, .oauth2_authorization_grant()
&mut rng, .add(
&clock, &mut rng,
client, &clock,
redirect_uri.clone(), &client,
params.auth.scope, redirect_uri.clone(),
code, params.auth.scope,
params.auth.state.clone(), code,
params.auth.nonce, params.auth.state.clone(),
params.auth.max_age, params.auth.nonce,
None, params.auth.max_age,
response_mode, response_mode,
response_type.has_id_token(), response_type.has_id_token(),
requires_consent, requires_consent,
) )
.await?; .await?;
let continue_grant = PostAuthAction::continue_grant(grant.id); let continue_grant = PostAuthAction::continue_grant(grant.id);
let res = match maybe_session { let res = match maybe_session {

View File

@ -29,11 +29,7 @@ use mas_keystore::Encrypter;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
authorization_grant::{get_grant_by_id, give_consent_to_grant},
consent::insert_client_consent,
OAuth2ClientRepository,
},
Repository, Repository,
}; };
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
@ -91,7 +87,9 @@ pub(crate) async fn get(
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut conn).await?;
let grant = get_grant_by_id(&mut conn, grant_id) let grant = conn
.oauth2_authorization_grant()
.lookup(grant_id)
.await? .await?
.ok_or(RouteError::GrantNotFound)?; .ok_or(RouteError::GrantNotFound)?;
@ -146,7 +144,9 @@ pub(crate) async fn post(
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut txn).await?;
let grant = get_grant_by_id(&mut txn, grant_id) let grant = txn
.oauth2_authorization_grant()
.lookup(grant_id)
.await? .await?
.ok_or(RouteError::GrantNotFound)?; .ok_or(RouteError::GrantNotFound)?;
let next = PostAuthAction::continue_grant(grant_id); let next = PostAuthAction::continue_grant(grant_id);
@ -180,17 +180,17 @@ pub(crate) async fn post(
.filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:")) .filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:"))
.cloned() .cloned()
.collect(); .collect();
insert_client_consent( txn.oauth2_client()
&mut txn, .give_consent_for_user(
&mut rng, &mut rng,
&clock, &clock,
&session.user, &client,
&client, &session.user,
&scope_without_device, &scope_without_device,
) )
.await?; .await?;
let _grant = give_consent_to_grant(&mut txn, grant).await?; txn.oauth2_authorization_grant().give_consent(grant).await?;
txn.commit().await?; txn.commit().await?;

View File

@ -23,10 +23,7 @@ use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{ oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
access_token::find_access_token, refresh_token::lookup_refresh_token,
OAuth2SessionRepository,
},
user::{BrowserSessionRepository, UserRepository}, user::{BrowserSessionRepository, UserRepository},
Clock, Repository, Clock, Repository,
}; };
@ -169,7 +166,9 @@ pub(crate) async fn post(
let reply = match token_type { let reply = match token_type {
TokenType::AccessToken => { TokenType::AccessToken => {
let token = find_access_token(&mut conn, token) let token = conn
.oauth2_access_token()
.find_by_token(token)
.await? .await?
.filter(|t| t.is_valid(clock.now())) .filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
@ -206,7 +205,9 @@ pub(crate) async fn post(
} }
TokenType::RefreshToken => { TokenType::RefreshToken => {
let token = lookup_refresh_token(&mut conn, token) let token = conn
.oauth2_refresh_token()
.find_by_token(token)
.await? .await?
.filter(|t| t.is_valid()) .filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;

View File

@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // Copyright 2021-2023 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -33,10 +33,8 @@ use mas_keystore::{Encrypter, Keystore};
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
oauth2::{ oauth2::{
access_token::{add_access_token, lookup_access_token, revoke_access_token}, OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
authorization_grant::{exchange_grant, lookup_grant_by_code}, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token},
OAuth2SessionRepository,
}, },
user::BrowserSessionRepository, user::BrowserSessionRepository,
Repository, Repository,
@ -217,9 +215,9 @@ async fn authorization_code_grant(
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
// TODO: there is a bunch of unnecessary cloning here let authz_grant = txn
// TODO: handle "not found" cases .oauth2_authorization_grant()
let authz_grant = lookup_grant_by_code(&mut txn, &grant.code) .find_by_code(&grant.code)
.await? .await?
.ok_or(RouteError::GrantNotFound)?; .ok_or(RouteError::GrantNotFound)?;
@ -301,18 +299,15 @@ async fn authorization_code_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng); let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let access_token = let access_token = txn
add_access_token(&mut txn, &mut rng, &clock, &session, access_token_str, ttl).await?; .oauth2_access_token()
.add(&mut rng, &clock, &session, access_token_str, ttl)
.await?;
let refresh_token = add_refresh_token( let refresh_token = txn
&mut txn, .oauth2_refresh_token()
&mut rng, .add(&mut rng, &clock, &session, &access_token, refresh_token_str)
&clock, .await?;
&session,
&access_token,
refresh_token_str,
)
.await?;
let id_token = if session.scope.contains(&scope::OPENID) { let id_token = if session.scope.contains(&scope::OPENID) {
let mut claims = HashMap::new(); let mut claims = HashMap::new();
@ -360,7 +355,9 @@ async fn authorization_code_grant(
params = params.with_id_token(id_token); params = params.with_id_token(id_token);
} }
exchange_grant(&mut txn, &clock, authz_grant).await?; txn.oauth2_authorization_grant()
.exchange(&clock, authz_grant)
.await?;
txn.commit().await?; txn.commit().await?;
@ -374,7 +371,9 @@ async fn refresh_token_grant(
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token) let refresh_token = txn
.oauth2_refresh_token()
.find_by_token(&grant.refresh_token)
.await? .await?
.ok_or(RouteError::InvalidGrant)?; .ok_or(RouteError::InvalidGrant)?;
@ -397,31 +396,32 @@ async fn refresh_token_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng); let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let new_access_token = add_access_token( let new_access_token = txn
&mut txn, .oauth2_access_token()
&mut rng, .add(&mut rng, &clock, &session, access_token_str.clone(), ttl)
&clock, .await?;
&session,
access_token_str.clone(),
ttl,
)
.await?;
let new_refresh_token = add_refresh_token( let new_refresh_token = txn
&mut txn, .oauth2_refresh_token()
&mut rng, .add(
&clock, &mut rng,
&session, &clock,
&new_access_token, &session,
refresh_token_str, &new_access_token,
) refresh_token_str,
.await?; )
.await?;
let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?; let refresh_token = txn
.oauth2_refresh_token()
.consume(&clock, refresh_token)
.await?;
if let Some(access_token_id) = refresh_token.access_token_id { if let Some(access_token_id) = refresh_token.access_token_id {
if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? { if let Some(access_token) = txn.oauth2_access_token().lookup(access_token_id).await? {
revoke_access_token(&mut txn, &clock, access_token).await?; txn.oauth2_access_token()
.revoke(&clock, access_token)
.await?;
} }
} }

View File

@ -15,7 +15,7 @@
use anyhow::Context; use anyhow::Context;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
compat::CompatSsoLoginRepository, oauth2::authorization_grant::get_grant_by_id, compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository,
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository,
}; };
use mas_templates::{PostAuthContext, PostAuthContextInner}; use mas_templates::{PostAuthContext, PostAuthContextInner};
@ -46,7 +46,9 @@ impl OptionalPostAuthAction {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action { let ctx = match action {
PostAuthAction::ContinueAuthorizationGrant { id } => { PostAuthAction::ContinueAuthorizationGrant { id } => {
let grant = get_grant_by_id(conn, id) let grant = conn
.oauth2_authorization_grant()
.lookup(id)
.await? .await?
.context("Failed to load authorization grant")?; .context("Failed to load authorization grant")?;
let grant = Box::new(grant); let grant = Box::new(grant);

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C. // Copyright 2021-2023 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,67 +12,61 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, AccessTokenState, Session}; use mas_data_model::{AccessToken, AccessTokenState, Session};
use rand::Rng; use rand::RngCore;
use sqlx::{PgConnection, PgExecutor}; use sqlx::PgConnection;
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, DatabaseError, LookupResultExt}; use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[tracing::instrument( #[async_trait]
skip_all, pub trait OAuth2AccessTokenRepository: Send + Sync {
fields( type Error;
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
access_token.id,
),
err,
)]
pub async fn add_access_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, sqlx::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id)); /// Lookup an access token by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error>;
sqlx::query!( /// Find an access token by its token
r#" async fn find_by_token(
INSERT INTO oauth2_access_tokens &mut self,
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) access_token: &str,
VALUES ) -> Result<Option<AccessToken>, Self::Error>;
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
&access_token,
created_at,
expires_at,
)
.execute(executor)
.await?;
Ok(AccessToken { /// Add a new access token to the database
id, async fn add(
state: AccessTokenState::default(), &mut self,
access_token, rng: &mut (dyn RngCore + Send),
session_id: session.id, clock: &Clock,
created_at, session: &Session,
expires_at, access_token: String,
}) expires_after: Duration,
) -> Result<AccessToken, Self::Error>;
/// Revoke an access token
async fn revoke(
&mut self,
clock: &Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error>;
/// Cleanup expired access tokens
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error>;
} }
#[derive(Debug)] pub struct PgOAuth2AccessTokenRepository<'c> {
pub struct OAuth2AccessTokenLookup { conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid, oauth2_access_token_id: Uuid,
oauth2_session_id: Uuid, oauth2_session_id: Uuid,
access_token: String, access_token: String,
@ -99,118 +93,164 @@ impl From<OAuth2AccessTokenLookup> for AccessToken {
} }
} }
#[tracing::instrument(skip_all, err)] #[async_trait]
pub async fn find_access_token( impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
conn: &mut PgConnection, type Error = DatabaseError;
token: &str,
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
WHERE access_token = $1 FROM oauth2_access_tokens
"#,
token,
)
.fetch_one(&mut *conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) }; WHERE oauth2_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
Ok(Some(res.into())) let Some(res) = res else { return Ok(None) };
}
Ok(Some(res.into()))
#[tracing::instrument( }
skip_all,
fields(access_token.id = %access_token_id), #[tracing::instrument(
err, name = "db.oauth2_access_token.find_by_token",
)] skip_all,
pub async fn lookup_access_token( fields(
conn: &mut PgConnection, db.statement,
access_token_id: Ulid, ),
) -> Result<Option<AccessToken>, DatabaseError> { err,
let res = sqlx::query_as!( )]
OAuth2AccessTokenLookup, async fn find_by_token(
r#" &mut self,
SELECT oauth2_access_token_id access_token: &str,
, access_token ) -> Result<Option<AccessToken>, Self::Error> {
, created_at let res = sqlx::query_as!(
, expires_at OAuth2AccessTokenLookup,
, revoked_at r#"
, oauth2_session_id SELECT oauth2_access_token_id
, access_token
FROM oauth2_access_tokens , created_at
, expires_at
WHERE oauth2_access_token_id = $1 , revoked_at
"#, , oauth2_session_id
Uuid::from(access_token_id),
) FROM oauth2_access_tokens
.fetch_one(&mut *conn)
.await WHERE access_token = $1
.to_option()?; "#,
access_token,
let Some(res) = res else { return Ok(None) }; )
.fetch_one(&mut *self.conn)
Ok(Some(res.into())) .await
} .to_option()?;
#[tracing::instrument( let Some(res) = res else { return Ok(None) };
skip_all,
fields( Ok(Some(res.into()))
%access_token.id, }
session.id = %access_token.session_id,
), #[tracing::instrument(
err, name = "db.oauth2_access_token.add",
)] skip_all,
pub async fn revoke_access_token( fields(
executor: impl PgExecutor<'_>, db.statement,
clock: &Clock, %session.id,
access_token: AccessToken, user_session.id = %session.user_session_id,
) -> Result<AccessToken, DatabaseError> { client.id = %session.client_id,
let revoked_at = clock.now(); access_token.id,
let res = sqlx::query!( ),
r#" err,
UPDATE oauth2_access_tokens )]
SET revoked_at = $2 async fn add(
WHERE oauth2_access_token_id = $1 &mut self,
"#, rng: &mut (dyn RngCore + Send),
Uuid::from(access_token.id), clock: &Clock,
revoked_at, session: &Session,
) access_token: String,
.execute(executor) expires_after: Duration,
.await?; ) -> Result<AccessToken, Self::Error> {
let created_at = clock.now();
DatabaseError::ensure_affected_rows(&res, 1)?; let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
access_token
.revoke(revoked_at) tracing::Span::current().record("access_token.id", tracing::field::display(id));
.map_err(DatabaseError::to_invalid_operation)
} sqlx::query!(
r#"
pub async fn cleanup_expired( INSERT INTO oauth2_access_tokens
executor: impl PgExecutor<'_>, (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
clock: &Clock, VALUES
) -> Result<u64, sqlx::Error> { ($1, $2, $3, $4, $5)
// Cleanup token which expired more than 15 minutes ago "#,
let threshold = clock.now() - Duration::minutes(15); Uuid::from(id),
let res = sqlx::query!( Uuid::from(session.id),
r#" &access_token,
DELETE FROM oauth2_access_tokens created_at,
WHERE expires_at < $1 expires_at,
"#, )
threshold, .traced()
) .execute(&mut *self.conn)
.execute(executor) .await?;
.await?;
Ok(AccessToken {
Ok(res.rows_affected()) id,
state: AccessTokenState::default(),
access_token,
session_id: session.id,
created_at,
expires_at,
})
}
async fn revoke(
&mut self,
clock: &Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token.id),
revoked_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error> {
// Cleanup token which expired more than 15 minutes ago
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE expires_at < $1
"#,
threshold,
)
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
} }

View File

@ -14,138 +14,97 @@
use std::num::NonZeroU32; use std::num::NonZeroU32;
use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{ use mas_data_model::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
}; };
use mas_iana::oauth::PkceCodeChallengeMethod; use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope}; use oauth2_types::{requests::ResponseMode, scope::Scope};
use rand::Rng; use rand::RngCore;
use sqlx::{PgConnection, PgExecutor}; use sqlx::PgConnection;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[tracing::instrument( #[async_trait]
skip_all, pub trait OAuth2AuthorizationGrantRepository {
fields( type Error;
%client.id,
grant.id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
client: Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
_acr_values: Option<String>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, sqlx::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let created_at = clock.now(); #[allow(clippy::too_many_arguments)]
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); async fn add(
tracing::Span::current().record("grant.id", tracing::field::display(id)); &mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, Self::Error>;
sqlx::query!( async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error>;
r#"
INSERT INTO oauth2_authorization_grants (
oauth2_authorization_grant_id,
oauth2_client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_mode,
code_challenge,
code_challenge_method,
response_type_code,
response_type_id_token,
authorization_code,
requires_consent,
created_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#,
Uuid::from(id),
Uuid::from(client.id),
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
max_age_i32,
response_mode.to_string(),
code_challenge,
code_challenge_method,
code.is_some(),
response_type_id_token,
code_str,
requires_consent,
created_at,
)
.execute(executor)
.await?;
Ok(AuthorizationGrant { async fn find_by_code(&mut self, code: &str)
id, -> Result<Option<AuthorizationGrant>, Self::Error>;
stage: AuthorizationGrantStage::Pending,
code, async fn fulfill(
redirect_uri, &mut self,
client_id: client.id, clock: &Clock,
scope, session: &Session,
state, authorization_grant: AuthorizationGrant,
nonce, ) -> Result<AuthorizationGrant, Self::Error>;
max_age,
response_mode, async fn exchange(
created_at, &mut self,
response_type_id_token, clock: &Clock,
requires_consent, authorization_grant: AuthorizationGrant,
}) ) -> Result<AuthorizationGrant, Self::Error>;
async fn give_consent(
&mut self,
authorization_grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error>;
}
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
} }
#[allow(clippy::struct_excessive_bools)] #[allow(clippy::struct_excessive_bools)]
struct GrantLookup { struct GrantLookup {
oauth2_authorization_grant_id: Uuid, oauth2_authorization_grant_id: Uuid,
oauth2_authorization_grant_created_at: DateTime<Utc>, created_at: DateTime<Utc>,
oauth2_authorization_grant_cancelled_at: Option<DateTime<Utc>>, cancelled_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_fulfilled_at: Option<DateTime<Utc>>, fulfilled_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_exchanged_at: Option<DateTime<Utc>>, exchanged_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_scope: String, scope: String,
oauth2_authorization_grant_state: Option<String>, state: Option<String>,
oauth2_authorization_grant_nonce: Option<String>, nonce: Option<String>,
oauth2_authorization_grant_redirect_uri: String, redirect_uri: String,
oauth2_authorization_grant_response_mode: String, response_mode: String,
oauth2_authorization_grant_max_age: Option<i32>, max_age: Option<i32>,
oauth2_authorization_grant_response_type_code: bool, response_type_code: bool,
oauth2_authorization_grant_response_type_id_token: bool, response_type_id_token: bool,
oauth2_authorization_grant_code: Option<String>, authorization_code: Option<String>,
oauth2_authorization_grant_code_challenge: Option<String>, code_challenge: Option<String>,
oauth2_authorization_grant_code_challenge_method: Option<String>, code_challenge_method: Option<String>,
oauth2_authorization_grant_requires_consent: bool, requires_consent: bool,
oauth2_client_id: Uuid, oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>, oauth2_session_id: Option<Uuid>,
} }
@ -156,20 +115,17 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> { fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
let id = value.oauth2_authorization_grant_id.into(); let id = value.oauth2_authorization_grant_id.into();
let scope: Scope = value let scope: Scope = value.scope.parse().map_err(|e| {
.oauth2_authorization_grant_scope DatabaseInconsistencyError::on("oauth2_authorization_grants")
.parse() .column("scope")
.map_err(|e| { .row(id)
DatabaseInconsistencyError::on("oauth2_authorization_grants") .source(e)
.column("scope") })?;
.row(id)
.source(e)
})?;
let stage = match ( let stage = match (
value.oauth2_authorization_grant_fulfilled_at, value.fulfilled_at,
value.oauth2_authorization_grant_exchanged_at, value.exchanged_at,
value.oauth2_authorization_grant_cancelled_at, value.cancelled_at,
value.oauth2_session_id, value.oauth2_session_id,
) { ) {
(None, None, None, None) => AuthorizationGrantStage::Pending, (None, None, None, None) => AuthorizationGrantStage::Pending,
@ -198,10 +154,7 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
} }
}; };
let pkce = match ( let pkce = match (value.code_challenge, value.code_challenge_method) {
value.oauth2_authorization_grant_code_challenge,
value.oauth2_authorization_grant_code_challenge_method,
) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce { Some(Pkce {
challenge_method: PkceCodeChallengeMethod::Plain, challenge_method: PkceCodeChallengeMethod::Plain,
@ -222,44 +175,35 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
} }
}; };
let code: Option<AuthorizationCode> = match ( let code: Option<AuthorizationCode> =
value.oauth2_authorization_grant_response_type_code, match (value.response_type_code, value.authorization_code, pkce) {
value.oauth2_authorization_grant_code, (false, None, None) => None,
pkce, (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
) { _ => {
(false, None, None) => None, return Err(
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), DatabaseInconsistencyError::on("oauth2_authorization_grants")
_ => { .column("authorization_code")
return Err( .row(id),
DatabaseInconsistencyError::on("oauth2_authorization_grants") );
.column("authorization_code") }
.row(id), };
);
}
};
let redirect_uri = value let redirect_uri = value.redirect_uri.parse().map_err(|e| {
.oauth2_authorization_grant_redirect_uri DatabaseInconsistencyError::on("oauth2_authorization_grants")
.parse() .column("redirect_uri")
.map_err(|e| { .row(id)
DatabaseInconsistencyError::on("oauth2_authorization_grants") .source(e)
.column("redirect_uri") })?;
.row(id)
.source(e)
})?;
let response_mode = value let response_mode = value.response_mode.parse().map_err(|e| {
.oauth2_authorization_grant_response_mode DatabaseInconsistencyError::on("oauth2_authorization_grants")
.parse() .column("response_mode")
.map_err(|e| { .row(id)
DatabaseInconsistencyError::on("oauth2_authorization_grants") .source(e)
.column("response_mode") })?;
.row(id)
.source(e)
})?;
let max_age = value let max_age = value
.oauth2_authorization_grant_max_age .max_age
.map(u32::try_from) .map(u32::try_from)
.transpose() .transpose()
.map_err(|e| { .map_err(|e| {
@ -283,209 +227,330 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
client_id: value.oauth2_client_id.into(), client_id: value.oauth2_client_id.into(),
code, code,
scope, scope,
state: value.oauth2_authorization_grant_state, state: value.state,
nonce: value.oauth2_authorization_grant_nonce, nonce: value.nonce,
max_age, max_age,
response_mode, response_mode,
redirect_uri, redirect_uri,
created_at: value.oauth2_authorization_grant_created_at, created_at: value.created_at,
response_type_id_token: value.oauth2_authorization_grant_response_type_id_token, response_type_id_token: value.response_type_id_token,
requires_consent: value.oauth2_authorization_grant_requires_consent, requires_consent: value.requires_consent,
}) })
} }
} }
#[tracing::instrument( #[async_trait]
skip_all, impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
fields(grant.id = %id), type Error = DatabaseError;
err,
)]
pub async fn get_grant_by_id(
conn: &mut PgConnection,
id: Ulid,
) -> Result<Option<AuthorizationGrant>, DatabaseError> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at AS oauth2_authorization_grant_created_at
, cancelled_at AS oauth2_authorization_grant_cancelled_at
, fulfilled_at AS oauth2_authorization_grant_fulfilled_at
, exchanged_at AS oauth2_authorization_grant_exchanged_at
, scope AS oauth2_authorization_grant_scope
, state AS oauth2_authorization_grant_state
, redirect_uri AS oauth2_authorization_grant_redirect_uri
, response_mode AS oauth2_authorization_grant_response_mode
, nonce AS oauth2_authorization_grant_nonce
, max_age AS oauth2_authorization_grant_max_age
, oauth2_client_id AS oauth2_client_id
, authorization_code AS oauth2_authorization_grant_code
, response_type_code AS oauth2_authorization_grant_response_type_code
, response_type_id_token AS oauth2_authorization_grant_response_type_id_token
, code_challenge AS oauth2_authorization_grant_code_challenge
, code_challenge_method AS oauth2_authorization_grant_code_challenge_method
, requires_consent AS oauth2_authorization_grant_requires_consent
, oauth2_session_id AS "oauth2_session_id?"
FROM
oauth2_authorization_grants
WHERE oauth2_authorization_grant_id = $1 #[tracing::instrument(
"#, name = "db.oauth2_authorization_grant.add",
Uuid::from(id), skip_all,
) fields(
.fetch_one(&mut *conn) db.statement,
.await grant.id,
.to_option()?; grant.scope = %scope,
%client.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, Self::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let Some(res) = res else { return Ok(None) }; let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("grant.id", tracing::field::display(id));
Ok(Some(res.try_into()?)) sqlx::query!(
} r#"
INSERT INTO oauth2_authorization_grants (
#[tracing::instrument(skip_all, err)] oauth2_authorization_grant_id,
pub async fn lookup_grant_by_code( oauth2_client_id,
conn: &mut PgConnection, redirect_uri,
code: &str, scope,
) -> Result<Option<AuthorizationGrant>, DatabaseError> { state,
let res = sqlx::query_as!( nonce,
GrantLookup, max_age,
r#" response_mode,
SELECT oauth2_authorization_grant_id code_challenge,
, created_at AS oauth2_authorization_grant_created_at code_challenge_method,
, cancelled_at AS oauth2_authorization_grant_cancelled_at response_type_code,
, fulfilled_at AS oauth2_authorization_grant_fulfilled_at response_type_id_token,
, exchanged_at AS oauth2_authorization_grant_exchanged_at authorization_code,
, scope AS oauth2_authorization_grant_scope requires_consent,
, state AS oauth2_authorization_grant_state created_at
, redirect_uri AS oauth2_authorization_grant_redirect_uri )
, response_mode AS oauth2_authorization_grant_response_mode VALUES
, nonce AS oauth2_authorization_grant_nonce ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
, max_age AS oauth2_authorization_grant_max_age "#,
, oauth2_client_id AS oauth2_client_id Uuid::from(id),
, authorization_code AS oauth2_authorization_grant_code Uuid::from(client.id),
, response_type_code AS oauth2_authorization_grant_response_type_code redirect_uri.to_string(),
, response_type_id_token AS oauth2_authorization_grant_response_type_id_token scope.to_string(),
, code_challenge AS oauth2_authorization_grant_code_challenge state,
, code_challenge_method AS oauth2_authorization_grant_code_challenge_method nonce,
, requires_consent AS oauth2_authorization_grant_requires_consent max_age_i32,
, oauth2_session_id AS "oauth2_session_id?" response_mode.to_string(),
FROM code_challenge,
oauth2_authorization_grants code_challenge_method,
code.is_some(),
WHERE authorization_code = $1 response_type_id_token,
"#, code_str,
code, requires_consent,
) created_at,
.fetch_one(&mut *conn) )
.await .execute(&mut *self.conn)
.to_option()?; .await?;
let Some(res) = res else { return Ok(None) }; Ok(AuthorizationGrant {
id,
Ok(Some(res.try_into()?)) stage: AuthorizationGrantStage::Pending,
} code,
redirect_uri,
#[tracing::instrument( client_id: client.id,
skip_all, scope,
fields( state,
%grant.id, nonce,
client.id = %grant.client_id, max_age,
%session.id, response_mode,
user_session.id = %session.user_session_id, created_at,
), response_type_id_token,
err, requires_consent,
)] })
pub async fn fulfill_grant( }
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant, #[tracing::instrument(
session: Session, name = "db.oauth2_authorization_grant.lookup",
) -> Result<AuthorizationGrant, DatabaseError> { skip_all,
let fulfilled_at = sqlx::query_scalar!( fields(
r#" db.statement,
UPDATE oauth2_authorization_grants AS og grant.id = %id,
SET ),
oauth2_session_id = os.oauth2_session_id, err,
fulfilled_at = os.created_at )]
FROM oauth2_sessions os async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
WHERE let res = sqlx::query_as!(
og.oauth2_authorization_grant_id = $1 GrantLookup,
AND os.oauth2_session_id = $2 r#"
RETURNING fulfilled_at AS "fulfilled_at!: DateTime<Utc>" SELECT oauth2_authorization_grant_id
"#, , created_at
Uuid::from(grant.id), , cancelled_at
Uuid::from(session.id), , fulfilled_at
) , exchanged_at
.fetch_one(executor) , scope
.await?; , state
, redirect_uri
grant.stage = grant , response_mode
.stage , nonce
.fulfill(fulfilled_at, &session) , max_age
.map_err(DatabaseError::to_invalid_operation)?; , oauth2_client_id
, authorization_code
Ok(grant) , response_type_code
} , response_type_id_token
, code_challenge
#[tracing::instrument( , code_challenge_method
skip_all, , requires_consent
fields( , oauth2_session_id
%grant.id, FROM
client.id = %grant.client_id, oauth2_authorization_grants
),
err, WHERE oauth2_authorization_grant_id = $1
)] "#,
pub async fn give_consent_to_grant( Uuid::from(id),
executor: impl PgExecutor<'_>, )
mut grant: AuthorizationGrant, .fetch_one(&mut *self.conn)
) -> Result<AuthorizationGrant, sqlx::Error> { .await
sqlx::query!( .to_option()?;
r#"
UPDATE oauth2_authorization_grants AS og let Some(res) = res else { return Ok(None) };
SET
requires_consent = 'f' Ok(Some(res.try_into()?))
WHERE }
og.oauth2_authorization_grant_id = $1
"#, #[tracing::instrument(
Uuid::from(grant.id), name = "db.oauth2_authorization_grant.find_by_code",
) skip_all,
.execute(executor) fields(
.await?; db.statement,
),
grant.requires_consent = false; err,
)]
Ok(grant) async fn find_by_code(
} &mut self,
code: &str,
#[tracing::instrument( ) -> Result<Option<AuthorizationGrant>, Self::Error> {
skip_all, let res = sqlx::query_as!(
fields( GrantLookup,
%grant.id, r#"
client.id = %grant.client_id, SELECT oauth2_authorization_grant_id
), , created_at
err, , cancelled_at
)] , fulfilled_at
pub async fn exchange_grant( , exchanged_at
executor: impl PgExecutor<'_>, , scope
clock: &Clock, , state
mut grant: AuthorizationGrant, , redirect_uri
) -> Result<AuthorizationGrant, DatabaseError> { , response_mode
let exchanged_at = clock.now(); , nonce
sqlx::query!( , max_age
r#" , oauth2_client_id
UPDATE oauth2_authorization_grants , authorization_code
SET exchanged_at = $2 , response_type_code
WHERE oauth2_authorization_grant_id = $1 , response_type_id_token
"#, , code_challenge
Uuid::from(grant.id), , code_challenge_method
exchanged_at, , requires_consent
) , oauth2_session_id
.execute(executor) FROM
.await?; oauth2_authorization_grants
grant.stage = grant WHERE authorization_code = $1
.stage "#,
.exchange(exchanged_at) code,
.map_err(DatabaseError::to_invalid_operation)?; )
.traced()
Ok(grant) .fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.fulfill",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
%session.id,
user_session.id = %session.user_session_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
session: &Session,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let fulfilled_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET fulfilled_at = $2
, oauth2_session_id = $3
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
fulfilled_at,
Uuid::from(session.id),
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
// XXX: check affected rows & new methods
let grant = grant
.fulfill(fulfilled_at, session)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.exchange",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let exchanged_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
exchanged_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let grant = grant
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.give_consent",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn give_consent(
&mut self,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
requires_consent = 'f'
WHERE
og.oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
)
.execute(&mut *self.conn)
.await?;
grant.requires_consent = false;
Ok(grant)
}
} }

View File

@ -14,17 +14,21 @@
use std::{ use std::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
str::FromStr,
string::ToString, string::ToString,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use mas_data_model::{Client, JwksOrJwksUri}; use mas_data_model::{Client, JwksOrJwksUri, User};
use mas_iana::{ use mas_iana::{
jose::JsonWebSignatureAlg, jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
}; };
use mas_jose::jwk::PublicJsonWebKeySet; use mas_jose::jwk::PublicJsonWebKeySet;
use oauth2_types::requests::GrantType; use oauth2_types::{
requests::GrantType,
scope::{Scope, ScopeToken},
};
use rand::{Rng, RngCore}; use rand::{Rng, RngCore};
use sqlx::PgConnection; use sqlx::PgConnection;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
@ -87,6 +91,21 @@ pub trait OAuth2ClientRepository: Send + Sync {
jwks_uri: Option<Url>, jwks_uri: Option<Url>,
redirect_uris: Vec<Url>, redirect_uris: Vec<Url>,
) -> Result<Client, Self::Error>; ) -> Result<Client, Self::Error>;
async fn get_consent_for_user(
&mut self,
client: &Client,
user: &User,
) -> Result<Scope, Self::Error>;
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
user: &User,
scope: &Scope,
) -> Result<(), Self::Error>;
} }
pub struct PgOAuth2ClientRepository<'c> { pub struct PgOAuth2ClientRepository<'c> {
@ -702,4 +721,94 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
initiate_login_uri: None, initiate_login_uri: None,
}) })
} }
#[tracing::instrument(
name = "db.oauth2_client.get_consent_for_user",
skip_all,
fields(
db.statement,
%user.id,
%client.id,
),
err,
)]
async fn get_consent_for_user(
&mut self,
client: &Client,
user: &User,
) -> Result<Scope, Self::Error> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
Uuid::from(user.id),
Uuid::from(client.id),
)
.fetch_all(&mut *self.conn)
.await?;
let scope: Result<Scope, _> = scope_tokens
.into_iter()
.map(|s| ScopeToken::from_str(&s))
.collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
skip_all,
fields(
db.statement,
%user.id,
%client.id,
%scope,
),
err,
)]
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
user: &User,
scope: &Scope,
) -> Result<(), Self::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
&ids,
Uuid::from(user.id),
Uuid::from(client.id),
&tokens,
now,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
} }

View File

@ -1,110 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::str::FromStr;
use mas_data_model::{Client, User};
use oauth2_types::scope::{Scope, ScopeToken};
use rand::Rng;
use sqlx::PgExecutor;
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
#[tracing::instrument(
skip_all,
fields(
%user.id,
%client.id,
),
err,
)]
pub async fn fetch_client_consent(
executor: impl PgExecutor<'_>,
user: &User,
client: &Client,
) -> Result<Scope, DatabaseError> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
Uuid::from(user.id),
Uuid::from(client.id),
)
.fetch_all(executor)
.await?;
let scope: Result<Scope, _> = scope_tokens
.into_iter()
.map(|s| ScopeToken::from_str(&s))
.collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
%client.id,
%scope,
),
err,
)]
pub async fn insert_client_consent(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User,
client: &Client,
scope: &Scope,
) -> Result<(), sqlx::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
&ids,
Uuid::from(user.id),
Uuid::from(client.id),
&tokens,
now,
)
.execute(executor)
.await?;
Ok(())
}

View File

@ -12,14 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod access_token; mod access_token;
pub mod authorization_grant; pub mod authorization_grant;
mod client; mod client;
pub mod consent; mod refresh_token;
pub mod refresh_token;
mod session; mod session;
pub use self::{ pub use self::{
access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository},
authorization_grant::{
OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository,
},
client::{OAuth2ClientRepository, PgOAuth2ClientRepository}, client::{OAuth2ClientRepository, PgOAuth2ClientRepository},
refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository},
session::{OAuth2SessionRepository, PgOAuth2SessionRepository}, session::{OAuth2SessionRepository, PgOAuth2SessionRepository},
}; };

View File

@ -12,62 +12,55 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use rand::Rng; use rand::RngCore;
use sqlx::{PgConnection, PgExecutor}; use sqlx::PgConnection;
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, DatabaseError}; use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[tracing::instrument( #[async_trait]
skip_all, pub trait OAuth2RefreshTokenRepository: Send + Sync {
fields( type Error;
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
refresh_token.id,
),
err,
)]
pub async fn add_refresh_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!( /// Lookup a refresh token by its ID
r#" async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error>;
INSERT INTO oauth2_refresh_tokens
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
refresh_token,
created_at,
)
.execute(executor)
.await?;
Ok(RefreshToken { /// Find a refresh token by its token
id, async fn find_by_token(
state: RefreshTokenState::default(), &mut self,
session_id: session.id, refresh_token: &str,
refresh_token, ) -> Result<Option<RefreshToken>, Self::Error>;
access_token_id: Some(access_token.id),
created_at, /// Add a new refresh token to the database
}) async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, Self::Error>;
/// Consume a refresh token
async fn consume(
&mut self,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error>;
}
pub struct PgOAuth2RefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
} }
struct OAuth2RefreshTokenLookup { struct OAuth2RefreshTokenLookup {
@ -79,75 +72,183 @@ struct OAuth2RefreshTokenLookup {
oauth2_session_id: Uuid, oauth2_session_id: Uuid,
} }
#[tracing::instrument(skip_all, err)] impl From<OAuth2RefreshTokenLookup> for RefreshToken {
#[allow(clippy::too_many_lines)] fn from(value: OAuth2RefreshTokenLookup) -> Self {
pub async fn lookup_refresh_token( let state = match value.consumed_at {
conn: &mut PgConnection, None => RefreshTokenState::Valid,
token: &str, Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
) -> Result<Option<RefreshToken>, DatabaseError> { };
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE refresh_token = $1 RefreshToken {
"#, id: value.oauth2_refresh_token_id.into(),
token, state,
) session_id: value.oauth2_session_id.into(),
.fetch_one(&mut *conn) refresh_token: value.refresh_token,
.await?; created_at: value.created_at,
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
let state = match res.consumed_at { }
None => RefreshTokenState::Valid, }
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
let refresh_token = RefreshToken {
id: res.oauth2_refresh_token_id.into(),
state,
session_id: res.oauth2_session_id.into(),
refresh_token: res.refresh_token,
created_at: res.created_at,
access_token_id: res.oauth2_access_token_id.map(Ulid::from),
};
Ok(Some(refresh_token))
} }
#[tracing::instrument( #[async_trait]
skip_all, impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
fields( type Error = DatabaseError;
%refresh_token.id,
),
err,
)]
pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, DatabaseError> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?; #[tracing::instrument(
name = "db.oauth2_refresh_token.lookup",
skip_all,
fields(
db.statement,
refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
refresh_token WHERE oauth2_refresh_token_id = $1
.consume(consumed_at) "#,
.map_err(DatabaseError::to_invalid_operation) Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
refresh_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
refresh_token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
})
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.consume",
skip_all,
fields(
db.statement,
%refresh_token.id,
session.id = %refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
}
} }

View File

@ -19,7 +19,10 @@ use crate::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
PgCompatSsoLoginRepository, PgCompatSsoLoginRepository,
}, },
oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository}, oauth2::{
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
},
upstream_oauth2::{ upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
PgUpstreamOAuthSessionRepository, PgUpstreamOAuthSessionRepository,
@ -63,10 +66,22 @@ pub trait Repository {
where where
Self: 'c; Self: 'c;
type OAuth2AuthorizationGrantRepository<'c>
where
Self: 'c;
type OAuth2SessionRepository<'c> type OAuth2SessionRepository<'c>
where where
Self: 'c; Self: 'c;
type OAuth2AccessTokenRepository<'c>
where
Self: 'c;
type OAuth2RefreshTokenRepository<'c>
where
Self: 'c;
type CompatSessionRepository<'c> type CompatSessionRepository<'c>
where where
Self: 'c; Self: 'c;
@ -91,7 +106,10 @@ pub trait Repository {
fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; fn user_password(&mut self) -> Self::UserPasswordRepository<'_>;
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_>;
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_>;
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_>;
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>; fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>;
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>; fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>;
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
@ -107,7 +125,10 @@ impl Repository for PgConnection {
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
@ -145,10 +166,22 @@ impl Repository for PgConnection {
PgOAuth2ClientRepository::new(self) PgOAuth2ClientRepository::new(self)
} }
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self) PgOAuth2SessionRepository::new(self)
} }
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self) PgCompatSessionRepository::new(self)
} }
@ -175,7 +208,10 @@ impl<'t> Repository for Transaction<'t, Postgres> {
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
@ -213,10 +249,22 @@ impl<'t> Repository for Transaction<'t, Postgres> {
PgOAuth2ClientRepository::new(self) PgOAuth2ClientRepository::new(self)
} }
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self) PgOAuth2SessionRepository::new(self)
} }
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self) PgCompatSessionRepository::new(self)
} }

View File

@ -14,7 +14,7 @@
//! Database-related tasks //! Database-related tasks
use mas_storage::Clock; use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository};
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use tracing::{debug, error, info}; use tracing::{debug, error, info};
@ -32,7 +32,12 @@ impl std::fmt::Debug for CleanupExpired {
#[async_trait::async_trait] #[async_trait::async_trait]
impl Task for CleanupExpired { impl Task for CleanupExpired {
async fn run(&self) { async fn run(&self) {
let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await; let res = async move {
let mut conn = self.0.acquire().await?;
conn.oauth2_access_token().cleanup_expired(&self.1).await
}
.await;
match res { match res {
Ok(0) => { Ok(0) => {
debug!("no token to clean up"); debug!("no token to clean up");