diff --git a/Cargo.lock b/Cargo.lock index 094e0e6a..7194e647 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3579,6 +3579,7 @@ dependencies = [ "async-stream", "async-trait", "chrono", + "event-listener", "futures-lite", "mas-data-model", "mas-email", diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 6331cdd8..0f7449e6 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -12,12 +12,13 @@ apalis-cron = "0.4.2" async-stream = "0.3.5" async-trait = "0.1.71" chrono = "0.4.26" +event-listener = "2.5.3" futures-lite = "1.13.0" rand = "0.8.5" rand_chacha = "0.3.1" sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres"] } thiserror = "1.0.43" -tokio = { version = "1.29.1", features = ["macros", "time"] } +tokio = { version = "1.29.1", features = ["rt"] } tower = "0.4.13" tracing = "0.1.37" tracing-opentelemetry = "0.19.0" diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index 47fa4333..e1caf712 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -105,7 +105,9 @@ pub(crate) fn register( .layer(state.inject()) .layer(trace_layer()) .layer(metrics_layer()) - .with_storage(storage) + .with_storage_config(storage, |c| { + c.fetch_interval(std::time::Duration::from_secs(1)) + }) .build_fn(verify_email); monitor.register(worker) } diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index 313a321e..395dccbb 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::time::Duration; + use anyhow::Context; use apalis_core::{ builder::{WorkerBuilder, WorkerFactoryFn}, @@ -167,7 +169,7 @@ pub(crate) fn register( .layer(state.inject()) .layer(trace_layer()) .layer(metrics_layer()) - .with_storage(storage) + .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) .build_fn(provision_user); let storage = storage_factory.build(); @@ -176,7 +178,7 @@ pub(crate) fn register( .layer(state.inject()) .layer(trace_layer()) .layer(metrics_layer()) - .with_storage(storage) + .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) .build_fn(provision_device); let storage = storage_factory.build(); @@ -185,7 +187,7 @@ pub(crate) fn register( .layer(state.inject()) .layer(trace_layer()) .layer(metrics_layer()) - .with_storage(storage) + .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1))) .build_fn(delete_device); monitor diff --git a/crates/tasks/src/storage/postgres.rs b/crates/tasks/src/storage/postgres.rs index 1cb6566d..3d92cc60 100644 --- a/crates/tasks/src/storage/postgres.rs +++ b/crates/tasks/src/storage/postgres.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{convert::TryInto, marker::PhantomData, ops::Add, time::Duration}; +use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration}; use apalis_core::{ error::JobStreamError, @@ -24,6 +24,7 @@ use apalis_core::{ }; use async_stream::try_stream; use chrono::{DateTime, Utc}; +use event_listener::Event; use futures_lite::{Stream, StreamExt}; use serde::{de::DeserializeOwned, Serialize}; use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row}; @@ -33,17 +34,14 @@ use super::SqlJobRequest; pub struct StorageFactory { pool: PgPool, - sender: tokio::sync::watch::Sender<()>, - receiver: tokio::sync::watch::Receiver<()>, + event: Arc, } impl StorageFactory { pub fn new(pool: Pool) -> Self { - let (sender, receiver) = tokio::sync::watch::channel(()); StorageFactory { pool, - sender, - receiver, + event: Arc::new(Event::new()), } } @@ -54,8 +52,8 @@ impl StorageFactory { let handle = tokio::spawn(async move { loop { let notification = listener.recv().await.expect("Failed to poll notification"); - self.sender.send(()).expect("Failed to send notification"); - tracing::debug!(?notification, "Received notification"); + self.event.notify(usize::MAX); + tracing::debug!(?notification, "Broadcast notification"); } }); @@ -65,7 +63,7 @@ impl StorageFactory { pub fn build(&self) -> Storage { Storage { pool: self.pool.clone(), - notifier: self.receiver.clone(), + event: self.event.clone(), job_type: PhantomData, } } @@ -75,7 +73,7 @@ impl StorageFactory { #[derive(Debug)] pub struct Storage { pool: PgPool, - notifier: tokio::sync::watch::Receiver<()>, + event: Arc, job_type: PhantomData, } @@ -83,7 +81,7 @@ impl Clone for Storage { fn clone(&self) -> Self { Storage { pool: self.pool.clone(), - notifier: self.notifier.clone(), + event: self.event.clone(), job_type: PhantomData, } } @@ -97,24 +95,21 @@ impl Storage { buffer_size: usize, ) -> impl Stream, JobStreamError>> { let pool = self.pool.clone(); - let mut notifier = self.notifier.clone(); let sleeper = apalis_core::utils::timer::TokioTimer; let worker_id = worker_id.clone(); + let event = self.event.clone(); try_stream! { loop { // Wait for a notification or a timeout + let listener = event.listen(); let interval = sleeper.sleep(interval); - let res = tokio::select! { - biased; changed = notifier.changed() => changed.map_err(|e| JobStreamError::BrokenPipe(Box::from(e))), - _ = interval => Ok(()), - }; - res?; + futures_lite::future::race(interval, listener).await; let tx = pool.clone(); let job_type = T::NAME; let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);"; let jobs: Vec> = sqlx::query_as(fetch_query) - .bind(worker_id.to_string()) + .bind(worker_id.name()) .bind(job_type) // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html .bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?) @@ -141,7 +136,7 @@ impl Storage { ON CONFLICT (id) DO UPDATE SET last_seen = EXCLUDED.last_seen"; sqlx::query(query) - .bind(worker_id.to_string()) + .bind(worker_id.name()) .bind(worker_type) .bind(storage_name) .bind(std::any::type_name::()) @@ -272,7 +267,7 @@ where "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; sqlx::query(query) .bind(job_id.to_string()) - .bind(worker_id.to_string()) + .bind(worker_id.name()) .execute(&mut *conn) .await .map_err(|e| StorageError::Database(Box::from(e)))?; @@ -292,7 +287,7 @@ where "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; sqlx::query(query) .bind(job_id.to_string()) - .bind(worker_id.to_string()) + .bind(worker_id.name()) .execute(&mut *conn) .await .map_err(|e| StorageError::Database(Box::from(e)))?; @@ -327,7 +322,7 @@ where "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2"; sqlx::query(query) .bind(job_id.to_string()) - .bind(worker_id.to_string()) + .bind(worker_id.name()) .execute(&pool) .await .map_err(|e| StorageError::Database(Box::from(e)))?;