You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Split the storage trait from the implementation
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,18 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UpstreamOAuthLinkRepository: Send + Sync {
|
||||
@ -63,241 +56,3 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUpstreamOAuthLinkRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUpstreamOAuthLinkRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct LinkLookup {
|
||||
upstream_oauth_link_id: Uuid,
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
user_id: Option<Uuid>,
|
||||
subject: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<LinkLookup> for UpstreamOAuthLink {
|
||||
fn from(value: LinkLookup) -> Self {
|
||||
UpstreamOAuthLink {
|
||||
id: Ulid::from(value.upstream_oauth_link_id),
|
||||
provider_id: Ulid::from(value.upstream_oauth_provider_id),
|
||||
user_id: value.user_id.map(Ulid::from),
|
||||
subject: value.subject,
|
||||
created_at: value.created_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_link.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
WHERE upstream_oauth_link_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?
|
||||
.map(Into::into);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.find_by_subject",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_link.subject = subject,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_subject(
|
||||
&mut self,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: &str,
|
||||
) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
WHERE upstream_oauth_provider_id = $1
|
||||
AND subject = $2
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
subject,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?
|
||||
.map(Into::into);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_link.id,
|
||||
upstream_oauth_link.subject = subject,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: String,
|
||||
) -> Result<UpstreamOAuthLink, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO upstream_oauth_links (
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
) VALUES ($1, $2, NULL, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
&subject,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(UpstreamOAuthLink {
|
||||
id,
|
||||
provider_id: upstream_oauth_provider.id,
|
||||
user_id: None,
|
||||
subject,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.associate_to_user",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_link.id,
|
||||
%upstream_oauth_link.subject,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn associate_to_user(
|
||||
&mut self,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
user: &User,
|
||||
) -> Result<(), Self::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_links
|
||||
SET user_id = $1
|
||||
WHERE upstream_oauth_link_id = $2
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("upstream_oauth_link_id", pagination);
|
||||
|
||||
let edges: Vec<LinkLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
|
||||
Ok(page)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -17,249 +17,6 @@ mod provider;
|
||||
mod session;
|
||||
|
||||
pub use self::{
|
||||
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
|
||||
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
|
||||
session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository},
|
||||
link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository,
|
||||
session::UpstreamOAuthSessionRepository,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::Duration;
|
||||
use oauth2_types::scope::{Scope, OPENID};
|
||||
use rand::SeedableRng;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::*;
|
||||
use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository};
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_repository(pool: PgPool) {
|
||||
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// The provider list should be empty at the start
|
||||
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
|
||||
assert!(all_providers.is_empty());
|
||||
|
||||
// Let's add a provider
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
"https://example.com/".to_owned(),
|
||||
Scope::from_iter([OPENID]),
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
"client-id".to_owned(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Look it up in the database
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.lookup(provider.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("provider to be found in the database");
|
||||
assert_eq!(provider.issuer, "https://example.com/");
|
||||
assert_eq!(provider.client_id, "client-id");
|
||||
|
||||
// Start a session
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
"some-state".to_owned(),
|
||||
None,
|
||||
"some-nonce".to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Look it up in the database
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("session to be found in the database");
|
||||
assert_eq!(session.provider_id, provider.id);
|
||||
assert_eq!(session.link_id(), None);
|
||||
assert!(session.is_pending());
|
||||
assert!(!session.is_completed());
|
||||
assert!(!session.is_consumed());
|
||||
|
||||
// Create a link
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We can look it up by its ID
|
||||
repo.upstream_oauth_link()
|
||||
.lookup(link.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("link to be found in database");
|
||||
|
||||
// or by its subject
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.find_by_subject(&provider, "a-subject")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("link to be found in database");
|
||||
assert_eq!(link.subject, "a-subject");
|
||||
assert_eq!(link.provider_id, provider.id);
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Reload the session
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("session to be found in the database");
|
||||
assert!(session.is_completed());
|
||||
assert!(!session.is_consumed());
|
||||
assert_eq!(session.link_id(), Some(link.id));
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.consume(&clock, session)
|
||||
.await
|
||||
.unwrap();
|
||||
// Reload the session
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("session to be found in the database");
|
||||
assert!(session.is_consumed());
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&link, &user)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let links = repo
|
||||
.upstream_oauth_link()
|
||||
.list_paginated(&user, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!links.has_previous_page);
|
||||
assert!(!links.has_next_page);
|
||||
assert_eq!(links.edges.len(), 1);
|
||||
assert_eq!(links.edges[0].id, link.id);
|
||||
assert_eq!(links.edges[0].user_id, Some(user.id));
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_provider_repository_pagination(pool: PgPool) {
|
||||
const ISSUER: &str = "https://example.com/";
|
||||
let scope = Scope::from_iter([OPENID]);
|
||||
|
||||
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
let mut ids = Vec::with_capacity(20);
|
||||
// Create 20 providers
|
||||
for idx in 0..20 {
|
||||
let client_id = format!("client-{idx}");
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
ISSUER.to_owned(),
|
||||
scope.clone(),
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
client_id,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
ids.push(provider.id);
|
||||
clock.advance(Duration::seconds(10));
|
||||
}
|
||||
|
||||
// Lookup the first 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the first 10 items
|
||||
assert!(page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[..10]);
|
||||
|
||||
// Lookup the next 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::first(10).after(ids[9]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the next 10 items
|
||||
assert!(!page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[10..]);
|
||||
|
||||
// Lookup the last 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::last(10))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the last 10 items
|
||||
assert!(page.has_previous_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[10..]);
|
||||
|
||||
// Lookup the previous 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::last(10).before(ids[10]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the previous 10 items
|
||||
assert!(!page.has_previous_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[..10]);
|
||||
|
||||
// Lookup 10 items between two IDs
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::first(10).after(ids[5]).before(ids[8]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the items in between
|
||||
assert!(!page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[6..8]);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,20 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::UpstreamOAuthProvider;
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
@ -58,247 +51,3 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
/// Get all upstream OAuth providers
|
||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUpstreamOAuthProviderRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUpstreamOAuthProviderRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct ProviderLookup {
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
issuer: String,
|
||||
scope: String,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
token_endpoint_signing_alg: Option<String>,
|
||||
token_endpoint_auth_method: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.upstream_oauth_provider_id.into();
|
||||
let scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_auth_method")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
let token_endpoint_signing_alg = value
|
||||
.token_endpoint_signing_alg
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_signing_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
issuer: value.issuer,
|
||||
scope,
|
||||
client_id: value.client_id,
|
||||
encrypted_client_secret: value.encrypted_client_secret,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
created_at: value.created_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
ProviderLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at
|
||||
FROM upstream_oauth_providers
|
||||
WHERE upstream_oauth_provider_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let res = res
|
||||
.map(UpstreamOAuthProvider::try_from)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::from)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id,
|
||||
upstream_oauth_provider.issuer = %issuer,
|
||||
upstream_oauth_provider.client_id = %client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO upstream_oauth_providers (
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&issuer,
|
||||
scope.to_string(),
|
||||
token_endpoint_auth_method.to_string(),
|
||||
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
|
||||
&client_id,
|
||||
encrypted_client_secret.as_deref(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at
|
||||
FROM upstream_oauth_providers
|
||||
WHERE 1 = 1
|
||||
"#,
|
||||
);
|
||||
|
||||
query.generate_pagination("upstream_oauth_provider_id", pagination);
|
||||
|
||||
let edges: Vec<ProviderLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).try_map(TryInto::try_into)?;
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.all",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
ProviderLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at
|
||||
FROM upstream_oauth_providers
|
||||
"#,
|
||||
)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
|
||||
Ok(res?)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,19 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
|
||||
UpstreamOAuthProvider,
|
||||
};
|
||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
@ -64,262 +56,3 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUpstreamOAuthSessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUpstreamOAuthSessionRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct SessionLookup {
|
||||
upstream_oauth_authorization_session_id: Uuid,
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
upstream_oauth_link_id: Option<Uuid>,
|
||||
state: String,
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
id_token: Option<String>,
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.upstream_oauth_authorization_session_id.into();
|
||||
let state = match (
|
||||
value.upstream_oauth_link_id,
|
||||
value.id_token,
|
||||
value.completed_at,
|
||||
value.consumed_at,
|
||||
) {
|
||||
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
|
||||
(Some(link_id), id_token, Some(completed_at), None) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Completed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
}
|
||||
}
|
||||
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Consumed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
consumed_at,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
provider_id: value.upstream_oauth_provider_id.into(),
|
||||
state_str: value.state,
|
||||
nonce: value.nonce,
|
||||
code_challenge_verifier: value.code_challenge_verifier,
|
||||
created_at: value.created_at,
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(
|
||||
&mut self,
|
||||
id: Ulid,
|
||||
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_authorization_session_id,
|
||||
upstream_oauth_provider_id,
|
||||
upstream_oauth_link_id,
|
||||
state,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
id_token,
|
||||
created_at,
|
||||
completed_at,
|
||||
consumed_at
|
||||
FROM upstream_oauth_authorization_sessions
|
||||
WHERE upstream_oauth_authorization_session_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
upstream_oauth_authorization_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
state_str: String,
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record(
|
||||
"upstream_oauth_authorization_session.id",
|
||||
tracing::field::display(id),
|
||||
);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO upstream_oauth_authorization_sessions (
|
||||
upstream_oauth_authorization_session_id,
|
||||
upstream_oauth_provider_id,
|
||||
state,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
created_at,
|
||||
completed_at,
|
||||
consumed_at,
|
||||
id_token
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
&state_str,
|
||||
code_challenge_verifier.as_deref(),
|
||||
nonce,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(UpstreamOAuthAuthorizationSession {
|
||||
id,
|
||||
state: UpstreamOAuthAuthorizationSessionState::default(),
|
||||
provider_id: upstream_oauth_provider.id,
|
||||
state_str,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.complete_with_link",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_authorization_session.id,
|
||||
%upstream_oauth_link.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn complete_with_link(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let completed_at = clock.now();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_authorization_sessions
|
||||
SET upstream_oauth_link_id = $1,
|
||||
completed_at = $2,
|
||||
id_token = $3
|
||||
WHERE upstream_oauth_authorization_session_id = $4
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
completed_at,
|
||||
id_token,
|
||||
Uuid::from(upstream_oauth_authorization_session.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
|
||||
.complete(completed_at, upstream_oauth_link, id_token)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
|
||||
/// Mark a session as consumed
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.consume",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_authorization_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_authorization_sessions
|
||||
SET consumed_at = $1
|
||||
WHERE upstream_oauth_authorization_session_id = $2
|
||||
"#,
|
||||
consumed_at,
|
||||
Uuid::from(upstream_oauth_authorization_session.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user