From 68db56c2a26af373937b3c24e4fed7c39171168a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 17 Jul 2023 17:36:37 +0200 Subject: [PATCH] Reimplementation of a postgres-backed storage with a shared PG listener --- Cargo.lock | 69 +-- crates/axum-utils/src/http_client_factory.rs | 2 +- crates/cli/src/commands/server.rs | 2 +- crates/cli/src/commands/worker.rs | 2 +- .../src/layers/form_urlencoded_request.rs | 4 +- crates/http/src/layers/json_request.rs | 4 +- crates/http/src/layers/json_response.rs | 4 +- crates/oidc-client/src/http_service/hyper.rs | 2 +- crates/tasks/Cargo.toml | 7 +- crates/tasks/src/email.rs | 4 +- crates/tasks/src/lib.rs | 27 +- crates/tasks/src/matrix.rs | 8 +- crates/tasks/src/storage/from_row.rs | 78 ++++ crates/tasks/src/storage/mod.rs | 22 + crates/tasks/src/storage/postgres.rs | 400 ++++++++++++++++++ 15 files changed, 540 insertions(+), 95 deletions(-) create mode 100644 crates/tasks/src/storage/from_row.rs create mode 100644 crates/tasks/src/storage/mod.rs create mode 100644 crates/tasks/src/storage/postgres.rs diff --git a/Cargo.lock b/Cargo.lock index 4a70dec9..d88c38cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -216,25 +216,6 @@ dependencies = [ "tower", ] -[[package]] -name = "apalis-sql" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df8e0971620c65068a26b8fe06ebdec8a235388991ecdd64c267d9edae313244" -dependencies = [ - "apalis-core", - "async-stream", - "async-trait", - "chrono", - "debounced", - "futures 0.3.28", - "futures-lite", - "serde", - "serde_json", - "sqlx", - "tokio", -] - [[package]] name = "arbitrary" version = "1.3.0" @@ -1702,16 +1683,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eaa37046cc0f6c3cc6090fbdbf73ef0b8ef4cfcc37f6befc0020f63e8cf121e1" -[[package]] -name = "debounced" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d8b0346b9fa0aa01a3fa4bcce48d62f8738e9c2956e92f275bbf6cf9d6fab5" -dependencies = [ - "futures-timer", - "futures-util", -] - [[package]] name = "debugid" version = "0.8.0" @@ -2011,18 +1982,6 @@ dependencies = [ "log", ] -[[package]] -name = "flume" -version = "0.10.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" -dependencies = [ - "futures-core", - "futures-sink", - "pin-project", - "spin 0.9.8", -] - [[package]] name = "fnv" version = "1.0.7" @@ -2960,17 +2919,6 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" -[[package]] -name = "libsqlite3-sys" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898745e570c7d0453cc1fbc4a701eb6c662ed54e8fec8b7d14be137ebeeb9d14" -dependencies = [ - "cc", - "pkg-config", - "vcpkg", -] - [[package]] name = "linked-hash-map" version = "0.5.6" @@ -3569,9 +3517,10 @@ dependencies = [ "anyhow", "apalis-core", "apalis-cron", - "apalis-sql", + "async-stream", "async-trait", "chrono", + "futures-lite", "mas-data-model", "mas-email", "mas-matrix", @@ -3582,8 +3531,10 @@ dependencies = [ "rand 0.8.5", "rand_chacha 0.3.1", "serde", + "serde_json", "sqlx", "thiserror", + "tokio", "tower", "tracing", "tracing-opentelemetry", @@ -5533,9 +5484,6 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] [[package]] name = "spki" @@ -5593,10 +5541,8 @@ dependencies = [ "dotenvy", "either", "event-listener", - "flume", "futures-channel", "futures-core", - "futures-executor", "futures-intrusive", "futures-util", "hashlink", @@ -5606,7 +5552,6 @@ dependencies = [ "indexmap 1.9.3", "itoa", "libc", - "libsqlite3-sys", "log", "md-5", "memchr", @@ -6520,12 +6465,6 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4d330786735ea358f3bc09eea4caa098569c1c93f342d9aca0514915022fe7e" -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.4" diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs index b2b562a0..e6644407 100644 --- a/crates/axum-utils/src/http_client_factory.rs +++ b/crates/axum-utils/src/http_client_factory.rs @@ -63,7 +63,7 @@ impl HttpClientFactory { let client = ( MapErrLayer::new(BoxError::from), MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)), - BodyToBytesResponseLayer::default(), + BodyToBytesResponseLayer, ) .layer(client); diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index d0d23cc5..c0f1c0e6 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -110,7 +110,7 @@ impl Options { config.matrix.secret.clone(), http_client_factory, ); - let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn); + let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?; // TODO: grab the handle tokio::spawn(monitor.run()); } diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index e0940bb6..1c11256f 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -64,7 +64,7 @@ impl Options { let worker_name = Alphanumeric.sample_string(&mut rng, 10); info!(worker_name, "Starting task scheduler"); - let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn); + let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?; span.exit(); diff --git a/crates/http/src/layers/form_urlencoded_request.rs b/crates/http/src/layers/form_urlencoded_request.rs index 1d414573..102abeb4 100644 --- a/crates/http/src/layers/form_urlencoded_request.rs +++ b/crates/http/src/layers/form_urlencoded_request.rs @@ -106,9 +106,7 @@ pub struct FormUrlencodedRequestLayer { impl Default for FormUrlencodedRequestLayer { fn default() -> Self { - Self { - _t: PhantomData::default(), - } + Self { _t: PhantomData } } } diff --git a/crates/http/src/layers/json_request.rs b/crates/http/src/layers/json_request.rs index 52d2fb3f..6631cb5f 100644 --- a/crates/http/src/layers/json_request.rs +++ b/crates/http/src/layers/json_request.rs @@ -106,9 +106,7 @@ pub struct JsonRequestLayer { impl Default for JsonRequestLayer { fn default() -> Self { - Self { - _t: PhantomData::default(), - } + Self { _t: PhantomData } } } diff --git a/crates/http/src/layers/json_response.rs b/crates/http/src/layers/json_response.rs index 171e3ecf..1dc90391 100644 --- a/crates/http/src/layers/json_response.rs +++ b/crates/http/src/layers/json_response.rs @@ -106,9 +106,7 @@ pub struct JsonResponseLayer { impl Default for JsonResponseLayer { fn default() -> Self { - Self { - _t: PhantomData::default(), - } + Self { _t: PhantomData } } } diff --git a/crates/oidc-client/src/http_service/hyper.rs b/crates/oidc-client/src/http_service/hyper.rs index 8be0fa11..d1c4f04b 100644 --- a/crates/oidc-client/src/http_service/hyper.rs +++ b/crates/oidc-client/src/http_service/hyper.rs @@ -57,7 +57,7 @@ pub fn hyper_service() -> HttpService { let client = ServiceBuilder::new() .map_err(BoxError::from) .map_request_body(Full::new) - .layer(BodyToBytesResponseLayer::default()) + .layer(BodyToBytesResponseLayer) .override_request_header(USER_AGENT, MAS_USER_AGENT.clone()) .concurrency_limit(10) .follow_redirects() diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 017d3c18..97a472f1 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -7,15 +7,17 @@ license = "apache-2.0" [dependencies] anyhow = "1.0.71" -apalis-core = { version = "0.4.2", features = ["extensions", "tokio-comp"] } +apalis-core = { version = "0.4.2", features = ["extensions", "tokio-comp", "storage"] } apalis-cron = "0.4.2" -apalis-sql = { version = "0.4.2", features = ["postgres", "tokio-comp"] } +async-stream = "0.3.5" async-trait = "0.1.71" chrono = "0.4.26" +futures-lite = "1.13.0" rand = "0.8.5" rand_chacha = "0.3.1" sqlx = { version = "0.6.3", features = ["runtime-tokio-rustls", "postgres"] } thiserror = "1.0.41" +tokio = { version = "1.28.2", features = ["macros", "time"] } tower = "0.4.13" tracing = "0.1.37" tracing-opentelemetry = "0.19.0" @@ -23,6 +25,7 @@ opentelemetry = "0.19.0" ulid = "1.0.0" url = "2.4.0" serde = { version = "1.0.166", features = ["derive"] } +serde_json = "1.0.97" mas-data-model = { path = "../data-model" } mas-email = { path = "../email" } diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index dd229505..47fa4333 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -28,6 +28,7 @@ use rand::{distributions::Uniform, Rng}; use tracing::info; use crate::{ + storage::PostgresStorageFactory, utils::{metrics_layer, trace_layer}, JobContextExt, State, }; @@ -96,8 +97,9 @@ pub(crate) fn register( suffix: &str, monitor: Monitor, state: &State, + storage_factory: &PostgresStorageFactory, ) -> Monitor { - let storage = state.store(); + let storage = storage_factory.build(); let worker_name = format!("{job}-{suffix}", job = VerifyEmailJob::NAME); let worker = WorkerBuilder::new(worker_name) .layer(state.inject()) diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index a524bcdc..da599687 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -19,7 +19,6 @@ use std::sync::Arc; use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor}; -use apalis_sql::postgres::PostgresStorage; use mas_email::Mailer; use mas_matrix::HomeserverConnection; use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock}; @@ -28,9 +27,12 @@ use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tracing::debug; +use crate::storage::PostgresStorageFactory; + mod database; mod email; mod matrix; +mod storage; mod utils; #[derive(Clone)] @@ -68,10 +70,6 @@ impl State { Box::new(self.clock.clone()) } - pub fn store(&self) -> PostgresStorage { - PostgresStorage::new(self.pool.clone()) - } - pub fn mailer(&self) -> &Mailer { &self.mailer } @@ -108,23 +106,30 @@ impl JobContextExt for apalis_core::context::JobContext { } } -#[must_use] -pub fn init( +/// Initialise the workers. +/// +/// # Errors +/// +/// This function can fail if the database connection fails. +pub async fn init( name: &str, pool: &Pool, mailer: &Mailer, homeserver: impl HomeserverConnection + 'static, -) -> Monitor { +) -> Result, sqlx::Error> { let state = State::new( pool.clone(), SystemClock::default(), mailer.clone(), homeserver, ); + let factory = PostgresStorageFactory::new(pool.clone()); let monitor = Monitor::new().executor(TokioExecutor::new()); let monitor = self::database::register(name, monitor, &state); - let monitor = self::email::register(name, monitor, &state); - let monitor = self::matrix::register(name, monitor, &state); + let monitor = self::email::register(name, monitor, &state, &factory); + let monitor = self::matrix::register(name, monitor, &state, &factory); + // TODO: we might want to grab the join handle here + factory.listen().await?; debug!(?monitor, "workers registered"); - monitor + Ok(monitor) } diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index 649f297d..313a321e 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -30,6 +30,7 @@ use mas_storage::{ use tracing::info; use crate::{ + storage::PostgresStorageFactory, utils::{metrics_layer, trace_layer}, JobContextExt, State, }; @@ -158,8 +159,9 @@ pub(crate) fn register( suffix: &str, monitor: Monitor, state: &State, + storage_factory: &PostgresStorageFactory, ) -> Monitor { - let storage = state.store(); + let storage = storage_factory.build(); let worker_name = format!("{job}-{suffix}", job = ProvisionUserJob::NAME); let provision_user_worker = WorkerBuilder::new(worker_name) .layer(state.inject()) @@ -168,7 +170,7 @@ pub(crate) fn register( .with_storage(storage) .build_fn(provision_user); - let storage = state.store(); + let storage = storage_factory.build(); let worker_name = format!("{job}-{suffix}", job = ProvisionDeviceJob::NAME); let provision_device_worker = WorkerBuilder::new(worker_name) .layer(state.inject()) @@ -177,7 +179,7 @@ pub(crate) fn register( .with_storage(storage) .build_fn(provision_device); - let storage = state.store(); + let storage = storage_factory.build(); let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME); let delete_device_worker = WorkerBuilder::new(worker_name) .layer(state.inject()) diff --git a/crates/tasks/src/storage/from_row.rs b/crates/tasks/src/storage/from_row.rs new file mode 100644 index 00000000..d2620365 --- /dev/null +++ b/crates/tasks/src/storage/from_row.rs @@ -0,0 +1,78 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::str::FromStr; + +use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId}; +use chrono::{DateTime, Utc}; +use serde_json::Value; +use sqlx::Row; + +/// Wrapper for [`JobRequest`] +pub(crate) struct SqlJobRequest(JobRequest); + +impl From> for JobRequest { + fn from(val: SqlJobRequest) -> Self { + val.0 + } +} + +impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow> + for SqlJobRequest +{ + fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { + let job: Value = row.try_get("job")?; + let id: JobId = + JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { + index: "id".to_owned(), + source: Box::new(e), + })?; + let mut context = JobContext::new(id); + + let run_at = row.try_get("run_at")?; + context.set_run_at(run_at); + + let attempts = row.try_get("attempts").unwrap_or(0); + context.set_attempts(attempts); + + let max_attempts = row.try_get("max_attempts").unwrap_or(25); + context.set_max_attempts(max_attempts); + + let done_at: Option> = row.try_get("done_at").unwrap_or_default(); + context.set_done_at(done_at); + + let lock_at: Option> = row.try_get("lock_at").unwrap_or_default(); + context.set_lock_at(lock_at); + + let last_error = row.try_get("last_error").unwrap_or_default(); + context.set_last_error(last_error); + + let status: String = row.try_get("status")?; + context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode { + index: "job".to_owned(), + source: Box::new(e), + })?); + + let lock_by: Option = row.try_get("lock_by").unwrap_or_default(); + context.set_lock_by(lock_by.map(WorkerId::new)); + + Ok(SqlJobRequest(JobRequest::new_with_context( + serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode { + index: "job".to_owned(), + source: Box::new(e), + })?, + context, + ))) + } +} diff --git a/crates/tasks/src/storage/mod.rs b/crates/tasks/src/storage/mod.rs new file mode 100644 index 00000000..7884c083 --- /dev/null +++ b/crates/tasks/src/storage/mod.rs @@ -0,0 +1,22 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a +//! shared connection for the [`PgListener`] + +mod from_row; +mod postgres; + +use self::from_row::SqlJobRequest; +pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory; diff --git a/crates/tasks/src/storage/postgres.rs b/crates/tasks/src/storage/postgres.rs new file mode 100644 index 00000000..2c86f423 --- /dev/null +++ b/crates/tasks/src/storage/postgres.rs @@ -0,0 +1,400 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{convert::TryInto, marker::PhantomData, ops::Add, time::Duration}; + +use apalis_core::{ + error::JobStreamError, + job::{Job, JobId, JobStreamResult}, + request::JobRequest, + storage::{StorageError, StorageResult, StorageWorkerPulse}, + utils::Timer, + worker::WorkerId, +}; +use async_stream::try_stream; +use chrono::{DateTime, Utc}; +use futures_lite::{Stream, StreamExt}; +use serde::{de::DeserializeOwned, Serialize}; +use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row}; +use tokio::task::JoinHandle; + +use super::SqlJobRequest; + +pub struct StorageFactory { + pool: PgPool, + sender: tokio::sync::watch::Sender<()>, + receiver: tokio::sync::watch::Receiver<()>, +} + +impl StorageFactory { + pub fn new(pool: Pool) -> Self { + let (sender, receiver) = tokio::sync::watch::channel(()); + StorageFactory { + pool, + sender, + receiver, + } + } + + pub async fn listen(self) -> Result, sqlx::Error> { + let mut listener = PgListener::connect_with(&self.pool).await?; + listener.listen("apalis::job").await?; + + let handle = tokio::spawn(async move { + loop { + let notification = listener.recv().await.expect("Failed to poll notification"); + self.sender.send(()).expect("Failed to send notification"); + tracing::debug!(?notification, "Received notification"); + } + }); + + Ok(handle) + } + + pub fn build(&self) -> Storage { + Storage { + pool: self.pool.clone(), + notifier: self.receiver.clone(), + job_type: PhantomData, + } + } +} + +/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres +#[derive(Debug)] +pub struct Storage { + pool: PgPool, + notifier: tokio::sync::watch::Receiver<()>, + job_type: PhantomData, +} + +impl Clone for Storage { + fn clone(&self) -> Self { + Storage { + pool: self.pool.clone(), + notifier: self.notifier.clone(), + job_type: PhantomData, + } + } +} + +impl Storage { + fn stream_jobs( + &self, + worker_id: &WorkerId, + interval: Duration, + buffer_size: usize, + ) -> impl Stream, JobStreamError>> { + let pool = self.pool.clone(); + let mut notifier = self.notifier.clone(); + let sleeper = apalis_core::utils::timer::TokioTimer; + let worker_id = worker_id.clone(); + try_stream! { + loop { + // Wait for a notification or a timeout + let interval = sleeper.sleep(interval); + let res = tokio::select! { + biased; changed = notifier.changed() => changed.map_err(|e| JobStreamError::BrokenPipe(Box::from(e))), + _ = interval => Ok(()), + }; + res?; + + let tx = pool.clone(); + let job_type = T::NAME; + let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);"; + let jobs: Vec> = sqlx::query_as(fetch_query) + .bind(worker_id.to_string()) + .bind(job_type) + // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html + .bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?) + .fetch_all(&tx) + .await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?; + for job in jobs { + yield job.into() + } + } + } + } + + async fn keep_alive_at( + &mut self, + worker_id: &WorkerId, + last_seen: DateTime, + ) -> StorageResult<()> { + let pool = self.pool.clone(); + + let worker_type = T::NAME; + let storage_name = std::any::type_name::(); + let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (id) DO + UPDATE SET last_seen = EXCLUDED.last_seen"; + sqlx::query(query) + .bind(worker_id.to_string()) + .bind(worker_type) + .bind(storage_name) + .bind(std::any::type_name::()) + .bind(last_seen) + .execute(&pool) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(()) + } +} + +#[async_trait::async_trait] +impl apalis_core::storage::Storage for Storage +where + T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, +{ + type Output = T; + + /// Push a job to Postgres [Storage] + /// + /// # SQL Example + /// + /// ```sql + /// SELECT apalis.push_job(job_type::text, job::json); + /// ``` + async fn push(&mut self, job: Self::Output) -> StorageResult { + let id = JobId::new(); + let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)"; + let pool = self.pool.clone(); + let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; + let job_type = T::NAME; + sqlx::query(query) + .bind(job) + .bind(id.to_string()) + .bind(job_type) + .execute(&pool) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(id) + } + + async fn schedule( + &mut self, + job: Self::Output, + on: chrono::DateTime, + ) -> StorageResult { + let query = + "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)"; + + let mut conn = self + .pool + .acquire() + .await + .map_err(|e| StorageError::Connection(Box::from(e)))?; + + let id = JobId::new(); + let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; + let job_type = T::NAME; + sqlx::query(query) + .bind(job) + .bind(id.to_string()) + .bind(job_type) + .bind(on) + .execute(&mut conn) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(id) + } + + async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult>> { + let pool = self.pool.clone(); + + let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1"; + let res: Option> = sqlx::query_as(fetch_query) + .bind(job_id.to_string()) + .fetch_optional(&pool) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(res.map(Into::into)) + } + + async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult { + let pool = self.pool.clone(); + match pulse { + StorageWorkerPulse::EnqueueScheduled { count: _ } => { + // Ideally jobs are queue via run_at. So this is not necessary + Ok(true) + } + + // Worker not seen in 5 minutes yet has running jobs + StorageWorkerPulse::ReenqueueOrphaned { count } => { + let job_type = T::NAME; + let mut tx = pool + .acquire() + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + let query = "UPDATE apalis.jobs + SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned' + WHERE id in + (SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id + WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes' + AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);"; + sqlx::query(query) + .bind(job_type) + .bind(count) + .execute(&mut tx) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(true) + } + + _ => unimplemented!(), + } + } + + async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { + let pool = self.pool.clone(); + + let mut tx = pool + .acquire() + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + let query = + "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; + sqlx::query(query) + .bind(job_id.to_string()) + .bind(worker_id.to_string()) + .execute(&mut tx) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(()) + } + + /// Puts the job instantly back into the queue + /// Another [Worker] may consume + async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { + let pool = self.pool.clone(); + + let mut tx = pool + .acquire() + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + let query = + "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; + sqlx::query(query) + .bind(job_id.to_string()) + .bind(worker_id.to_string()) + .execute(&mut tx) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(()) + } + + fn consume( + &mut self, + worker_id: &WorkerId, + interval: Duration, + buffer_size: usize, + ) -> JobStreamResult { + Box::pin( + self.stream_jobs(worker_id, interval, buffer_size) + .map(|r| r.map(Some)), + ) + } + async fn len(&self) -> StorageResult { + let pool = self.pool.clone(); + let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'"; + let record = sqlx::query(query) + .fetch_one(&pool) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(record + .try_get("count") + .map_err(|e| StorageError::Database(Box::from(e)))?) + } + async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { + let pool = self.pool.clone(); + let query = + "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2"; + sqlx::query(query) + .bind(job_id.to_string()) + .bind(worker_id.to_string()) + .execute(&pool) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(()) + } + + async fn reschedule(&mut self, job: &JobRequest, wait: Duration) -> StorageResult<()> { + let pool = self.pool.clone(); + let job_id = job.id(); + + let wait: i64 = wait + .as_secs() + .try_into() + .map_err(|e| StorageError::Database(Box::new(e)))?; + let wait = chrono::Duration::seconds(wait); + // TODO: should we use a clock here? + #[allow(clippy::disallowed_methods)] + let run_at = Utc::now().add(wait); + + let mut tx = pool + .acquire() + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + let query = + "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1"; + sqlx::query(query) + .bind(job_id.to_string()) + .bind(run_at) + .execute(&mut tx) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(()) + } + + async fn update_by_id( + &self, + job_id: &JobId, + job: &JobRequest, + ) -> StorageResult<()> { + let pool = self.pool.clone(); + let status = job.status().as_ref(); + let attempts = job.attempts(); + let done_at = *job.done_at(); + let lock_by = job.lock_by().clone(); + let lock_at = *job.lock_at(); + let last_error = job.last_error().clone(); + + let mut tx = pool + .acquire() + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + let query = + "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7"; + sqlx::query(query) + .bind(status.to_owned()) + .bind(attempts) + .bind(done_at) + .bind(lock_by.as_ref().map(WorkerId::name)) + .bind(lock_at) + .bind(last_error) + .bind(job_id.to_string()) + .execute(&mut tx) + .await + .map_err(|e| StorageError::Database(Box::from(e)))?; + Ok(()) + } + + async fn keep_alive(&mut self, worker_id: &WorkerId) -> StorageResult<()> { + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + + self.keep_alive_at::(worker_id, now).await + } +}