1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Schedule jobs through the repository

This commit is contained in:
Quentin Gliech
2023-03-31 15:51:09 +02:00
parent cdd535ddc4
commit 1f748f7d1e
18 changed files with 305 additions and 84 deletions

12
Cargo.lock generated
View File

@ -3170,9 +3170,6 @@ name = "mas-handlers"
version = "0.1.0"
dependencies = [
"anyhow",
"apalis-core",
"apalis-cron",
"apalis-sql",
"argon2",
"async-graphql",
"axum",
@ -3189,7 +3186,6 @@ dependencies = [
"lettre",
"mas-axum-utils",
"mas-data-model",
"mas-email",
"mas-graphql",
"mas-http",
"mas-iana",
@ -3200,7 +3196,6 @@ dependencies = [
"mas-router",
"mas-storage",
"mas-storage-pg",
"mas-tasks",
"mas-templates",
"mime",
"oauth2-types",
@ -3454,6 +3449,7 @@ dependencies = [
name = "mas-storage"
version = "0.1.0"
dependencies = [
"apalis-core",
"async-trait",
"chrono",
"futures-util",
@ -3462,6 +3458,8 @@ dependencies = [
"mas-jose",
"oauth2-types",
"rand_core 0.6.4",
"serde",
"serde_json",
"thiserror",
"ulid",
"url",
@ -5332,9 +5330,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.94"
version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea"
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744"
dependencies = [
"itoa",
"ryu",

View File

@ -31,11 +31,6 @@ async-graphql = { version = "5.0.6", features = ["tracing", "apollo_tracing"] }
# Emails
lettre = { version = "0.10.3", default-features = false, features = ["builder"] }
# Job scheduling
apalis-core = "0.4.0-alpha.4"
apalis-cron = "0.4.0-alpha.4"
apalis-sql = { version = "0.4.0-alpha.4", features = ["postgres"] }
# Database access
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
@ -64,7 +59,6 @@ ulid = "1.0.0"
mas-axum-utils = { path = "../axum-utils", default-features = false }
mas-data-model = { path = "../data-model" }
mas-email = { path = "../email" }
mas-graphql = { path = "../graphql" }
mas-http = { path = "../http", default-features = false }
mas-iana = { path = "../iana" }
@ -75,7 +69,6 @@ mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
mas-tasks = { path = "../tasks" }
mas-templates = { path = "../templates" }
oauth2-types = { path = "../oauth2-types" }

View File

@ -107,12 +107,6 @@ impl FromRef<AppState> for PasswordManager {
}
}
impl<J: apalis_core::job::Job> FromRef<AppState> for apalis_sql::postgres::PostgresStorage<J> {
fn from_ref(input: &AppState) -> Self {
apalis_sql::postgres::PostgresStorage::new(input.pool.clone())
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxClock {
type Rejection = Infallible;

View File

@ -29,7 +29,6 @@
use std::{convert::Infallible, sync::Arc, time::Duration};
use apalis_sql::postgres::PostgresStorage;
use axum::{
body::{Bytes, HttpBody},
extract::{FromRef, FromRequestParts},
@ -44,7 +43,6 @@ use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_tasks::VerifyEmailJob;
use mas_templates::{ErrorContext, Templates};
use passwords::PasswordManager;
use sqlx::PgPool;
@ -263,7 +261,6 @@ where
Keystore: FromRef<S>,
HttpClientFactory: FromRef<S>,
PasswordManager: FromRef<S>,
PostgresStorage<VerifyEmailJob>: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
{

View File

@ -261,12 +261,6 @@ impl FromRef<TestState> for PasswordManager {
}
}
impl<J: apalis_core::job::Job> FromRef<TestState> for apalis_sql::postgres::PostgresStorage<J> {
fn from_ref(input: &TestState) -> Self {
apalis_sql::postgres::PostgresStorage::new(input.pool.clone())
}
}
#[async_trait]
impl FromRequestParts<TestState> for BoxClock {
type Rejection = Infallible;

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use apalis_core::storage::Storage;
use apalis_sql::postgres::PostgresStorage;
use axum::{
extract::{Form, Query, State},
response::{Html, IntoResponse, Response},
@ -25,8 +23,11 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng};
use mas_tasks::VerifyEmailJob;
use mas_storage::{
job::{JobRepositoryExt, VerifyEmailJob},
user::UserEmailRepository,
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{EmailAddContext, TemplateContext, Templates};
use serde::Deserialize;
@ -69,7 +70,6 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(mut job_storage): State<PostgresStorage<VerifyEmailJob>>,
cookie_jar: PrivateCookieJar<Encrypter>,
Query(query): Query<OptionalPostAuthAction>,
Form(form): Form<ProtectedForm<EmailForm>>,
@ -96,10 +96,11 @@ pub(crate) async fn post(
next
};
repo.save().await?;
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
// XXX: this grabs a new connection from the pool, which is not ideal
job_storage.push(VerifyEmailJob::new(&user_email)).await?;
repo.save().await?;
Ok((cookie_jar, next.go()).into_response())
}

View File

@ -13,8 +13,6 @@
// limitations under the License.
use anyhow::{anyhow, Context};
use apalis_core::storage::Storage;
use apalis_sql::postgres::PostgresStorage;
use axum::{
extract::{Form, State},
response::{Html, IntoResponse, Response},
@ -28,9 +26,10 @@ use mas_data_model::BrowserSession;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
job::{JobRepositoryExt, VerifyEmailJob},
user::UserEmailRepository,
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_tasks::VerifyEmailJob;
use mas_templates::{AccountEmailsContext, TemplateContext, Templates};
use rand::Rng;
use serde::Deserialize;
@ -93,7 +92,6 @@ pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
State(mut job_storage): State<PostgresStorage<VerifyEmailJob>>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>,
@ -111,37 +109,41 @@ pub(crate) async fn post(
match form {
ManagementForm::Add { email } => {
let email = repo
let user_email = repo
.user_email()
.add(&mut rng, &clock, &session.user, email)
.await?;
let next = mas_router::AccountVerifyEmail::new(email.id);
let next = mas_router::AccountVerifyEmail::new(user_email.id);
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
repo.save().await?;
// XXX: this grabs a new connection from the pool, which is not ideal
job_storage.push(VerifyEmailJob::new(&email)).await?;
return Ok((cookie_jar, next.go()).into_response());
}
ManagementForm::ResendConfirmation { id } => {
let id = id.parse()?;
let email = repo
let user_email = repo
.user_email()
.lookup(id)
.await?
.context("Email not found")?;
if email.user_id != session.user.id {
if user_email.user_id != session.user.id {
return Err(anyhow!("Email not found").into());
}
let next = mas_router::AccountVerifyEmail::new(email.id);
let next = mas_router::AccountVerifyEmail::new(user_email.id);
// XXX: this grabs a new connection from the pool, which is not ideal
job_storage.push(VerifyEmailJob::new(&email)).await?;
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
repo.save().await?;
return Ok((cookie_jar, next.go()).into_response());
}

View File

@ -14,8 +14,6 @@
use std::{str::FromStr, sync::Arc};
use apalis_core::storage::Storage;
use apalis_sql::postgres::PostgresStorage;
use axum::{
extract::{Form, Query, State},
response::{Html, IntoResponse, Response},
@ -30,10 +28,10 @@ use mas_keystore::Encrypter;
use mas_policy::PolicyFactory;
use mas_router::Route;
use mas_storage::{
job::{JobRepositoryExt, VerifyEmailJob},
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_tasks::VerifyEmailJob;
use mas_templates::{
FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates,
ToFormState,
@ -96,7 +94,6 @@ pub(crate) async fn post(
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
mut repo: BoxRepository,
State(mut job_storage): State<PostgresStorage<VerifyEmailJob>>,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>,
@ -204,10 +201,11 @@ pub(crate) async fn post(
.authenticate_with_password(&mut rng, &clock, session, &user_password)
.await?;
repo.save().await?;
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
// XXX: this grabs a new connection from the pool, which is not ideal
job_storage.push(VerifyEmailJob::new(&user_email)).await?;
repo.save().await?;
let cookie_jar = cookie_jar.set_session(&session);
Ok((cookie_jar, next.go()).into_response())

View File

@ -1799,6 +1799,27 @@
},
"query": "\n UPDATE oauth2_sessions\n SET finished_at = $2\n WHERE oauth2_session_id = $1\n "
},
"b753790eecbbb4bcd87b9e9a1d1b0dd6c3b50e82ffbfee356e2cf755d72f00be": {
"describe": {
"columns": [
{
"name": "id!",
"ordinal": 0,
"type_info": "Text"
}
],
"nullable": [
null
],
"parameters": {
"Left": [
"Text",
"Json"
]
}
},
"query": "\n SELECT id as \"id!\"\n FROM apalis.push_job($1::text, $2::json, 'Pending', now(), 25)\n "
},
"b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": {
"describe": {
"columns": [],

View File

@ -0,0 +1,76 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A module containing the PostgreSQL implementation of the [`JobRepository`].
use async_trait::async_trait;
use mas_storage::job::{JobId, JobRepository, JobSubmission};
use sqlx::PgConnection;
use crate::{errors::DatabaseInconsistencyError, DatabaseError, ExecuteExt};
/// An implementation of [`JobRepository`] for a PostgreSQL connection.
pub struct PgJobRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgJobRepository<'c> {
/// Create a new [`PgJobRepository`] from an active PostgreSQL connection.
#[must_use]
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[async_trait]
impl<'c> JobRepository for PgJobRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.job.schedule_submission",
skip_all,
fields(
db.statement,
job.id,
job.name,
),
err,
)]
async fn schedule_submission(
&mut self,
submission: JobSubmission,
) -> Result<JobId, Self::Error> {
// XXX: The apalis.push_job function is not unique, so we have to specify all
// the arguments
let res = sqlx::query_scalar!(
r#"
SELECT id as "id!"
FROM apalis.push_job($1::text, $2::json, 'Pending', now(), 25)
"#,
submission.name(),
submission.payload(),
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
let id = res
.parse()
.map_err(|source| DatabaseInconsistencyError::on("apalis.push_job").source(source))?;
tracing::Span::current().record("job.id", tracing::field::display(&id));
Ok(id)
}
}

View File

@ -209,6 +209,7 @@ impl<T> LookupResultExt for Result<T, sqlx::Error> {
}
pub mod compat;
pub mod job;
pub mod oauth2;
pub mod upstream_oauth2;
pub mod user;

View File

@ -18,6 +18,7 @@ use mas_storage::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
job::JobRepository,
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
@ -36,6 +37,7 @@ use crate::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
PgCompatSsoLoginRepository,
},
job::PgJobRepository,
oauth2::{
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
@ -178,4 +180,8 @@ impl RepositoryAccess for PgRepository {
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
Box::new(PgCompatRefreshTokenRepository::new(&mut self.txn))
}
fn job<'c>(&'c mut self) -> Box<dyn JobRepository<Error = Self::Error> + 'c> {
Box::new(PgJobRepository::new(&mut self.txn))
}
}

View File

@ -11,7 +11,10 @@ chrono = "0.4.24"
thiserror = "1.0.39"
futures-util = "0.3.27"
apalis-core = { version = "0.4.0-alpha.4", features = ["tokio-comp"] }
rand_core = "0.6.4"
serde = "1.0.159"
serde_json = "1.0.95"
url = "2.3.1"
ulid = "1.0.0"

146
crates/storage/src/job.rs Normal file
View File

@ -0,0 +1,146 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Repository to schedule persistent jobs.
pub use apalis_core::job::{Job, JobId};
use async_trait::async_trait;
use serde::Serialize;
use serde_json::Value;
use crate::repository_impl;
/// A job submission to be scheduled through the repository.
pub struct JobSubmission {
name: &'static str,
payload: Value,
}
impl JobSubmission {
/// Create a new job submission out of a [`Job`].
///
/// # Panics
///
/// Panics if the job cannot be serialized.
#[must_use]
pub fn new<J: Job + Serialize>(job: J) -> Self {
Self {
name: J::NAME,
payload: serde_json::to_value(job).expect("failed to serialize job"),
}
}
/// The name of the job.
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
/// The payload of the job.
#[must_use]
pub fn payload(&self) -> &Value {
&self.payload
}
}
/// A [`JobRepository`] is used to schedule jobs to be executed by a worker.
#[async_trait]
pub trait JobRepository: Send + Sync {
/// The error type returned by the repository.
type Error;
/// Schedule a job submission to be executed at a later time.
///
/// # Parameters
///
/// * `submission` - The job to schedule.
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn schedule_submission(
&mut self,
submission: JobSubmission,
) -> Result<JobId, Self::Error>;
}
repository_impl!(JobRepository:
async fn schedule_submission(&mut self, submission: JobSubmission) -> Result<JobId, Self::Error>;
);
/// An extension trait for [`JobRepository`] to schedule jobs directly.
#[async_trait]
pub trait JobRepositoryExt {
/// The error type returned by the repository.
type Error;
/// Schedule a job to be executed at a later time.
///
/// # Parameters
///
/// * `job` - The job to schedule.
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn schedule_job<J: Job + Serialize>(&mut self, job: J) -> Result<JobId, Self::Error>;
}
#[async_trait]
impl<T> JobRepositoryExt for T
where
T: JobRepository + ?Sized,
{
type Error = T::Error;
async fn schedule_job<J: Job + Serialize>(&mut self, job: J) -> Result<JobId, Self::Error> {
self.schedule_submission(JobSubmission::new(job)).await
}
}
mod jobs {
// XXX: Move this somewhere else?
use apalis_core::job::Job;
use mas_data_model::UserEmail;
use serde::{Deserialize, Serialize};
use ulid::Ulid;
/// A job to verify an email address.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VerifyEmailJob {
user_email_id: Ulid,
}
impl VerifyEmailJob {
/// Create a new job to verify an email address.
#[must_use]
pub fn new(user_email: &UserEmail) -> Self {
Self {
user_email_id: user_email.id,
}
}
/// The ID of the email address to verify.
#[must_use]
pub fn user_email_id(&self) -> Ulid {
self.user_email_id
}
}
impl Job for VerifyEmailJob {
const NAME: &'static str = "verify-email";
}
}
pub use self::jobs::VerifyEmailJob;

View File

@ -150,6 +150,7 @@ pub(crate) mod repository;
mod utils;
pub mod compat;
pub mod job;
pub mod oauth2;
pub mod upstream_oauth2;
pub mod user;

View File

@ -20,6 +20,7 @@ use crate::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
job::JobRepository,
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
@ -192,6 +193,9 @@ pub trait RepositoryAccess: Send {
fn compat_refresh_token<'c>(
&'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
/// Get a [`JobRepository`]
fn job<'c>(&'c mut self) -> Box<dyn JobRepository<Error = Self::Error> + 'c>;
}
/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
@ -205,6 +209,7 @@ mod impls {
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
job::JobRepository,
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
@ -373,6 +378,10 @@ mod impls {
&mut self.mapper,
))
}
fn job<'c>(&'c mut self) -> Box<dyn JobRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.job(), &mut self.mapper))
}
}
impl<R: RepositoryAccess + ?Sized> RepositoryAccess for Box<R> {
@ -469,5 +478,9 @@ mod impls {
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
(**self).compat_refresh_token()
}
fn job<'c>(&'c mut self) -> Box<dyn JobRepository<Error = Self::Error> + 'c> {
(**self).job()
}
}
}

View File

@ -17,39 +17,18 @@ use apalis_core::{
builder::{WorkerBuilder, WorkerFactory},
context::JobContext,
executor::TokioExecutor,
job::Job,
job_fn::job_fn,
monitor::Monitor,
storage::builder::WithStorage,
};
use chrono::Duration;
use mas_data_model::UserEmail;
use mas_email::{Address, EmailVerificationContext, Mailbox};
use mas_storage::job::VerifyEmailJob;
use rand::{distributions::Uniform, Rng};
use serde::{Deserialize, Serialize};
use tracing::info;
use ulid::Ulid;
use crate::{JobContextExt, State};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VerifyEmailJob {
user_email_id: Ulid,
}
impl VerifyEmailJob {
#[must_use]
pub fn new(user_email: &UserEmail) -> Self {
Self {
user_email_id: user_email.id,
}
}
}
impl Job for VerifyEmailJob {
const NAME: &'static str = "verify-email";
}
async fn verify_email(job: VerifyEmailJob, ctx: JobContext) -> Result<(), anyhow::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
@ -60,7 +39,7 @@ async fn verify_email(job: VerifyEmailJob, ctx: JobContext) -> Result<(), anyhow
// Lookup the user email
let user_email = repo
.user_email()
.lookup(job.user_email_id)
.lookup(job.user_email_id())
.await?
.context("User email not found")?;

View File

@ -28,8 +28,6 @@ use tracing::debug;
mod database;
mod email;
pub use self::email::VerifyEmailJob;
#[derive(Clone)]
struct State {
pool: Pool<Postgres>,