1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Box the repository everywhere

This commit is contained in:
Quentin Gliech
2023-01-20 17:49:16 +01:00
parent f4c64c2171
commit a9facab131
49 changed files with 296 additions and 296 deletions

5
Cargo.lock generated
View File

@ -2804,11 +2804,10 @@ dependencies = [
"chrono", "chrono",
"mas-data-model", "mas-data-model",
"mas-storage", "mas-storage",
"mas-storage-pg",
"oauth2-types", "oauth2-types",
"serde", "serde",
"sqlx",
"thiserror", "thiserror",
"tokio",
"tracing", "tracing",
"ulid", "ulid",
"url", "url",
@ -3101,6 +3100,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"chrono", "chrono",
"futures-util",
"mas-data-model", "mas-data-model",
"mas-iana", "mas-iana",
"mas-jose", "mas-jose",
@ -3117,6 +3117,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"chrono", "chrono",
"futures-util",
"mas-data-model", "mas-data-model",
"mas-iana", "mas-iana",
"mas-jose", "mas-jose",

View File

@ -72,10 +72,10 @@ pub enum Credentials {
} }
impl Credentials { impl Credentials {
pub async fn fetch<'r, R>(&self, repo: &'r mut R) -> Result<Option<Client>, R::Error> pub async fn fetch<E>(
where &self,
R: Repository, repo: &mut (impl Repository<Error = E> + ?Sized),
{ ) -> Result<Option<Client>, E> {
let client_id = match self { let client_id = match self {
Credentials::None { client_id } Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. } | Credentials::ClientSecretBasic { client_id, .. }

View File

@ -43,10 +43,10 @@ impl SessionInfo {
} }
/// Load the [`BrowserSession`] from database /// Load the [`BrowserSession`] from database
pub async fn load_session<R: Repository>( pub async fn load_session<E>(
&self, &self,
repo: &mut R, repo: &mut (impl Repository<Error = E> + ?Sized),
) -> Result<Option<BrowserSession>, R::Error> { ) -> Result<Option<BrowserSession>, E> {
let session_id = if let Some(id) = self.current { let session_id = if let Some(id) = self.current {
id id
} else { } else {

View File

@ -51,11 +51,10 @@ enum AccessToken {
} }
impl AccessToken { impl AccessToken {
async fn fetch<R: Repository>( async fn fetch<E>(
&self, &self,
repo: &mut R, repo: &mut (impl Repository<Error = E> + ?Sized),
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<R::Error>> ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
{
let token = match self { let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::Form(t) | AccessToken::Header(t) => t,
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
@ -85,11 +84,11 @@ pub struct UserAuthorization<F = ()> {
impl<F: Send> UserAuthorization<F> { impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter // TODO: take scopes to validate as parameter
pub async fn protected_form<R: Repository, C: Clock>( pub async fn protected_form<E>(
self, self,
repo: &mut R, repo: &mut (impl Repository<Error = E> + ?Sized),
clock: &C, clock: &impl Clock,
) -> Result<(Session, F), AuthorizationVerificationError<R::Error>> { ) -> Result<(Session, F), AuthorizationVerificationError<E>> {
let form = match self.form { let form = match self.form {
Some(f) => f, Some(f) => f,
None => return Err(AuthorizationVerificationError::MissingForm), None => return Err(AuthorizationVerificationError::MissingForm),
@ -105,11 +104,11 @@ impl<F: Send> UserAuthorization<F> {
} }
// TODO: take scopes to validate as parameter // TODO: take scopes to validate as parameter
pub async fn protected<R: Repository, C: Clock>( pub async fn protected<E>(
self, self,
repo: &mut R, repo: &mut (impl Repository<Error = E> + ?Sized),
clock: &C, clock: &impl Clock,
) -> Result<Session, AuthorizationVerificationError<R::Error>> { ) -> Result<Session, AuthorizationVerificationError<E>> {
let (token, session) = self.access_token.fetch(repo).await?; let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() { if !token.is_valid(clock.now()) || !session.is_valid() {

View File

@ -203,7 +203,7 @@ impl Options {
let pool = database_from_config(&database_config).await?; let pool = database_from_config(&database_config).await?;
let password_manager = password_manager_from_config(&passwords_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?;
let mut repo = PgRepository::from_pool(&pool).await?; let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = repo let user = repo
.user() .user()
.find_by_username(username) .find_by_username(username)
@ -234,7 +234,7 @@ impl Options {
let config: DatabaseConfig = root.load_config()?; let config: DatabaseConfig = root.load_config()?;
let pool = database_from_config(&config).await?; let pool = database_from_config(&config).await?;
let mut repo = PgRepository::from_pool(&pool).await?; let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = repo let user = repo
.user() .user()
@ -262,7 +262,7 @@ impl Options {
let pool = database_from_config(&config.database).await?; let pool = database_from_config(&config.database).await?;
let encrypter = config.secrets.encrypter(); let encrypter = config.secrets.encrypter();
let mut repo = PgRepository::from_pool(&pool).await?; let mut repo = PgRepository::from_pool(&pool).await?.boxed();
for client in config.clients.iter() { for client in config.clients.iter() {
let client_id = client.client_id; let client_id = client.client_id;

View File

@ -102,7 +102,7 @@ impl Options {
watch_templates(&templates).await?; watch_templates(&templates).await?;
} }
let graphql_schema = mas_handlers::graphql_schema(&pool); let graphql_schema = mas_handlers::graphql_schema();
// Maximum 50 outgoing HTTP requests at a time // Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50); let http_client_factory = HttpClientFactory::new(50);

View File

@ -10,7 +10,7 @@ anyhow = "1.0.68"
async-graphql = { version = "5.0.4", features = ["chrono", "url"] } async-graphql = { version = "5.0.4", features = ["chrono", "url"] }
chrono = "0.4.23" chrono = "0.4.23"
serde = { version = "1.0.152", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } tokio = { version = "1.23.0", features = ["sync"] }
thiserror = "1.0.38" thiserror = "1.0.38"
tracing = "0.1.37" tracing = "0.1.37"
ulid = "1.0.0" ulid = "1.0.0"
@ -19,7 +19,6 @@ url = "2.3.1"
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
[[bin]] [[bin]]
name = "schema" name = "schema"

View File

@ -34,11 +34,10 @@ use mas_storage::{
oauth2::OAuth2ClientRepository, oauth2::OAuth2ClientRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
Pagination, Repository, BoxRepository, Pagination,
}; };
use mas_storage_pg::PgRepository;
use model::CreationEvent; use model::CreationEvent;
use sqlx::PgPool; use tokio::sync::Mutex;
use self::model::{ use self::model::{
BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link, BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link,
@ -94,7 +93,7 @@ impl RootQuery {
id: ID, id: ID,
) -> Result<Option<OAuth2Client>, async_graphql::Error> { ) -> Result<Option<OAuth2Client>, async_graphql::Error> {
let id = NodeType::OAuth2Client.extract_ulid(&id)?; let id = NodeType::OAuth2Client.extract_ulid(&id)?;
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo.oauth2_client().lookup(id).await?; let client = repo.oauth2_client().lookup(id).await?;
@ -124,7 +123,7 @@ impl RootQuery {
) -> Result<Option<BrowserSession>, async_graphql::Error> { ) -> Result<Option<BrowserSession>, async_graphql::Error> {
let id = NodeType::BrowserSession.extract_ulid(&id)?; let id = NodeType::BrowserSession.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned(); let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
@ -150,7 +149,7 @@ impl RootQuery {
) -> Result<Option<UserEmail>, async_graphql::Error> { ) -> Result<Option<UserEmail>, async_graphql::Error> {
let id = NodeType::UserEmail.extract_ulid(&id)?; let id = NodeType::UserEmail.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned(); let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
@ -172,7 +171,7 @@ impl RootQuery {
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> { ) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned(); let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
@ -192,7 +191,7 @@ impl RootQuery {
id: ID, id: ID,
) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> { ) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?;
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let provider = repo.upstream_oauth_provider().lookup(id).await?; let provider = repo.upstream_oauth_provider().lookup(id).await?;
@ -211,7 +210,7 @@ impl RootQuery {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> { ) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query( query(
after, after,

View File

@ -15,9 +15,8 @@
use anyhow::Context as _; use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID}; use async_graphql::{Context, Description, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository};
use mas_storage_pg::PgRepository; use tokio::sync::Mutex;
use sqlx::PgPool;
use url::Url; use url::Url;
use super::{NodeType, User}; use super::{NodeType, User};
@ -36,7 +35,7 @@ impl CompatSession {
/// The user authorized for this session. /// The user authorized for this session.
async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> { async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let user = repo let user = repo
.user() .user()
.lookup(self.0.user_id) .lookup(self.0.user_id)
@ -101,7 +100,7 @@ impl CompatSsoLogin {
) -> Result<Option<CompatSession>, async_graphql::Error> { ) -> Result<Option<CompatSession>, async_graphql::Error> {
let Some(session_id) = self.0.session_id() else { return Ok(None) }; let Some(session_id) = self.0.session_id() else { return Ok(None) };
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let session = repo let session = repo
.compat_session() .compat_session()
.lookup(session_id) .lookup(session_id)

View File

@ -14,10 +14,9 @@
use anyhow::Context as _; use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID}; use async_graphql::{Context, Description, Object, ID};
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository};
use mas_storage_pg::PgRepository;
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use sqlx::PgPool; use tokio::sync::Mutex;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
@ -37,7 +36,7 @@ impl OAuth2Session {
/// OAuth 2.0 client used by this session. /// OAuth 2.0 client used by this session.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> { pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo let client = repo
.oauth2_client() .oauth2_client()
.lookup(self.0.client_id) .lookup(self.0.client_id)
@ -57,7 +56,7 @@ impl OAuth2Session {
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
) -> Result<BrowserSession, async_graphql::Error> { ) -> Result<BrowserSession, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let browser_session = repo let browser_session = repo
.browser_session() .browser_session()
.lookup(self.0.user_session_id) .lookup(self.0.user_session_id)
@ -69,7 +68,7 @@ impl OAuth2Session {
/// User authorized for this session. /// User authorized for this session.
pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> { pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let browser_session = repo let browser_session = repo
.browser_session() .browser_session()
.lookup(self.0.user_session_id) .lookup(self.0.user_session_id)
@ -139,7 +138,7 @@ impl OAuth2Consent {
/// OAuth 2.0 client for which the user granted access. /// OAuth 2.0 client for which the user granted access.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> { pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo let client = repo
.oauth2_client() .oauth2_client()
.lookup(self.client_id) .lookup(self.client_id)

View File

@ -16,10 +16,9 @@ use anyhow::Context as _;
use async_graphql::{Context, Object, ID}; use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository,
}; };
use mas_storage_pg::PgRepository; use tokio::sync::Mutex;
use sqlx::PgPool;
use super::{NodeType, User}; use super::{NodeType, User};
@ -103,7 +102,7 @@ impl UpstreamOAuth2Link {
provider.clone() provider.clone()
} else { } else {
// Fetch on-the-fly // Fetch on-the-fly
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let provider = repo let provider = repo
.upstream_oauth_provider() .upstream_oauth_provider()
.lookup(self.link.provider_id) .lookup(self.link.provider_id)
@ -122,7 +121,7 @@ impl UpstreamOAuth2Link {
user.clone() user.clone()
} else if let Some(user_id) = &self.link.user_id { } else if let Some(user_id) = &self.link.user_id {
// Fetch on-the-fly // Fetch on-the-fly
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let user = repo let user = repo
.user() .user()
.lookup(*user_id) .lookup(*user_id)

View File

@ -22,10 +22,9 @@ use mas_storage::{
oauth2::OAuth2SessionRepository, oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository, upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
Pagination, Repository, BoxRepository, Pagination,
}; };
use mas_storage_pg::PgRepository; use tokio::sync::Mutex;
use sqlx::PgPool;
use super::{ use super::{
compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session, compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session,
@ -65,10 +64,9 @@ impl User {
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
) -> Result<Option<UserEmail>, async_graphql::Error> { ) -> Result<Option<UserEmail>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let mut user_email_repo = repo.user_email(); let mut user_email_repo = repo.user_email();
Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail)) Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail))
} }
@ -84,7 +82,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> { ) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query( query(
after, after,
@ -131,7 +129,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, BrowserSession>, async_graphql::Error> { ) -> Result<Connection<Cursor, BrowserSession>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query( query(
after, after,
@ -178,7 +176,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> { ) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query( query(
after, after,
@ -229,7 +227,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> { ) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query( query(
after, after,
@ -276,7 +274,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> { ) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query( query(
after, after,
@ -350,7 +348,7 @@ pub struct UserEmailsPagination(mas_data_model::User);
impl UserEmailsPagination { impl UserEmailsPagination {
/// Identifies the total count of items in the connection. /// Identifies the total count of items in the connection.
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> { async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?; let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let count = repo.user_email().count(&self.0).await?; let count = repo.user_email().count(&self.0).await?;
Ok(count) Ok(count)
} }

View File

@ -25,7 +25,7 @@ use mas_email::Mailer;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRng, SystemClock}; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use mas_templates::Templates; use mas_templates::Templates;
use rand::SeedableRng; use rand::SeedableRng;
@ -156,7 +156,7 @@ impl IntoResponse for RepositoryError {
} }
#[async_trait] #[async_trait]
impl FromRequestParts<AppState> for PgRepository { impl FromRequestParts<AppState> for BoxRepository {
type Rejection = RepositoryError; type Rejection = RepositoryError;
async fn from_request_parts( async fn from_request_parts(
@ -164,6 +164,8 @@ impl FromRequestParts<AppState> for PgRepository {
state: &AppState, state: &AppState,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
let repo = PgRepository::from_pool(&state.pool).await?; let repo = PgRepository::from_pool(&state.pool).await?;
Ok(repo) Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
} }
} }

View File

@ -22,9 +22,8 @@ use mas_storage::{
CompatSsoLoginRepository, CompatSsoLoginRepository,
}, },
user::{UserPasswordRepository, UserRepository}, user::{UserPasswordRepository, UserRepository},
BoxClock, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock,
}; };
use mas_storage_pg::PgRepository;
use rand::{CryptoRng, RngCore}; use rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
@ -154,7 +153,7 @@ pub enum RouteError {
InvalidLoginToken, InvalidLoginToken,
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -196,7 +195,7 @@ pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
mut repo: PgRepository, mut repo: BoxRepository,
State(homeserver): State<MatrixHomeserver>, State(homeserver): State<MatrixHomeserver>,
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
@ -262,7 +261,7 @@ pub(crate) async fn post(
} }
async fn token_login( async fn token_login(
repo: &mut PgRepository, repo: &mut BoxRepository,
clock: &dyn Clock, clock: &dyn Clock,
token: &str, token: &str,
) -> Result<(CompatSession, User), RouteError> { ) -> Result<(CompatSession, User), RouteError> {
@ -331,7 +330,7 @@ async fn user_password_login(
mut rng: &mut (impl RngCore + CryptoRng + Send), mut rng: &mut (impl RngCore + CryptoRng + Send),
clock: &impl Clock, clock: &impl Clock,
password_manager: &PasswordManager, password_manager: &PasswordManager,
repo: &mut PgRepository, repo: &mut BoxRepository,
username: String, username: String,
password: String, password: String,
) -> Result<(CompatSession, User), RouteError> { ) -> Result<(CompatSession, User), RouteError> {

View File

@ -31,9 +31,8 @@ use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository}, compat::{CompatSessionRepository, CompatSsoLoginRepository},
BoxClock, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ulid::Ulid; use ulid::Ulid;
@ -55,7 +54,7 @@ pub struct Params {
pub async fn get( pub async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
State(templates): State<Templates>, State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<Ulid>, Path(id): Path<Ulid>,
@ -64,7 +63,7 @@ pub async fn get(
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -117,7 +116,7 @@ pub async fn get(
pub async fn post( pub async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
State(templates): State<Templates>, State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<Ulid>, Path(id): Path<Ulid>,
@ -127,7 +126,7 @@ pub async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
cookie_jar.verify_form(&clock, form)?; cookie_jar.verify_form(&clock, form)?;
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session

View File

@ -19,8 +19,7 @@ use axum::{
}; };
use hyper::StatusCode; use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRng, Repository}; use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::PgRepository;
use rand::distributions::{Alphanumeric, DistString}; use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize; use serde::Deserialize;
use serde_with::serde; use serde_with::serde;
@ -48,7 +47,7 @@ pub enum RouteError {
InvalidRedirectUrl, InvalidRedirectUrl,
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -59,7 +58,7 @@ impl IntoResponse for RouteError {
pub async fn get( pub async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
Query(params): Query<Params>, Query(params): Query<Params>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {

View File

@ -18,9 +18,8 @@ use hyper::StatusCode;
use mas_data_model::TokenType; use mas_data_model::TokenType;
use mas_storage::{ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatSessionRepository},
BoxClock, Clock, Repository, BoxClock, BoxRepository, Clock,
}; };
use mas_storage_pg::PgRepository;
use thiserror::Error; use thiserror::Error;
use super::MatrixError; use super::MatrixError;
@ -41,7 +40,7 @@ pub enum RouteError {
InvalidAuthorization, InvalidAuthorization,
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -68,7 +67,7 @@ impl IntoResponse for RouteError {
pub(crate) async fn post( pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>, maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;

View File

@ -18,9 +18,8 @@ use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType}; use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::{ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
BoxClock, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock,
}; };
use mas_storage_pg::PgRepository;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds}; use serde_with::{serde_as, DurationMilliSeconds};
use thiserror::Error; use thiserror::Error;
@ -69,7 +68,7 @@ impl IntoResponse for RouteError {
} }
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<TokenFormatError> for RouteError { impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self { fn from(_e: TokenFormatError) -> Self {
@ -89,7 +88,7 @@ pub struct ResponseBody {
pub(crate) async fn post( pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let token_type = TokenType::check(&input.refresh_token)?; let token_type = TokenType::check(&input.refresh_token)?;

View File

@ -22,20 +22,19 @@ use axum::{
Json, TypedHeader, Json, TypedHeader,
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use futures_util::{StreamExt, TryStreamExt}; use futures_util::TryStreamExt;
use headers::{ContentType, HeaderValue}; use headers::{ContentType, HeaderValue};
use hyper::header::CACHE_CONTROL; use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema; use mas_graphql::Schema;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage_pg::PgRepository; use mas_storage::BoxRepository;
use sqlx::PgPool; use tokio::sync::Mutex;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
#[must_use] #[must_use]
pub fn schema(pool: &PgPool) -> Schema { pub fn schema() -> Schema {
mas_graphql::schema_builder() mas_graphql::schema_builder()
.data(pool.clone())
.extension(Tracing) .extension(Tracing)
.extension(ApolloTracing) .extension(ApolloTracing)
.finish() .finish()
@ -59,8 +58,8 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
} }
pub async fn post( pub async fn post(
State(pool): State<PgPool>,
State(schema): State<Schema>, State(schema): State<Schema>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
content_type: Option<TypedHeader<ContentType>>, content_type: Option<TypedHeader<ContentType>>,
body: BodyStream, body: BodyStream,
@ -68,62 +67,46 @@ pub async fn post(
let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let mut repo = PgRepository::from_pool(&pool).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
repo.cancel().await?;
let mut request = async_graphql::http::receive_batch_body( let mut request = async_graphql::http::receive_body(
content_type, content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(), .into_async_read(),
MultipartOptions::default(), MultipartOptions::default(),
) )
.await?; // XXX: this should probably return another error response? .await? // XXX: this should probably return another error response?
.data(Mutex::new(repo));
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
request = request.data(session); request = request.data(session);
} }
let response = match request { let span = span_for_graphql_request(&request);
async_graphql::BatchRequest::Single(request) => { let response = schema.execute(request).instrument(span).await;
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
async_graphql::BatchResponse::Single(response)
}
async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch(
futures_util::stream::iter(requests.into_iter())
.then(|request| {
let span = span_for_graphql_request(&request);
schema.execute(request).instrument(span)
})
.collect()
.await,
),
};
let cache_control = response let cache_control = response
.cache_control() .cache_control
.value() .value()
.and_then(|v| HeaderValue::from_str(&v).ok()) .and_then(|v| HeaderValue::from_str(&v).ok())
.map(|h| [(CACHE_CONTROL, h)]); .map(|h| [(CACHE_CONTROL, h)]);
let headers = response.http_headers(); let headers = response.http_headers.clone();
Ok((headers, cache_control, Json(response))) Ok((headers, cache_control, Json(response)))
} }
pub async fn get( pub async fn get(
State(pool): State<PgPool>,
State(schema): State<Schema>, State(schema): State<Schema>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
RawQuery(query): RawQuery, RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let mut repo = PgRepository::from_pool(&pool).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
repo.cancel().await?;
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; let mut request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo));
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
request = request.data(session); request = request.data(session);

View File

@ -43,8 +43,7 @@ use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder}; use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRng}; use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::PgRepository;
use mas_templates::{ErrorContext, Templates}; use mas_templates::{ErrorContext, Templates};
use passwords::PasswordManager; use passwords::PasswordManager;
use sqlx::PgPool; use sqlx::PgPool;
@ -98,7 +97,7 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync, <B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
mas_graphql::Schema: FromRef<S>, mas_graphql::Schema: FromRef<S>,
PgPool: FromRef<S>, BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>, Encrypter: FromRef<S>,
{ {
let mut router = Router::new().route( let mut router = Router::new().route(
@ -158,7 +157,7 @@ where
Keystore: FromRef<S>, Keystore: FromRef<S>,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>, Arc<PolicyFactory>: FromRef<S>,
PgRepository: FromRequestParts<S>, BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>, Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>, HttpClientFactory: FromRef<S>,
BoxClock: FromRequestParts<S>, BoxClock: FromRequestParts<S>,
@ -213,7 +212,7 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync, <B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
PgRepository: FromRequestParts<S>, BoxRepository: FromRequestParts<S>,
MatrixHomeserver: FromRef<S>, MatrixHomeserver: FromRef<S>,
PasswordManager: FromRef<S>, PasswordManager: FromRef<S>,
BoxClock: FromRequestParts<S>, BoxClock: FromRequestParts<S>,
@ -258,7 +257,7 @@ where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>, Arc<PolicyFactory>: FromRef<S>,
PgRepository: FromRequestParts<S>, BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>, Encrypter: FromRef<S>,
Templates: FromRef<S>, Templates: FromRef<S>,
Mailer: FromRef<S>, Mailer: FromRef<S>,
@ -401,7 +400,7 @@ async fn test_state(pool: sqlx::PgPool) -> Result<AppState, anyhow::Error> {
let policy_factory = Arc::new(policy_factory); let policy_factory = Arc::new(policy_factory);
let graphql_schema = graphql_schema(&pool); let graphql_schema = graphql_schema();
let http_client_factory = HttpClientFactory::new(10); let http_client_factory = HttpClientFactory::new(10);

View File

@ -27,9 +27,8 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::Templates; use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
use thiserror::Error; use thiserror::Error;
@ -69,7 +68,7 @@ impl IntoResponse for RouteError {
} }
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_policy::EvaluationError);
@ -81,13 +80,13 @@ pub(crate) async fn get(
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let grant = repo let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
@ -147,7 +146,7 @@ pub enum GrantCompletionError {
NoSuchClient, NoSuchClient,
} }
impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError); impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
@ -159,7 +158,7 @@ pub(crate) async fn complete(
grant: AuthorizationGrant, grant: AuthorizationGrant,
browser_session: BrowserSession, browser_session: BrowserSession,
policy_factory: &PolicyFactory, policy_factory: &PolicyFactory,
mut repo: PgRepository, mut repo: BoxRepository,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> { ) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
// Verify that the grant is in a pending stage // Verify that the grant is in a pending stage
if !grant.stage.is_pending() { if !grant.stage.is_pending() {

View File

@ -27,9 +27,8 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::Templates; use mas_templates::Templates;
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
@ -90,7 +89,7 @@ impl IntoResponse for RouteError {
} }
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::InstanciateError);
@ -135,7 +134,7 @@ pub(crate) async fn get(
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(params): Form<Params>, Form(params): Form<Params>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
@ -168,7 +167,7 @@ pub(crate) async fn get(
let templates = templates.clone(); let templates = templates.clone();
let callback_destination = callback_destination.clone(); let callback_destination = callback_destination.clone();
async move { async move {
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default(); let prompt = params.auth.prompt.as_deref().unwrap_or_default();
// Check if the request/request_uri/registration params are used. If so, reply // Check if the request/request_uri/registration params are used. If so, reply

View File

@ -30,9 +30,8 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
@ -61,7 +60,7 @@ pub enum RouteError {
} }
impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_policy::EvaluationError);
@ -77,13 +76,13 @@ pub(crate) async fn get(
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let grant = repo let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
@ -130,7 +129,7 @@ pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
Form(form): Form<ProtectedForm<()>>, Form(form): Form<ProtectedForm<()>>,
@ -139,7 +138,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let grant = repo let grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()

View File

@ -25,9 +25,8 @@ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
user::{BrowserSessionRepository, UserRepository}, user::{BrowserSessionRepository, UserRepository},
BoxClock, Clock, Repository, BoxClock, BoxRepository, Clock,
}; };
use mas_storage_pg::PgRepository;
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
requests::{IntrospectionRequest, IntrospectionResponse}, requests::{IntrospectionRequest, IntrospectionResponse},
@ -96,7 +95,7 @@ impl IntoResponse for RouteError {
} }
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<TokenFormatError> for RouteError { impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self { fn from(_e: TokenFormatError) -> Self {
@ -125,13 +124,13 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc
pub(crate) async fn post( pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>, State(http_client_factory): State<HttpClientFactory>,
mut repo: PgRepository, mut repo: BoxRepository,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<IntrospectionRequest>, client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization let client = client_authorization
.credentials .credentials
.fetch(&mut repo) .fetch(&mut *repo)
.await .await
.unwrap() .unwrap()
.ok_or(RouteError::ClientNotFound)?; .ok_or(RouteError::ClientNotFound)?;

View File

@ -19,8 +19,7 @@ use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation}; use mas_policy::{PolicyFactory, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRng, Repository}; use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::PgRepository;
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
registration::{ registration::{
@ -48,7 +47,7 @@ pub(crate) enum RouteError {
PolicyDenied(Vec<Violation>), PolicyDenied(Vec<Violation>),
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_policy::EvaluationError);
@ -108,7 +107,7 @@ impl IntoResponse for RouteError {
pub(crate) async fn post( pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
Json(body): Json<ClientMetadata>, Json(body): Json<ClientMetadata>,

View File

@ -37,9 +37,8 @@ use mas_storage::{
OAuth2RefreshTokenRepository, OAuth2SessionRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
}, },
user::BrowserSessionRepository, user::BrowserSessionRepository,
BoxClock, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock,
}; };
use mas_storage_pg::PgRepository;
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
pkce::CodeChallengeError, pkce::CodeChallengeError,
@ -150,7 +149,7 @@ impl IntoResponse for RouteError {
} }
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::claims::ClaimError); impl_from_error_for_route!(mas_jose::claims::ClaimError);
impl_from_error_for_route!(mas_jose::claims::TokenHashError); impl_from_error_for_route!(mas_jose::claims::TokenHashError);
@ -163,13 +162,13 @@ pub(crate) async fn post(
State(http_client_factory): State<HttpClientFactory>, State(http_client_factory): State<HttpClientFactory>,
State(key_store): State<Keystore>, State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
mut repo: PgRepository, mut repo: BoxRepository,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<AccessTokenRequest>, client_authorization: ClientAuthorization<AccessTokenRequest>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization let client = client_authorization
.credentials .credentials
.fetch(&mut repo) .fetch(&mut *repo)
.await? .await?
.ok_or(RouteError::ClientNotFound)?; .ok_or(RouteError::ClientNotFound)?;
@ -185,7 +184,7 @@ pub(crate) async fn post(
let form = client_authorization.form.ok_or(RouteError::BadRequest)?; let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
let reply = match form { let (reply, repo) = match form {
AccessTokenRequest::AuthorizationCode(grant) => { AccessTokenRequest::AuthorizationCode(grant) => {
authorization_code_grant( authorization_code_grant(
&mut rng, &mut rng,
@ -206,6 +205,8 @@ pub(crate) async fn post(
} }
}; };
repo.save().await?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.typed_insert(CacheControl::new().with_no_store()); headers.typed_insert(CacheControl::new().with_no_store());
headers.typed_insert(Pragma::no_cache()); headers.typed_insert(Pragma::no_cache());
@ -221,8 +222,8 @@ async fn authorization_code_grant(
client: &Client, client: &Client,
key_store: &Keystore, key_store: &Keystore,
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
mut repo: PgRepository, mut repo: BoxRepository,
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
let authz_grant = repo let authz_grant = repo
.oauth2_authorization_grant() .oauth2_authorization_grant()
.find_by_code(&grant.code) .find_by_code(&grant.code)
@ -367,9 +368,7 @@ async fn authorization_code_grant(
.exchange(clock, authz_grant) .exchange(clock, authz_grant)
.await?; .await?;
repo.save().await?; Ok((params, repo))
Ok(params)
} }
async fn refresh_token_grant( async fn refresh_token_grant(
@ -377,8 +376,8 @@ async fn refresh_token_grant(
clock: &impl Clock, clock: &impl Clock,
grant: &RefreshTokenGrant, grant: &RefreshTokenGrant,
client: &Client, client: &Client,
mut repo: PgRepository, mut repo: BoxRepository,
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
let refresh_token = repo let refresh_token = repo
.oauth2_refresh_token() .oauth2_refresh_token()
.find_by_token(&grant.refresh_token) .find_by_token(&grant.refresh_token)
@ -439,7 +438,5 @@ async fn refresh_token_grant(
.with_refresh_token(new_refresh_token.refresh_token) .with_refresh_token(new_refresh_token.refresh_token)
.with_scope(session.scope); .with_scope(session.scope);
repo.save().await?; Ok((params, repo))
Ok(params)
} }

View File

@ -31,9 +31,8 @@ use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
oauth2::OAuth2ClientRepository, oauth2::OAuth2ClientRepository,
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use oauth2_types::scope; use oauth2_types::scope;
use serde::Serialize; use serde::Serialize;
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
@ -65,7 +64,7 @@ pub enum RouteError {
#[error("failed to authenticate")] #[error("failed to authenticate")]
AuthorizationVerificationError( AuthorizationVerificationError(
#[from] AuthorizationVerificationError<mas_storage_pg::DatabaseError>, #[from] AuthorizationVerificationError<mas_storage::RepositoryError>,
), ),
#[error("no suitable key found for signing")] #[error("no suitable key found for signing")]
@ -78,7 +77,7 @@ pub enum RouteError {
NoSuchBrowserSession, NoSuchBrowserSession,
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
@ -100,11 +99,11 @@ pub async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
mut repo: PgRepository, mut repo: BoxRepository,
State(key_store): State<Keystore>, State(key_store): State<Keystore>,
user_authorization: UserAuthorization, user_authorization: UserAuthorization,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let session = user_authorization.protected(&mut repo, &clock).await?; let session = user_authorization.protected(&mut *repo, &clock).await?;
let browser_session = repo let browser_session = repo
.browser_session() .browser_session()

View File

@ -24,9 +24,8 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
@ -45,7 +44,7 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -60,7 +59,7 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>, State(http_client_factory): State<HttpClientFactory>,
mut repo: PgRepository, mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(provider_id): Path<Ulid>, Path(provider_id): Path<Ulid>,

View File

@ -30,9 +30,8 @@ use mas_storage::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository,
}, },
BoxClock, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock,
}; };
use mas_storage_pg::PgRepository;
use oauth2_types::errors::ClientErrorCode; use oauth2_types::errors::ClientErrorCode;
use serde::Deserialize; use serde::Deserialize;
use thiserror::Error; use thiserror::Error;
@ -99,7 +98,7 @@ pub(crate) enum RouteError {
Internal(Box<dyn std::error::Error>), Internal(Box<dyn std::error::Error>),
} }
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::JwksError); impl_from_error_for_route!(mas_oidc_client::error::JwksError);
@ -123,7 +122,7 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>, State(http_client_factory): State<HttpClientFactory>,
mut repo: PgRepository, mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
State(keystore): State<Keystore>, State(keystore): State<Keystore>,

View File

@ -27,9 +27,8 @@ use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserRepository}, user::{BrowserSessionRepository, UserRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{ use mas_templates::{
EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink, UpstreamSuggestLink,
@ -72,7 +71,7 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -95,7 +94,7 @@ pub(crate) enum FormData {
pub(crate) async fn get( pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
State(templates): State<Templates>, State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(link_id): Path<Ulid>, Path(link_id): Path<Ulid>,
@ -129,7 +128,7 @@ pub(crate) async fn get(
let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (user_session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_user_session = user_session_info.load_session(&mut repo).await?; let maybe_user_session = user_session_info.load_session(&mut *repo).await?;
let render = match (maybe_user_session, link.user_id) { let render = match (maybe_user_session, link.user_id) {
(Some(session), Some(user_id)) if session.user.id == user_id => { (Some(session), Some(user_id)) if session.user.id == user_id => {
@ -211,7 +210,7 @@ pub(crate) async fn get(
pub(crate) async fn post( pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(link_id): Path<Ulid>, Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>, Form(form): Form<ProtectedForm<FormData>>,
@ -250,7 +249,7 @@ pub(crate) async fn post(
} }
let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut repo).await?; let maybe_user_session = user_session_info.load_session(&mut *repo).await?;
let session = match (maybe_user_session, link.user_id, form) { let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => { (Some(session), None, FormData::Link) => {

View File

@ -24,8 +24,7 @@ use mas_axum_utils::{
use mas_email::Mailer; use mas_email::Mailer;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::PgRepository;
use mas_templates::{EmailAddContext, TemplateContext, Templates}; use mas_templates::{EmailAddContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
@ -41,13 +40,13 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -68,7 +67,7 @@ pub(crate) async fn get(
pub(crate) async fn post( pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
State(mailer): State<Mailer>, State(mailer): State<Mailer>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
@ -77,7 +76,7 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(&clock, form)?; let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -99,7 +98,7 @@ pub(crate) async fn post(
}; };
start_email_verification( start_email_verification(
&mailer, &mailer,
&mut repo, &mut *repo,
&mut rng, &mut rng,
&clock, &clock,
&session.user, &session.user,

View File

@ -28,8 +28,7 @@ use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer; use mas_email::Mailer;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Clock, Repository}; use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository};
use mas_storage_pg::PgRepository;
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, Rng}; use rand::{distributions::Uniform, Rng};
use serde::Deserialize; use serde::Deserialize;
@ -51,28 +50,28 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await render(&mut rng, &clock, templates, session, cookie_jar, &mut *repo).await
} else { } else {
let login = mas_router::Login::default(); let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response()) Ok((cookie_jar, login.go()).into_response())
} }
} }
async fn render( async fn render<E: std::error::Error>(
rng: impl Rng + Send, rng: impl Rng + Send,
clock: &impl Clock, clock: &impl Clock,
templates: Templates, templates: Templates,
session: BrowserSession, session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
repo: &mut impl Repository, repo: &mut (impl Repository<Error = E> + ?Sized),
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
@ -87,9 +86,9 @@ async fn render(
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
async fn start_email_verification( async fn start_email_verification<E: std::error::Error + Send + Sync + 'static>(
mailer: &Mailer, mailer: &Mailer,
repo: &mut impl Repository, repo: &mut (impl Repository<Error = E> + ?Sized),
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &impl Clock, clock: &impl Clock,
user: &User, user: &User,
@ -124,14 +123,14 @@ pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
State(mailer): State<Mailer>, State(mailer): State<Mailer>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>, Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let mut session = if let Some(session) = maybe_session { let mut session = if let Some(session) = maybe_session {
session session
@ -150,7 +149,7 @@ pub(crate) async fn post(
.await?; .await?;
let next = mas_router::AccountVerifyEmail::new(email.id); let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email)
.await?; .await?;
repo.save().await?; repo.save().await?;
return Ok((cookie_jar, next.go()).into_response()); return Ok((cookie_jar, next.go()).into_response());
@ -169,7 +168,7 @@ pub(crate) async fn post(
} }
let next = mas_router::AccountVerifyEmail::new(email.id); let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email)
.await?; .await?;
repo.save().await?; repo.save().await?;
return Ok((cookie_jar, next.go()).into_response()); return Ok((cookie_jar, next.go()).into_response());
@ -212,7 +211,7 @@ pub(crate) async fn post(
templates.clone(), templates.clone(),
session, session,
cookie_jar, cookie_jar,
&mut repo, &mut *repo,
) )
.await?; .await?;

View File

@ -24,8 +24,7 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::PgRepository;
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use ulid::Ulid; use ulid::Ulid;
@ -41,7 +40,7 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
Path(id): Path<Ulid>, Path(id): Path<Ulid>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
@ -49,7 +48,7 @@ pub(crate) async fn get(
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -82,7 +81,7 @@ pub(crate) async fn get(
pub(crate) async fn post( pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
Path(id): Path<Ulid>, Path(id): Path<Ulid>,
@ -91,7 +90,7 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(&clock, form)?; let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session

View File

@ -25,22 +25,21 @@ use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionRepository, UserEmailRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{AccountContext, TemplateContext, Templates}; use mas_templates::{AccountContext, TemplateContext, Templates};
pub(crate) async fn get( pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session

View File

@ -27,9 +27,8 @@ use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository}, user::{BrowserSessionRepository, UserPasswordRepository},
BoxClock, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{EmptyContext, TemplateContext, Templates}; use mas_templates::{EmptyContext, TemplateContext, Templates};
use rand::Rng; use rand::Rng;
use serde::Deserialize; use serde::Deserialize;
@ -48,12 +47,12 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar).await render(&mut rng, &clock, templates, session, cookie_jar).await
@ -86,7 +85,7 @@ pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ChangeForm>>, Form(form): Form<ProtectedForm<ChangeForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
@ -94,7 +93,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session

View File

@ -20,8 +20,7 @@ use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRng}; use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::PgRepository;
use mas_templates::{IndexContext, TemplateContext, Templates}; use mas_templates::{IndexContext, TemplateContext, Templates};
pub async fn get( pub async fn get(
@ -29,12 +28,12 @@ pub async fn get(
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut repo).await?; let session = session_info.load_session(&mut *repo).await?;
let ctx = IndexContext::new(url_builder.oidc_discovery()) let ctx = IndexContext::new(url_builder.oidc_discovery())
.maybe_with_session(session) .maybe_with_session(session)

View File

@ -26,9 +26,8 @@ use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, upstream_oauth2::UpstreamOAuthProviderRepository,
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock, Repository,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{ use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
}; };
@ -53,14 +52,14 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
if maybe_session.is_some() { if maybe_session.is_some() {
let reply = query.go_next(); let reply = query.go_next();
@ -71,7 +70,7 @@ pub(crate) async fn get(
LoginContext::default().with_upstrem_providers(providers), LoginContext::default().with_upstrem_providers(providers),
query, query,
csrf_token, csrf_token,
&mut repo, &mut *repo,
&templates, &templates,
) )
.await?; .await?;
@ -85,7 +84,7 @@ pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>, Form(form): Form<ProtectedForm<LoginForm>>,
@ -117,7 +116,7 @@ pub(crate) async fn post(
.with_upstrem_providers(providers), .with_upstrem_providers(providers),
query, query,
csrf_token, csrf_token,
&mut repo, &mut *repo,
&templates, &templates,
) )
.await?; .await?;
@ -127,7 +126,7 @@ pub(crate) async fn post(
match login( match login(
password_manager, password_manager,
&mut repo, &mut *repo,
rng, rng,
&clock, &clock,
&form.username, &form.username,
@ -149,7 +148,7 @@ pub(crate) async fn post(
LoginContext::default().with_form_state(state), LoginContext::default().with_form_state(state),
query, query,
csrf_token, csrf_token,
&mut repo, &mut *repo,
&templates, &templates,
) )
.await?; .await?;
@ -162,7 +161,7 @@ pub(crate) async fn post(
// TODO: move that logic elsewhere? // TODO: move that logic elsewhere?
async fn login( async fn login(
password_manager: PasswordManager, password_manager: PasswordManager,
repo: &mut impl Repository, repo: &mut (impl Repository + ?Sized),
mut rng: impl Rng + CryptoRng + Send, mut rng: impl Rng + CryptoRng + Send,
clock: &impl Clock, clock: &impl Clock,
username: &str, username: &str,
@ -236,7 +235,7 @@ async fn render(
ctx: LoginContext, ctx: LoginContext,
action: OptionalPostAuthAction, action: OptionalPostAuthAction,
csrf_token: CsrfToken, csrf_token: CsrfToken,
repo: &mut impl Repository, repo: &mut (impl Repository + ?Sized),
templates: &Templates, templates: &Templates,
) -> Result<String, FancyError> { ) -> Result<String, FancyError> {
let next = action.load_context(repo).await?; let next = action.load_context(repo).await?;

View File

@ -20,12 +20,11 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{user::BrowserSessionRepository, BoxClock, Repository}; use mas_storage::{user::BrowserSessionRepository, BoxClock, BoxRepository};
use mas_storage_pg::PgRepository;
pub(crate) async fn post( pub(crate) async fn post(
clock: BoxClock, clock: BoxClock,
mut repo: PgRepository, mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>, Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
@ -33,7 +32,7 @@ pub(crate) async fn post(
let (session_info, mut cookie_jar) = cookie_jar.session_info(); let (session_info, mut cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
repo.browser_session().finish(&clock, session).await?; repo.browser_session().finish(&clock, session).await?;

View File

@ -26,9 +26,8 @@ use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository}, user::{BrowserSessionRepository, UserPasswordRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{ReauthContext, TemplateContext, Templates}; use mas_templates::{ReauthContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use zeroize::Zeroizing; use zeroize::Zeroizing;
@ -45,14 +44,14 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
@ -64,7 +63,7 @@ pub(crate) async fn get(
}; };
let ctx = ReauthContext::default(); let ctx = ReauthContext::default();
let next = query.load_context(&mut repo).await?; let next = query.load_context(&mut *repo).await?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
ctx.with_post_action(next) ctx.with_post_action(next)
} else { } else {
@ -81,7 +80,7 @@ pub(crate) async fn post(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(password_manager): State<PasswordManager>, State(password_manager): State<PasswordManager>,
mut repo: PgRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>, Form(form): Form<ProtectedForm<ReauthForm>>,
@ -90,7 +89,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session

View File

@ -33,9 +33,8 @@ use mas_policy::PolicyFactory;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRng, Repository, BoxClock, BoxRepository, BoxRng, Repository,
}; };
use mas_storage_pg::PgRepository;
use mas_templates::{ use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
TemplateContext, Templates, ToFormState, TemplateContext, Templates, ToFormState,
@ -63,14 +62,14 @@ pub(crate) async fn get(
mut rng: BoxRng, mut rng: BoxRng,
clock: BoxClock, clock: BoxClock,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut *repo).await?;
if maybe_session.is_some() { if maybe_session.is_some() {
let reply = query.go_next(); let reply = query.go_next();
@ -80,7 +79,7 @@ pub(crate) async fn get(
RegisterContext::default(), RegisterContext::default(),
query, query,
csrf_token, csrf_token,
&mut repo, &mut *repo,
&templates, &templates,
) )
.await?; .await?;
@ -97,7 +96,7 @@ pub(crate) async fn post(
State(mailer): State<Mailer>, State(mailer): State<Mailer>,
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>, State(templates): State<Templates>,
mut repo: PgRepository, mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>, Form(form): Form<ProtectedForm<RegisterForm>>,
@ -175,7 +174,7 @@ pub(crate) async fn post(
RegisterContext::default().with_form_state(state), RegisterContext::default().with_form_state(state),
query, query,
csrf_token, csrf_token,
&mut repo, &mut *repo,
&templates, &templates,
) )
.await?; .await?;
@ -234,7 +233,7 @@ async fn render(
ctx: RegisterContext, ctx: RegisterContext,
action: OptionalPostAuthAction, action: OptionalPostAuthAction,
csrf_token: CsrfToken, csrf_token: CsrfToken,
repo: &mut impl Repository, repo: &mut (impl Repository + ?Sized),
templates: &Templates, templates: &Templates,
) -> Result<String, FancyError> { ) -> Result<String, FancyError> {
let next = action.load_context(repo).await?; let next = action.load_context(repo).await?;

View File

@ -40,9 +40,9 @@ impl OptionalPostAuthAction {
self.go_next_or_default(&mas_router::Index) self.go_next_or_default(&mas_router::Index)
} }
pub async fn load_context<R: Repository>( pub async fn load_context<'a>(
&self, &'a self,
repo: &mut R, repo: &'a mut (impl Repository + ?Sized),
) -> anyhow::Result<Option<PostAuthContext>> { ) -> anyhow::Result<Option<PostAuthContext>> {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action { let ctx = match action {

View File

@ -13,6 +13,7 @@ serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91" serde_json = "1.0.91"
thiserror = "1.0.38" thiserror = "1.0.38"
tracing = "0.1.37" tracing = "0.1.37"
futures-util = "0.3.25"
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1" rand_chacha = "0.3.1"

View File

@ -103,7 +103,7 @@ mod tests {
const SECOND_TOKEN: &str = "second_access_token"; const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42); let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default(); let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user // Create a user
let user = repo let user = repo
@ -139,7 +139,7 @@ mod tests {
repo.save().await.unwrap(); repo.save().await.unwrap();
{ {
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Adding the same token a second time should conflict // Adding the same token a second time should conflict
assert!(repo assert!(repo
.compat_access_token() .compat_access_token()
@ -156,7 +156,7 @@ mod tests {
} }
// Grab a new repo // Grab a new repo
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Looking up via ID works // Looking up via ID works
let token_lookup = repo let token_lookup = repo
@ -223,7 +223,7 @@ mod tests {
const REFRESH_TOKEN: &str = "refresh_token"; const REFRESH_TOKEN: &str = "refresh_token";
let mut rng = ChaChaRng::seed_from_u64(42); let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default(); let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user // Create a user
let user = repo let user = repo

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use mas_storage::{ use mas_storage::{
compat::{ compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
@ -59,21 +60,19 @@ impl PgRepository {
let txn = pool.begin().await?; let txn = pool.begin().await?;
Ok(PgRepository { txn }) Ok(PgRepository { txn })
} }
pub async fn save(self) -> Result<(), DatabaseError> {
self.txn.commit().await?;
Ok(())
}
pub async fn cancel(self) -> Result<(), DatabaseError> {
self.txn.rollback().await?;
Ok(())
}
} }
impl Repository for PgRepository { impl Repository for PgRepository {
type Error = DatabaseError; type Error = DatabaseError;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.commit().map_err(DatabaseError::from).boxed()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.rollback().map_err(DatabaseError::from).boxed()
}
fn upstream_oauth_link<'c>( fn upstream_oauth_link<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> { ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {

View File

@ -29,7 +29,7 @@ use crate::PgRepository;
async fn test_user_repo(pool: PgPool) { async fn test_user_repo(pool: PgPool) {
const USERNAME: &str = "john"; const USERNAME: &str = "john";
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let mut rng = ChaChaRng::seed_from_u64(42); let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default(); let clock = MockClock::default();
@ -77,7 +77,7 @@ async fn test_user_email_repo(pool: PgPool) {
const CODE2: &str = "543210"; const CODE2: &str = "543210";
const EMAIL: &str = "john@example.com"; const EMAIL: &str = "john@example.com";
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let mut rng = ChaChaRng::seed_from_u64(42); let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default(); let clock = MockClock::default();
@ -259,7 +259,7 @@ async fn test_user_password_repo(pool: PgPool) {
const FIRST_PASSWORD_HASH: &str = "doesntmatter"; const FIRST_PASSWORD_HASH: &str = "doesntmatter";
const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter"; const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter";
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let mut rng = ChaChaRng::seed_from_u64(42); let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default(); let clock = MockClock::default();

View File

@ -9,6 +9,7 @@ license = "Apache-2.0"
async-trait = "0.1.60" async-trait = "0.1.60"
chrono = "0.4.23" chrono = "0.4.23"
thiserror = "1.0.38" thiserror = "1.0.38"
futures-util = "0.3.25"
rand_core = "0.6.4" rand_core = "0.6.4"
url = "2.3.1" url = "2.3.1"

View File

@ -28,21 +28,21 @@
clippy::module_name_repetitions clippy::module_name_repetitions
)] )]
use rand_core::CryptoRngCore;
pub mod clock; pub mod clock;
pub mod pagination;
pub(crate) mod repository;
pub mod compat; pub mod compat;
pub mod oauth2; pub mod oauth2;
pub mod pagination;
pub(crate) mod repository;
pub mod upstream_oauth2; pub mod upstream_oauth2;
pub mod user; pub mod user;
use rand_core::CryptoRngCore;
pub use self::{ pub use self::{
clock::{Clock, SystemClock}, clock::{Clock, SystemClock},
pagination::{Page, Pagination}, pagination::{Page, Pagination},
repository::Repository, repository::{BoxRepository, Repository, RepositoryError},
}; };
pub struct MapErr<Repository, Mapper> { pub struct MapErr<Repository, Mapper> {
@ -86,7 +86,6 @@ macro_rules! repository_impl {
where where
R: $repo_trait, R: $repo_trait,
F: FnMut(<R as $repo_trait>::Error) -> E + ::std::marker::Send + ::std::marker::Sync, F: FnMut(<R as $repo_trait>::Error) -> E + ::std::marker::Send + ::std::marker::Sync,
E: ::std::error::Error + ::std::marker::Send + ::std::marker::Sync,
{ {
type Error = E; type Error = E;

View File

@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use thiserror::Error;
use crate::{ use crate::{
compat::{ compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
@ -32,6 +35,23 @@ use crate::{
pub trait Repository: Send { pub trait Repository: Send {
type Error: std::error::Error + Send + Sync + 'static; type Error: std::error::Error + Send + Sync + 'static;
fn map_err<Mapper>(self, mapper: Mapper) -> MapErr<Self, Mapper>
where
Self: Sized,
{
MapErr::new(self, mapper)
}
fn boxed(self) -> BoxRepository<Self::Error>
where
Self: Sized + Sync + 'static,
{
Box::new(self)
}
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn upstream_oauth_link<'c>( fn upstream_oauth_link<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>; ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
@ -91,14 +111,44 @@ pub trait Repository: Send {
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>; ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
} }
/// An opaque, type-erased error
#[derive(Debug, Error)]
#[error(transparent)]
pub struct RepositoryError {
source: Box<dyn std::error::Error + Send + Sync + 'static>,
}
impl RepositoryError {
pub fn from_error<E>(value: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self {
source: Box::new(value),
}
}
}
pub type BoxRepository<E = RepositoryError> =
Box<dyn Repository<Error = E> + Send + Sync + 'static>;
impl<R, F, E> Repository for crate::MapErr<R, F> impl<R, F, E> Repository for crate::MapErr<R, F>
where where
R: Repository, R: Repository,
F: FnMut(R::Error) -> E + Send + Sync, R::Error: 'static,
F: FnMut(R::Error) -> E + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static,
{ {
type Error = E; type Error = E;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::new(self.inner).save().map_err(self.mapper).boxed()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::new(self.inner).cancel().map_err(self.mapper).boxed()
}
fn upstream_oauth_link<'c>( fn upstream_oauth_link<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> { ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {

View File

@ -21,7 +21,7 @@ use crate::{pagination::Page, repository_impl, Clock, Pagination};
#[async_trait] #[async_trait]
pub trait UpstreamOAuthLinkRepository: Send + Sync { pub trait UpstreamOAuthLinkRepository: Send + Sync {
type Error: std::error::Error + Send + Sync; type Error;
/// Lookup an upstream OAuth link by its ID /// Lookup an upstream OAuth link by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error>; async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error>;