1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Split the storage trait from the implementation

This commit is contained in:
Quentin Gliech
2023-01-18 09:53:42 +01:00
parent b33a330b5f
commit 73a921cc30
95 changed files with 6294 additions and 5741 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,18 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait UpstreamOAuthLinkRepository: Send + Sync {
@ -63,241 +56,3 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
}
pub struct PgUpstreamOAuthLinkRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthLinkRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct LinkLookup {
upstream_oauth_link_id: Uuid,
upstream_oauth_provider_id: Uuid,
user_id: Option<Uuid>,
subject: String,
created_at: DateTime<Utc>,
}
impl From<LinkLookup> for UpstreamOAuthLink {
fn from(value: LinkLookup) -> Self {
UpstreamOAuthLink {
id: Ulid::from(value.upstream_oauth_link_id),
provider_id: Ulid::from(value.upstream_oauth_provider_id),
user_id: value.user_id.map(Ulid::from),
subject: value.subject,
created_at: value.created_at,
}
}
}
#[async_trait]
impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_link.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_link.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_link_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.find_by_subject",
skip_all,
fields(
db.statement,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn find_by_subject(
&mut self,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: &str,
) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
AND subject = $2
"#,
Uuid::from(upstream_oauth_provider.id),
subject,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.add",
skip_all,
fields(
db.statement,
upstream_oauth_link.id,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
) -> Result<UpstreamOAuthLink, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_links (
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
) VALUES ($1, $2, NULL, $3, $4)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&subject,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthLink {
id,
provider_id: upstream_oauth_provider.id,
user_id: None,
subject,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_link.associate_to_user",
skip_all,
fields(
db.statement,
%upstream_oauth_link.id,
%upstream_oauth_link.subject,
%user.id,
%user.username,
),
err,
)]
async fn associate_to_user(
&mut self,
upstream_oauth_link: &UpstreamOAuthLink,
user: &User,
) -> Result<(), Self::Error> {
sqlx::query!(
r#"
UPDATE upstream_oauth_links
SET user_id = $1
WHERE upstream_oauth_link_id = $2
"#,
Uuid::from(user.id),
Uuid::from(upstream_oauth_link.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.upstream_oauth_link.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("upstream_oauth_link_id", pagination);
let edges: Vec<LinkLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
Ok(page)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -17,249 +17,6 @@ mod provider;
mod session;
pub use self::{
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository},
link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository,
session::UpstreamOAuthSessionRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use oauth2_types::scope::{Scope, OPENID};
use rand::SeedableRng;
use sqlx::PgPool;
use super::*;
use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository};
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repository(pool: PgPool) {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// The provider list should be empty at the start
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
assert!(all_providers.is_empty());
// Let's add a provider
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
"https://example.com/".to_owned(),
Scope::from_iter([OPENID]),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
"client-id".to_owned(),
None,
)
.await
.unwrap();
// Look it up in the database
let provider = repo
.upstream_oauth_provider()
.lookup(provider.id)
.await
.unwrap()
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.client_id, "client-id");
// Start a session
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
&clock,
&provider,
"some-state".to_owned(),
None,
"some-nonce".to_owned(),
)
.await
.unwrap();
// Look it up in the database
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert_eq!(session.provider_id, provider.id);
assert_eq!(session.link_id(), None);
assert!(session.is_pending());
assert!(!session.is_completed());
assert!(!session.is_consumed());
// Create a link
let link = repo
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.await
.unwrap();
// We can look it up by its ID
repo.upstream_oauth_link()
.lookup(link.id)
.await
.unwrap()
.expect("link to be found in database");
// or by its subject
let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "a-subject")
.await
.unwrap()
.expect("link to be found in database");
assert_eq!(link.subject, "a-subject");
assert_eq!(link.provider_id, provider.id);
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.await
.unwrap();
// Reload the session
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert!(session.is_completed());
assert!(!session.is_consumed());
assert_eq!(session.link_id(), Some(link.id));
let session = repo
.upstream_oauth_session()
.consume(&clock, session)
.await
.unwrap();
// Reload the session
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert!(session.is_consumed());
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await
.unwrap();
let links = repo
.upstream_oauth_link()
.list_paginated(&user, Pagination::first(10))
.await
.unwrap();
assert!(!links.has_previous_page);
assert!(!links.has_next_page);
assert_eq!(links.edges.len(), 1);
assert_eq!(links.edges[0].id, link.id);
assert_eq!(links.edges[0].user_id, Some(user.id));
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_provider_repository_pagination(pool: PgPool) {
const ISSUER: &str = "https://example.com/";
let scope = Scope::from_iter([OPENID]);
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut ids = Vec::with_capacity(20);
// Create 20 providers
for idx in 0..20 {
let client_id = format!("client-{idx}");
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
ISSUER.to_owned(),
scope.clone(),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
client_id,
None,
)
.await
.unwrap();
ids.push(provider.id);
clock.advance(Duration::seconds(10));
}
// Lookup the first 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10))
.await
.unwrap();
// It returned the first 10 items
assert!(page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[..10]);
// Lookup the next 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10).after(ids[9]))
.await
.unwrap();
// It returned the next 10 items
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[10..]);
// Lookup the last 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::last(10))
.await
.unwrap();
// It returned the last 10 items
assert!(page.has_previous_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[10..]);
// Lookup the previous 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::last(10).before(ids[10]))
.await
.unwrap();
// It returned the previous 10 items
assert!(!page.has_previous_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[..10]);
// Lookup 10 items between two IDs
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10).after(ids[5]).before(ids[8]))
.await
.unwrap();
// It returned the items in between
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[6..8]);
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,20 +13,13 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::UpstreamOAuthProvider;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait UpstreamOAuthProviderRepository: Send + Sync {
@ -58,247 +51,3 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
/// Get all upstream OAuth providers
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
}
pub struct PgUpstreamOAuthProviderRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthProviderRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct ProviderLookup {
upstream_oauth_provider_id: Uuid,
issuer: String,
scope: String,
client_id: String,
encrypted_client_secret: Option<String>,
token_endpoint_signing_alg: Option<String>,
token_endpoint_auth_method: String,
created_at: DateTime<Utc>,
}
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
type Error = DatabaseInconsistencyError;
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_provider_id.into();
let scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("scope")
.row(id)
.source(e)
})?;
let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
})?;
let token_endpoint_signing_alg = value
.token_endpoint_signing_alg
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_signing_alg")
.row(id)
.source(e)
})?;
Ok(UpstreamOAuthProvider {
id,
issuer: value.issuer,
scope,
client_id: value.client_id,
encrypted_client_secret: value.encrypted_client_secret,
token_endpoint_auth_method,
token_endpoint_signing_alg,
created_at: value.created_at,
})
}
}
#[async_trait]
impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_provider.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let res = res
.map(UpstreamOAuthProvider::try_from)
.transpose()
.map_err(DatabaseError::from)?;
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.add",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_providers (
upstream_oauth_provider_id,
issuer,
scope,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id,
encrypted_client_secret,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#,
Uuid::from(id),
&issuer,
scope.to_string(),
token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
&client_id,
encrypted_client_secret.as_deref(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.list_paginated",
skip_all,
fields(
db.statement,
),
err,
)]
async fn list_paginated(
&mut self,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE 1 = 1
"#,
);
query.generate_pagination("upstream_oauth_provider_id", pagination);
let edges: Vec<ProviderLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).try_map(TryInto::try_into)?;
Ok(page)
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.all",
skip_all,
fields(
db.statement,
),
err,
)]
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
"#,
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
Ok(res?)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,19 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
UpstreamOAuthProvider,
};
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
use crate::Clock;
#[async_trait]
pub trait UpstreamOAuthSessionRepository: Send + Sync {
@ -64,262 +56,3 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
}
pub struct PgUpstreamOAuthSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct SessionLookup {
upstream_oauth_authorization_session_id: Uuid,
upstream_oauth_provider_id: Uuid,
upstream_oauth_link_id: Option<Uuid>,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
id_token: Option<String>,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
consumed_at: Option<DateTime<Utc>>,
}
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_authorization_session_id.into();
let state = match (
value.upstream_oauth_link_id,
value.id_token,
value.completed_at,
value.consumed_at,
) {
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(Some(link_id), id_token, Some(completed_at), None) => {
UpstreamOAuthAuthorizationSessionState::Completed {
completed_at,
link_id: link_id.into(),
id_token,
}
}
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
UpstreamOAuthAuthorizationSessionState::Consumed {
completed_at,
link_id: link_id.into(),
id_token,
consumed_at,
}
}
_ => {
return Err(
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
)
}
};
Ok(Self {
id,
provider_id: value.upstream_oauth_provider_id.into(),
state_str: value.state,
nonce: value.nonce,
code_challenge_verifier: value.code_challenge_verifier,
created_at: value.created_at,
state,
})
}
}
#[async_trait]
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
),
err,
)]
async fn lookup(
&mut self,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
id_token,
created_at,
completed_at,
consumed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.add",
skip_all,
fields(
db.statement,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state_str: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"upstream_oauth_authorization_session.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO upstream_oauth_authorization_sessions (
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state_str,
code_challenge_verifier.as_deref(),
nonce,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthAuthorizationSession {
id,
state: UpstreamOAuthAuthorizationSessionState::default(),
provider_id: upstream_oauth_provider.id,
state_str,
code_challenge_verifier,
nonce,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.complete_with_link",
skip_all,
fields(
db.statement,
%upstream_oauth_authorization_session.id,
%upstream_oauth_link.id,
),
err,
)]
async fn complete_with_link(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.complete(completed_at, upstream_oauth_link, id_token)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}
/// Mark a session as consumed
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.consume",
skip_all,
fields(
db.statement,
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET consumed_at = $1
WHERE upstream_oauth_authorization_session_id = $2
"#,
consumed_at,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}
}