1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Box the repository everywhere

This commit is contained in:
Quentin Gliech
2023-01-20 17:49:16 +01:00
parent f4c64c2171
commit a9facab131
49 changed files with 296 additions and 296 deletions

5
Cargo.lock generated
View File

@ -2804,11 +2804,10 @@ dependencies = [
"chrono",
"mas-data-model",
"mas-storage",
"mas-storage-pg",
"oauth2-types",
"serde",
"sqlx",
"thiserror",
"tokio",
"tracing",
"ulid",
"url",
@ -3101,6 +3100,7 @@ version = "0.1.0"
dependencies = [
"async-trait",
"chrono",
"futures-util",
"mas-data-model",
"mas-iana",
"mas-jose",
@ -3117,6 +3117,7 @@ version = "0.1.0"
dependencies = [
"async-trait",
"chrono",
"futures-util",
"mas-data-model",
"mas-iana",
"mas-jose",

View File

@ -72,10 +72,10 @@ pub enum Credentials {
}
impl Credentials {
pub async fn fetch<'r, R>(&self, repo: &'r mut R) -> Result<Option<Client>, R::Error>
where
R: Repository,
{
pub async fn fetch<E>(
&self,
repo: &mut (impl Repository<Error = E> + ?Sized),
) -> Result<Option<Client>, E> {
let client_id = match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }

View File

@ -43,10 +43,10 @@ impl SessionInfo {
}
/// Load the [`BrowserSession`] from database
pub async fn load_session<R: Repository>(
pub async fn load_session<E>(
&self,
repo: &mut R,
) -> Result<Option<BrowserSession>, R::Error> {
repo: &mut (impl Repository<Error = E> + ?Sized),
) -> Result<Option<BrowserSession>, E> {
let session_id = if let Some(id) = self.current {
id
} else {

View File

@ -51,11 +51,10 @@ enum AccessToken {
}
impl AccessToken {
async fn fetch<R: Repository>(
async fn fetch<E>(
&self,
repo: &mut R,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<R::Error>>
{
repo: &mut (impl Repository<Error = E> + ?Sized),
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t,
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
@ -85,11 +84,11 @@ pub struct UserAuthorization<F = ()> {
impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter
pub async fn protected_form<R: Repository, C: Clock>(
pub async fn protected_form<E>(
self,
repo: &mut R,
clock: &C,
) -> Result<(Session, F), AuthorizationVerificationError<R::Error>> {
repo: &mut (impl Repository<Error = E> + ?Sized),
clock: &impl Clock,
) -> Result<(Session, F), AuthorizationVerificationError<E>> {
let form = match self.form {
Some(f) => f,
None => return Err(AuthorizationVerificationError::MissingForm),
@ -105,11 +104,11 @@ impl<F: Send> UserAuthorization<F> {
}
// TODO: take scopes to validate as parameter
pub async fn protected<R: Repository, C: Clock>(
pub async fn protected<E>(
self,
repo: &mut R,
clock: &C,
) -> Result<Session, AuthorizationVerificationError<R::Error>> {
repo: &mut (impl Repository<Error = E> + ?Sized),
clock: &impl Clock,
) -> Result<Session, AuthorizationVerificationError<E>> {
let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() {

View File

@ -203,7 +203,7 @@ impl Options {
let pool = database_from_config(&database_config).await?;
let password_manager = password_manager_from_config(&passwords_config).await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = repo
.user()
.find_by_username(username)
@ -234,7 +234,7 @@ impl Options {
let config: DatabaseConfig = root.load_config()?;
let pool = database_from_config(&config).await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let user = repo
.user()
@ -262,7 +262,7 @@ impl Options {
let pool = database_from_config(&config.database).await?;
let encrypter = config.secrets.encrypter();
let mut repo = PgRepository::from_pool(&pool).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
for client in config.clients.iter() {
let client_id = client.client_id;

View File

@ -102,7 +102,7 @@ impl Options {
watch_templates(&templates).await?;
}
let graphql_schema = mas_handlers::graphql_schema(&pool);
let graphql_schema = mas_handlers::graphql_schema();
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);

View File

@ -10,7 +10,7 @@ anyhow = "1.0.68"
async-graphql = { version = "5.0.4", features = ["chrono", "url"] }
chrono = "0.4.23"
serde = { version = "1.0.152", features = ["derive"] }
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
tokio = { version = "1.23.0", features = ["sync"] }
thiserror = "1.0.38"
tracing = "0.1.37"
ulid = "1.0.0"
@ -19,7 +19,6 @@ url = "2.3.1"
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
[[bin]]
name = "schema"

View File

@ -34,11 +34,10 @@ use mas_storage::{
oauth2::OAuth2ClientRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
user::{BrowserSessionRepository, UserEmailRepository},
Pagination, Repository,
BoxRepository, Pagination,
};
use mas_storage_pg::PgRepository;
use model::CreationEvent;
use sqlx::PgPool;
use tokio::sync::Mutex;
use self::model::{
BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link,
@ -94,7 +93,7 @@ impl RootQuery {
id: ID,
) -> Result<Option<OAuth2Client>, async_graphql::Error> {
let id = NodeType::OAuth2Client.extract_ulid(&id)?;
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo.oauth2_client().lookup(id).await?;
@ -124,7 +123,7 @@ impl RootQuery {
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let id = NodeType::BrowserSession.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
@ -150,7 +149,7 @@ impl RootQuery {
) -> Result<Option<UserEmail>, async_graphql::Error> {
let id = NodeType::UserEmail.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
@ -172,7 +171,7 @@ impl RootQuery {
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
@ -192,7 +191,7 @@ impl RootQuery {
id: ID,
) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?;
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let provider = repo.upstream_oauth_provider().lookup(id).await?;
@ -211,7 +210,7 @@ impl RootQuery {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,

View File

@ -15,9 +15,8 @@
use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository};
use tokio::sync::Mutex;
use url::Url;
use super::{NodeType, User};
@ -36,7 +35,7 @@ impl CompatSession {
/// The user authorized for this session.
async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let user = repo
.user()
.lookup(self.0.user_id)
@ -101,7 +100,7 @@ impl CompatSsoLogin {
) -> Result<Option<CompatSession>, async_graphql::Error> {
let Some(session_id) = self.0.session_id() else { return Ok(None) };
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let session = repo
.compat_session()
.lookup(session_id)

View File

@ -14,10 +14,9 @@
use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository};
use mas_storage_pg::PgRepository;
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository};
use oauth2_types::scope::Scope;
use sqlx::PgPool;
use tokio::sync::Mutex;
use ulid::Ulid;
use url::Url;
@ -37,7 +36,7 @@ impl OAuth2Session {
/// OAuth 2.0 client used by this session.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo
.oauth2_client()
.lookup(self.0.client_id)
@ -57,7 +56,7 @@ impl OAuth2Session {
&self,
ctx: &Context<'_>,
) -> Result<BrowserSession, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let browser_session = repo
.browser_session()
.lookup(self.0.user_session_id)
@ -69,7 +68,7 @@ impl OAuth2Session {
/// User authorized for this session.
pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let browser_session = repo
.browser_session()
.lookup(self.0.user_session_id)
@ -139,7 +138,7 @@ impl OAuth2Consent {
/// OAuth 2.0 client for which the user granted access.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let client = repo
.oauth2_client()
.lookup(self.client_id)

View File

@ -16,10 +16,9 @@ use anyhow::Context as _;
use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository,
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository,
};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use tokio::sync::Mutex;
use super::{NodeType, User};
@ -103,7 +102,7 @@ impl UpstreamOAuth2Link {
provider.clone()
} else {
// Fetch on-the-fly
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let provider = repo
.upstream_oauth_provider()
.lookup(self.link.provider_id)
@ -122,7 +121,7 @@ impl UpstreamOAuth2Link {
user.clone()
} else if let Some(user_id) = &self.link.user_id {
// Fetch on-the-fly
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let user = repo
.user()
.lookup(*user_id)

View File

@ -22,10 +22,9 @@ use mas_storage::{
oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionRepository, UserEmailRepository},
Pagination, Repository,
BoxRepository, Pagination,
};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use tokio::sync::Mutex;
use super::{
compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session,
@ -65,10 +64,9 @@ impl User {
&self,
ctx: &Context<'_>,
) -> Result<Option<UserEmail>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let mut user_email_repo = repo.user_email();
Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail))
}
@ -84,7 +82,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@ -131,7 +129,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, BrowserSession>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@ -178,7 +176,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@ -229,7 +227,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@ -276,7 +274,7 @@ impl User {
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
query(
after,
@ -350,7 +348,7 @@ pub struct UserEmailsPagination(mas_data_model::User);
impl UserEmailsPagination {
/// Identifies the total count of items in the connection.
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> {
let mut repo = PgRepository::from_pool(ctx.data::<PgPool>()?).await?;
let mut repo = ctx.data::<Mutex<BoxRepository>>()?.lock().await;
let count = repo.user_email().count(&self.0).await?;
Ok(count)
}

View File

@ -25,7 +25,7 @@ use mas_email::Mailer;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRng, SystemClock};
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage_pg::PgRepository;
use mas_templates::Templates;
use rand::SeedableRng;
@ -156,7 +156,7 @@ impl IntoResponse for RepositoryError {
}
#[async_trait]
impl FromRequestParts<AppState> for PgRepository {
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = RepositoryError;
async fn from_request_parts(
@ -164,6 +164,8 @@ impl FromRequestParts<AppState> for PgRepository {
state: &AppState,
) -> Result<Self, Self::Rejection> {
let repo = PgRepository::from_pool(&state.pool).await?;
Ok(repo)
Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
}
}

View File

@ -22,9 +22,8 @@ use mas_storage::{
CompatSsoLoginRepository,
},
user::{UserPasswordRepository, UserRepository},
BoxClock, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_storage_pg::PgRepository;
use rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
@ -154,7 +153,7 @@ pub enum RouteError {
InvalidLoginToken,
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@ -196,7 +195,7 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
mut repo: PgRepository,
mut repo: BoxRepository,
State(homeserver): State<MatrixHomeserver>,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
@ -262,7 +261,7 @@ pub(crate) async fn post(
}
async fn token_login(
repo: &mut PgRepository,
repo: &mut BoxRepository,
clock: &dyn Clock,
token: &str,
) -> Result<(CompatSession, User), RouteError> {
@ -331,7 +330,7 @@ async fn user_password_login(
mut rng: &mut (impl RngCore + CryptoRng + Send),
clock: &impl Clock,
password_manager: &PasswordManager,
repo: &mut PgRepository,
repo: &mut BoxRepository,
username: String,
password: String,
) -> Result<(CompatSession, User), RouteError> {

View File

@ -31,9 +31,8 @@ use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository},
BoxClock, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_storage_pg::PgRepository;
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize};
use ulid::Ulid;
@ -55,7 +54,7 @@ pub struct Params {
pub async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<Ulid>,
@ -64,7 +63,7 @@ pub async fn get(
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session
@ -117,7 +116,7 @@ pub async fn get(
pub async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<Ulid>,
@ -127,7 +126,7 @@ pub async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info();
cookie_jar.verify_form(&clock, form)?;
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session

View File

@ -19,8 +19,7 @@ use axum::{
};
use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRng, Repository};
use mas_storage_pg::PgRepository;
use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRepository, BoxRng};
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use serde_with::serde;
@ -48,7 +47,7 @@ pub enum RouteError {
InvalidRedirectUrl,
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@ -59,7 +58,7 @@ impl IntoResponse for RouteError {
pub async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
Query(params): Query<Params>,
) -> Result<impl IntoResponse, RouteError> {

View File

@ -18,9 +18,8 @@ use hyper::StatusCode;
use mas_data_model::TokenType;
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
BoxClock, Clock, Repository,
BoxClock, BoxRepository, Clock,
};
use mas_storage_pg::PgRepository;
use thiserror::Error;
use super::MatrixError;
@ -41,7 +40,7 @@ pub enum RouteError {
InvalidAuthorization,
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@ -68,7 +67,7 @@ impl IntoResponse for RouteError {
pub(crate) async fn post(
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
) -> Result<impl IntoResponse, RouteError> {
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;

View File

@ -18,9 +18,8 @@ use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
BoxClock, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_storage_pg::PgRepository;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds};
use thiserror::Error;
@ -69,7 +68,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self {
@ -89,7 +88,7 @@ pub struct ResponseBody {
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
let token_type = TokenType::check(&input.refresh_token)?;

View File

@ -22,20 +22,19 @@ use axum::{
Json, TypedHeader,
};
use axum_extra::extract::PrivateCookieJar;
use futures_util::{StreamExt, TryStreamExt};
use futures_util::TryStreamExt;
use headers::{ContentType, HeaderValue};
use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema;
use mas_keystore::Encrypter;
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use mas_storage::BoxRepository;
use tokio::sync::Mutex;
use tracing::{info_span, Instrument};
#[must_use]
pub fn schema(pool: &PgPool) -> Schema {
pub fn schema() -> Schema {
mas_graphql::schema_builder()
.data(pool.clone())
.extension(Tracing)
.extension(ApolloTracing)
.finish()
@ -59,8 +58,8 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
}
pub async fn post(
State(pool): State<PgPool>,
State(schema): State<Schema>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
content_type: Option<TypedHeader<ContentType>>,
body: BodyStream,
@ -68,62 +67,46 @@ pub async fn post(
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let mut repo = PgRepository::from_pool(&pool).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
repo.cancel().await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let mut request = async_graphql::http::receive_batch_body(
let mut request = async_graphql::http::receive_body(
content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(),
MultipartOptions::default(),
)
.await?; // XXX: this should probably return another error response?
.await? // XXX: this should probably return another error response?
.data(Mutex::new(repo));
if let Some(session) = maybe_session {
request = request.data(session);
}
let response = match request {
async_graphql::BatchRequest::Single(request) => {
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
async_graphql::BatchResponse::Single(response)
}
async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch(
futures_util::stream::iter(requests.into_iter())
.then(|request| {
let span = span_for_graphql_request(&request);
schema.execute(request).instrument(span)
})
.collect()
.await,
),
};
let cache_control = response
.cache_control()
.cache_control
.value()
.and_then(|v| HeaderValue::from_str(&v).ok())
.map(|h| [(CACHE_CONTROL, h)]);
let headers = response.http_headers();
let headers = response.http_headers.clone();
Ok((headers, cache_control, Json(response)))
}
pub async fn get(
State(pool): State<PgPool>,
State(schema): State<Schema>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info();
let mut repo = PgRepository::from_pool(&pool).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
repo.cancel().await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?;
let mut request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo));
if let Some(session) = maybe_session {
request = request.data(session);

View File

@ -43,8 +43,7 @@ use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRng};
use mas_storage_pg::PgRepository;
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{ErrorContext, Templates};
use passwords::PasswordManager;
use sqlx::PgPool;
@ -98,7 +97,7 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
mas_graphql::Schema: FromRef<S>,
PgPool: FromRef<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
{
let mut router = Router::new().route(
@ -158,7 +157,7 @@ where
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgRepository: FromRequestParts<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>,
BoxClock: FromRequestParts<S>,
@ -213,7 +212,7 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
PgRepository: FromRequestParts<S>,
BoxRepository: FromRequestParts<S>,
MatrixHomeserver: FromRef<S>,
PasswordManager: FromRef<S>,
BoxClock: FromRequestParts<S>,
@ -258,7 +257,7 @@ where
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgRepository: FromRequestParts<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
Templates: FromRef<S>,
Mailer: FromRef<S>,
@ -401,7 +400,7 @@ async fn test_state(pool: sqlx::PgPool) -> Result<AppState, anyhow::Error> {
let policy_factory = Arc::new(policy_factory);
let graphql_schema = graphql_schema(&pool);
let graphql_schema = graphql_schema();
let http_client_factory = HttpClientFactory::new(10);

View File

@ -27,9 +27,8 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
use thiserror::Error;
@ -69,7 +68,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@ -81,13 +80,13 @@ pub(crate) async fn get(
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let grant = repo
.oauth2_authorization_grant()
@ -147,7 +146,7 @@ pub enum GrantCompletionError {
NoSuchClient,
}
impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError);
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);
@ -159,7 +158,7 @@ pub(crate) async fn complete(
grant: AuthorizationGrant,
browser_session: BrowserSession,
policy_factory: &PolicyFactory,
mut repo: PgRepository,
mut repo: BoxRepository,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
// Verify that the grant is in a pending stage
if !grant.stage.is_pending() {

View File

@ -27,9 +27,8 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use mas_templates::Templates;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
@ -90,7 +89,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
@ -135,7 +134,7 @@ pub(crate) async fn get(
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(params): Form<Params>,
) -> Result<Response, RouteError> {
@ -168,7 +167,7 @@ pub(crate) async fn get(
let templates = templates.clone();
let callback_destination = callback_destination.clone();
async move {
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
// Check if the request/request_uri/registration params are used. If so, reply

View File

@ -30,9 +30,8 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
use thiserror::Error;
use ulid::Ulid;
@ -61,7 +60,7 @@ pub enum RouteError {
}
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@ -77,13 +76,13 @@ pub(crate) async fn get(
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let grant = repo
.oauth2_authorization_grant()
@ -130,7 +129,7 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
Form(form): Form<ProtectedForm<()>>,
@ -139,7 +138,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let grant = repo
.oauth2_authorization_grant()

View File

@ -25,9 +25,8 @@ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
user::{BrowserSessionRepository, UserRepository},
BoxClock, Clock, Repository,
BoxClock, BoxRepository, Clock,
};
use mas_storage_pg::PgRepository;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
requests::{IntrospectionRequest, IntrospectionResponse},
@ -96,7 +95,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self {
@ -125,13 +124,13 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc
pub(crate) async fn post(
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
mut repo: PgRepository,
mut repo: BoxRepository,
State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization
.credentials
.fetch(&mut repo)
.fetch(&mut *repo)
.await
.unwrap()
.ok_or(RouteError::ClientNotFound)?;

View File

@ -19,8 +19,7 @@ use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRng, Repository};
use mas_storage_pg::PgRepository;
use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::{
@ -48,7 +47,7 @@ pub(crate) enum RouteError {
PolicyDenied(Vec<Violation>),
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@ -108,7 +107,7 @@ impl IntoResponse for RouteError {
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
State(policy_factory): State<Arc<PolicyFactory>>,
State(encrypter): State<Encrypter>,
Json(body): Json<ClientMetadata>,

View File

@ -37,9 +37,8 @@ use mas_storage::{
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
user::BrowserSessionRepository,
BoxClock, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_storage_pg::PgRepository;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
pkce::CodeChallengeError,
@ -150,7 +149,7 @@ impl IntoResponse for RouteError {
}
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::claims::ClaimError);
impl_from_error_for_route!(mas_jose::claims::TokenHashError);
@ -163,13 +162,13 @@ pub(crate) async fn post(
State(http_client_factory): State<HttpClientFactory>,
State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>,
mut repo: PgRepository,
mut repo: BoxRepository,
State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<AccessTokenRequest>,
) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization
.credentials
.fetch(&mut repo)
.fetch(&mut *repo)
.await?
.ok_or(RouteError::ClientNotFound)?;
@ -185,7 +184,7 @@ pub(crate) async fn post(
let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
let reply = match form {
let (reply, repo) = match form {
AccessTokenRequest::AuthorizationCode(grant) => {
authorization_code_grant(
&mut rng,
@ -206,6 +205,8 @@ pub(crate) async fn post(
}
};
repo.save().await?;
let mut headers = HeaderMap::new();
headers.typed_insert(CacheControl::new().with_no_store());
headers.typed_insert(Pragma::no_cache());
@ -221,8 +222,8 @@ async fn authorization_code_grant(
client: &Client,
key_store: &Keystore,
url_builder: &UrlBuilder,
mut repo: PgRepository,
) -> Result<AccessTokenResponse, RouteError> {
mut repo: BoxRepository,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
let authz_grant = repo
.oauth2_authorization_grant()
.find_by_code(&grant.code)
@ -367,9 +368,7 @@ async fn authorization_code_grant(
.exchange(clock, authz_grant)
.await?;
repo.save().await?;
Ok(params)
Ok((params, repo))
}
async fn refresh_token_grant(
@ -377,8 +376,8 @@ async fn refresh_token_grant(
clock: &impl Clock,
grant: &RefreshTokenGrant,
client: &Client,
mut repo: PgRepository,
) -> Result<AccessTokenResponse, RouteError> {
mut repo: BoxRepository,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
let refresh_token = repo
.oauth2_refresh_token()
.find_by_token(&grant.refresh_token)
@ -439,7 +438,5 @@ async fn refresh_token_grant(
.with_refresh_token(new_refresh_token.refresh_token)
.with_scope(session.scope);
repo.save().await?;
Ok(params)
Ok((params, repo))
}

View File

@ -31,9 +31,8 @@ use mas_router::UrlBuilder;
use mas_storage::{
oauth2::OAuth2ClientRepository,
user::{BrowserSessionRepository, UserEmailRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use oauth2_types::scope;
use serde::Serialize;
use serde_with::skip_serializing_none;
@ -65,7 +64,7 @@ pub enum RouteError {
#[error("failed to authenticate")]
AuthorizationVerificationError(
#[from] AuthorizationVerificationError<mas_storage_pg::DatabaseError>,
#[from] AuthorizationVerificationError<mas_storage::RepositoryError>,
),
#[error("no suitable key found for signing")]
@ -78,7 +77,7 @@ pub enum RouteError {
NoSuchBrowserSession,
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
@ -100,11 +99,11 @@ pub async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(url_builder): State<UrlBuilder>,
mut repo: PgRepository,
mut repo: BoxRepository,
State(key_store): State<Keystore>,
user_authorization: UserAuthorization,
) -> Result<Response, RouteError> {
let session = user_authorization.protected(&mut repo, &clock).await?;
let session = user_authorization.protected(&mut *repo, &clock).await?;
let browser_session = repo
.browser_session()

View File

@ -24,9 +24,8 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder;
use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use thiserror::Error;
use ulid::Ulid;
@ -45,7 +44,7 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@ -60,7 +59,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
mut repo: PgRepository,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(provider_id): Path<Ulid>,

View File

@ -30,9 +30,8 @@ use mas_storage::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
BoxClock, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_storage_pg::PgRepository;
use oauth2_types::errors::ClientErrorCode;
use serde::Deserialize;
use thiserror::Error;
@ -99,7 +98,7 @@ pub(crate) enum RouteError {
Internal(Box<dyn std::error::Error>),
}
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
@ -123,7 +122,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(http_client_factory): State<HttpClientFactory>,
mut repo: PgRepository,
mut repo: BoxRepository,
State(url_builder): State<UrlBuilder>,
State(encrypter): State<Encrypter>,
State(keystore): State<Keystore>,

View File

@ -27,9 +27,8 @@ use mas_keystore::Encrypter;
use mas_storage::{
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use mas_templates::{
EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
@ -72,7 +71,7 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@ -95,7 +94,7 @@ pub(crate) enum FormData {
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
State(templates): State<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(link_id): Path<Ulid>,
@ -129,7 +128,7 @@ pub(crate) async fn get(
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let maybe_user_session = user_session_info.load_session(&mut *repo).await?;
let render = match (maybe_user_session, link.user_id) {
(Some(session), Some(user_id)) if session.user.id == user_id => {
@ -211,7 +210,7 @@ pub(crate) async fn get(
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>,
@ -250,7 +249,7 @@ pub(crate) async fn post(
}
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let maybe_user_session = user_session_info.load_session(&mut *repo).await?;
let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {

View File

@ -24,8 +24,7 @@ use mas_axum_utils::{
use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository};
use mas_storage_pg::PgRepository;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
use mas_templates::{EmailAddContext, TemplateContext, Templates};
use serde::Deserialize;
@ -41,13 +40,13 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session
@ -68,7 +67,7 @@ pub(crate) async fn get(
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
State(mailer): State<Mailer>,
cookie_jar: PrivateCookieJar<Encrypter>,
Query(query): Query<OptionalPostAuthAction>,
@ -77,7 +76,7 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session
@ -99,7 +98,7 @@ pub(crate) async fn post(
};
start_email_verification(
&mailer,
&mut repo,
&mut *repo,
&mut rng,
&clock,
&session.user,

View File

@ -28,8 +28,7 @@ use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Clock, Repository};
use mas_storage_pg::PgRepository;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository};
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, Rng};
use serde::Deserialize;
@ -51,28 +50,28 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await
render(&mut rng, &clock, templates, session, cookie_jar, &mut *repo).await
} else {
let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response())
}
}
async fn render(
async fn render<E: std::error::Error>(
rng: impl Rng + Send,
clock: &impl Clock,
templates: Templates,
session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>,
repo: &mut impl Repository,
repo: &mut (impl Repository<Error = E> + ?Sized),
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
@ -87,9 +86,9 @@ async fn render(
Ok((cookie_jar, Html(content)).into_response())
}
async fn start_email_verification(
async fn start_email_verification<E: std::error::Error + Send + Sync + 'static>(
mailer: &Mailer,
repo: &mut impl Repository,
repo: &mut (impl Repository<Error = E> + ?Sized),
mut rng: impl Rng + Send,
clock: &impl Clock,
user: &User,
@ -124,14 +123,14 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
State(mailer): State<Mailer>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let mut session = if let Some(session) = maybe_session {
session
@ -150,7 +149,7 @@ pub(crate) async fn post(
.await?;
let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email)
.await?;
repo.save().await?;
return Ok((cookie_jar, next.go()).into_response());
@ -169,7 +168,7 @@ pub(crate) async fn post(
}
let next = mas_router::AccountVerifyEmail::new(email.id);
start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email)
start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email)
.await?;
repo.save().await?;
return Ok((cookie_jar, next.go()).into_response());
@ -212,7 +211,7 @@ pub(crate) async fn post(
templates.clone(),
session,
cookie_jar,
&mut repo,
&mut *repo,
)
.await?;

View File

@ -24,8 +24,7 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository};
use mas_storage_pg::PgRepository;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
use serde::Deserialize;
use ulid::Ulid;
@ -41,7 +40,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
Path(id): Path<Ulid>,
cookie_jar: PrivateCookieJar<Encrypter>,
@ -49,7 +48,7 @@ pub(crate) async fn get(
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session
@ -82,7 +81,7 @@ pub(crate) async fn get(
pub(crate) async fn post(
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Query(query): Query<OptionalPostAuthAction>,
Path(id): Path<Ulid>,
@ -91,7 +90,7 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session

View File

@ -25,22 +25,21 @@ use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use mas_templates::{AccountContext, TemplateContext, Templates};
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session

View File

@ -27,9 +27,8 @@ use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository},
BoxClock, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock,
};
use mas_storage_pg::PgRepository;
use mas_templates::{EmptyContext, TemplateContext, Templates};
use rand::Rng;
use serde::Deserialize;
@ -48,12 +47,12 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar).await
@ -86,7 +85,7 @@ pub(crate) async fn post(
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ChangeForm>>,
) -> Result<Response, FancyError> {
@ -94,7 +93,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session

View File

@ -20,8 +20,7 @@ use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_keystore::Encrypter;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRng};
use mas_storage_pg::PgRepository;
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{IndexContext, TemplateContext, Templates};
pub async fn get(
@ -29,12 +28,12 @@ pub async fn get(
clock: BoxClock,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<impl IntoResponse, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut repo).await?;
let session = session_info.load_session(&mut *repo).await?;
let ctx = IndexContext::new(url_builder.oidc_discovery())
.maybe_with_session(session)

View File

@ -26,9 +26,8 @@ use mas_keystore::Encrypter;
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock, Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
};
@ -53,14 +52,14 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
if maybe_session.is_some() {
let reply = query.go_next();
@ -71,7 +70,7 @@ pub(crate) async fn get(
LoginContext::default().with_upstrem_providers(providers),
query,
csrf_token,
&mut repo,
&mut *repo,
&templates,
)
.await?;
@ -85,7 +84,7 @@ pub(crate) async fn post(
clock: BoxClock,
State(password_manager): State<PasswordManager>,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>,
@ -117,7 +116,7 @@ pub(crate) async fn post(
.with_upstrem_providers(providers),
query,
csrf_token,
&mut repo,
&mut *repo,
&templates,
)
.await?;
@ -127,7 +126,7 @@ pub(crate) async fn post(
match login(
password_manager,
&mut repo,
&mut *repo,
rng,
&clock,
&form.username,
@ -149,7 +148,7 @@ pub(crate) async fn post(
LoginContext::default().with_form_state(state),
query,
csrf_token,
&mut repo,
&mut *repo,
&templates,
)
.await?;
@ -162,7 +161,7 @@ pub(crate) async fn post(
// TODO: move that logic elsewhere?
async fn login(
password_manager: PasswordManager,
repo: &mut impl Repository,
repo: &mut (impl Repository + ?Sized),
mut rng: impl Rng + CryptoRng + Send,
clock: &impl Clock,
username: &str,
@ -236,7 +235,7 @@ async fn render(
ctx: LoginContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
repo: &mut impl Repository,
repo: &mut (impl Repository + ?Sized),
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(repo).await?;

View File

@ -20,12 +20,11 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route};
use mas_storage::{user::BrowserSessionRepository, BoxClock, Repository};
use mas_storage_pg::PgRepository;
use mas_storage::{user::BrowserSessionRepository, BoxClock, BoxRepository};
pub(crate) async fn post(
clock: BoxClock,
mut repo: PgRepository,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
) -> Result<impl IntoResponse, FancyError> {
@ -33,7 +32,7 @@ pub(crate) async fn post(
let (session_info, mut cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
if let Some(session) = maybe_session {
repo.browser_session().finish(&clock, session).await?;

View File

@ -26,9 +26,8 @@ use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng,
};
use mas_storage_pg::PgRepository;
use mas_templates::{ReauthContext, TemplateContext, Templates};
use serde::Deserialize;
use zeroize::Zeroizing;
@ -45,14 +44,14 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session
@ -64,7 +63,7 @@ pub(crate) async fn get(
};
let ctx = ReauthContext::default();
let next = query.load_context(&mut repo).await?;
let next = query.load_context(&mut *repo).await?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {
@ -81,7 +80,7 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(password_manager): State<PasswordManager>,
mut repo: PgRepository,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>,
@ -90,7 +89,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
let session = if let Some(session) = maybe_session {
session

View File

@ -33,9 +33,8 @@ use mas_policy::PolicyFactory;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng, Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
TemplateContext, Templates, ToFormState,
@ -63,14 +62,14 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_session(&mut *repo).await?;
if maybe_session.is_some() {
let reply = query.go_next();
@ -80,7 +79,7 @@ pub(crate) async fn get(
RegisterContext::default(),
query,
csrf_token,
&mut repo,
&mut *repo,
&templates,
)
.await?;
@ -97,7 +96,7 @@ pub(crate) async fn post(
State(mailer): State<Mailer>,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
mut repo: PgRepository,
mut repo: BoxRepository,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>,
@ -175,7 +174,7 @@ pub(crate) async fn post(
RegisterContext::default().with_form_state(state),
query,
csrf_token,
&mut repo,
&mut *repo,
&templates,
)
.await?;
@ -234,7 +233,7 @@ async fn render(
ctx: RegisterContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
repo: &mut impl Repository,
repo: &mut (impl Repository + ?Sized),
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(repo).await?;

View File

@ -40,9 +40,9 @@ impl OptionalPostAuthAction {
self.go_next_or_default(&mas_router::Index)
}
pub async fn load_context<R: Repository>(
&self,
repo: &mut R,
pub async fn load_context<'a>(
&'a self,
repo: &'a mut (impl Repository + ?Sized),
) -> anyhow::Result<Option<PostAuthContext>> {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action {

View File

@ -13,6 +13,7 @@ serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91"
thiserror = "1.0.38"
tracing = "0.1.37"
futures-util = "0.3.25"
rand = "0.8.5"
rand_chacha = "0.3.1"

View File

@ -103,7 +103,7 @@ mod tests {
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo
@ -139,7 +139,7 @@ mod tests {
repo.save().await.unwrap();
{
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Adding the same token a second time should conflict
assert!(repo
.compat_access_token()
@ -156,7 +156,7 @@ mod tests {
}
// Grab a new repo
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Looking up via ID works
let token_lookup = repo
@ -223,7 +223,7 @@ mod tests {
const REFRESH_TOKEN: &str = "refresh_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use mas_storage::{
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
@ -59,21 +60,19 @@ impl PgRepository {
let txn = pool.begin().await?;
Ok(PgRepository { txn })
}
pub async fn save(self) -> Result<(), DatabaseError> {
self.txn.commit().await?;
Ok(())
}
pub async fn cancel(self) -> Result<(), DatabaseError> {
self.txn.rollback().await?;
Ok(())
}
}
impl Repository for PgRepository {
type Error = DatabaseError;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.commit().map_err(DatabaseError::from).boxed()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.rollback().map_err(DatabaseError::from).boxed()
}
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {

View File

@ -29,7 +29,7 @@ use crate::PgRepository;
async fn test_user_repo(pool: PgPool) {
const USERNAME: &str = "john";
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
@ -77,7 +77,7 @@ async fn test_user_email_repo(pool: PgPool) {
const CODE2: &str = "543210";
const EMAIL: &str = "john@example.com";
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
@ -259,7 +259,7 @@ async fn test_user_password_repo(pool: PgPool) {
const FIRST_PASSWORD_HASH: &str = "doesntmatter";
const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter";
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();

View File

@ -9,6 +9,7 @@ license = "Apache-2.0"
async-trait = "0.1.60"
chrono = "0.4.23"
thiserror = "1.0.38"
futures-util = "0.3.25"
rand_core = "0.6.4"
url = "2.3.1"

View File

@ -28,21 +28,21 @@
clippy::module_name_repetitions
)]
use rand_core::CryptoRngCore;
pub mod clock;
pub mod pagination;
pub(crate) mod repository;
pub mod compat;
pub mod oauth2;
pub mod pagination;
pub(crate) mod repository;
pub mod upstream_oauth2;
pub mod user;
use rand_core::CryptoRngCore;
pub use self::{
clock::{Clock, SystemClock},
pagination::{Page, Pagination},
repository::Repository,
repository::{BoxRepository, Repository, RepositoryError},
};
pub struct MapErr<Repository, Mapper> {
@ -86,7 +86,6 @@ macro_rules! repository_impl {
where
R: $repo_trait,
F: FnMut(<R as $repo_trait>::Error) -> E + ::std::marker::Send + ::std::marker::Sync,
E: ::std::error::Error + ::std::marker::Send + ::std::marker::Sync,
{
type Error = E;

View File

@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use thiserror::Error;
use crate::{
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
@ -32,6 +35,23 @@ use crate::{
pub trait Repository: Send {
type Error: std::error::Error + Send + Sync + 'static;
fn map_err<Mapper>(self, mapper: Mapper) -> MapErr<Self, Mapper>
where
Self: Sized,
{
MapErr::new(self, mapper)
}
fn boxed(self) -> BoxRepository<Self::Error>
where
Self: Sized + Sync + 'static,
{
Box::new(self)
}
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
@ -91,14 +111,44 @@ pub trait Repository: Send {
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
}
/// An opaque, type-erased error
#[derive(Debug, Error)]
#[error(transparent)]
pub struct RepositoryError {
source: Box<dyn std::error::Error + Send + Sync + 'static>,
}
impl RepositoryError {
pub fn from_error<E>(value: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self {
source: Box::new(value),
}
}
}
pub type BoxRepository<E = RepositoryError> =
Box<dyn Repository<Error = E> + Send + Sync + 'static>;
impl<R, F, E> Repository for crate::MapErr<R, F>
where
R: Repository,
F: FnMut(R::Error) -> E + Send + Sync,
R::Error: 'static,
F: FnMut(R::Error) -> E + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
type Error = E;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::new(self.inner).save().map_err(self.mapper).boxed()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::new(self.inner).cancel().map_err(self.mapper).boxed()
}
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {

View File

@ -21,7 +21,7 @@ use crate::{pagination::Page, repository_impl, Clock, Pagination};
#[async_trait]
pub trait UpstreamOAuthLinkRepository: Send + Sync {
type Error: std::error::Error + Send + Sync;
type Error;
/// Lookup an upstream OAuth link by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error>;