diff --git a/Cargo.lock b/Cargo.lock index 1ce2dd06..0bfd46ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -117,9 +117,63 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" +checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" + +[[package]] +name = "apalis-core" +version = "0.4.0-alpha.4" +source = "git+https://github.com/geofmureithi/apalis.git?rev=ead6f840b92a3590a8bf46398eaf65ed55aa92dc#ead6f840b92a3590a8bf46398eaf65ed55aa92dc" +dependencies = [ + "async-stream", + "async-trait", + "chrono", + "futures 0.3.27", + "graceful-shutdown", + "http", + "log", + "pin-project-lite", + "serde", + "serde_json", + "strum", + "thiserror", + "tokio", + "tower", + "tracing", + "ulid", +] + +[[package]] +name = "apalis-cron" +version = "0.4.0-alpha.4" +source = "git+https://github.com/geofmureithi/apalis.git?rev=ead6f840b92a3590a8bf46398eaf65ed55aa92dc#ead6f840b92a3590a8bf46398eaf65ed55aa92dc" +dependencies = [ + "apalis-core", + "async-stream", + "chrono", + "cron", + "futures 0.3.27", + "tokio", + "tower", +] + +[[package]] +name = "apalis-sql" +version = "0.4.0-alpha.4" +source = "git+https://github.com/geofmureithi/apalis.git?rev=ead6f840b92a3590a8bf46398eaf65ed55aa92dc#ead6f840b92a3590a8bf46398eaf65ed55aa92dc" +dependencies = [ + "apalis-core", + "async-stream", + "async-trait", + "chrono", + "futures 0.3.27", + "futures-lite", + "serde", + "serde_json", + "sqlx", + "tokio", +] [[package]] name = "argon2" @@ -1433,6 +1487,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "cron" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d76219e9243e100d5a37676005f08379297f8addfebc247613299600625c734d" +dependencies = [ + "chrono", + "nom", + "once_cell", +] + [[package]] name = "crossbeam-channel" version = "0.5.7" @@ -1949,6 +2014,18 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "pin-project", + "spin 0.9.6", +] + [[package]] name = "fnv" version = "1.0.7" @@ -2231,6 +2308,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "graceful-shutdown" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3effbaf774a1da3462925bb182ccf975c284cf46edca5569ea93420a657af484" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "group" version = "0.13.0" @@ -2828,6 +2916,17 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" +[[package]] +name = "libsqlite3-sys" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "898745e570c7d0453cc1fbc4a701eb6c662ed54e8fec8b7d14be137ebeeb9d14" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "link-cplusplus" version = "1.0.8" @@ -2935,6 +3034,7 @@ name = "mas-cli" version = "0.1.0" dependencies = [ "anyhow", + "apalis-core", "atty", "axum", "camino", @@ -3070,6 +3170,9 @@ name = "mas-handlers" version = "0.1.0" dependencies = [ "anyhow", + "apalis-core", + "apalis-cron", + "apalis-sql", "argon2", "async-graphql", "axum", @@ -3097,6 +3200,7 @@ dependencies = [ "mas-router", "mas-storage", "mas-storage-pg", + "mas-tasks", "mas-templates", "mime", "oauth2-types", @@ -3391,14 +3495,27 @@ dependencies = [ name = "mas-tasks" version = "0.1.0" dependencies = [ + "anyhow", + "apalis-core", + "apalis-cron", + "apalis-sql", "async-trait", + "chrono", "futures-util", + "mas-data-model", + "mas-email", "mas-storage", "mas-storage-pg", + "rand 0.8.5", + "rand_chacha 0.3.1", + "serde", "sqlx", + "thiserror", "tokio", "tokio-stream", + "tower", "tracing", + "ulid", ] [[package]] @@ -5171,9 +5288,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" dependencies = [ "serde_derive", ] @@ -5193,9 +5310,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" dependencies = [ "proc-macro2 1.0.52", "quote 1.0.26", @@ -5420,6 +5537,9 @@ name = "spin" version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5d6e0250b93c8427a177b849d144a96d5acc57006149479403d7861ab721e34" +dependencies = [ + "lock_api", +] [[package]] name = "spki" @@ -5477,8 +5597,10 @@ dependencies = [ "dotenvy", "either", "event-listener", + "flume", "futures-channel", "futures-core", + "futures-executor", "futures-intrusive", "futures-util", "hashlink", @@ -5488,6 +5610,7 @@ dependencies = [ "indexmap", "itoa", "libc", + "libsqlite3-sys", "log", "md-5", "memchr", @@ -5597,6 +5720,28 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "strum" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" +dependencies = [ + "heck", + "proc-macro2 1.0.52", + "quote 1.0.26", + "rustversion", + "syn 1.0.109", +] + [[package]] name = "subtle" version = "2.4.1" @@ -6401,6 +6546,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index 644ecbf4..cf7288f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,15 @@ opt-level = 3 [patch.crates-io.ulid] git = "https://github.com/dylanhart/ulid-rs.git" rev = "0b9295c2db2114cd87aa19abcc1fc00c16b272db" + +[patch.crates-io.apalis-core] +git = "https://github.com/geofmureithi/apalis.git" +rev = "ead6f840b92a3590a8bf46398eaf65ed55aa92dc" + +[patch.crates-io.apalis-sql] +git = "https://github.com/geofmureithi/apalis.git" +rev = "ead6f840b92a3590a8bf46398eaf65ed55aa92dc" + +[patch.crates-io.apalis-cron] +git = "https://github.com/geofmureithi/apalis.git" +rev = "ead6f840b92a3590a8bf46398eaf65ed55aa92dc" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 2cde9227..3b838be5 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] +apalis-core = "0.4.0-alpha.4" anyhow = "1.0.69" atty = "0.2.14" axum = "0.6.11" diff --git a/crates/cli/src/commands/mod.rs b/crates/cli/src/commands/mod.rs index 9ea132d8..8054b171 100644 --- a/crates/cli/src/commands/mod.rs +++ b/crates/cli/src/commands/mod.rs @@ -23,6 +23,7 @@ mod debug; mod manage; mod server; mod templates; +mod worker; #[derive(Parser, Debug)] enum Subcommand { @@ -35,6 +36,9 @@ enum Subcommand { /// Runs the web server Server(self::server::Options), + /// Run the worker + Worker(self::worker::Options), + /// Manage the instance Manage(self::manage::Options), @@ -62,6 +66,7 @@ impl Options { Some(S::Config(c)) => c.run(self).await, Some(S::Database(c)) => c.run(self).await, Some(S::Server(c)) => c.run(self).await, + Some(S::Worker(c)) => c.run(self).await, Some(S::Manage(c)) => c.run(self).await, Some(S::Templates(c)) => c.run(self).await, Some(S::Debug(c)) => c.run(self).await, diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 00230953..9ef72534 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -22,7 +22,6 @@ use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_router::UrlBuilder; use mas_storage_pg::MIGRATOR; -use mas_tasks::TaskQueue; use tokio::signal::unix::SignalKind; use tracing::{info, info_span, warn, Instrument}; @@ -37,6 +36,10 @@ pub(super) struct Options { #[arg(long)] migrate: bool, + /// Do not start the task worker + #[arg(long)] + no_worker: bool, + /// Watch for changes for templates on the filesystem #[arg(short, long)] watch: bool, @@ -61,11 +64,6 @@ impl Options { .context("could not run migrations")?; } - info!("Starting task scheduler"); - let queue = TaskQueue::default(); - queue.recuring(Duration::from_secs(15), mas_tasks::cleanup_expired(&pool)); - queue.start(); - // Initialize the key store let key_store = config .secrets @@ -85,8 +83,15 @@ impl Options { // Load and compile the templates let templates = templates_from_config(&config.templates, &url_builder).await?; - let mailer = mailer_from_config(&config.email, &templates).await?; - mailer.test_connection().await?; + if !self.no_worker { + let mailer = mailer_from_config(&config.email, &templates).await?; + mailer.test_connection().await?; + + info!("Starting task worker"); + let monitor = mas_tasks::init(&pool, &mailer); + // TODO: grab the handle + tokio::spawn(monitor.run()); + } let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone()); @@ -113,7 +118,6 @@ impl Options { key_store, encrypter, url_builder, - mailer, homeserver, policy_factory, graphql_schema, diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs new file mode 100644 index 00000000..23742dc1 --- /dev/null +++ b/crates/cli/src/commands/worker.rs @@ -0,0 +1,51 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use clap::Parser; +use mas_config::RootConfig; +use mas_router::UrlBuilder; +use tracing::{info_span, log::info}; + +use crate::util::{database_from_config, mailer_from_config, templates_from_config}; + +#[derive(Parser, Debug, Default)] +pub(super) struct Options {} + +impl Options { + pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { + let span = info_span!("cli.worker.init").entered(); + let config: RootConfig = root.load_config()?; + + // Connect to the database + info!("Conntecting to the database"); + let pool = database_from_config(&config.database).await?; + + let url_builder = UrlBuilder::new(config.http.public_base.clone()); + + // Load and compile the templates + let templates = templates_from_config(&config.templates, &url_builder).await?; + + let mailer = mailer_from_config(&config.email, &templates).await?; + mailer.test_connection().await?; + drop(config); + + info!("Starting task scheduler"); + let monitor = mas_tasks::init(&pool, &mailer); + + span.exit(); + + monitor.run().await?; + Ok(()) + } +} diff --git a/crates/email/src/lib.rs b/crates/email/src/lib.rs index 754d09df..4d87a6ae 100644 --- a/crates/email/src/lib.rs +++ b/crates/email/src/lib.rs @@ -26,7 +26,10 @@ mod mailer; mod transport; -pub use lettre::transport::smtp::authentication::Credentials as SmtpCredentials; +pub use lettre::{ + message::Mailbox, transport::smtp::authentication::Credentials as SmtpCredentials, Address, +}; +pub use mas_templates::EmailVerificationContext; pub use self::{ mailer::Mailer, diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 231b28a4..7af36e6c 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -31,6 +31,11 @@ async-graphql = { version = "5.0.6", features = ["tracing", "apollo_tracing"] } # Emails lettre = { version = "0.10.3", default-features = false, features = ["builder"] } +# Job scheduling +apalis-core = "0.4.0-alpha.4" +apalis-cron = "0.4.0-alpha.4" +apalis-sql = { version = "0.4.0-alpha.4", features = ["postgres"] } + # Database access sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } @@ -70,6 +75,7 @@ mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-storage = { path = "../storage" } mas-storage-pg = { path = "../storage-pg" } +mas-tasks = { path = "../tasks" } mas-templates = { path = "../templates" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index 2e826bad..e9a7b494 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -21,7 +21,6 @@ use axum::{ }; use hyper::StatusCode; use mas_axum_utils::http_client_factory::HttpClientFactory; -use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; @@ -41,7 +40,6 @@ pub struct AppState { pub key_store: Keystore, pub encrypter: Encrypter, pub url_builder: UrlBuilder, - pub mailer: Mailer, pub homeserver: MatrixHomeserver, pub policy_factory: Arc, pub graphql_schema: mas_graphql::Schema, @@ -85,12 +83,6 @@ impl FromRef for UrlBuilder { } } -impl FromRef for Mailer { - fn from_ref(input: &AppState) -> Self { - input.mailer.clone() - } -} - impl FromRef for MatrixHomeserver { fn from_ref(input: &AppState) -> Self { input.homeserver.clone() @@ -115,6 +107,12 @@ impl FromRef for PasswordManager { } } +impl FromRef for apalis_sql::postgres::PostgresStorage { + fn from_ref(input: &AppState) -> Self { + apalis_sql::postgres::PostgresStorage::new(input.pool.clone()) + } +} + #[async_trait] impl FromRequestParts for BoxClock { type Rejection = Infallible; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index f1ea5d51..e69b7864 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -29,6 +29,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration}; +use apalis_sql::postgres::PostgresStorage; use axum::{ body::{Bytes, HttpBody}, extract::{FromRef, FromRequestParts}, @@ -38,12 +39,12 @@ use axum::{ }; use headers::HeaderName; use hyper::header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_TYPE}; -use mas_email::Mailer; use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRepository, BoxRng}; +use mas_tasks::VerifyEmailJob; use mas_templates::{ErrorContext, Templates}; use passwords::PasswordManager; use sqlx::PgPool; @@ -259,10 +260,10 @@ where BoxRepository: FromRequestParts, Encrypter: FromRef, Templates: FromRef, - Mailer: FromRef, Keystore: FromRef, HttpClientFactory: FromRef, PasswordManager: FromRef, + PostgresStorage: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, { diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index d2479598..60063694 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -22,7 +22,6 @@ use axum::{ use headers::{Authorization, ContentType, HeaderMapExt, HeaderName}; use hyper::{header::CONTENT_TYPE, Request, Response, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; -use mas_email::{MailTransport, Mailer}; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; use mas_policy::PolicyFactory; use mas_router::{SimpleRoute, UrlBuilder}; @@ -57,7 +56,6 @@ pub(crate) struct TestState { pub key_store: Keystore, pub encrypter: Encrypter, pub url_builder: UrlBuilder, - pub mailer: Mailer, pub homeserver: MatrixHomeserver, pub policy_factory: Arc, pub graphql_schema: mas_graphql::Schema, @@ -91,10 +89,6 @@ impl TestState { let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; - let transport = MailTransport::blackhole(); - let mailbox: lettre::message::Mailbox = "server@example.com".parse()?; - let mailer = Mailer::new(templates.clone(), transport, mailbox.clone(), mailbox); - let homeserver = MatrixHomeserver::new("example.com".to_owned()); let file = @@ -124,7 +118,6 @@ impl TestState { key_store, encrypter, url_builder, - mailer, homeserver, policy_factory, graphql_schema, @@ -244,12 +237,6 @@ impl FromRef for UrlBuilder { } } -impl FromRef for Mailer { - fn from_ref(input: &TestState) -> Self { - input.mailer.clone() - } -} - impl FromRef for MatrixHomeserver { fn from_ref(input: &TestState) -> Self { input.homeserver.clone() @@ -274,6 +261,12 @@ impl FromRef for PasswordManager { } } +impl FromRef for apalis_sql::postgres::PostgresStorage { + fn from_ref(input: &TestState) -> Self { + apalis_sql::postgres::PostgresStorage::new(input.pool.clone()) + } +} + #[async_trait] impl FromRequestParts for BoxClock { type Rejection = Infallible; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index fe7ec162..244c4f5d 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use apalis_core::storage::Storage; +use apalis_sql::postgres::PostgresStorage; use axum::{ extract::{Form, Query, State}, response::{Html, IntoResponse, Response}, @@ -21,14 +23,13 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, FancyError, SessionInfoExt, }; -use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; +use mas_tasks::VerifyEmailJob; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; -use super::start_email_verification; use crate::views::shared::OptionalPostAuthAction; #[derive(Deserialize, Debug)] @@ -68,7 +69,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, mut repo: BoxRepository, - State(mailer): State, + State(mut job_storage): State>, cookie_jar: PrivateCookieJar, Query(query): Query, Form(form): Form>, @@ -94,17 +95,11 @@ pub(crate) async fn post( } else { next }; - start_email_verification( - &mailer, - &mut repo, - &mut rng, - &clock, - &session.user, - user_email, - ) - .await?; repo.save().await?; + // XXX: this grabs a new connection from the pool, which is not ideal + job_storage.push(VerifyEmailJob::new(&user_email)).await?; + Ok((cookie_jar, next.go()).into_response()) } diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 22082b53..a55f96ae 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -13,28 +13,27 @@ // limitations under the License. use anyhow::{anyhow, Context}; +use apalis_core::storage::Storage; +use apalis_sql::postgres::PostgresStorage; use axum::{ extract::{Form, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; -use chrono::Duration; -use lettre::{message::Mailbox, Address}; use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, FancyError, SessionInfoExt, }; -use mas_data_model::{BrowserSession, User, UserEmail}; -use mas_email::Mailer; +use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; -use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; -use rand::{distributions::Uniform, Rng}; +use mas_tasks::VerifyEmailJob; +use mas_templates::{AccountEmailsContext, TemplateContext, Templates}; +use rand::Rng; use serde::Deserialize; -use tracing::info; pub mod add; pub mod verify; @@ -89,46 +88,13 @@ async fn render( Ok((cookie_jar, Html(content)).into_response()) } -async fn start_email_verification( - mailer: &Mailer, - repo: &mut impl RepositoryAccess, - mut rng: impl Rng + Send, - clock: &impl Clock, - user: &User, - user_email: UserEmail, -) -> anyhow::Result<()> { - // First, generate a code - let range = Uniform::::from(0..1_000_000); - let code = rng.sample(range).to_string(); - - let address: Address = user_email.email.parse()?; - - let verification = repo - .user_email() - .add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code) - .await?; - - // And send the verification email - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - let context = EmailVerificationContext::new(user.clone(), verification.clone()); - - mailer.send_verification_email(mailbox, &context).await?; - - info!( - email.id = %user_email.id, - "Verification email sent" - ); - Ok(()) -} - #[tracing::instrument(name = "handlers.views.account_email_list.post", skip_all, err)] pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(templates): State, + State(mut job_storage): State>, mut repo: BoxRepository, - State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { @@ -151,9 +117,12 @@ pub(crate) async fn post( .await?; let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) - .await?; + repo.save().await?; + + // XXX: this grabs a new connection from the pool, which is not ideal + job_storage.push(VerifyEmailJob::new(&email)).await?; + return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::ResendConfirmation { id } => { @@ -170,9 +139,10 @@ pub(crate) async fn post( } let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) - .await?; - repo.save().await?; + + // XXX: this grabs a new connection from the pool, which is not ideal + job_storage.push(VerifyEmailJob::new(&email)).await?; + return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::Remove { id } => { diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index c9f33754..bc8bb671 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -14,18 +14,18 @@ use std::{str::FromStr, sync::Arc}; +use apalis_core::storage::Storage; +use apalis_sql::postgres::PostgresStorage; use axum::{ extract::{Form, Query, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; -use chrono::Duration; -use lettre::{message::Mailbox, Address}; +use lettre::Address; use mas_axum_utils::{ csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; -use mas_email::Mailer; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::Route; @@ -33,11 +33,11 @@ use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; +use mas_tasks::VerifyEmailJob; use mas_templates::{ - EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, - TemplateContext, Templates, ToFormState, + FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates, + ToFormState, }; -use rand::{distributions::Uniform, Rng}; use serde::{Deserialize, Serialize}; use zeroize::Zeroizing; @@ -93,10 +93,10 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - State(mailer): State, State(policy_factory): State>, State(templates): State, mut repo: BoxRepository, + State(mut job_storage): State>, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, @@ -195,25 +195,6 @@ pub(crate) async fn post( .add(&mut rng, &clock, &user, form.email) .await?; - // First, generate a code - let range = Uniform::::from(0..1_000_000); - let code = rng.sample(range); - let code = format!("{code:06}"); - - let address: Address = user_email.email.parse()?; - - let verification = repo - .user_email() - .add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code) - .await?; - - // And send the verification email - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - let context = EmailVerificationContext::new(user.clone(), verification.clone()); - - mailer.send_verification_email(mailbox, &context).await?; - let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action); let session = repo.browser_session().add(&mut rng, &clock, &user).await?; @@ -225,6 +206,9 @@ pub(crate) async fn post( repo.save().await?; + // XXX: this grabs a new connection from the pool, which is not ideal + job_storage.push(VerifyEmailJob::new(&user_email)).await?; + let cookie_jar = cookie_jar.set_session(&session); Ok((cookie_jar, next.go()).into_response()) } diff --git a/crates/storage-pg/migrations/20220530084123_jobs_workers.sql b/crates/storage-pg/migrations/20220530084123_jobs_workers.sql new file mode 100644 index 00000000..133a89b4 --- /dev/null +++ b/crates/storage-pg/migrations/20220530084123_jobs_workers.sql @@ -0,0 +1,83 @@ + CREATE SCHEMA apalis; + + CREATE TABLE IF NOT EXISTS apalis.workers ( + id TEXT NOT NULL, + worker_type TEXT NOT NULL, + storage_name TEXT NOT NULL, + layers TEXT NOT NULL DEFAULT '', + last_seen timestamptz not null default now() + ); + + CREATE INDEX IF NOT EXISTS Idx ON apalis.workers(id); + + CREATE UNIQUE INDEX IF NOT EXISTS unique_worker_id ON apalis.workers (id); + + CREATE INDEX IF NOT EXISTS WTIdx ON apalis.workers(worker_type); + + CREATE INDEX IF NOT EXISTS LSIdx ON apalis.workers(last_seen); + + CREATE TABLE IF NOT EXISTS apalis.jobs ( + job JSONB NOT NULL, + id TEXT NOT NULL, + job_type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'Pending', + attempts INTEGER NOT NULL DEFAULT 0, + max_attempts INTEGER NOT NULL DEFAULT 25, + run_at timestamptz NOT NULL default now(), + last_error TEXT, + lock_at timestamptz, + lock_by TEXT, + done_at timestamptz, + CONSTRAINT fk_worker_lock_by FOREIGN KEY(lock_by) REFERENCES apalis.workers(id) + ); + + CREATE INDEX IF NOT EXISTS TIdx ON apalis.jobs(id); + + CREATE INDEX IF NOT EXISTS SIdx ON apalis.jobs(status); + + CREATE UNIQUE INDEX IF NOT EXISTS unique_job_id ON apalis.jobs (id); + + CREATE INDEX IF NOT EXISTS LIdx ON apalis.jobs(lock_by); + + CREATE INDEX IF NOT EXISTS JTIdx ON apalis.jobs(job_type); + + CREATE OR replace FUNCTION apalis.get_job( + worker_id TEXT, + v_job_type TEXT + ) returns apalis.jobs AS $$ + DECLARE + v_job_id text; + v_job_row apalis.jobs; + BEGIN + SELECT id, job_type + INTO v_job_id, v_job_type + FROM apalis.jobs + WHERE status = 'Pending' + AND run_at < now() + AND job_type = v_job_type + ORDER BY run_at ASC limit 1 FOR UPDATE skip LOCKED; + + IF v_job_id IS NULL THEN + RETURN NULL; + END IF; + + UPDATE apalis.jobs + SET + status = 'Running', + lock_by = worker_id, + lock_at = now() + WHERE id = v_job_id + returning * INTO v_job_row; + RETURN v_job_row; + END; + $$ LANGUAGE plpgsql volatile; + + CREATE FUNCTION apalis.notify_new_jobs() returns trigger as $$ + BEGIN + perform pg_notify('apalis::job', 'insert'); + return new; + END; + $$ language plpgsql; + + CREATE TRIGGER notify_workers after insert on apalis.jobs for each statement execute procedure apalis.notify_new_jobs(); + diff --git a/crates/storage-pg/migrations/20220709210445_add_job_fn.sql b/crates/storage-pg/migrations/20220709210445_add_job_fn.sql new file mode 100644 index 00000000..26b56e1d --- /dev/null +++ b/crates/storage-pg/migrations/20220709210445_add_job_fn.sql @@ -0,0 +1,46 @@ +CREATE OR REPLACE FUNCTION apalis.push_job( + job_type text, + job json DEFAULT NULL :: json, + job_id text DEFAULT NULL :: text, + status text DEFAULT 'Pending' :: text, + run_at timestamptz DEFAULT NOW() :: timestamptz, + max_attempts integer DEFAULT 25 :: integer +) RETURNS apalis.jobs AS $$ + + DECLARE + v_job_row apalis.jobs; + v_job_id text; + + BEGIN + IF job_type is not NULL and length(job_type) > 512 THEN raise exception 'Job_type is too long (max length: 512).' USING errcode = 'APAJT'; + END IF; + + IF max_attempts < 1 THEN raise exception 'Job maximum attempts must be at least 1.' USING errcode = 'APAMA'; + end IF; + + SELECT + uuid_in( + md5(random() :: text || now() :: text) :: cstring + ) INTO v_job_id; + INSERT INTO + apalis.jobs + VALUES + ( + job, + v_job_id, + job_type, + status, + 0, + max_attempts, + run_at, + NULL, + NULL, + NULL, + NULL + ) + returning * INTO v_job_row; + RETURN v_job_row; +END; +$$ LANGUAGE plpgsql volatile; + + diff --git a/crates/storage-pg/migrations/20230330210841_replace_add_job_fn.sql b/crates/storage-pg/migrations/20230330210841_replace_add_job_fn.sql new file mode 100644 index 00000000..3562e96f --- /dev/null +++ b/crates/storage-pg/migrations/20230330210841_replace_add_job_fn.sql @@ -0,0 +1,106 @@ +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +CREATE OR REPLACE FUNCTION generate_ulid() +RETURNS TEXT +AS $$ +DECLARE + -- Crockford's Base32 + encoding BYTEA = '0123456789ABCDEFGHJKMNPQRSTVWXYZ'; + timestamp BYTEA = E'\\000\\000\\000\\000\\000\\000'; + output TEXT = ''; + + unix_time BIGINT; + ulid BYTEA; +BEGIN + -- 6 timestamp bytes + unix_time = (EXTRACT(EPOCH FROM CLOCK_TIMESTAMP()) * 1000)::BIGINT; + timestamp = SET_BYTE(timestamp, 0, (unix_time >> 40)::BIT(8)::INTEGER); + timestamp = SET_BYTE(timestamp, 1, (unix_time >> 32)::BIT(8)::INTEGER); + timestamp = SET_BYTE(timestamp, 2, (unix_time >> 24)::BIT(8)::INTEGER); + timestamp = SET_BYTE(timestamp, 3, (unix_time >> 16)::BIT(8)::INTEGER); + timestamp = SET_BYTE(timestamp, 4, (unix_time >> 8)::BIT(8)::INTEGER); + timestamp = SET_BYTE(timestamp, 5, unix_time::BIT(8)::INTEGER); + + -- 10 entropy bytes + ulid = timestamp || gen_random_bytes(10); + + -- Encode the timestamp + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 0) & 224) >> 5)); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 0) & 31))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 1) & 248) >> 3)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 1) & 7) << 2) | ((GET_BYTE(ulid, 2) & 192) >> 6))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 2) & 62) >> 1)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 2) & 1) << 4) | ((GET_BYTE(ulid, 3) & 240) >> 4))); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 3) & 15) << 1) | ((GET_BYTE(ulid, 4) & 128) >> 7))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 4) & 124) >> 2)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 4) & 3) << 3) | ((GET_BYTE(ulid, 5) & 224) >> 5))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 5) & 31))); + + -- Encode the entropy + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 6) & 248) >> 3)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 6) & 7) << 2) | ((GET_BYTE(ulid, 7) & 192) >> 6))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 7) & 62) >> 1)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 7) & 1) << 4) | ((GET_BYTE(ulid, 8) & 240) >> 4))); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 8) & 15) << 1) | ((GET_BYTE(ulid, 9) & 128) >> 7))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 9) & 124) >> 2)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 9) & 3) << 3) | ((GET_BYTE(ulid, 10) & 224) >> 5))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 10) & 31))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 11) & 248) >> 3)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 11) & 7) << 2) | ((GET_BYTE(ulid, 12) & 192) >> 6))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 12) & 62) >> 1)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 12) & 1) << 4) | ((GET_BYTE(ulid, 13) & 240) >> 4))); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 13) & 15) << 1) | ((GET_BYTE(ulid, 14) & 128) >> 7))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 14) & 124) >> 2)); + output = output || CHR(GET_BYTE(encoding, ((GET_BYTE(ulid, 14) & 3) << 3) | ((GET_BYTE(ulid, 15) & 224) >> 5))); + output = output || CHR(GET_BYTE(encoding, (GET_BYTE(ulid, 15) & 31))); + + RETURN output; +END +$$ +LANGUAGE plpgsql +VOLATILE; + + +CREATE OR REPLACE FUNCTION apalis.push_job( + job_type text, + job json DEFAULT NULL :: json, + status text DEFAULT 'Pending' :: text, + run_at timestamptz DEFAULT NOW() :: timestamptz, + max_attempts integer DEFAULT 25 :: integer +) RETURNS apalis.jobs AS $$ + + DECLARE + v_job_row apalis.jobs; + v_job_id text; + + BEGIN + IF job_type is not NULL and length(job_type) > 512 THEN raise exception 'Job_type is too long (max length: 512).' USING errcode = 'APAJT'; + END IF; + + IF max_attempts < 1 THEN raise exception 'Job maximum attempts must be at least 1.' USING errcode = 'APAMA'; + end IF; + + SELECT + CONCAT('JID-' || generate_ulid()) INTO v_job_id; + INSERT INTO + apalis.jobs + VALUES + ( + job, + v_job_id, + job_type, + status, + 0, + max_attempts, + run_at, + NULL, + NULL, + NULL, + NULL + ) + returning * INTO v_job_row; + RETURN v_job_row; +END; +$$ LANGUAGE plpgsql volatile; + + diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 9e88b710..66758816 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -6,12 +6,25 @@ edition = "2021" license = "Apache-2.0" [dependencies] -tokio = "1.26.0" +anyhow = "1.0.70" +apalis-core = { version = "0.4.0-alpha.4", features = ["extensions", "tokio-comp"] } +apalis-cron = "0.4.0-alpha.4" +apalis-sql = { version = "0.4.0-alpha.4", features = ["postgres", "tokio-comp"] } async-trait = "0.1.66" -tokio-stream = "0.1.12" +chrono = "0.4.24" futures-util = "0.3.27" -tracing = "0.1.37" +rand = "0.8.5" +rand_chacha = "0.3.1" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } +thiserror = "1.0.30" +tokio = "1.26.0" +tokio-stream = "0.1.12" +tower = { version = "0.4.13", features = ["util"] } +tracing = "0.1.37" +ulid = "1.0.0" +serde = { version = "1.0.159", features = ["derive"] } mas-storage = { path = "../storage" } mas-storage-pg = { path = "../storage-pg" } +mas-email = { path = "../email" } +mas-data-model = { path = "../data-model" } diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index ae9dfac8..319c1083 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,50 +14,66 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Repository, RepositoryAccess, SystemClock}; -use mas_storage_pg::PgRepository; -use sqlx::{Pool, Postgres}; -use tracing::{debug, error, info}; +use std::str::FromStr; -use super::Task; +use apalis_core::{ + builder::{WorkerBuilder, WorkerFactory}, + context::JobContext, + executor::TokioExecutor, + job::Job, + job_fn::job_fn, + monitor::Monitor, +}; +use apalis_cron::CronStream; +use chrono::{DateTime, Utc}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess}; +use tracing::{debug, info}; -#[derive(Clone)] -struct CleanupExpired(Pool, SystemClock); +use crate::{JobContextExt, State}; -impl std::fmt::Debug for CleanupExpired { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CleanupExpired").finish_non_exhaustive() +#[derive(Default, Clone)] +pub struct CleanupExpiredTokensJob { + scheduled: DateTime, +} + +impl From> for CleanupExpiredTokensJob { + fn from(scheduled: DateTime) -> Self { + Self { scheduled } } } -#[async_trait::async_trait] -impl Task for CleanupExpired { - async fn run(&self) { - let res = async move { - let mut repo = PgRepository::from_pool(&self.0).await?.boxed(); - let res = repo.oauth2_access_token().cleanup_expired(&self.1).await; - repo.save().await?; - res - } - .await; +impl Job for CleanupExpiredTokensJob { + const NAME: &'static str = "cleanup-expired-tokens"; +} - match res { - Ok(0) => { - debug!("no token to clean up"); - } - Ok(count) => { - info!(count, "cleaned up expired tokens"); - } - Err(error) => { - error!(?error, "failed to cleanup expired tokens"); - } - } +pub async fn cleanup_expired_tokens( + job: CleanupExpiredTokensJob, + ctx: JobContext, +) -> Result<(), Box> { + debug!("cleanup expired tokens job scheduled at {}", job.scheduled); + + let state = ctx.state(); + let clock = state.clock(); + let mut repo = state.repository().await?; + + let count = repo.oauth2_access_token().cleanup_expired(&clock).await?; + repo.save().await?; + + if count == 0 { + debug!("no token to clean up"); + } else { + info!(count, "cleaned up expired tokens"); } + + Ok(()) } -/// Cleanup expired tokens -#[must_use] -pub fn cleanup_expired(pool: &Pool) -> impl Task + Clone { - // XXX: the clock should come from somewhere else - CleanupExpired(pool.clone(), SystemClock::default()) +pub(crate) fn register(monitor: Monitor, state: &State) -> Monitor { + let schedule = apalis_cron::Schedule::from_str("*/15 * * * * *").unwrap(); + let worker = WorkerBuilder::new("cleanup-expired-tokens") + .stream(CronStream::new(schedule).to_stream()) + .layer(state.inject()) + .build(job_fn(cleanup_expired_tokens)); + + monitor.register(worker) } diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs new file mode 100644 index 00000000..3a07fa49 --- /dev/null +++ b/crates/tasks/src/email.rs @@ -0,0 +1,111 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Context; +use apalis_core::{ + builder::{WorkerBuilder, WorkerFactory}, + context::JobContext, + executor::TokioExecutor, + job::Job, + job_fn::job_fn, + monitor::Monitor, + storage::builder::WithStorage, +}; +use chrono::Duration; +use mas_data_model::UserEmail; +use mas_email::{Address, EmailVerificationContext, Mailbox}; +use rand::{distributions::Uniform, Rng}; +use serde::{Deserialize, Serialize}; +use tracing::info; +use ulid::Ulid; + +use crate::{JobContextExt, State}; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct VerifyEmailJob { + user_email_id: Ulid, +} + +impl VerifyEmailJob { + #[must_use] + pub fn new(user_email: &UserEmail) -> Self { + Self { + user_email_id: user_email.id, + } + } +} + +impl Job for VerifyEmailJob { + const NAME: &'static str = "verify-email"; +} + +async fn verify_email(job: VerifyEmailJob, ctx: JobContext) -> Result<(), anyhow::Error> { + let state = ctx.state(); + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let mailer = state.mailer(); + let clock = state.clock(); + + // Lookup the user email + let user_email = repo + .user_email() + .lookup(job.user_email_id) + .await? + .context("User email not found")?; + + // Lookup the user associated with the email + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("User not found")?; + + // Generate a verification code + let range = Uniform::::from(0..1_000_000); + let code = rng.sample(range); + let code = format!("{code:06}"); + + let address: Address = user_email.email.parse()?; + + // Save the verification code in the database + let verification = repo + .user_email() + .add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code) + .await?; + + // And send the verification email + let mailbox = Mailbox::new(Some(user.username.clone()), address); + + let context = EmailVerificationContext::new(user.clone(), verification.clone()); + + mailer.send_verification_email(mailbox, &context).await?; + + info!( + email.id = %user_email.id, + "Verification email sent" + ); + + repo.save().await?; + + Ok(()) +} + +pub(crate) fn register(monitor: Monitor, state: &State) -> Monitor { + let storage = state.store(); + let worker = WorkerBuilder::new("verify-email") + .layer(state.inject()) + .with_storage(storage) + .build(job_fn(verify_email)); + monitor.register(worker) +} diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index a154afe7..5f500397 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,113 +12,96 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Generic, sequential task scheduler -//! -//! Tasks here are ran one after another to avoid having to unnecesarily lock -//! resources and avoid database conflicts. Tasks are not persisted, which is -//! considered "good enough" for now. - #![forbid(unsafe_code)] -#![deny( - clippy::all, - clippy::str_to_string, - missing_docs, - rustdoc::broken_intra_doc_links -)] +#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] #![warn(clippy::pedantic)] -use std::{collections::VecDeque, sync::Arc, time::Duration}; - -use futures_util::StreamExt; -use tokio::{ - sync::{Mutex, Notify}, - time::Interval, -}; -use tokio_stream::wrappers::IntervalStream; +use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor}; +use apalis_sql::postgres::PostgresStorage; +use mas_email::Mailer; +use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock}; +use mas_storage_pg::{DatabaseError, PgRepository}; +use rand::SeedableRng; +use sqlx::{Pool, Postgres}; use tracing::debug; mod database; +mod email; -pub use self::database::cleanup_expired; +pub use self::email::VerifyEmailJob; -/// A [`Task`] can be executed by a [`TaskQueue`] -#[async_trait::async_trait] -pub trait Task: std::fmt::Debug + Send + Sync + 'static { - /// Execute the [`Task`] - async fn run(&self); +#[derive(Clone)] +struct State { + pool: Pool, + mailer: Mailer, + clock: SystemClock, } -#[derive(Default)] -struct TaskQueueInner { - pending_tasks: Mutex>>, - notifier: Notify, -} - -impl TaskQueueInner { - async fn recuring(&self, interval: Interval, task: T) { - let mut stream = IntervalStream::new(interval); - - while (stream.next()).await.is_some() { - self.schedule(task.clone()).await; +impl State { + pub fn new(pool: Pool, clock: SystemClock, mailer: Mailer) -> Self { + Self { + pool, + mailer, + clock, } } - async fn schedule(&self, task: T) { - let task = Box::new(task); - self.pending_tasks.lock().await.push_back(task); - self.notifier.notify_one(); + pub fn inject(&self) -> Extension { + Extension(self.clone()) } - async fn tick(&self) { - loop { - let pending = { - let mut tasks = self.pending_tasks.lock().await; - tasks.pop_front() - }; - - if let Some(pending) = pending { - pending.run().await; - } else { - break; - } - } + pub fn pool(&self) -> &Pool { + &self.pool } - async fn run_forever(&self) { - loop { - self.notifier.notified().await; - self.tick().await; - } + pub fn clock(&self) -> BoxClock { + Box::new(self.clock.clone()) + } + + pub fn store(&self) -> PostgresStorage { + PostgresStorage::new(self.pool.clone()) + } + + pub fn mailer(&self) -> &Mailer { + &self.mailer + } + + pub fn rng(&self) -> rand_chacha::ChaChaRng { + let _ = self; + + // This is fine. + #[allow(clippy::disallowed_methods)] + rand_chacha::ChaChaRng::from_rng(rand::thread_rng()).expect("failed to seed rng") + } + + pub async fn repository(&self) -> Result { + let repo = PgRepository::from_pool(self.pool()) + .await? + .map_err(mas_storage::RepositoryError::from_error) + .boxed(); + + Ok(repo) } } -/// A [`TaskQueue`] executes tasks inserted in it in order -#[derive(Default)] -pub struct TaskQueue { - inner: Arc, +trait JobContextExt { + fn state(&self) -> State; } -impl TaskQueue { - /// Start the task queue to run forever - pub fn start(&self) { - let queue = self.inner.clone(); - tokio::task::spawn(async move { - queue.run_forever().await; - }); - } - - #[allow(dead_code)] - async fn schedule(&self, task: T) { - let queue = self.inner.clone(); - queue.schedule(task).await; - } - - /// Schedule a task in the queue at regular intervals - pub fn recuring(&self, every: Duration, task: impl Task + Clone + std::fmt::Debug) { - debug!(?task, period = every.as_secs(), "Scheduling recuring task"); - let queue = self.inner.clone(); - tokio::task::spawn(async move { - queue.recuring(tokio::time::interval(every), task).await; - }); +impl JobContextExt for apalis_core::context::JobContext { + fn state(&self) -> State { + self.data_opt::() + .expect("state not injected in job context") + .clone() } } + +#[must_use] +pub fn init(pool: &Pool, mailer: &Mailer) -> Monitor { + let state = State::new(pool.clone(), SystemClock::default(), mailer.clone()); + let monitor = Monitor::new(); + let monitor = self::database::register(monitor, &state); + let monitor = self::email::register(monitor, &state); + debug!(?monitor, "workers registered"); + monitor +}