1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Switch to event-listener for in-process job notifications

This commit is contained in:
Quentin Gliech
2023-07-18 10:06:43 +02:00
parent f6d4bfdb76
commit ab00002acd
5 changed files with 28 additions and 27 deletions

1
Cargo.lock generated
View File

@ -3579,6 +3579,7 @@ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"chrono", "chrono",
"event-listener",
"futures-lite", "futures-lite",
"mas-data-model", "mas-data-model",
"mas-email", "mas-email",

View File

@ -12,12 +12,13 @@ apalis-cron = "0.4.2"
async-stream = "0.3.5" async-stream = "0.3.5"
async-trait = "0.1.71" async-trait = "0.1.71"
chrono = "0.4.26" chrono = "0.4.26"
event-listener = "2.5.3"
futures-lite = "1.13.0" futures-lite = "1.13.0"
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1" rand_chacha = "0.3.1"
sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres"] } sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres"] }
thiserror = "1.0.43" thiserror = "1.0.43"
tokio = { version = "1.29.1", features = ["macros", "time"] } tokio = { version = "1.29.1", features = ["rt"] }
tower = "0.4.13" tower = "0.4.13"
tracing = "0.1.37" tracing = "0.1.37"
tracing-opentelemetry = "0.19.0" tracing-opentelemetry = "0.19.0"

View File

@ -105,7 +105,9 @@ pub(crate) fn register(
.layer(state.inject()) .layer(state.inject())
.layer(trace_layer()) .layer(trace_layer())
.layer(metrics_layer()) .layer(metrics_layer())
.with_storage(storage) .with_storage_config(storage, |c| {
c.fetch_interval(std::time::Duration::from_secs(1))
})
.build_fn(verify_email); .build_fn(verify_email);
monitor.register(worker) monitor.register(worker)
} }

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use apalis_core::{ use apalis_core::{
builder::{WorkerBuilder, WorkerFactoryFn}, builder::{WorkerBuilder, WorkerFactoryFn},
@ -167,7 +169,7 @@ pub(crate) fn register(
.layer(state.inject()) .layer(state.inject())
.layer(trace_layer()) .layer(trace_layer())
.layer(metrics_layer()) .layer(metrics_layer())
.with_storage(storage) .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
.build_fn(provision_user); .build_fn(provision_user);
let storage = storage_factory.build(); let storage = storage_factory.build();
@ -176,7 +178,7 @@ pub(crate) fn register(
.layer(state.inject()) .layer(state.inject())
.layer(trace_layer()) .layer(trace_layer())
.layer(metrics_layer()) .layer(metrics_layer())
.with_storage(storage) .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
.build_fn(provision_device); .build_fn(provision_device);
let storage = storage_factory.build(); let storage = storage_factory.build();
@ -185,7 +187,7 @@ pub(crate) fn register(
.layer(state.inject()) .layer(state.inject())
.layer(trace_layer()) .layer(trace_layer())
.layer(metrics_layer()) .layer(metrics_layer())
.with_storage(storage) .with_storage_config(storage, |c| c.fetch_interval(Duration::from_secs(1)))
.build_fn(delete_device); .build_fn(delete_device);
monitor monitor

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{convert::TryInto, marker::PhantomData, ops::Add, time::Duration}; use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration};
use apalis_core::{ use apalis_core::{
error::JobStreamError, error::JobStreamError,
@ -24,6 +24,7 @@ use apalis_core::{
}; };
use async_stream::try_stream; use async_stream::try_stream;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use event_listener::Event;
use futures_lite::{Stream, StreamExt}; use futures_lite::{Stream, StreamExt};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row}; use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row};
@ -33,17 +34,14 @@ use super::SqlJobRequest;
pub struct StorageFactory { pub struct StorageFactory {
pool: PgPool, pool: PgPool,
sender: tokio::sync::watch::Sender<()>, event: Arc<Event>,
receiver: tokio::sync::watch::Receiver<()>,
} }
impl StorageFactory { impl StorageFactory {
pub fn new(pool: Pool<Postgres>) -> Self { pub fn new(pool: Pool<Postgres>) -> Self {
let (sender, receiver) = tokio::sync::watch::channel(());
StorageFactory { StorageFactory {
pool, pool,
sender, event: Arc::new(Event::new()),
receiver,
} }
} }
@ -54,8 +52,8 @@ impl StorageFactory {
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
loop { loop {
let notification = listener.recv().await.expect("Failed to poll notification"); let notification = listener.recv().await.expect("Failed to poll notification");
self.sender.send(()).expect("Failed to send notification"); self.event.notify(usize::MAX);
tracing::debug!(?notification, "Received notification"); tracing::debug!(?notification, "Broadcast notification");
} }
}); });
@ -65,7 +63,7 @@ impl StorageFactory {
pub fn build<T>(&self) -> Storage<T> { pub fn build<T>(&self) -> Storage<T> {
Storage { Storage {
pool: self.pool.clone(), pool: self.pool.clone(),
notifier: self.receiver.clone(), event: self.event.clone(),
job_type: PhantomData, job_type: PhantomData,
} }
} }
@ -75,7 +73,7 @@ impl StorageFactory {
#[derive(Debug)] #[derive(Debug)]
pub struct Storage<T> { pub struct Storage<T> {
pool: PgPool, pool: PgPool,
notifier: tokio::sync::watch::Receiver<()>, event: Arc<Event>,
job_type: PhantomData<T>, job_type: PhantomData<T>,
} }
@ -83,7 +81,7 @@ impl<T> Clone for Storage<T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Storage { Storage {
pool: self.pool.clone(), pool: self.pool.clone(),
notifier: self.notifier.clone(), event: self.event.clone(),
job_type: PhantomData, job_type: PhantomData,
} }
} }
@ -97,24 +95,21 @@ impl<T: DeserializeOwned + Send + Unpin + Job> Storage<T> {
buffer_size: usize, buffer_size: usize,
) -> impl Stream<Item = Result<JobRequest<T>, JobStreamError>> { ) -> impl Stream<Item = Result<JobRequest<T>, JobStreamError>> {
let pool = self.pool.clone(); let pool = self.pool.clone();
let mut notifier = self.notifier.clone();
let sleeper = apalis_core::utils::timer::TokioTimer; let sleeper = apalis_core::utils::timer::TokioTimer;
let worker_id = worker_id.clone(); let worker_id = worker_id.clone();
let event = self.event.clone();
try_stream! { try_stream! {
loop { loop {
// Wait for a notification or a timeout // Wait for a notification or a timeout
let listener = event.listen();
let interval = sleeper.sleep(interval); let interval = sleeper.sleep(interval);
let res = tokio::select! { futures_lite::future::race(interval, listener).await;
biased; changed = notifier.changed() => changed.map_err(|e| JobStreamError::BrokenPipe(Box::from(e))),
_ = interval => Ok(()),
};
res?;
let tx = pool.clone(); let tx = pool.clone();
let job_type = T::NAME; let job_type = T::NAME;
let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);"; let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);";
let jobs: Vec<SqlJobRequest<T>> = sqlx::query_as(fetch_query) let jobs: Vec<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
.bind(worker_id.to_string()) .bind(worker_id.name())
.bind(job_type) .bind(job_type)
// https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html
.bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?) .bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?)
@ -141,7 +136,7 @@ impl<T: DeserializeOwned + Send + Unpin + Job> Storage<T> {
ON CONFLICT (id) DO ON CONFLICT (id) DO
UPDATE SET last_seen = EXCLUDED.last_seen"; UPDATE SET last_seen = EXCLUDED.last_seen";
sqlx::query(query) sqlx::query(query)
.bind(worker_id.to_string()) .bind(worker_id.name())
.bind(worker_type) .bind(worker_type)
.bind(storage_name) .bind(storage_name)
.bind(std::any::type_name::<Service>()) .bind(std::any::type_name::<Service>())
@ -272,7 +267,7 @@ where
"UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
sqlx::query(query) sqlx::query(query)
.bind(job_id.to_string()) .bind(job_id.to_string())
.bind(worker_id.to_string()) .bind(worker_id.name())
.execute(&mut *conn) .execute(&mut *conn)
.await .await
.map_err(|e| StorageError::Database(Box::from(e)))?; .map_err(|e| StorageError::Database(Box::from(e)))?;
@ -292,7 +287,7 @@ where
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
sqlx::query(query) sqlx::query(query)
.bind(job_id.to_string()) .bind(job_id.to_string())
.bind(worker_id.to_string()) .bind(worker_id.name())
.execute(&mut *conn) .execute(&mut *conn)
.await .await
.map_err(|e| StorageError::Database(Box::from(e)))?; .map_err(|e| StorageError::Database(Box::from(e)))?;
@ -327,7 +322,7 @@ where
"UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2"; "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2";
sqlx::query(query) sqlx::query(query)
.bind(job_id.to_string()) .bind(job_id.to_string())
.bind(worker_id.to_string()) .bind(worker_id.name())
.execute(&pool) .execute(&pool)
.await .await
.map_err(|e| StorageError::Database(Box::from(e)))?; .map_err(|e| StorageError::Database(Box::from(e)))?;