1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: wrap the postgres repository in a struct

This commit is contained in:
Quentin Gliech
2023-01-13 18:03:37 +01:00
parent 488a666a8d
commit 195203823a
44 changed files with 505 additions and 548 deletions

1
Cargo.lock generated
View File

@ -2673,7 +2673,6 @@ dependencies = [
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"serde_with", "serde_with",
"sqlx",
"thiserror", "thiserror",
"tokio", "tokio",
"tower", "tower",

View File

@ -21,7 +21,6 @@ serde = "1.0.152"
serde_with = "2.1.0" serde_with = "2.1.0"
serde_urlencoded = "0.7.1" serde_urlencoded = "0.7.1"
serde_json = "1.0.91" serde_json = "1.0.91"
sqlx = "0.6.2"
thiserror = "1.0.38" thiserror = "1.0.38"
tokio = "1.23.0" tokio = "1.23.0"
tower = { version = "0.4.13", features = ["util"] } tower = { version = "0.4.13", features = ["util"] }

View File

@ -31,10 +31,9 @@ use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{oauth2::OAuth2ClientRepository, DatabaseError, Repository}; use mas_storage::{oauth2::OAuth2ClientRepository, Repository};
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use serde_json::Value; use serde_json::Value;
use sqlx::PgConnection;
use thiserror::Error; use thiserror::Error;
use tower::{Service, ServiceExt}; use tower::{Service, ServiceExt};
@ -73,7 +72,10 @@ pub enum Credentials {
} }
impl Credentials { impl Credentials {
pub async fn fetch(&self, conn: &mut PgConnection) -> Result<Option<Client>, DatabaseError> { pub async fn fetch<'r, R>(&self, repo: &'r mut R) -> Result<Option<Client>, R::Error>
where
R: Repository,
{
let client_id = match self { let client_id = match self {
Credentials::None { client_id } Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. } | Credentials::ClientSecretBasic { client_id, .. }
@ -81,7 +83,7 @@ impl Credentials {
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
}; };
conn.oauth2_client().find_by_client_id(client_id).await repo.oauth2_client().find_by_client_id(client_id).await
} }
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]

View File

@ -14,9 +14,8 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use mas_data_model::BrowserSession; use mas_data_model::BrowserSession;
use mas_storage::{user::BrowserSessionRepository, DatabaseError, Repository}; use mas_storage::{user::BrowserSessionRepository, Repository};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgConnection;
use ulid::Ulid; use ulid::Ulid;
use crate::CookieExt; use crate::CookieExt;
@ -44,17 +43,17 @@ impl SessionInfo {
} }
/// Load the [`BrowserSession`] from database /// Load the [`BrowserSession`] from database
pub async fn load_session( pub async fn load_session<R: Repository>(
&self, &self,
conn: &mut PgConnection, repo: &mut R,
) -> Result<Option<BrowserSession>, DatabaseError> { ) -> Result<Option<BrowserSession>, R::Error> {
let session_id = if let Some(id) = self.current { let session_id = if let Some(id) = self.current {
id id
} else { } else {
return Ok(None); return Ok(None);
}; };
let maybe_session = conn let maybe_session = repo
.browser_session() .browser_session()
.lookup(session_id) .lookup(session_id)
.await? .await?

View File

@ -30,10 +30,9 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode
use mas_data_model::Session; use mas_data_model::Session;
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
DatabaseError, Repository, Repository,
}; };
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use sqlx::PgConnection;
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -53,22 +52,23 @@ enum AccessToken {
} }
impl AccessToken { impl AccessToken {
async fn fetch( async fn fetch<R: Repository>(
&self, &self,
conn: &mut PgConnection, repo: &mut R,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<R::Error>>
{
let token = match self { let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::Form(t) | AccessToken::Header(t) => t,
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
}; };
let token = conn let token = repo
.oauth2_access_token() .oauth2_access_token()
.find_by_token(token.as_str()) .find_by_token(token.as_str())
.await? .await?
.ok_or(AuthorizationVerificationError::InvalidToken)?; .ok_or(AuthorizationVerificationError::InvalidToken)?;
let session = conn let session = repo
.oauth2_session() .oauth2_session()
.lookup(token.session_id) .lookup(token.session_id)
.await? .await?
@ -86,17 +86,17 @@ pub struct UserAuthorization<F = ()> {
impl<F: Send> UserAuthorization<F> { impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter // TODO: take scopes to validate as parameter
pub async fn protected_form( pub async fn protected_form<R: Repository>(
self, self,
conn: &mut PgConnection, repo: &mut R,
now: DateTime<Utc>, now: DateTime<Utc>,
) -> Result<(Session, F), AuthorizationVerificationError> { ) -> Result<(Session, F), AuthorizationVerificationError<R::Error>> {
let form = match self.form { let form = match self.form {
Some(f) => f, Some(f) => f,
None => return Err(AuthorizationVerificationError::MissingForm), None => return Err(AuthorizationVerificationError::MissingForm),
}; };
let (token, session) = self.access_token.fetch(conn).await?; let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(now) || !session.is_valid() { if !token.is_valid(now) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken); return Err(AuthorizationVerificationError::InvalidToken);
@ -106,12 +106,12 @@ impl<F: Send> UserAuthorization<F> {
} }
// TODO: take scopes to validate as parameter // TODO: take scopes to validate as parameter
pub async fn protected( pub async fn protected<R: Repository>(
self, self,
conn: &mut PgConnection, repo: &mut R,
now: DateTime<Utc>, now: DateTime<Utc>,
) -> Result<Session, AuthorizationVerificationError> { ) -> Result<Session, AuthorizationVerificationError<R::Error>> {
let (token, session) = self.access_token.fetch(conn).await?; let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(now) || !session.is_valid() { if !token.is_valid(now) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken); return Err(AuthorizationVerificationError::InvalidToken);
@ -129,7 +129,7 @@ pub enum UserAuthorizationError {
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum AuthorizationVerificationError { pub enum AuthorizationVerificationError<E> {
#[error("missing token")] #[error("missing token")]
MissingToken, MissingToken,
@ -140,7 +140,7 @@ pub enum AuthorizationVerificationError {
MissingForm, MissingForm,
#[error(transparent)] #[error(transparent)]
Internal(#[from] DatabaseError), Internal(#[from] E),
} }
enum BearerError { enum BearerError {
@ -248,7 +248,10 @@ impl IntoResponse for UserAuthorizationError {
} }
} }
impl IntoResponse for AuthorizationVerificationError { impl<E> IntoResponse for AuthorizationVerificationError<E>
where
E: ToString,
{
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self { match self {
Self::MissingForm | Self::MissingToken => { Self::MissingForm | Self::MissingToken => {

View File

@ -21,7 +21,7 @@ use mas_storage::{
oauth2::OAuth2ClientRepository, oauth2::OAuth2ClientRepository,
upstream_oauth2::UpstreamOAuthProviderRepository, upstream_oauth2::UpstreamOAuthProviderRepository,
user::{UserEmailRepository, UserPasswordRepository, UserRepository}, user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Clock, Repository, Clock, PgRepository, Repository,
}; };
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use rand::SeedableRng; use rand::SeedableRng;
@ -202,8 +202,8 @@ impl Options {
let pool = database_from_config(&database_config).await?; let pool = database_from_config(&database_config).await?;
let password_manager = password_manager_from_config(&passwords_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?;
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let user = txn let user = repo
.user() .user()
.find_by_username(username) .find_by_username(username)
.await? .await?
@ -213,12 +213,12 @@ impl Options {
let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
txn.user_password() repo.user_password()
.add(&mut rng, &clock, &user, version, hashed_password, None) .add(&mut rng, &clock, &user, version, hashed_password, None)
.await?; .await?;
info!(%user.id, %user.username, "Password changed"); info!(%user.id, %user.username, "Password changed");
txn.commit().await?; repo.save().await?;
Ok(()) Ok(())
} }
@ -233,22 +233,22 @@ impl Options {
let config: DatabaseConfig = root.load_config()?; let config: DatabaseConfig = root.load_config()?;
let pool = database_from_config(&config).await?; let pool = database_from_config(&config).await?;
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let user = txn let user = repo
.user() .user()
.find_by_username(username) .find_by_username(username)
.await? .await?
.context("User not found")?; .context("User not found")?;
let email = txn let email = repo
.user_email() .user_email()
.find(&user, email) .find(&user, email)
.await? .await?
.context("Email not found")?; .context("Email not found")?;
let email = txn.user_email().mark_as_verified(&clock, email).await?; let email = repo.user_email().mark_as_verified(&clock, email).await?;
txn.commit().await?; repo.save().await?;
info!(?email, "Email marked as verified"); info!(?email, "Email marked as verified");
Ok(()) Ok(())
@ -261,12 +261,12 @@ impl Options {
let pool = database_from_config(&config.database).await?; let pool = database_from_config(&config.database).await?;
let encrypter = config.secrets.encrypter(); let encrypter = config.secrets.encrypter();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
for client in config.clients.iter() { for client in config.clients.iter() {
let client_id = client.client_id; let client_id = client.client_id;
let existing = txn.oauth2_client().lookup(client_id).await?.is_some(); let existing = repo.oauth2_client().lookup(client_id).await?.is_some();
if !update && existing { if !update && existing {
warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients.");
continue; continue;
@ -288,7 +288,7 @@ impl Options {
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
.transpose()?; .transpose()?;
txn.oauth2_client() repo.oauth2_client()
.add_from_config( .add_from_config(
&mut rng, &mut rng,
&clock, &clock,
@ -302,7 +302,7 @@ impl Options {
.await?; .await?;
} }
txn.commit().await?; repo.save().await?;
Ok(()) Ok(())
} }
@ -326,7 +326,7 @@ impl Options {
let encrypter = config.secrets.encrypter(); let encrypter = config.secrets.encrypter();
let pool = database_from_config(&config.database).await?; let pool = database_from_config(&config.database).await?;
let url_builder = UrlBuilder::new(config.http.public_base); let url_builder = UrlBuilder::new(config.http.public_base);
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let requires_client_secret = token_endpoint_auth_method.requires_client_secret(); let requires_client_secret = token_endpoint_auth_method.requires_client_secret();
@ -347,7 +347,7 @@ impl Options {
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
.transpose()?; .transpose()?;
let provider = conn let provider = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.add( .add(
&mut rng, &mut rng,

View File

@ -32,9 +32,9 @@ use async_graphql::{
}; };
use mas_storage::{ use mas_storage::{
oauth2::OAuth2ClientRepository, oauth2::OAuth2ClientRepository,
upstream_oauth2::UpstreamOAuthProviderRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
Repository, UpstreamOAuthLinkRepository, PgRepository, Repository,
}; };
use model::CreationEvent; use model::CreationEvent;
use sqlx::PgPool; use sqlx::PgPool;
@ -93,10 +93,9 @@ impl RootQuery {
id: ID, id: ID,
) -> Result<Option<OAuth2Client>, async_graphql::Error> { ) -> Result<Option<OAuth2Client>, async_graphql::Error> {
let id = NodeType::OAuth2Client.extract_ulid(&id)?; let id = NodeType::OAuth2Client.extract_ulid(&id)?;
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?;
let client = conn.oauth2_client().lookup(id).await?; let client = repo.oauth2_client().lookup(id).await?;
Ok(client.map(OAuth2Client)) Ok(client.map(OAuth2Client))
} }
@ -124,13 +123,12 @@ impl RootQuery {
) -> Result<Option<BrowserSession>, async_graphql::Error> { ) -> Result<Option<BrowserSession>, async_graphql::Error> {
let id = NodeType::BrowserSession.extract_ulid(&id)?; let id = NodeType::BrowserSession.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned(); let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
let browser_session = conn.browser_session().lookup(id).await?; let browser_session = repo.browser_session().lookup(id).await?;
let ret = browser_session.and_then(|browser_session| { let ret = browser_session.and_then(|browser_session| {
if browser_session.user.id == current_user.id { if browser_session.user.id == current_user.id {
@ -151,13 +149,12 @@ impl RootQuery {
) -> Result<Option<UserEmail>, async_graphql::Error> { ) -> Result<Option<UserEmail>, async_graphql::Error> {
let id = NodeType::UserEmail.extract_ulid(&id)?; let id = NodeType::UserEmail.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned(); let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
let user_email = conn let user_email = repo
.user_email() .user_email()
.lookup(id) .lookup(id)
.await? .await?
@ -174,13 +171,12 @@ impl RootQuery {
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> { ) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned(); let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
let link = conn.upstream_oauth_link().lookup(id).await?; let link = repo.upstream_oauth_link().lookup(id).await?;
// Ensure that the link belongs to the current user // Ensure that the link belongs to the current user
let link = link.filter(|link| link.user_id == Some(current_user.id)); let link = link.filter(|link| link.user_id == Some(current_user.id));
@ -195,10 +191,9 @@ impl RootQuery {
id: ID, id: ID,
) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> { ) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?;
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?;
let provider = conn.upstream_oauth_provider().lookup(id).await?; let provider = repo.upstream_oauth_provider().lookup(id).await?;
Ok(provider.map(UpstreamOAuth2Provider::new)) Ok(provider.map(UpstreamOAuth2Provider::new))
} }
@ -215,7 +210,7 @@ impl RootQuery {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> { ) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
query( query(
after, after,
@ -223,7 +218,6 @@ impl RootQuery {
first, first,
last, last,
|after, before, first, last| async move { |after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| { .map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Provider) x.extract_for_type(NodeType::UpstreamOAuth2Provider)
@ -235,7 +229,7 @@ impl RootQuery {
}) })
.transpose()?; .transpose()?;
let page = conn let page = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.list_paginated(before_id, after_id, first, last) .list_paginated(before_id, after_id, first, last)
.await?; .await?;

View File

@ -15,7 +15,9 @@
use anyhow::Context as _; use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID}; use async_graphql::{Context, Description, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; use mas_storage::{
compat::CompatSessionRepository, user::UserRepository, PgRepository, Repository,
};
use sqlx::PgPool; use sqlx::PgPool;
use url::Url; use url::Url;
@ -35,8 +37,8 @@ impl CompatSession {
/// The user authorized for this session. /// The user authorized for this session.
async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> { async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let user = conn let user = repo
.user() .user()
.lookup(self.0.user_id) .lookup(self.0.user_id)
.await? .await?
@ -100,8 +102,8 @@ impl CompatSsoLogin {
) -> Result<Option<CompatSession>, async_graphql::Error> { ) -> Result<Option<CompatSession>, async_graphql::Error> {
let Some(session_id) = self.0.session_id() else { return Ok(None) }; let Some(session_id) = self.0.session_id() else { return Ok(None) };
let mut conn = ctx.data::<PgPool>()?.acquire().await?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let session = conn let session = repo
.compat_session() .compat_session()
.lookup(session_id) .lookup(session_id)
.await? .await?

View File

@ -14,7 +14,9 @@
use anyhow::Context as _; use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID}; use async_graphql::{Context, Description, Object, ID};
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; use mas_storage::{
oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, PgRepository, Repository,
};
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use sqlx::PgPool; use sqlx::PgPool;
use ulid::Ulid; use ulid::Ulid;
@ -36,8 +38,8 @@ impl OAuth2Session {
/// OAuth 2.0 client used by this session. /// OAuth 2.0 client used by this session.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> { pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let client = conn let client = repo
.oauth2_client() .oauth2_client()
.lookup(self.0.client_id) .lookup(self.0.client_id)
.await? .await?
@ -56,8 +58,8 @@ impl OAuth2Session {
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
) -> Result<BrowserSession, async_graphql::Error> { ) -> Result<BrowserSession, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let browser_session = conn let browser_session = repo
.browser_session() .browser_session()
.lookup(self.0.user_session_id) .lookup(self.0.user_session_id)
.await? .await?
@ -68,8 +70,8 @@ impl OAuth2Session {
/// User authorized for this session. /// User authorized for this session.
pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> { pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let browser_session = conn let browser_session = repo
.browser_session() .browser_session()
.lookup(self.0.user_session_id) .lookup(self.0.user_session_id)
.await? .await?
@ -138,8 +140,8 @@ impl OAuth2Consent {
/// OAuth 2.0 client for which the user granted access. /// OAuth 2.0 client for which the user granted access.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> { pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let client = conn let client = repo
.oauth2_client() .oauth2_client()
.lookup(self.client_id) .lookup(self.client_id)
.await? .await?

View File

@ -16,7 +16,8 @@ use anyhow::Context as _;
use async_graphql::{Context, Object, ID}; use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, PgRepository,
Repository,
}; };
use sqlx::PgPool; use sqlx::PgPool;
@ -102,9 +103,8 @@ impl UpstreamOAuth2Link {
provider.clone() provider.clone()
} else { } else {
// Fetch on-the-fly // Fetch on-the-fly
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?; repo.upstream_oauth_provider()
conn.upstream_oauth_provider()
.lookup(self.link.provider_id) .lookup(self.link.provider_id)
.await? .await?
.context("Upstream OAuth 2.0 provider not found")? .context("Upstream OAuth 2.0 provider not found")?
@ -120,9 +120,8 @@ impl UpstreamOAuth2Link {
user.clone() user.clone()
} else if let Some(user_id) = &self.link.user_id { } else if let Some(user_id) = &self.link.user_id {
// Fetch on-the-fly // Fetch on-the-fly
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?; repo.user()
conn.user()
.lookup(*user_id) .lookup(*user_id)
.await? .await?
.context("User not found")? .context("User not found")?

View File

@ -20,8 +20,9 @@ use chrono::{DateTime, Utc};
use mas_storage::{ use mas_storage::{
compat::CompatSsoLoginRepository, compat::CompatSsoLoginRepository,
oauth2::OAuth2SessionRepository, oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
Repository, UpstreamOAuthLinkRepository, PgRepository, Repository,
}; };
use sqlx::PgPool; use sqlx::PgPool;
@ -63,10 +64,9 @@ impl User {
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
) -> Result<Option<UserEmail>, async_graphql::Error> { ) -> Result<Option<UserEmail>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut conn = database.acquire().await?;
Ok(conn.user_email().get_primary(&self.0).await?.map(UserEmail)) Ok(repo.user_email().get_primary(&self.0).await?.map(UserEmail))
} }
/// Get the list of compatibility SSO logins, chronologically sorted /// Get the list of compatibility SSO logins, chronologically sorted
@ -81,7 +81,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> { ) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
query( query(
after, after,
@ -89,7 +89,6 @@ impl User {
first, first,
last, last,
|after, before, first, last| async move { |after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
.transpose()?; .transpose()?;
@ -97,7 +96,7 @@ impl User {
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
.transpose()?; .transpose()?;
let page = conn let page = repo
.compat_sso_login() .compat_sso_login()
.list_paginated(&self.0, before_id, after_id, first, last) .list_paginated(&self.0, before_id, after_id, first, last)
.await?; .await?;
@ -128,7 +127,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, BrowserSession>, async_graphql::Error> { ) -> Result<Connection<Cursor, BrowserSession>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
query( query(
after, after,
@ -136,7 +135,6 @@ impl User {
first, first,
last, last,
|after, before, first, last| async move { |after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
.transpose()?; .transpose()?;
@ -144,7 +142,7 @@ impl User {
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
.transpose()?; .transpose()?;
let page = conn let page = repo
.browser_session() .browser_session()
.list_active_paginated(&self.0, before_id, after_id, first, last) .list_active_paginated(&self.0, before_id, after_id, first, last)
.await?; .await?;
@ -175,7 +173,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> { ) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
query( query(
after, after,
@ -183,7 +181,6 @@ impl User {
first, first,
last, last,
|after, before, first, last| async move { |after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
.transpose()?; .transpose()?;
@ -191,7 +188,7 @@ impl User {
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
.transpose()?; .transpose()?;
let page = conn let page = repo
.user_email() .user_email()
.list_paginated(&self.0, before_id, after_id, first, last) .list_paginated(&self.0, before_id, after_id, first, last)
.await?; .await?;
@ -226,7 +223,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> { ) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
query( query(
after, after,
@ -234,7 +231,6 @@ impl User {
first, first,
last, last,
|after, before, first, last| async move { |after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
.transpose()?; .transpose()?;
@ -242,7 +238,7 @@ impl User {
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session)) .map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
.transpose()?; .transpose()?;
let page = conn let page = repo
.oauth2_session() .oauth2_session()
.list_paginated(&self.0, before_id, after_id, first, last) .list_paginated(&self.0, before_id, after_id, first, last)
.await?; .await?;
@ -273,7 +269,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> { ) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
query( query(
after, after,
@ -281,7 +277,6 @@ impl User {
first, first,
last, last,
|after, before, first, last| async move { |after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| { .map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Link) x.extract_for_type(NodeType::UpstreamOAuth2Link)
@ -293,7 +288,7 @@ impl User {
}) })
.transpose()?; .transpose()?;
let page = conn let page = repo
.upstream_oauth_link() .upstream_oauth_link()
.list_paginated(&self.0, before_id, after_id, first, last) .list_paginated(&self.0, before_id, after_id, first, last)
.await?; .await?;
@ -347,8 +342,8 @@ pub struct UserEmailsPagination(mas_data_model::User);
impl UserEmailsPagination { impl UserEmailsPagination {
/// Identifies the total count of items in the connection. /// Identifies the total count of items in the connection.
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> { async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?; let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let count = conn.user_email().count(&self.0).await?; let count = repo.user_email().count(&self.0).await?;
Ok(count) Ok(count)
} }
} }

View File

@ -22,11 +22,11 @@ use mas_storage::{
CompatSsoLoginRepository, CompatSsoLoginRepository,
}, },
user::{UserPasswordRepository, UserRepository}, user::{UserPasswordRepository, UserRepository},
Clock, Repository, Clock, PgRepository, Repository,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use zeroize::Zeroizing; use zeroize::Zeroizing;
@ -199,14 +199,14 @@ pub(crate) async fn post(
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session, user) = match input.credentials { let (session, user) = match input.credentials {
Credentials::Password { Credentials::Password {
identifier: Identifier::User { user }, identifier: Identifier::User { user },
password, password,
} => user_password_login(&password_manager, &mut txn, user, password).await?, } => user_password_login(&password_manager, &mut repo, user, password).await?,
Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?, Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?,
_ => { _ => {
return Err(RouteError::Unsupported); return Err(RouteError::Unsupported);
@ -224,14 +224,14 @@ pub(crate) async fn post(
}; };
let access_token = TokenType::CompatAccessToken.generate(&mut rng); let access_token = TokenType::CompatAccessToken.generate(&mut rng);
let access_token = txn let access_token = repo
.compat_access_token() .compat_access_token()
.add(&mut rng, &clock, &session, access_token, expires_in) .add(&mut rng, &clock, &session, access_token, expires_in)
.await?; .await?;
let refresh_token = if input.refresh_token { let refresh_token = if input.refresh_token {
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng); let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
let refresh_token = txn let refresh_token = repo
.compat_refresh_token() .compat_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token) .add(&mut rng, &clock, &session, &access_token, refresh_token)
.await?; .await?;
@ -240,7 +240,7 @@ pub(crate) async fn post(
None None
}; };
txn.commit().await?; repo.save().await?;
Ok(Json(ResponseBody { Ok(Json(ResponseBody {
access_token: access_token.token, access_token: access_token.token,
@ -252,11 +252,11 @@ pub(crate) async fn post(
} }
async fn token_login( async fn token_login(
txn: &mut Transaction<'_, Postgres>, repo: &mut PgRepository,
clock: &Clock, clock: &Clock,
token: &str, token: &str,
) -> Result<(CompatSession, User), RouteError> { ) -> Result<(CompatSession, User), RouteError> {
let login = txn let login = repo
.compat_sso_login() .compat_sso_login()
.find_by_token(token) .find_by_token(token)
.await? .await?
@ -300,40 +300,40 @@ async fn token_login(
} }
}; };
let session = txn let session = repo
.compat_session() .compat_session()
.lookup(session_id) .lookup(session_id)
.await? .await?
.ok_or(RouteError::SessionNotFound)?; .ok_or(RouteError::SessionNotFound)?;
let user = txn let user = repo
.user() .user()
.lookup(session.user_id) .lookup(session.user_id)
.await? .await?
.ok_or(RouteError::UserNotFound)?; .ok_or(RouteError::UserNotFound)?;
txn.compat_sso_login().exchange(clock, login).await?; repo.compat_sso_login().exchange(clock, login).await?;
Ok((session, user)) Ok((session, user))
} }
async fn user_password_login( async fn user_password_login(
password_manager: &PasswordManager, password_manager: &PasswordManager,
txn: &mut Transaction<'_, Postgres>, repo: &mut PgRepository,
username: String, username: String,
password: String, password: String,
) -> Result<(CompatSession, User), RouteError> { ) -> Result<(CompatSession, User), RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
// Find the user // Find the user
let user = txn let user = repo
.user() .user()
.find_by_username(&username) .find_by_username(&username)
.await? .await?
.ok_or(RouteError::UserNotFound)?; .ok_or(RouteError::UserNotFound)?;
// Lookup its password // Lookup its password
let user_password = txn let user_password = repo
.user_password() .user_password()
.active(&user) .active(&user)
.await? .await?
@ -354,7 +354,7 @@ async fn user_password_login(
if let Some((version, hashed_password)) = new_password_hash { if let Some((version, hashed_password)) = new_password_hash {
// Save the upgraded password if needed // Save the upgraded password if needed
txn.user_password() repo.user_password()
.add( .add(
&mut rng, &mut rng,
&clock, &clock,
@ -368,7 +368,7 @@ async fn user_password_login(
// Now that the user credentials have been verified, start a new compat session // Now that the user credentials have been verified, start a new compat session
let device = Device::generate(&mut rng); let device = Device::generate(&mut rng);
let session = txn let session = repo
.compat_session() .compat_session()
.add(&mut rng, &clock, &user, device) .add(&mut rng, &clock, &user, device)
.await?; .await?;

View File

@ -31,7 +31,7 @@ use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository}, compat::{CompatSessionRepository, CompatSsoLoginRepository},
Repository, PgRepository, Repository,
}; };
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -60,12 +60,12 @@ pub async fn get(
Query(params): Query<Params>, Query(params): Query<Params>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -90,7 +90,7 @@ pub async fn get(
return Ok((cookie_jar, destination.go()).into_response()); return Ok((cookie_jar, destination.go()).into_response());
} }
let login = conn let login = repo
.compat_sso_login() .compat_sso_login()
.lookup(id) .lookup(id)
.await? .await?
@ -124,12 +124,12 @@ pub async fn post(
Form(form): Form<ProtectedForm<()>>, Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
cookie_jar.verify_form(clock.now(), form)?; cookie_jar.verify_form(clock.now(), form)?;
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -154,7 +154,7 @@ pub async fn post(
return Ok((cookie_jar, destination.go()).into_response()); return Ok((cookie_jar, destination.go()).into_response());
} }
let login = txn let login = repo
.compat_sso_login() .compat_sso_login()
.lookup(id) .lookup(id)
.await? .await?
@ -188,16 +188,16 @@ pub async fn post(
}; };
let device = Device::generate(&mut rng); let device = Device::generate(&mut rng);
let compat_session = txn let compat_session = repo
.compat_session() .compat_session()
.add(&mut rng, &clock, &session.user, device) .add(&mut rng, &clock, &session.user, device)
.await?; .await?;
txn.compat_sso_login() repo.compat_sso_login()
.fulfill(&clock, login, &compat_session) .fulfill(&clock, login, &compat_session)
.await?; .await?;
txn.commit().await?; repo.save().await?;
Ok((cookie_jar, Redirect::to(redirect_uri.as_str())).into_response()) Ok((cookie_jar, Redirect::to(redirect_uri.as_str())).into_response())
} }

View File

@ -19,7 +19,7 @@ use axum::{
}; };
use hyper::StatusCode; use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::{compat::CompatSsoLoginRepository, Repository}; use mas_storage::{compat::CompatSsoLoginRepository, PgRepository, Repository};
use rand::distributions::{Alphanumeric, DistString}; use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize; use serde::Deserialize;
use serde_with::serde; use serde_with::serde;
@ -80,8 +80,8 @@ pub async fn get(
} }
let token = Alphanumeric.sample_string(&mut rng, 32); let token = Alphanumeric.sample_string(&mut rng, 32);
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let login = conn let login = repo
.compat_sso_login() .compat_sso_login()
.add(&mut rng, &clock, token, redirect_url) .add(&mut rng, &clock, token, redirect_url)
.await?; .await?;

View File

@ -18,7 +18,7 @@ use hyper::StatusCode;
use mas_data_model::TokenType; use mas_data_model::TokenType;
use mas_storage::{ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatSessionRepository},
Clock, Repository, Clock, PgRepository, Repository,
}; };
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
@ -72,7 +72,7 @@ pub(crate) async fn post(
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>, maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default(); let clock = Clock::default();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;
@ -83,23 +83,23 @@ pub(crate) async fn post(
return Err(RouteError::InvalidAuthorization); return Err(RouteError::InvalidAuthorization);
} }
let token = txn let token = repo
.compat_access_token() .compat_access_token()
.find_by_token(token) .find_by_token(token)
.await? .await?
.filter(|t| t.is_valid(clock.now())) .filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::InvalidAuthorization)?; .ok_or(RouteError::InvalidAuthorization)?;
let session = txn let session = repo
.compat_session() .compat_session()
.lookup(token.session_id) .lookup(token.session_id)
.await? .await?
.filter(|s| s.is_valid()) .filter(|s| s.is_valid())
.ok_or(RouteError::InvalidAuthorization)?; .ok_or(RouteError::InvalidAuthorization)?;
txn.compat_session().finish(&clock, session).await?; repo.compat_session().finish(&clock, session).await?;
txn.commit().await?; repo.save().await?;
Ok(Json(serde_json::json!({}))) Ok(Json(serde_json::json!({})))
} }

View File

@ -18,7 +18,7 @@ use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType}; use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::{ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
Repository, PgRepository, Repository,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds}; use serde_with::{serde_as, DurationMilliSeconds};
@ -92,7 +92,7 @@ pub(crate) async fn post(
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let token_type = TokenType::check(&input.refresh_token)?; let token_type = TokenType::check(&input.refresh_token)?;
@ -100,7 +100,7 @@ pub(crate) async fn post(
return Err(RouteError::InvalidToken); return Err(RouteError::InvalidToken);
} }
let refresh_token = txn let refresh_token = repo
.compat_refresh_token() .compat_refresh_token()
.find_by_token(&input.refresh_token) .find_by_token(&input.refresh_token)
.await? .await?
@ -110,7 +110,7 @@ pub(crate) async fn post(
return Err(RouteError::RefreshTokenConsumed); return Err(RouteError::RefreshTokenConsumed);
} }
let session = txn let session = repo
.compat_session() .compat_session()
.lookup(refresh_token.session_id) .lookup(refresh_token.session_id)
.await? .await?
@ -120,7 +120,7 @@ pub(crate) async fn post(
return Err(RouteError::InvalidSession); return Err(RouteError::InvalidSession);
} }
let access_token = txn let access_token = repo
.compat_access_token() .compat_access_token()
.lookup(refresh_token.access_token_id) .lookup(refresh_token.access_token_id)
.await? .await?
@ -130,7 +130,7 @@ pub(crate) async fn post(
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
let expires_in = Duration::minutes(5); let expires_in = Duration::minutes(5);
let new_access_token = txn let new_access_token = repo
.compat_access_token() .compat_access_token()
.add( .add(
&mut rng, &mut rng,
@ -140,7 +140,7 @@ pub(crate) async fn post(
Some(expires_in), Some(expires_in),
) )
.await?; .await?;
let new_refresh_token = txn let new_refresh_token = repo
.compat_refresh_token() .compat_refresh_token()
.add( .add(
&mut rng, &mut rng,
@ -151,17 +151,17 @@ pub(crate) async fn post(
) )
.await?; .await?;
txn.compat_refresh_token() repo.compat_refresh_token()
.consume(&clock, refresh_token) .consume(&clock, refresh_token)
.await?; .await?;
if let Some(access_token) = access_token { if let Some(access_token) = access_token {
txn.compat_access_token() repo.compat_access_token()
.expire(&clock, access_token) .expire(&clock, access_token)
.await?; .await?;
} }
txn.commit().await?; repo.save().await?;
Ok(Json(ResponseBody { Ok(Json(ResponseBody {
access_token: new_access_token.token, access_token: new_access_token.token,

View File

@ -28,6 +28,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema; use mas_graphql::Schema;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::PgRepository;
use sqlx::PgPool; use sqlx::PgPool;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
@ -67,8 +68,9 @@ pub async fn post(
let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
repo.cancel().await?;
let mut request = async_graphql::http::receive_batch_body( let mut request = async_graphql::http::receive_batch_body(
content_type, content_type,
@ -117,8 +119,9 @@ pub async fn get(
RawQuery(query): RawQuery, RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
repo.cancel().await?;
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?;

View File

@ -27,11 +27,11 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
Repository, PgRepository, Repository,
}; };
use mas_templates::Templates; use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
@ -84,13 +84,13 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let grant = txn let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.lookup(grant_id) .lookup(grant_id)
.await? .await?
@ -107,7 +107,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response()); return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response());
}; };
match complete(grant, session, &policy_factory, txn).await { match complete(grant, session, &policy_factory, repo).await {
Ok(params) => { Ok(params) => {
let res = callback_destination.go(&templates, params).await?; let res = callback_destination.go(&templates, params).await?;
Ok((cookie_jar, res).into_response()) Ok((cookie_jar, res).into_response())
@ -159,7 +159,7 @@ pub(crate) async fn complete(
grant: AuthorizationGrant, grant: AuthorizationGrant,
browser_session: BrowserSession, browser_session: BrowserSession,
policy_factory: &PolicyFactory, policy_factory: &PolicyFactory,
mut txn: Transaction<'_, Postgres>, mut repo: PgRepository,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> { ) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
@ -170,7 +170,7 @@ pub(crate) async fn complete(
// Check if the authentication is fresh enough // Check if the authentication is fresh enough
if !browser_session.was_authenticated_after(grant.max_auth_time()) { if !browser_session.was_authenticated_after(grant.max_auth_time()) {
txn.commit().await?; repo.save().await?;
return Err(GrantCompletionError::RequiresReauth); return Err(GrantCompletionError::RequiresReauth);
} }
@ -184,13 +184,13 @@ pub(crate) async fn complete(
return Err(GrantCompletionError::PolicyViolation); return Err(GrantCompletionError::PolicyViolation);
} }
let client = txn let client = repo
.oauth2_client() .oauth2_client()
.lookup(grant.client_id) .lookup(grant.client_id)
.await? .await?
.ok_or(GrantCompletionError::NoSuchClient)?; .ok_or(GrantCompletionError::NoSuchClient)?;
let current_consent = txn let current_consent = repo
.oauth2_client() .oauth2_client()
.get_consent_for_user(&client, &browser_session.user) .get_consent_for_user(&client, &browser_session.user)
.await?; .await?;
@ -202,17 +202,17 @@ pub(crate) async fn complete(
// Check if the client lacks consent *or* if consent was explicitely asked // Check if the client lacks consent *or* if consent was explicitely asked
if lacks_consent || grant.requires_consent { if lacks_consent || grant.requires_consent {
txn.commit().await?; repo.save().await?;
return Err(GrantCompletionError::RequiresConsent); return Err(GrantCompletionError::RequiresConsent);
} }
// All good, let's start the session // All good, let's start the session
let session = txn let session = repo
.oauth2_session() .oauth2_session()
.create_from_grant(&mut rng, &clock, &grant, &browser_session) .create_from_grant(&mut rng, &clock, &grant, &browser_session)
.await?; .await?;
let grant = txn let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.fulfill(&clock, &session, grant) .fulfill(&clock, &session, grant)
.await?; .await?;
@ -233,6 +233,6 @@ pub(crate) async fn complete(
)); ));
} }
txn.commit().await?; repo.save().await?;
Ok(params) Ok(params)
} }

View File

@ -27,7 +27,7 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
Repository, PgRepository, Repository,
}; };
use mas_templates::Templates; use mas_templates::Templates;
use oauth2_types::{ use oauth2_types::{
@ -139,10 +139,10 @@ pub(crate) async fn get(
Form(params): Form<Params>, Form(params): Form<Params>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
// First, figure out what client it is // First, figure out what client it is
let client = txn let client = repo
.oauth2_client() .oauth2_client()
.find_by_client_id(&params.auth.client_id) .find_by_client_id(&params.auth.client_id)
.await? .await?
@ -170,7 +170,7 @@ pub(crate) async fn get(
let templates = templates.clone(); let templates = templates.clone();
let callback_destination = callback_destination.clone(); let callback_destination = callback_destination.clone();
async move { async move {
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default(); let prompt = params.auth.prompt.as_deref().unwrap_or_default();
// Check if the request/request_uri/registration params are used. If so, reply // Check if the request/request_uri/registration params are used. If so, reply
@ -275,7 +275,7 @@ pub(crate) async fn get(
let requires_consent = prompt.contains(&Prompt::Consent); let requires_consent = prompt.contains(&Prompt::Consent);
let grant = txn let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.add( .add(
&mut rng, &mut rng,
@ -302,7 +302,7 @@ pub(crate) async fn get(
} }
None if prompt.contains(&Prompt::Create) => { None if prompt.contains(&Prompt::Create) => {
// Client asked for a registration, show the registration prompt // Client asked for a registration, show the registration prompt
txn.commit().await?; repo.save().await?;
mas_router::Register::and_then(continue_grant) mas_router::Register::and_then(continue_grant)
.go() .go()
@ -310,7 +310,7 @@ pub(crate) async fn get(
} }
None => { None => {
// Other cases where we don't have a session, ask for a login // Other cases where we don't have a session, ask for a login
txn.commit().await?; repo.save().await?;
mas_router::Login::and_then(continue_grant) mas_router::Login::and_then(continue_grant)
.go() .go()
@ -323,7 +323,7 @@ pub(crate) async fn get(
|| prompt.contains(&Prompt::SelectAccount) => || prompt.contains(&Prompt::SelectAccount) =>
{ {
// TODO: better pages here // TODO: better pages here
txn.commit().await?; repo.save().await?;
mas_router::Reauth::and_then(continue_grant) mas_router::Reauth::and_then(continue_grant)
.go() .go()
@ -333,7 +333,7 @@ pub(crate) async fn get(
// Else, we immediately try to complete the authorization grant // Else, we immediately try to complete the authorization grant
Some(user_session) if prompt.contains(&Prompt::None) => { Some(user_session) if prompt.contains(&Prompt::None) => {
// With prompt=none, we should get back to the client immediately // With prompt=none, we should get back to the client immediately
match self::complete::complete(grant, user_session, &policy_factory, txn).await match self::complete::complete(grant, user_session, &policy_factory, repo).await
{ {
Ok(params) => callback_destination.go(&templates, params).await?, Ok(params) => callback_destination.go(&templates, params).await?,
Err(GrantCompletionError::RequiresConsent) => { Err(GrantCompletionError::RequiresConsent) => {
@ -372,7 +372,7 @@ pub(crate) async fn get(
Some(user_session) => { Some(user_session) => {
let grant_id = grant.id; let grant_id = grant.id;
// Else, we show the relevant reauth/consent page if necessary // Else, we show the relevant reauth/consent page if necessary
match self::complete::complete(grant, user_session, &policy_factory, txn).await match self::complete::complete(grant, user_session, &policy_factory, repo).await
{ {
Ok(params) => callback_destination.go(&templates, params).await?, Ok(params) => callback_destination.go(&templates, params).await?,
Err( Err(

View File

@ -30,7 +30,7 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
Repository, PgRepository, Repository,
}; };
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;
@ -81,13 +81,13 @@ pub(crate) async fn get(
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let grant = conn let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.lookup(grant_id) .lookup(grant_id)
.await? .await?
@ -136,15 +136,15 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<()>>, Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
cookie_jar.verify_form(clock.now(), form)?; cookie_jar.verify_form(clock.now(), form)?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let grant = txn let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.lookup(grant_id) .lookup(grant_id)
.await? .await?
@ -167,7 +167,7 @@ pub(crate) async fn post(
return Err(RouteError::PolicyViolation); return Err(RouteError::PolicyViolation);
} }
let client = txn let client = repo
.oauth2_client() .oauth2_client()
.lookup(grant.client_id) .lookup(grant.client_id)
.await? .await?
@ -180,7 +180,7 @@ pub(crate) async fn post(
.filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:")) .filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:"))
.cloned() .cloned()
.collect(); .collect();
txn.oauth2_client() repo.oauth2_client()
.give_consent_for_user( .give_consent_for_user(
&mut rng, &mut rng,
&clock, &clock,
@ -190,9 +190,11 @@ pub(crate) async fn post(
) )
.await?; .await?;
txn.oauth2_authorization_grant().give_consent(grant).await?; repo.oauth2_authorization_grant()
.give_consent(grant)
.await?;
txn.commit().await?; repo.save().await?;
Ok((cookie_jar, next.go_next()).into_response()) Ok((cookie_jar, next.go_next()).into_response())
} }

View File

@ -25,7 +25,7 @@ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
user::{BrowserSessionRepository, UserRepository}, user::{BrowserSessionRepository, UserRepository},
Clock, Repository, Clock, PgRepository, Repository,
}; };
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
@ -130,12 +130,13 @@ pub(crate) async fn post(
client_authorization: ClientAuthorization<IntrospectionRequest>, client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default(); let clock = Clock::default();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let client = client_authorization let client = client_authorization
.credentials .credentials
.fetch(&mut conn) .fetch(&mut repo)
.await? .await
.unwrap()
.ok_or(RouteError::ClientNotFound)?; .ok_or(RouteError::ClientNotFound)?;
let method = match &client.token_endpoint_auth_method { let method = match &client.token_endpoint_auth_method {
@ -166,14 +167,14 @@ pub(crate) async fn post(
let reply = match token_type { let reply = match token_type {
TokenType::AccessToken => { TokenType::AccessToken => {
let token = conn let token = repo
.oauth2_access_token() .oauth2_access_token()
.find_by_token(token) .find_by_token(token)
.await? .await?
.filter(|t| t.is_valid(clock.now())) .filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let session = conn let session = repo
.oauth2_session() .oauth2_session()
.lookup(token.session_id) .lookup(token.session_id)
.await? .await?
@ -181,7 +182,7 @@ pub(crate) async fn post(
// XXX: is that the right error to bubble up? // XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let browser_session = conn let browser_session = repo
.browser_session() .browser_session()
.lookup(session.user_session_id) .lookup(session.user_session_id)
.await? .await?
@ -205,14 +206,14 @@ pub(crate) async fn post(
} }
TokenType::RefreshToken => { TokenType::RefreshToken => {
let token = conn let token = repo
.oauth2_refresh_token() .oauth2_refresh_token()
.find_by_token(token) .find_by_token(token)
.await? .await?
.filter(|t| t.is_valid()) .filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let session = conn let session = repo
.oauth2_session() .oauth2_session()
.lookup(token.session_id) .lookup(token.session_id)
.await? .await?
@ -220,7 +221,7 @@ pub(crate) async fn post(
// XXX: is that the right error to bubble up? // XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let browser_session = conn let browser_session = repo
.browser_session() .browser_session()
.lookup(session.user_session_id) .lookup(session.user_session_id)
.await? .await?
@ -244,21 +245,21 @@ pub(crate) async fn post(
} }
TokenType::CompatAccessToken => { TokenType::CompatAccessToken => {
let access_token = conn let access_token = repo
.compat_access_token() .compat_access_token()
.find_by_token(token) .find_by_token(token)
.await? .await?
.filter(|t| t.is_valid(clock.now())) .filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let session = conn let session = repo
.compat_session() .compat_session()
.lookup(access_token.session_id) .lookup(access_token.session_id)
.await? .await?
.filter(|s| s.is_valid()) .filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let user = conn let user = repo
.user() .user()
.lookup(session.user_id) .lookup(session.user_id)
.await? .await?
@ -285,21 +286,21 @@ pub(crate) async fn post(
} }
TokenType::CompatRefreshToken => { TokenType::CompatRefreshToken => {
let refresh_token = conn let refresh_token = repo
.compat_refresh_token() .compat_refresh_token()
.find_by_token(token) .find_by_token(token)
.await? .await?
.filter(|t| t.is_valid()) .filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let session = conn let session = repo
.compat_session() .compat_session()
.lookup(refresh_token.session_id) .lookup(refresh_token.session_id)
.await? .await?
.filter(|s| s.is_valid()) .filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let user = conn let user = repo
.user() .user()
.lookup(session.user_id) .lookup(session.user_id)
.await? .await?

View File

@ -19,7 +19,7 @@ use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation}; use mas_policy::{PolicyFactory, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; use mas_storage::{oauth2::OAuth2ClientRepository, PgRepository, Repository};
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
registration::{ registration::{
@ -124,8 +124,7 @@ pub(crate) async fn post(
return Err(RouteError::PolicyDenied(res.violations)); return Err(RouteError::PolicyDenied(res.violations));
} }
// Grab a txn let mut repo = PgRepository::from_pool(&pool).await?;
let mut txn = pool.begin().await?;
let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
Some( Some(
@ -141,7 +140,7 @@ pub(crate) async fn post(
_ => (None, None), _ => (None, None),
}; };
let client = txn let client = repo
.oauth2_client() .oauth2_client()
.add( .add(
&mut rng, &mut rng,
@ -170,7 +169,7 @@ pub(crate) async fn post(
) )
.await?; .await?;
txn.commit().await?; repo.save().await?;
let response = ClientRegistrationResponse { let response = ClientRegistrationResponse {
client_id: client.client_id, client_id: client.client_id,

View File

@ -37,7 +37,7 @@ use mas_storage::{
OAuth2RefreshTokenRepository, OAuth2SessionRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
}, },
user::BrowserSessionRepository, user::BrowserSessionRepository,
Repository, PgRepository, Repository,
}; };
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
@ -49,7 +49,7 @@ use oauth2_types::{
}; };
use serde::Serialize; use serde::Serialize;
use serde_with::{serde_as, skip_serializing_none}; use serde_with::{serde_as, skip_serializing_none};
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use tracing::debug; use tracing::debug;
use url::Url; use url::Url;
@ -166,11 +166,11 @@ pub(crate) async fn post(
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<AccessTokenRequest>, client_authorization: ClientAuthorization<AccessTokenRequest>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let client = client_authorization let client = client_authorization
.credentials .credentials
.fetch(&mut txn) .fetch(&mut repo)
.await? .await?
.ok_or(RouteError::ClientNotFound)?; .ok_or(RouteError::ClientNotFound)?;
@ -188,10 +188,10 @@ pub(crate) async fn post(
let reply = match form { let reply = match form {
AccessTokenRequest::AuthorizationCode(grant) => { AccessTokenRequest::AuthorizationCode(grant) => {
authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await? authorization_code_grant(&grant, &client, &key_store, &url_builder, repo).await?
} }
AccessTokenRequest::RefreshToken(grant) => { AccessTokenRequest::RefreshToken(grant) => {
refresh_token_grant(&grant, &client, txn).await? refresh_token_grant(&grant, &client, repo).await?
} }
_ => { _ => {
return Err(RouteError::InvalidGrant); return Err(RouteError::InvalidGrant);
@ -211,11 +211,11 @@ async fn authorization_code_grant(
client: &Client, client: &Client,
key_store: &Keystore, key_store: &Keystore,
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
mut txn: Transaction<'_, Postgres>, mut repo: PgRepository,
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let authz_grant = txn let authz_grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.find_by_code(&grant.code) .find_by_code(&grant.code)
.await? .await?
@ -238,13 +238,13 @@ async fn authorization_code_grant(
// Ending the session if the token was already exchanged more than 20s ago // Ending the session if the token was already exchanged more than 20s ago
if now - exchanged_at > Duration::seconds(20) { if now - exchanged_at > Duration::seconds(20) {
debug!("Ending potentially compromised session"); debug!("Ending potentially compromised session");
let session = txn let session = repo
.oauth2_session() .oauth2_session()
.lookup(session_id) .lookup(session_id)
.await? .await?
.ok_or(RouteError::NoSuchOAuthSession)?; .ok_or(RouteError::NoSuchOAuthSession)?;
txn.oauth2_session().finish(&clock, session).await?; repo.oauth2_session().finish(&clock, session).await?;
txn.commit().await?; repo.save().await?;
} }
return Err(RouteError::InvalidGrant); return Err(RouteError::InvalidGrant);
@ -266,7 +266,7 @@ async fn authorization_code_grant(
} }
}; };
let session = txn let session = repo
.oauth2_session() .oauth2_session()
.lookup(session_id) .lookup(session_id)
.await? .await?
@ -289,7 +289,7 @@ async fn authorization_code_grant(
} }
}; };
let browser_session = txn let browser_session = repo
.browser_session() .browser_session()
.lookup(session.user_session_id) .lookup(session.user_session_id)
.await? .await?
@ -299,12 +299,12 @@ async fn authorization_code_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng); let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let access_token = txn let access_token = repo
.oauth2_access_token() .oauth2_access_token()
.add(&mut rng, &clock, &session, access_token_str, ttl) .add(&mut rng, &clock, &session, access_token_str, ttl)
.await?; .await?;
let refresh_token = txn let refresh_token = repo
.oauth2_refresh_token() .oauth2_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token_str) .add(&mut rng, &clock, &session, &access_token, refresh_token_str)
.await?; .await?;
@ -355,11 +355,11 @@ async fn authorization_code_grant(
params = params.with_id_token(id_token); params = params.with_id_token(id_token);
} }
txn.oauth2_authorization_grant() repo.oauth2_authorization_grant()
.exchange(&clock, authz_grant) .exchange(&clock, authz_grant)
.await?; .await?;
txn.commit().await?; repo.save().await?;
Ok(params) Ok(params)
} }
@ -367,17 +367,17 @@ async fn authorization_code_grant(
async fn refresh_token_grant( async fn refresh_token_grant(
grant: &RefreshTokenGrant, grant: &RefreshTokenGrant,
client: &Client, client: &Client,
mut txn: Transaction<'_, Postgres>, mut repo: PgRepository,
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let refresh_token = txn let refresh_token = repo
.oauth2_refresh_token() .oauth2_refresh_token()
.find_by_token(&grant.refresh_token) .find_by_token(&grant.refresh_token)
.await? .await?
.ok_or(RouteError::InvalidGrant)?; .ok_or(RouteError::InvalidGrant)?;
let session = txn let session = repo
.oauth2_session() .oauth2_session()
.lookup(refresh_token.session_id) .lookup(refresh_token.session_id)
.await? .await?
@ -396,12 +396,12 @@ async fn refresh_token_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng); let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let new_access_token = txn let new_access_token = repo
.oauth2_access_token() .oauth2_access_token()
.add(&mut rng, &clock, &session, access_token_str.clone(), ttl) .add(&mut rng, &clock, &session, access_token_str.clone(), ttl)
.await?; .await?;
let new_refresh_token = txn let new_refresh_token = repo
.oauth2_refresh_token() .oauth2_refresh_token()
.add( .add(
&mut rng, &mut rng,
@ -412,14 +412,14 @@ async fn refresh_token_grant(
) )
.await?; .await?;
let refresh_token = txn let refresh_token = repo
.oauth2_refresh_token() .oauth2_refresh_token()
.consume(&clock, refresh_token) .consume(&clock, refresh_token)
.await?; .await?;
if let Some(access_token_id) = refresh_token.access_token_id { if let Some(access_token_id) = refresh_token.access_token_id {
if let Some(access_token) = txn.oauth2_access_token().lookup(access_token_id).await? { if let Some(access_token) = repo.oauth2_access_token().lookup(access_token_id).await? {
txn.oauth2_access_token() repo.oauth2_access_token()
.revoke(&clock, access_token) .revoke(&clock, access_token)
.await?; .await?;
} }
@ -430,7 +430,7 @@ async fn refresh_token_grant(
.with_refresh_token(new_refresh_token.refresh_token) .with_refresh_token(new_refresh_token.refresh_token)
.with_scope(session.scope); .with_scope(session.scope);
txn.commit().await?; repo.save().await?;
Ok(params) Ok(params)
} }

View File

@ -31,7 +31,7 @@ use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
oauth2::OAuth2ClientRepository, oauth2::OAuth2ClientRepository,
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
Repository, DatabaseError, PgRepository, Repository,
}; };
use oauth2_types::scope; use oauth2_types::scope;
use serde::Serialize; use serde::Serialize;
@ -64,7 +64,7 @@ pub enum RouteError {
Internal(Box<dyn std::error::Error + Send + Sync + 'static>), Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("failed to authenticate")] #[error("failed to authenticate")]
AuthorizationVerificationError(#[from] AuthorizationVerificationError), AuthorizationVerificationError(#[from] AuthorizationVerificationError<DatabaseError>),
#[error("no suitable key found for signing")] #[error("no suitable key found for signing")]
InvalidSigningKey, InvalidSigningKey,
@ -102,11 +102,11 @@ pub async fn get(
user_authorization: UserAuthorization, user_authorization: UserAuthorization,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let session = user_authorization.protected(&mut conn, clock.now()).await?; let session = user_authorization.protected(&mut repo, clock.now()).await?;
let browser_session = conn let browser_session = repo
.browser_session() .browser_session()
.lookup(session.user_session_id) .lookup(session.user_session_id)
.await? .await?
@ -115,7 +115,7 @@ pub async fn get(
let user = browser_session.user; let user = browser_session.user;
let user_email = if session.scope.contains(&scope::EMAIL) { let user_email = if session.scope.contains(&scope::EMAIL) {
conn.user_email().get_primary(&user).await? repo.user_email().get_primary(&user).await?
} else { } else {
None None
}; };
@ -127,7 +127,7 @@ pub async fn get(
email: user_email.map(|u| u.email), email: user_email.map(|u| u.email),
}; };
let client = conn let client = repo
.oauth2_client() .oauth2_client()
.lookup(session.client_id) .lookup(session.client_id)
.await? .await?

View File

@ -24,7 +24,7 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
Repository, PgRepository, Repository,
}; };
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
@ -67,9 +67,9 @@ pub(crate) async fn get(
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let provider = txn let provider = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.lookup(provider_id) .lookup(provider_id)
.await? .await?
@ -100,7 +100,7 @@ pub(crate) async fn get(
&mut rng, &mut rng,
)?; )?;
let session = txn let session = repo
.upstream_oauth_session() .upstream_oauth_session()
.add( .add(
&mut rng, &mut rng,
@ -116,7 +116,7 @@ pub(crate) async fn get(
.add(session.id, provider.id, data.state, query.post_auth_action) .add(session.id, provider.id, data.state, query.post_auth_action)
.save(cookie_jar, clock.now()); .save(cookie_jar, clock.now());
txn.commit().await?; repo.save().await?;
Ok((cookie_jar, Redirect::temporary(url.as_str()))) Ok((cookie_jar, Redirect::temporary(url.as_str())))
} }

View File

@ -26,8 +26,11 @@ use mas_oidc_client::requests::{
}; };
use mas_router::{Route, UrlBuilder}; use mas_router::{Route, UrlBuilder};
use mas_storage::{ use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, upstream_oauth2::{
Repository, UpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
PgRepository, Repository,
}; };
use oauth2_types::errors::ClientErrorCode; use oauth2_types::errors::ClientErrorCode;
use serde::Deserialize; use serde::Deserialize;
@ -129,9 +132,9 @@ pub(crate) async fn get(
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let provider = txn let provider = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.lookup(provider_id) .lookup(provider_id)
.await? .await?
@ -142,7 +145,7 @@ pub(crate) async fn get(
.find_session(provider_id, &params.state) .find_session(provider_id, &params.state)
.map_err(|_| RouteError::MissingCookie)?; .map_err(|_| RouteError::MissingCookie)?;
let session = txn let session = repo
.upstream_oauth_session() .upstream_oauth_session()
.lookup(session_id) .lookup(session_id)
.await? .await?
@ -244,7 +247,7 @@ pub(crate) async fn get(
let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?;
// Look for an existing link // Look for an existing link
let maybe_link = txn let maybe_link = repo
.upstream_oauth_link() .upstream_oauth_link()
.find_by_subject(&provider, &subject) .find_by_subject(&provider, &subject)
.await?; .await?;
@ -252,12 +255,12 @@ pub(crate) async fn get(
let link = if let Some(link) = maybe_link { let link = if let Some(link) = maybe_link {
link link
} else { } else {
txn.upstream_oauth_link() repo.upstream_oauth_link()
.add(&mut rng, &clock, &provider, subject) .add(&mut rng, &clock, &provider, subject)
.await? .await?
}; };
let session = txn let session = repo
.upstream_oauth_session() .upstream_oauth_session()
.complete_with_link(&clock, session, &link, response.id_token) .complete_with_link(&clock, session, &link, response.id_token)
.await?; .await?;
@ -266,7 +269,7 @@ pub(crate) async fn get(
.add_link_to_session(session.id, link.id)? .add_link_to_session(session.id, link.id)?
.save(cookie_jar, clock.now()); .save(cookie_jar, clock.now());
txn.commit().await?; repo.save().await?;
Ok(( Ok((
cookie_jar, cookie_jar,

View File

@ -25,9 +25,9 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthSessionRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserRepository}, user::{BrowserSessionRepository, UserRepository},
Repository, UpstreamOAuthLinkRepository, PgRepository, Repository,
}; };
use mas_templates::{ use mas_templates::{
EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
@ -99,7 +99,7 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(link_id): Path<Ulid>, Path(link_id): Path<Ulid>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
@ -107,13 +107,13 @@ pub(crate) async fn get(
.lookup_link(link_id) .lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?; .map_err(|_| RouteError::MissingCookie)?;
let link = txn let link = repo
.upstream_oauth_link() .upstream_oauth_link()
.lookup(link_id) .lookup(link_id)
.await? .await?
.ok_or(RouteError::LinkNotFound)?; .ok_or(RouteError::LinkNotFound)?;
let upstream_session = txn let upstream_session = repo
.upstream_oauth_session() .upstream_oauth_session()
.lookup(session_id) .lookup(session_id)
.await? .await?
@ -131,24 +131,24 @@ pub(crate) async fn get(
let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (user_session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let maybe_user_session = user_session_info.load_session(&mut txn).await?; let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let render = match (maybe_user_session, link.user_id) { let render = match (maybe_user_session, link.user_id) {
(Some(session), Some(user_id)) if session.user.id == user_id => { (Some(session), Some(user_id)) if session.user.id == user_id => {
// Session already linked, and link matches the currently logged // Session already linked, and link matches the currently logged
// user. Mark the session as consumed and renew the authentication. // user. Mark the session as consumed and renew the authentication.
txn.upstream_oauth_session() repo.upstream_oauth_session()
.consume(&clock, upstream_session) .consume(&clock, upstream_session)
.await?; .await?;
let session = txn let session = repo
.browser_session() .browser_session()
.authenticate_with_upstream(&mut rng, &clock, session, &link) .authenticate_with_upstream(&mut rng, &clock, session, &link)
.await?; .await?;
cookie_jar = cookie_jar.set_session(&session); cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?; repo.save().await?;
let ctx = EmptyContext let ctx = EmptyContext
.with_session(session) .with_session(session)
@ -163,7 +163,7 @@ pub(crate) async fn get(
// Session already linked, but link doesn't match the currently // Session already linked, but link doesn't match the currently
// logged user. Suggest logging out of the current user // logged user. Suggest logging out of the current user
// and logging in with the new one // and logging in with the new one
let user = txn let user = repo
.user() .user()
.lookup(user_id) .lookup(user_id)
.await? .await?
@ -187,7 +187,7 @@ pub(crate) async fn get(
(None, Some(user_id)) => { (None, Some(user_id)) => {
// Session linked, but user not logged in: do the login // Session linked, but user not logged in: do the login
let user = txn let user = repo
.user() .user()
.lookup(user_id) .lookup(user_id)
.await? .await?
@ -216,8 +216,8 @@ pub(crate) async fn post(
Path(link_id): Path<Ulid>, Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>, Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?;
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
@ -229,13 +229,13 @@ pub(crate) async fn post(
post_auth_action: post_auth_action.cloned(), post_auth_action: post_auth_action.cloned(),
}; };
let link = txn let link = repo
.upstream_oauth_link() .upstream_oauth_link()
.lookup(link_id) .lookup(link_id)
.await? .await?
.ok_or(RouteError::LinkNotFound)?; .ok_or(RouteError::LinkNotFound)?;
let upstream_session = txn let upstream_session = repo
.upstream_oauth_session() .upstream_oauth_session()
.lookup(session_id) .lookup(session_id)
.await? .await?
@ -252,11 +252,11 @@ pub(crate) async fn post(
} }
let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut txn).await?; let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let session = match (maybe_user_session, link.user_id, form) { let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => { (Some(session), None, FormData::Link) => {
txn.upstream_oauth_link() repo.upstream_oauth_link()
.associate_to_user(&link, &session.user) .associate_to_user(&link, &session.user)
.await?; .await?;
@ -264,32 +264,32 @@ pub(crate) async fn post(
} }
(None, Some(user_id), FormData::Login) => { (None, Some(user_id), FormData::Login) => {
let user = txn let user = repo
.user() .user()
.lookup(user_id) .lookup(user_id)
.await? .await?
.ok_or(RouteError::UserNotFound)?; .ok_or(RouteError::UserNotFound)?;
txn.browser_session().add(&mut rng, &clock, &user).await? repo.browser_session().add(&mut rng, &clock, &user).await?
} }
(None, None, FormData::Register { username }) => { (None, None, FormData::Register { username }) => {
let user = txn.user().add(&mut rng, &clock, username).await?; let user = repo.user().add(&mut rng, &clock, username).await?;
txn.upstream_oauth_link() repo.upstream_oauth_link()
.associate_to_user(&link, &user) .associate_to_user(&link, &user)
.await?; .await?;
txn.browser_session().add(&mut rng, &clock, &user).await? repo.browser_session().add(&mut rng, &clock, &user).await?
} }
_ => return Err(RouteError::InvalidFormAction), _ => return Err(RouteError::InvalidFormAction),
}; };
txn.upstream_oauth_session() repo.upstream_oauth_session()
.consume(&clock, upstream_session) .consume(&clock, upstream_session)
.await?; .await?;
let session = txn let session = repo
.browser_session() .browser_session()
.authenticate_with_upstream(&mut rng, &clock, session, &link) .authenticate_with_upstream(&mut rng, &clock, session, &link)
.await?; .await?;
@ -299,7 +299,7 @@ pub(crate) async fn post(
.save(cookie_jar, clock.now()); .save(cookie_jar, clock.now());
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?; repo.save().await?;
Ok((cookie_jar, post_auth_action.go_next())) Ok((cookie_jar, post_auth_action.go_next()))
} }

View File

@ -24,7 +24,7 @@ use mas_axum_utils::{
use mas_email::Mailer; use mas_email::Mailer;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{user::UserEmailRepository, Repository}; use mas_storage::{user::UserEmailRepository, PgRepository, Repository};
use mas_templates::{EmailAddContext, TemplateContext, Templates}; use mas_templates::{EmailAddContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
@ -43,12 +43,12 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -74,12 +74,12 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<EmailForm>>, Form(form): Form<ProtectedForm<EmailForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -88,7 +88,7 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let user_email = txn let user_email = repo
.user_email() .user_email()
.add(&mut rng, &clock, &session.user, form.email) .add(&mut rng, &clock, &session.user, form.email)
.await?; .await?;
@ -101,7 +101,7 @@ pub(crate) async fn post(
}; };
start_email_verification( start_email_verification(
&mailer, &mailer,
&mut txn, &mut repo,
&mut rng, &mut rng,
&clock, &clock,
&session.user, &session.user,
@ -109,7 +109,7 @@ pub(crate) async fn post(
) )
.await?; .await?;
txn.commit().await?; repo.save().await?;
Ok((cookie_jar, next.go()).into_response()) Ok((cookie_jar, next.go()).into_response())
} }

View File

@ -28,11 +28,11 @@ use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer; use mas_email::Mailer;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{user::UserEmailRepository, Clock, Repository}; use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository};
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, Rng}; use rand::{distributions::Uniform, Rng};
use serde::Deserialize; use serde::Deserialize;
use sqlx::{PgConnection, PgPool}; use sqlx::PgPool;
use tracing::info; use tracing::info;
pub mod add; pub mod add;
@ -54,14 +54,14 @@ pub(crate) async fn get(
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar, &mut conn).await render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await
} else { } else {
let login = mas_router::Login::default(); let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response()) Ok((cookie_jar, login.go()).into_response())
@ -74,11 +74,11 @@ async fn render(
templates: Templates, templates: Templates,
session: BrowserSession, session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
conn: &mut PgConnection, repo: &mut impl Repository,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng);
let emails = conn.user_email().all(&session.user).await?; let emails = repo.user_email().all(&session.user).await?;
let ctx = AccountEmailsContext::new(emails) let ctx = AccountEmailsContext::new(emails)
.with_session(session) .with_session(session)
@ -91,7 +91,7 @@ async fn render(
async fn start_email_verification( async fn start_email_verification(
mailer: &Mailer, mailer: &Mailer,
conn: &mut PgConnection, repo: &mut impl Repository,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
user: &User, user: &User,
@ -103,7 +103,7 @@ async fn start_email_verification(
let address: Address = user_email.email.parse()?; let address: Address = user_email.email.parse()?;
let verification = conn let verification = repo
.user_email() .user_email()
.add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code) .add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code)
.await?; .await?;
@ -130,11 +130,11 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<ManagementForm>>, Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let mut session = if let Some(session) = maybe_session { let mut session = if let Some(session) = maybe_session {
session session
@ -147,21 +147,21 @@ pub(crate) async fn post(
match form { match form {
ManagementForm::Add { email } => { ManagementForm::Add { email } => {
let email = txn let email = repo
.user_email() .user_email()
.add(&mut rng, &clock, &session.user, email) .add(&mut rng, &clock, &session.user, email)
.await?; .await?;
let next = mas_router::AccountVerifyEmail::new(email.id); let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
.await?; .await?;
txn.commit().await?; repo.save().await?;
return Ok((cookie_jar, next.go()).into_response()); return Ok((cookie_jar, next.go()).into_response());
} }
ManagementForm::ResendConfirmation { id } => { ManagementForm::ResendConfirmation { id } => {
let id = id.parse()?; let id = id.parse()?;
let email = txn let email = repo
.user_email() .user_email()
.lookup(id) .lookup(id)
.await? .await?
@ -172,15 +172,15 @@ pub(crate) async fn post(
} }
let next = mas_router::AccountVerifyEmail::new(email.id); let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
.await?; .await?;
txn.commit().await?; repo.save().await?;
return Ok((cookie_jar, next.go()).into_response()); return Ok((cookie_jar, next.go()).into_response());
} }
ManagementForm::Remove { id } => { ManagementForm::Remove { id } => {
let id = id.parse()?; let id = id.parse()?;
let email = txn let email = repo
.user_email() .user_email()
.lookup(id) .lookup(id)
.await? .await?
@ -190,11 +190,11 @@ pub(crate) async fn post(
return Err(anyhow!("Email not found").into()); return Err(anyhow!("Email not found").into());
} }
txn.user_email().remove(email).await?; repo.user_email().remove(email).await?;
} }
ManagementForm::SetPrimary { id } => { ManagementForm::SetPrimary { id } => {
let id = id.parse()?; let id = id.parse()?;
let email = txn let email = repo
.user_email() .user_email()
.lookup(id) .lookup(id)
.await? .await?
@ -204,7 +204,7 @@ pub(crate) async fn post(
return Err(anyhow!("Email not found").into()); return Err(anyhow!("Email not found").into());
} }
txn.user_email().set_as_primary(&email).await?; repo.user_email().set_as_primary(&email).await?;
session.user.primary_user_email_id = Some(email.id); session.user.primary_user_email_id = Some(email.id);
} }
}; };
@ -215,11 +215,11 @@ pub(crate) async fn post(
templates.clone(), templates.clone(),
session, session,
cookie_jar, cookie_jar,
&mut txn, &mut repo,
) )
.await?; .await?;
txn.commit().await?; repo.save().await?;
Ok(reply) Ok(reply)
} }

View File

@ -24,7 +24,7 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{user::UserEmailRepository, Clock, Repository}; use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository};
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
@ -45,12 +45,12 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -59,7 +59,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let user_email = conn let user_email = repo
.user_email() .user_email()
.lookup(id) .lookup(id)
.await? .await?
@ -89,12 +89,12 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<CodeForm>>, Form(form): Form<ProtectedForm<CodeForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let clock = Clock::default(); let clock = Clock::default();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -103,33 +103,33 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let user_email = txn let user_email = repo
.user_email() .user_email()
.lookup(id) .lookup(id)
.await? .await?
.filter(|u| u.user_id == session.user.id) .filter(|u| u.user_id == session.user.id)
.context("Could not find user email")?; .context("Could not find user email")?;
let verification = txn let verification = repo
.user_email() .user_email()
.find_verification_code(&clock, &user_email, &form.code) .find_verification_code(&clock, &user_email, &form.code)
.await? .await?
.context("Invalid code")?; .context("Invalid code")?;
// TODO: display nice errors if the code was already consumed or expired // TODO: display nice errors if the code was already consumed or expired
txn.user_email() repo.user_email()
.consume_verification_code(&clock, verification) .consume_verification_code(&clock, verification)
.await?; .await?;
if session.user.primary_user_email_id.is_none() { if session.user.primary_user_email_id.is_none() {
txn.user_email().set_as_primary(&user_email).await?; repo.user_email().set_as_primary(&user_email).await?;
} }
txn.user_email() repo.user_email()
.mark_as_verified(&clock, user_email) .mark_as_verified(&clock, user_email)
.await?; .await?;
txn.commit().await?; repo.save().await?;
let destination = query.go_next_or_default(&mas_router::AccountEmails); let destination = query.go_next_or_default(&mas_router::AccountEmails);
Ok((cookie_jar, destination).into_response()) Ok((cookie_jar, destination).into_response())

View File

@ -25,7 +25,7 @@ use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
Repository, PgRepository, Repository,
}; };
use mas_templates::{AccountContext, TemplateContext, Templates}; use mas_templates::{AccountContext, TemplateContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;
@ -36,12 +36,12 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -50,9 +50,9 @@ pub(crate) async fn get(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let active_sessions = conn.browser_session().count_active(&session.user).await?; let active_sessions = repo.browser_session().count_active(&session.user).await?;
let emails = conn.user_email().all(&session.user).await?; let emails = repo.user_email().all(&session.user).await?;
let ctx = AccountContext::new(active_sessions, emails) let ctx = AccountContext::new(active_sessions, emails)
.with_session(session) .with_session(session)

View File

@ -27,7 +27,7 @@ use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository}, user::{BrowserSessionRepository, UserPasswordRepository},
Clock, Repository, Clock, PgRepository, Repository,
}; };
use mas_templates::{EmptyContext, TemplateContext, Templates}; use mas_templates::{EmptyContext, TemplateContext, Templates};
use rand::Rng; use rand::Rng;
@ -50,11 +50,11 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar).await render(&mut rng, &clock, templates, session, cookie_jar).await
@ -90,13 +90,13 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<ChangeForm>>, Form(form): Form<ProtectedForm<ChangeForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -105,7 +105,7 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let user_password = txn let user_password = repo
.user_password() .user_password()
.active(&session.user) .active(&session.user)
.await? .await?
@ -129,7 +129,7 @@ pub(crate) async fn post(
} }
let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?;
let user_password = txn let user_password = repo
.user_password() .user_password()
.add( .add(
&mut rng, &mut rng,
@ -141,14 +141,14 @@ pub(crate) async fn post(
) )
.await?; .await?;
let session = txn let session = repo
.browser_session() .browser_session()
.authenticate_with_password(&mut rng, &clock, session, &user_password) .authenticate_with_password(&mut rng, &clock, session, &user_password)
.await?; .await?;
let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?; let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?;
txn.commit().await?; repo.save().await?;
Ok(reply) Ok(reply)
} }

View File

@ -20,6 +20,7 @@ use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::PgRepository;
use mas_templates::{IndexContext, TemplateContext, Templates}; use mas_templates::{IndexContext, TemplateContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;
@ -30,11 +31,11 @@ pub async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut conn).await?; let session = session_info.load_session(&mut repo).await?;
let ctx = IndexContext::new(url_builder.oidc_discovery()) let ctx = IndexContext::new(url_builder.oidc_discovery())
.maybe_with_session(session) .maybe_with_session(session)

View File

@ -26,14 +26,14 @@ use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, upstream_oauth2::UpstreamOAuthProviderRepository,
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
Clock, Repository, Clock, PgRepository, Repository,
}; };
use mas_templates::{ use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
}; };
use rand::{CryptoRng, Rng}; use rand::{CryptoRng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool}; use sqlx::PgPool;
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction; use super::shared::OptionalPostAuthAction;
@ -56,23 +56,23 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
if maybe_session.is_some() { if maybe_session.is_some() {
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} else { } else {
let providers = conn.upstream_oauth_provider().all().await?; let providers = repo.upstream_oauth_provider().all().await?;
let content = render( let content = render(
LoginContext::default().with_upstrem_providers(providers), LoginContext::default().with_upstrem_providers(providers),
query, query,
csrf_token, csrf_token,
&mut conn, &mut repo,
&templates, &templates,
) )
.await?; .await?;
@ -90,7 +90,7 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<LoginForm>>, Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
@ -112,14 +112,14 @@ pub(crate) async fn post(
}; };
if !state.is_valid() { if !state.is_valid() {
let providers = conn.upstream_oauth_provider().all().await?; let providers = repo.upstream_oauth_provider().all().await?;
let content = render( let content = render(
LoginContext::default() LoginContext::default()
.with_form_state(state) .with_form_state(state)
.with_upstrem_providers(providers), .with_upstrem_providers(providers),
query, query,
csrf_token, csrf_token,
&mut conn, &mut repo,
&templates, &templates,
) )
.await?; .await?;
@ -129,7 +129,7 @@ pub(crate) async fn post(
match login( match login(
password_manager, password_manager,
&mut conn, &mut repo,
rng, rng,
&clock, &clock,
&form.username, &form.username,
@ -138,6 +138,8 @@ pub(crate) async fn post(
.await .await
{ {
Ok(session_info) => { Ok(session_info) => {
repo.save().await?;
let cookie_jar = cookie_jar.set_session(&session_info); let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
@ -149,7 +151,7 @@ pub(crate) async fn post(
LoginContext::default().with_form_state(state), LoginContext::default().with_form_state(state),
query, query,
csrf_token, csrf_token,
&mut conn, &mut repo,
&templates, &templates,
) )
.await?; .await?;
@ -162,7 +164,7 @@ pub(crate) async fn post(
// TODO: move that logic elsewhere? // TODO: move that logic elsewhere?
async fn login( async fn login(
password_manager: PasswordManager, password_manager: PasswordManager,
conn: &mut PgConnection, repo: &mut impl Repository,
mut rng: impl Rng + CryptoRng + Send, mut rng: impl Rng + CryptoRng + Send,
clock: &Clock, clock: &Clock,
username: &str, username: &str,
@ -170,7 +172,7 @@ async fn login(
) -> Result<BrowserSession, FormError> { ) -> Result<BrowserSession, FormError> {
// XXX: we're loosing the error context here // XXX: we're loosing the error context here
// First, lookup the user // First, lookup the user
let user = conn let user = repo
.user() .user()
.find_by_username(username) .find_by_username(username)
.await .await
@ -178,7 +180,7 @@ async fn login(
.ok_or(FormError::InvalidCredentials)?; .ok_or(FormError::InvalidCredentials)?;
// And its password // And its password
let user_password = conn let user_password = repo
.user_password() .user_password()
.active(&user) .active(&user)
.await .await
@ -200,7 +202,7 @@ async fn login(
let user_password = if let Some((version, new_password_hash)) = new_password_hash { let user_password = if let Some((version, new_password_hash)) = new_password_hash {
// Save the upgraded password // Save the upgraded password
conn.user_password() repo.user_password()
.add( .add(
&mut rng, &mut rng,
clock, clock,
@ -216,14 +218,14 @@ async fn login(
}; };
// Start a new session // Start a new session
let user_session = conn let user_session = repo
.browser_session() .browser_session()
.add(&mut rng, clock, &user) .add(&mut rng, clock, &user)
.await .await
.map_err(|_| FormError::Internal)?; .map_err(|_| FormError::Internal)?;
// And mark it as authenticated by the password // And mark it as authenticated by the password
let user_session = conn let user_session = repo
.browser_session() .browser_session()
.authenticate_with_password(&mut rng, clock, user_session, &user_password) .authenticate_with_password(&mut rng, clock, user_session, &user_password)
.await .await
@ -236,10 +238,10 @@ async fn render(
ctx: LoginContext, ctx: LoginContext,
action: OptionalPostAuthAction, action: OptionalPostAuthAction,
csrf_token: CsrfToken, csrf_token: CsrfToken,
conn: &mut PgConnection, repo: &mut impl Repository,
templates: &Templates, templates: &Templates,
) -> Result<String, FancyError> { ) -> Result<String, FancyError> {
let next = action.load_context(conn).await?; let next = action.load_context(repo).await?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
ctx.with_post_action(next) ctx.with_post_action(next)
} else { } else {

View File

@ -23,7 +23,7 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{user::BrowserSessionRepository, Clock, Repository}; use mas_storage::{user::BrowserSessionRepository, Clock, PgRepository, Repository};
use sqlx::PgPool; use sqlx::PgPool;
pub(crate) async fn post( pub(crate) async fn post(
@ -32,20 +32,20 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>, Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let clock = Clock::default(); let clock = Clock::default();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
let (session_info, mut cookie_jar) = cookie_jar.session_info(); let (session_info, mut cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
txn.browser_session().finish(&clock, session).await?; repo.browser_session().finish(&clock, session).await?;
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
} }
txn.commit().await?; repo.save().await?;
let destination = if let Some(action) = form { let destination = if let Some(action) = form {
action.go_next() action.go_next()

View File

@ -26,7 +26,7 @@ use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository}, user::{BrowserSessionRepository, UserPasswordRepository},
Repository, PgRepository, Repository,
}; };
use mas_templates::{ReauthContext, TemplateContext, Templates}; use mas_templates::{ReauthContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
@ -48,12 +48,12 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -65,7 +65,7 @@ pub(crate) async fn get(
}; };
let ctx = ReauthContext::default(); let ctx = ReauthContext::default();
let next = query.load_context(&mut conn).await?; let next = query.load_context(&mut repo).await?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
ctx.with_post_action(next) ctx.with_post_action(next)
} else { } else {
@ -86,13 +86,13 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<ReauthForm>>, Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -104,7 +104,7 @@ pub(crate) async fn post(
}; };
// Load the user password // Load the user password
let user_password = txn let user_password = repo
.user_password() .user_password()
.active(&session.user) .active(&session.user)
.await? .await?
@ -125,7 +125,7 @@ pub(crate) async fn post(
let user_password = if let Some((version, new_password_hash)) = new_password_hash { let user_password = if let Some((version, new_password_hash)) = new_password_hash {
// Save the upgraded password // Save the upgraded password
txn.user_password() repo.user_password()
.add( .add(
&mut rng, &mut rng,
&clock, &clock,
@ -140,13 +140,13 @@ pub(crate) async fn post(
}; };
// Mark the session as authenticated by the password // Mark the session as authenticated by the password
let session = txn let session = repo
.browser_session() .browser_session()
.authenticate_with_password(&mut rng, &clock, session, &user_password) .authenticate_with_password(&mut rng, &clock, session, &user_password)
.await?; .await?;
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?; repo.save().await?;
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())

View File

@ -33,7 +33,7 @@ use mas_policy::PolicyFactory;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, PgRepository, Repository,
}; };
use mas_templates::{ use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
@ -41,7 +41,7 @@ use mas_templates::{
}; };
use rand::{distributions::Uniform, Rng}; use rand::{distributions::Uniform, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool}; use sqlx::PgPool;
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction; use super::shared::OptionalPostAuthAction;
@ -66,12 +66,12 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?; let maybe_session = session_info.load_session(&mut repo).await?;
if maybe_session.is_some() { if maybe_session.is_some() {
let reply = query.go_next(); let reply = query.go_next();
@ -81,7 +81,7 @@ pub(crate) async fn get(
RegisterContext::default(), RegisterContext::default(),
query, query,
csrf_token, csrf_token,
&mut conn, &mut repo,
&templates, &templates,
) )
.await?; .await?;
@ -102,7 +102,7 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<RegisterForm>>, Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut repo = PgRepository::from_pool(&pool).await?;
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
@ -114,7 +114,7 @@ pub(crate) async fn post(
if form.username.is_empty() { if form.username.is_empty() {
state.add_error_on_field(RegisterFormField::Username, FieldError::Required); state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
} else if txn.user().exists(&form.username).await? { } else if repo.user().exists(&form.username).await? {
state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
} }
@ -177,7 +177,7 @@ pub(crate) async fn post(
RegisterContext::default().with_form_state(state), RegisterContext::default().with_form_state(state),
query, query,
csrf_token, csrf_token,
&mut txn, &mut repo,
&templates, &templates,
) )
.await?; .await?;
@ -185,15 +185,15 @@ pub(crate) async fn post(
return Ok((cookie_jar, Html(content)).into_response()); return Ok((cookie_jar, Html(content)).into_response());
} }
let user = txn.user().add(&mut rng, &clock, form.username).await?; let user = repo.user().add(&mut rng, &clock, form.username).await?;
let password = Zeroizing::new(form.password.into_bytes()); let password = Zeroizing::new(form.password.into_bytes());
let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
let user_password = txn let user_password = repo
.user_password() .user_password()
.add(&mut rng, &clock, &user, version, hashed_password, None) .add(&mut rng, &clock, &user, version, hashed_password, None)
.await?; .await?;
let user_email = txn let user_email = repo
.user_email() .user_email()
.add(&mut rng, &clock, &user, form.email) .add(&mut rng, &clock, &user, form.email)
.await?; .await?;
@ -205,7 +205,7 @@ pub(crate) async fn post(
let address: Address = user_email.email.parse()?; let address: Address = user_email.email.parse()?;
let verification = txn let verification = repo
.user_email() .user_email()
.add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code) .add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code)
.await?; .await?;
@ -219,14 +219,14 @@ pub(crate) async fn post(
let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action); let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action);
let session = txn.browser_session().add(&mut rng, &clock, &user).await?; let session = repo.browser_session().add(&mut rng, &clock, &user).await?;
let session = txn let session = repo
.browser_session() .browser_session()
.authenticate_with_password(&mut rng, &clock, session, &user_password) .authenticate_with_password(&mut rng, &clock, session, &user_password)
.await?; .await?;
txn.commit().await?; repo.save().await?;
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
Ok((cookie_jar, next.go()).into_response()) Ok((cookie_jar, next.go()).into_response())
@ -236,10 +236,10 @@ async fn render(
ctx: RegisterContext, ctx: RegisterContext,
action: OptionalPostAuthAction, action: OptionalPostAuthAction,
csrf_token: CsrfToken, csrf_token: CsrfToken,
conn: &mut PgConnection, repo: &mut impl Repository,
templates: &Templates, templates: &Templates,
) -> Result<String, FancyError> { ) -> Result<String, FancyError> {
let next = action.load_context(conn).await?; let next = action.load_context(repo).await?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
ctx.with_post_action(next) ctx.with_post_action(next)
} else { } else {

View File

@ -15,12 +15,13 @@
use anyhow::Context; use anyhow::Context;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository, compat::CompatSsoLoginRepository,
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, oauth2::OAuth2AuthorizationGrantRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
Repository,
}; };
use mas_templates::{PostAuthContext, PostAuthContextInner}; use mas_templates::{PostAuthContext, PostAuthContextInner};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgConnection;
#[derive(Serialize, Deserialize, Default, Debug, Clone)] #[derive(Serialize, Deserialize, Default, Debug, Clone)]
pub(crate) struct OptionalPostAuthAction { pub(crate) struct OptionalPostAuthAction {
@ -39,14 +40,14 @@ impl OptionalPostAuthAction {
self.go_next_or_default(&mas_router::Index) self.go_next_or_default(&mas_router::Index)
} }
pub async fn load_context( pub async fn load_context<R: Repository>(
&self, &self,
conn: &mut PgConnection, repo: &mut R,
) -> anyhow::Result<Option<PostAuthContext>> { ) -> anyhow::Result<Option<PostAuthContext>> {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action { let ctx = match action {
PostAuthAction::ContinueAuthorizationGrant { id } => { PostAuthAction::ContinueAuthorizationGrant { id } => {
let grant = conn let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.lookup(id) .lookup(id)
.await? .await?
@ -56,7 +57,7 @@ impl OptionalPostAuthAction {
} }
PostAuthAction::ContinueCompatSsoLogin { id } => { PostAuthAction::ContinueCompatSsoLogin { id } => {
let login = conn let login = repo
.compat_sso_login() .compat_sso_login()
.lookup(id) .lookup(id)
.await? .await?
@ -68,13 +69,13 @@ impl OptionalPostAuthAction {
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
PostAuthAction::LinkUpstream { id } => { PostAuthAction::LinkUpstream { id } => {
let link = conn let link = repo
.upstream_oauth_link() .upstream_oauth_link()
.lookup(id) .lookup(id)
.await? .await?
.context("Failed to load upstream OAuth 2.0 link")?; .context("Failed to load upstream OAuth 2.0 link")?;
let provider = conn let provider = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.lookup(link.provider_id) .lookup(link.provider_id)
.await? .await?

View File

@ -183,7 +183,7 @@ pub(crate) mod tracing;
pub mod upstream_oauth2; pub mod upstream_oauth2;
pub mod user; pub mod user;
pub use self::{repository::Repository, upstream_oauth2::UpstreamOAuthLinkRepository}; pub use self::repository::{PgRepository, Repository};
/// Embedded migrations, allowing them to run on startup /// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!(); pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@ -32,7 +32,7 @@ use crate::{
}; };
#[async_trait] #[async_trait]
pub trait OAuth2AuthorizationGrantRepository { pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
type Error; type Error;
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]

View File

@ -27,7 +27,7 @@ use crate::{
}; };
#[async_trait] #[async_trait]
pub trait OAuth2SessionRepository { pub trait OAuth2SessionRepository: Send + Sync {
type Error; type Error;
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>; async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;

View File

@ -12,89 +12,100 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use sqlx::{PgConnection, Postgres, Transaction}; use sqlx::{PgPool, Postgres, Transaction};
use crate::{ use crate::{
compat::{ compat::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
PgCompatSsoLoginRepository, CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository,
PgCompatSessionRepository, PgCompatSsoLoginRepository,
}, },
oauth2::{ oauth2::{
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository,
PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository,
PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
}, },
upstream_oauth2::{ upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
PgUpstreamOAuthSessionRepository, PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository,
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository,
}, },
user::{ user::{
PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository,
PgUserRepository, PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository,
UserRepository,
}, },
DatabaseError,
}; };
pub trait Repository { pub trait Repository: Send {
type UpstreamOAuthLinkRepository<'c> type Error: std::error::Error + Send + Sync + 'static;
type UpstreamOAuthLinkRepository<'c>: UpstreamOAuthLinkRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type UpstreamOAuthProviderRepository<'c> type UpstreamOAuthProviderRepository<'c>: UpstreamOAuthProviderRepository<Error = Self::Error>
+ 'c
where where
Self: 'c; Self: 'c;
type UpstreamOAuthSessionRepository<'c> type UpstreamOAuthSessionRepository<'c>: UpstreamOAuthSessionRepository<Error = Self::Error>
+ 'c
where where
Self: 'c; Self: 'c;
type UserRepository<'c> type UserRepository<'c>: UserRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type UserEmailRepository<'c> type UserEmailRepository<'c>: UserEmailRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type UserPasswordRepository<'c> type UserPasswordRepository<'c>: UserPasswordRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type BrowserSessionRepository<'c> type BrowserSessionRepository<'c>: BrowserSessionRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type OAuth2ClientRepository<'c> type OAuth2ClientRepository<'c>: OAuth2ClientRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> type OAuth2AuthorizationGrantRepository<'c>: OAuth2AuthorizationGrantRepository<Error = Self::Error>
+ 'c
where where
Self: 'c; Self: 'c;
type OAuth2SessionRepository<'c> type OAuth2SessionRepository<'c>: OAuth2SessionRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type OAuth2AccessTokenRepository<'c> type OAuth2AccessTokenRepository<'c>: OAuth2AccessTokenRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type OAuth2RefreshTokenRepository<'c> type OAuth2RefreshTokenRepository<'c>: OAuth2RefreshTokenRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type CompatSessionRepository<'c> type CompatSessionRepository<'c>: CompatSessionRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type CompatSsoLoginRepository<'c> type CompatSsoLoginRepository<'c>: CompatSsoLoginRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type CompatAccessTokenRepository<'c> type CompatAccessTokenRepository<'c>: CompatAccessTokenRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
type CompatRefreshTokenRepository<'c> type CompatRefreshTokenRepository<'c>: CompatRefreshTokenRepository<Error = Self::Error> + 'c
where where
Self: 'c; Self: 'c;
@ -116,7 +127,30 @@ pub trait Repository {
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>;
} }
impl Repository for PgConnection { pub struct PgRepository {
txn: Transaction<'static, Postgres>,
}
impl PgRepository {
pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
let txn = pool.begin().await?;
Ok(PgRepository { txn })
}
pub async fn save(self) -> Result<(), DatabaseError> {
self.txn.commit().await?;
Ok(())
}
pub async fn cancel(self) -> Result<(), DatabaseError> {
self.txn.rollback().await?;
Ok(())
}
}
impl Repository for PgRepository {
type Error = DatabaseError;
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
@ -135,149 +169,66 @@ impl Repository for PgConnection {
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self) PgUpstreamOAuthLinkRepository::new(&mut self.txn)
} }
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self) PgUpstreamOAuthProviderRepository::new(&mut self.txn)
} }
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(self) PgUpstreamOAuthSessionRepository::new(&mut self.txn)
} }
fn user(&mut self) -> Self::UserRepository<'_> { fn user(&mut self) -> Self::UserRepository<'_> {
PgUserRepository::new(self) PgUserRepository::new(&mut self.txn)
} }
fn user_email(&mut self) -> Self::UserEmailRepository<'_> { fn user_email(&mut self) -> Self::UserEmailRepository<'_> {
PgUserEmailRepository::new(self) PgUserEmailRepository::new(&mut self.txn)
} }
fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { fn user_password(&mut self) -> Self::UserPasswordRepository<'_> {
PgUserPasswordRepository::new(self) PgUserPasswordRepository::new(&mut self.txn)
} }
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
PgBrowserSessionRepository::new(self) PgBrowserSessionRepository::new(&mut self.txn)
} }
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
PgOAuth2ClientRepository::new(self) PgOAuth2ClientRepository::new(&mut self.txn)
} }
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self) PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)
} }
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self) PgOAuth2SessionRepository::new(&mut self.txn)
} }
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self) PgOAuth2AccessTokenRepository::new(&mut self.txn)
} }
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self) PgOAuth2RefreshTokenRepository::new(&mut self.txn)
} }
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self) PgCompatSessionRepository::new(&mut self.txn)
} }
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self) PgCompatSsoLoginRepository::new(&mut self.txn)
} }
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self) PgCompatAccessTokenRepository::new(&mut self.txn)
} }
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self) PgCompatRefreshTokenRepository::new(&mut self.txn)
}
}
impl<'t> Repository for Transaction<'t, Postgres> {
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
type UserRepository<'c> = PgUserRepository<'c> where Self: 'c;
type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c;
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self)
}
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self)
}
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(self)
}
fn user(&mut self) -> Self::UserRepository<'_> {
PgUserRepository::new(self)
}
fn user_email(&mut self) -> Self::UserEmailRepository<'_> {
PgUserEmailRepository::new(self)
}
fn user_password(&mut self) -> Self::UserPasswordRepository<'_> {
PgUserPasswordRepository::new(self)
}
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
PgBrowserSessionRepository::new(self)
}
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
PgOAuth2ClientRepository::new(self)
}
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self)
} }
} }

View File

@ -29,20 +29,20 @@ mod tests {
use sqlx::PgPool; use sqlx::PgPool;
use super::*; use super::*;
use crate::{Clock, Repository}; use crate::{Clock, PgRepository, Repository};
#[sqlx::test(migrator = "crate::MIGRATOR")] #[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> { async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::default(); let clock = Clock::default();
let mut conn = pool.acquire().await?; let mut repo = PgRepository::from_pool(&pool).await?;
// The provider list should be empty at the start // The provider list should be empty at the start
let all_providers = conn.upstream_oauth_provider().all().await?; let all_providers = repo.upstream_oauth_provider().all().await?;
assert!(all_providers.is_empty()); assert!(all_providers.is_empty());
// Let's add a provider // Let's add a provider
let provider = conn let provider = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.add( .add(
&mut rng, &mut rng,
@ -57,7 +57,7 @@ mod tests {
.await?; .await?;
// Look it up in the database // Look it up in the database
let provider = conn let provider = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.lookup(provider.id) .lookup(provider.id)
.await? .await?
@ -66,7 +66,7 @@ mod tests {
assert_eq!(provider.client_id, "client-id"); assert_eq!(provider.client_id, "client-id");
// Start a session // Start a session
let session = conn let session = repo
.upstream_oauth_session() .upstream_oauth_session()
.add( .add(
&mut rng, &mut rng,
@ -79,7 +79,7 @@ mod tests {
.await?; .await?;
// Look it up in the database // Look it up in the database
let session = conn let session = repo
.upstream_oauth_session() .upstream_oauth_session()
.lookup(session.id) .lookup(session.id)
.await? .await?
@ -91,19 +91,19 @@ mod tests {
assert!(!session.is_consumed()); assert!(!session.is_consumed());
// Create a link // Create a link
let link = conn let link = repo
.upstream_oauth_link() .upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned()) .add(&mut rng, &clock, &provider, "a-subject".to_owned())
.await?; .await?;
// We can look it up by its ID // We can look it up by its ID
conn.upstream_oauth_link() repo.upstream_oauth_link()
.lookup(link.id) .lookup(link.id)
.await? .await?
.expect("link to be found in database"); .expect("link to be found in database");
// or by its subject // or by its subject
let link = conn let link = repo
.upstream_oauth_link() .upstream_oauth_link()
.find_by_subject(&provider, "a-subject") .find_by_subject(&provider, "a-subject")
.await? .await?
@ -111,7 +111,7 @@ mod tests {
assert_eq!(link.subject, "a-subject"); assert_eq!(link.subject, "a-subject");
assert_eq!(link.provider_id, provider.id); assert_eq!(link.provider_id, provider.id);
let session = conn let session = repo
.upstream_oauth_session() .upstream_oauth_session()
.complete_with_link(&clock, session, &link, None) .complete_with_link(&clock, session, &link, None)
.await?; .await?;
@ -119,7 +119,7 @@ mod tests {
assert!(!session.is_consumed()); assert!(!session.is_consumed());
assert_eq!(session.link_id(), Some(link.id)); assert_eq!(session.link_id(), Some(link.id));
let session = conn let session = repo
.upstream_oauth_session() .upstream_oauth_session()
.consume(&clock, session) .consume(&clock, session)
.await?; .await?;

View File

@ -14,7 +14,7 @@
//! Database-related tasks //! Database-related tasks
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository}; use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, PgRepository, Repository};
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use tracing::{debug, error, info}; use tracing::{debug, error, info};
@ -33,8 +33,8 @@ impl std::fmt::Debug for CleanupExpired {
impl Task for CleanupExpired { impl Task for CleanupExpired {
async fn run(&self) { async fn run(&self) {
let res = async move { let res = async move {
let mut conn = self.0.acquire().await?; let mut repo = PgRepository::from_pool(&self.0).await?;
conn.oauth2_access_token().cleanup_expired(&self.1).await repo.oauth2_access_token().cleanup_expired(&self.1).await
} }
.await; .await;