1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

storage: wrap the postgres repository in a struct

This commit is contained in:
Quentin Gliech
2023-01-13 18:03:37 +01:00
parent 488a666a8d
commit 195203823a
44 changed files with 505 additions and 548 deletions

View File

@ -183,7 +183,7 @@ pub(crate) mod tracing;
pub mod upstream_oauth2;
pub mod user;
pub use self::{repository::Repository, upstream_oauth2::UpstreamOAuthLinkRepository};
pub use self::repository::{PgRepository, Repository};
/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@ -32,7 +32,7 @@ use crate::{
};
#[async_trait]
pub trait OAuth2AuthorizationGrantRepository {
pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
type Error;
#[allow(clippy::too_many_arguments)]

View File

@ -27,7 +27,7 @@ use crate::{
};
#[async_trait]
pub trait OAuth2SessionRepository {
pub trait OAuth2SessionRepository: Send + Sync {
type Error;
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;

View File

@ -12,89 +12,100 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::{PgConnection, Postgres, Transaction};
use sqlx::{PgPool, Postgres, Transaction};
use crate::{
compat::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
PgCompatSsoLoginRepository,
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository,
PgCompatSessionRepository, PgCompatSsoLoginRepository,
},
oauth2::{
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository,
PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository,
PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
},
upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
PgUpstreamOAuthSessionRepository,
PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository,
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository,
},
user::{
PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
PgUserRepository,
BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository,
PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository,
UserRepository,
},
DatabaseError,
};
pub trait Repository {
type UpstreamOAuthLinkRepository<'c>
pub trait Repository: Send {
type Error: std::error::Error + Send + Sync + 'static;
type UpstreamOAuthLinkRepository<'c>: UpstreamOAuthLinkRepository<Error = Self::Error> + 'c
where
Self: 'c;
type UpstreamOAuthProviderRepository<'c>
type UpstreamOAuthProviderRepository<'c>: UpstreamOAuthProviderRepository<Error = Self::Error>
+ 'c
where
Self: 'c;
type UpstreamOAuthSessionRepository<'c>
type UpstreamOAuthSessionRepository<'c>: UpstreamOAuthSessionRepository<Error = Self::Error>
+ 'c
where
Self: 'c;
type UserRepository<'c>
type UserRepository<'c>: UserRepository<Error = Self::Error> + 'c
where
Self: 'c;
type UserEmailRepository<'c>
type UserEmailRepository<'c>: UserEmailRepository<Error = Self::Error> + 'c
where
Self: 'c;
type UserPasswordRepository<'c>
type UserPasswordRepository<'c>: UserPasswordRepository<Error = Self::Error> + 'c
where
Self: 'c;
type BrowserSessionRepository<'c>
type BrowserSessionRepository<'c>: BrowserSessionRepository<Error = Self::Error> + 'c
where
Self: 'c;
type OAuth2ClientRepository<'c>
type OAuth2ClientRepository<'c>: OAuth2ClientRepository<Error = Self::Error> + 'c
where
Self: 'c;
type OAuth2AuthorizationGrantRepository<'c>
type OAuth2AuthorizationGrantRepository<'c>: OAuth2AuthorizationGrantRepository<Error = Self::Error>
+ 'c
where
Self: 'c;
type OAuth2SessionRepository<'c>
type OAuth2SessionRepository<'c>: OAuth2SessionRepository<Error = Self::Error> + 'c
where
Self: 'c;
type OAuth2AccessTokenRepository<'c>
type OAuth2AccessTokenRepository<'c>: OAuth2AccessTokenRepository<Error = Self::Error> + 'c
where
Self: 'c;
type OAuth2RefreshTokenRepository<'c>
type OAuth2RefreshTokenRepository<'c>: OAuth2RefreshTokenRepository<Error = Self::Error> + 'c
where
Self: 'c;
type CompatSessionRepository<'c>
type CompatSessionRepository<'c>: CompatSessionRepository<Error = Self::Error> + 'c
where
Self: 'c;
type CompatSsoLoginRepository<'c>
type CompatSsoLoginRepository<'c>: CompatSsoLoginRepository<Error = Self::Error> + 'c
where
Self: 'c;
type CompatAccessTokenRepository<'c>
type CompatAccessTokenRepository<'c>: CompatAccessTokenRepository<Error = Self::Error> + 'c
where
Self: 'c;
type CompatRefreshTokenRepository<'c>
type CompatRefreshTokenRepository<'c>: CompatRefreshTokenRepository<Error = Self::Error> + 'c
where
Self: 'c;
@ -116,7 +127,30 @@ pub trait Repository {
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>;
}
impl Repository for PgConnection {
pub struct PgRepository {
txn: Transaction<'static, Postgres>,
}
impl PgRepository {
pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
let txn = pool.begin().await?;
Ok(PgRepository { txn })
}
pub async fn save(self) -> Result<(), DatabaseError> {
self.txn.commit().await?;
Ok(())
}
pub async fn cancel(self) -> Result<(), DatabaseError> {
self.txn.rollback().await?;
Ok(())
}
}
impl Repository for PgRepository {
type Error = DatabaseError;
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
@ -135,149 +169,66 @@ impl Repository for PgConnection {
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self)
PgUpstreamOAuthLinkRepository::new(&mut self.txn)
}
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self)
PgUpstreamOAuthProviderRepository::new(&mut self.txn)
}
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(self)
PgUpstreamOAuthSessionRepository::new(&mut self.txn)
}
fn user(&mut self) -> Self::UserRepository<'_> {
PgUserRepository::new(self)
PgUserRepository::new(&mut self.txn)
}
fn user_email(&mut self) -> Self::UserEmailRepository<'_> {
PgUserEmailRepository::new(self)
PgUserEmailRepository::new(&mut self.txn)
}
fn user_password(&mut self) -> Self::UserPasswordRepository<'_> {
PgUserPasswordRepository::new(self)
PgUserPasswordRepository::new(&mut self.txn)
}
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
PgBrowserSessionRepository::new(self)
PgBrowserSessionRepository::new(&mut self.txn)
}
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
PgOAuth2ClientRepository::new(self)
PgOAuth2ClientRepository::new(&mut self.txn)
}
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self)
PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
PgOAuth2SessionRepository::new(&mut self.txn)
}
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self)
PgOAuth2AccessTokenRepository::new(&mut self.txn)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self)
PgOAuth2RefreshTokenRepository::new(&mut self.txn)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
PgCompatSessionRepository::new(&mut self.txn)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self)
PgCompatSsoLoginRepository::new(&mut self.txn)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self)
PgCompatAccessTokenRepository::new(&mut self.txn)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self)
}
}
impl<'t> Repository for Transaction<'t, Postgres> {
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
type UserRepository<'c> = PgUserRepository<'c> where Self: 'c;
type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c;
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self)
}
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self)
}
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(self)
}
fn user(&mut self) -> Self::UserRepository<'_> {
PgUserRepository::new(self)
}
fn user_email(&mut self) -> Self::UserEmailRepository<'_> {
PgUserEmailRepository::new(self)
}
fn user_password(&mut self) -> Self::UserPasswordRepository<'_> {
PgUserPasswordRepository::new(self)
}
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
PgBrowserSessionRepository::new(self)
}
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
PgOAuth2ClientRepository::new(self)
}
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self)
PgCompatRefreshTokenRepository::new(&mut self.txn)
}
}

View File

@ -29,20 +29,20 @@ mod tests {
use sqlx::PgPool;
use super::*;
use crate::{Clock, Repository};
use crate::{Clock, PgRepository, Repository};
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::default();
let mut conn = pool.acquire().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
// The provider list should be empty at the start
let all_providers = conn.upstream_oauth_provider().all().await?;
let all_providers = repo.upstream_oauth_provider().all().await?;
assert!(all_providers.is_empty());
// Let's add a provider
let provider = conn
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
@ -57,7 +57,7 @@ mod tests {
.await?;
// Look it up in the database
let provider = conn
let provider = repo
.upstream_oauth_provider()
.lookup(provider.id)
.await?
@ -66,7 +66,7 @@ mod tests {
assert_eq!(provider.client_id, "client-id");
// Start a session
let session = conn
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
@ -79,7 +79,7 @@ mod tests {
.await?;
// Look it up in the database
let session = conn
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await?
@ -91,19 +91,19 @@ mod tests {
assert!(!session.is_consumed());
// Create a link
let link = conn
let link = repo
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.await?;
// We can look it up by its ID
conn.upstream_oauth_link()
repo.upstream_oauth_link()
.lookup(link.id)
.await?
.expect("link to be found in database");
// or by its subject
let link = conn
let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "a-subject")
.await?
@ -111,7 +111,7 @@ mod tests {
assert_eq!(link.subject, "a-subject");
assert_eq!(link.provider_id, provider.id);
let session = conn
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.await?;
@ -119,7 +119,7 @@ mod tests {
assert!(!session.is_consumed());
assert_eq!(session.link_id(), Some(link.id));
let session = conn
let session = repo
.upstream_oauth_session()
.consume(&clock, session)
.await?;