You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Reimplementation of a postgres-backed storage with a shared PG listener
This commit is contained in:
@ -63,7 +63,7 @@ impl HttpClientFactory {
|
||||
let client = (
|
||||
MapErrLayer::new(BoxError::from),
|
||||
MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)),
|
||||
BodyToBytesResponseLayer::default(),
|
||||
BodyToBytesResponseLayer,
|
||||
)
|
||||
.layer(client);
|
||||
|
||||
|
@ -110,7 +110,7 @@ impl Options {
|
||||
config.matrix.secret.clone(),
|
||||
http_client_factory,
|
||||
);
|
||||
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn);
|
||||
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
|
||||
// TODO: grab the handle
|
||||
tokio::spawn(monitor.run());
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ impl Options {
|
||||
let worker_name = Alphanumeric.sample_string(&mut rng, 10);
|
||||
|
||||
info!(worker_name, "Starting task scheduler");
|
||||
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn);
|
||||
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
|
||||
|
||||
span.exit();
|
||||
|
||||
|
@ -106,9 +106,7 @@ pub struct FormUrlencodedRequestLayer<T> {
|
||||
|
||||
impl<T> Default for FormUrlencodedRequestLayer<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
_t: PhantomData::default(),
|
||||
}
|
||||
Self { _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,9 +106,7 @@ pub struct JsonRequestLayer<T> {
|
||||
|
||||
impl<T> Default for JsonRequestLayer<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
_t: PhantomData::default(),
|
||||
}
|
||||
Self { _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,9 +106,7 @@ pub struct JsonResponseLayer<T> {
|
||||
|
||||
impl<T> Default for JsonResponseLayer<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
_t: PhantomData::default(),
|
||||
}
|
||||
Self { _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ pub fn hyper_service() -> HttpService {
|
||||
let client = ServiceBuilder::new()
|
||||
.map_err(BoxError::from)
|
||||
.map_request_body(Full::new)
|
||||
.layer(BodyToBytesResponseLayer::default())
|
||||
.layer(BodyToBytesResponseLayer)
|
||||
.override_request_header(USER_AGENT, MAS_USER_AGENT.clone())
|
||||
.concurrency_limit(10)
|
||||
.follow_redirects()
|
||||
|
@ -7,15 +7,17 @@ license = "apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.71"
|
||||
apalis-core = { version = "0.4.2", features = ["extensions", "tokio-comp"] }
|
||||
apalis-core = { version = "0.4.2", features = ["extensions", "tokio-comp", "storage"] }
|
||||
apalis-cron = "0.4.2"
|
||||
apalis-sql = { version = "0.4.2", features = ["postgres", "tokio-comp"] }
|
||||
async-stream = "0.3.5"
|
||||
async-trait = "0.1.71"
|
||||
chrono = "0.4.26"
|
||||
futures-lite = "1.13.0"
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
sqlx = { version = "0.6.3", features = ["runtime-tokio-rustls", "postgres"] }
|
||||
thiserror = "1.0.41"
|
||||
tokio = { version = "1.28.2", features = ["macros", "time"] }
|
||||
tower = "0.4.13"
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.19.0"
|
||||
@ -23,6 +25,7 @@ opentelemetry = "0.19.0"
|
||||
ulid = "1.0.0"
|
||||
url = "2.4.0"
|
||||
serde = { version = "1.0.166", features = ["derive"] }
|
||||
serde_json = "1.0.97"
|
||||
|
||||
mas-data-model = { path = "../data-model" }
|
||||
mas-email = { path = "../email" }
|
||||
|
@ -28,6 +28,7 @@ use rand::{distributions::Uniform, Rng};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
storage::PostgresStorageFactory,
|
||||
utils::{metrics_layer, trace_layer},
|
||||
JobContextExt, State,
|
||||
};
|
||||
@ -96,8 +97,9 @@ pub(crate) fn register(
|
||||
suffix: &str,
|
||||
monitor: Monitor<TokioExecutor>,
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let storage = state.store();
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = VerifyEmailJob::NAME);
|
||||
let worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
|
@ -19,7 +19,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor};
|
||||
use apalis_sql::postgres::PostgresStorage;
|
||||
use mas_email::Mailer;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock};
|
||||
@ -28,9 +27,12 @@ use rand::SeedableRng;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::storage::PostgresStorageFactory;
|
||||
|
||||
mod database;
|
||||
mod email;
|
||||
mod matrix;
|
||||
mod storage;
|
||||
mod utils;
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -68,10 +70,6 @@ impl State {
|
||||
Box::new(self.clock.clone())
|
||||
}
|
||||
|
||||
pub fn store<J>(&self) -> PostgresStorage<J> {
|
||||
PostgresStorage::new(self.pool.clone())
|
||||
}
|
||||
|
||||
pub fn mailer(&self) -> &Mailer {
|
||||
&self.mailer
|
||||
}
|
||||
@ -108,23 +106,30 @@ impl JobContextExt for apalis_core::context::JobContext {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn init(
|
||||
/// Initialise the workers.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function can fail if the database connection fails.
|
||||
pub async fn init(
|
||||
name: &str,
|
||||
pool: &Pool<Postgres>,
|
||||
mailer: &Mailer,
|
||||
homeserver: impl HomeserverConnection<Error = anyhow::Error> + 'static,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
) -> Result<Monitor<TokioExecutor>, sqlx::Error> {
|
||||
let state = State::new(
|
||||
pool.clone(),
|
||||
SystemClock::default(),
|
||||
mailer.clone(),
|
||||
homeserver,
|
||||
);
|
||||
let factory = PostgresStorageFactory::new(pool.clone());
|
||||
let monitor = Monitor::new().executor(TokioExecutor::new());
|
||||
let monitor = self::database::register(name, monitor, &state);
|
||||
let monitor = self::email::register(name, monitor, &state);
|
||||
let monitor = self::matrix::register(name, monitor, &state);
|
||||
let monitor = self::email::register(name, monitor, &state, &factory);
|
||||
let monitor = self::matrix::register(name, monitor, &state, &factory);
|
||||
// TODO: we might want to grab the join handle here
|
||||
factory.listen().await?;
|
||||
debug!(?monitor, "workers registered");
|
||||
monitor
|
||||
Ok(monitor)
|
||||
}
|
||||
|
@ -30,6 +30,7 @@ use mas_storage::{
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
storage::PostgresStorageFactory,
|
||||
utils::{metrics_layer, trace_layer},
|
||||
JobContextExt, State,
|
||||
};
|
||||
@ -158,8 +159,9 @@ pub(crate) fn register(
|
||||
suffix: &str,
|
||||
monitor: Monitor<TokioExecutor>,
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let storage = state.store();
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = ProvisionUserJob::NAME);
|
||||
let provision_user_worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
@ -168,7 +170,7 @@ pub(crate) fn register(
|
||||
.with_storage(storage)
|
||||
.build_fn(provision_user);
|
||||
|
||||
let storage = state.store();
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = ProvisionDeviceJob::NAME);
|
||||
let provision_device_worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
@ -177,7 +179,7 @@ pub(crate) fn register(
|
||||
.with_storage(storage)
|
||||
.build_fn(provision_device);
|
||||
|
||||
let storage = state.store();
|
||||
let storage = storage_factory.build();
|
||||
let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME);
|
||||
let delete_device_worker = WorkerBuilder::new(worker_name)
|
||||
.layer(state.inject())
|
||||
|
78
crates/tasks/src/storage/from_row.rs
Normal file
78
crates/tasks/src/storage/from_row.rs
Normal file
@ -0,0 +1,78 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::Value;
|
||||
use sqlx::Row;
|
||||
|
||||
/// Wrapper for [`JobRequest`]
|
||||
pub(crate) struct SqlJobRequest<T>(JobRequest<T>);
|
||||
|
||||
impl<T> From<SqlJobRequest<T>> for JobRequest<T> {
|
||||
fn from(val: SqlJobRequest<T>) -> Self {
|
||||
val.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow>
|
||||
for SqlJobRequest<T>
|
||||
{
|
||||
fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
|
||||
let job: Value = row.try_get("job")?;
|
||||
let id: JobId =
|
||||
JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
|
||||
index: "id".to_owned(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
let mut context = JobContext::new(id);
|
||||
|
||||
let run_at = row.try_get("run_at")?;
|
||||
context.set_run_at(run_at);
|
||||
|
||||
let attempts = row.try_get("attempts").unwrap_or(0);
|
||||
context.set_attempts(attempts);
|
||||
|
||||
let max_attempts = row.try_get("max_attempts").unwrap_or(25);
|
||||
context.set_max_attempts(max_attempts);
|
||||
|
||||
let done_at: Option<DateTime<Utc>> = row.try_get("done_at").unwrap_or_default();
|
||||
context.set_done_at(done_at);
|
||||
|
||||
let lock_at: Option<DateTime<Utc>> = row.try_get("lock_at").unwrap_or_default();
|
||||
context.set_lock_at(lock_at);
|
||||
|
||||
let last_error = row.try_get("last_error").unwrap_or_default();
|
||||
context.set_last_error(last_error);
|
||||
|
||||
let status: String = row.try_get("status")?;
|
||||
context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
|
||||
index: "job".to_owned(),
|
||||
source: Box::new(e),
|
||||
})?);
|
||||
|
||||
let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
|
||||
context.set_lock_by(lock_by.map(WorkerId::new));
|
||||
|
||||
Ok(SqlJobRequest(JobRequest::new_with_context(
|
||||
serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode {
|
||||
index: "job".to_owned(),
|
||||
source: Box::new(e),
|
||||
})?,
|
||||
context,
|
||||
)))
|
||||
}
|
||||
}
|
22
crates/tasks/src/storage/mod.rs
Normal file
22
crates/tasks/src/storage/mod.rs
Normal file
@ -0,0 +1,22 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a
|
||||
//! shared connection for the [`PgListener`]
|
||||
|
||||
mod from_row;
|
||||
mod postgres;
|
||||
|
||||
use self::from_row::SqlJobRequest;
|
||||
pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory;
|
400
crates/tasks/src/storage/postgres.rs
Normal file
400
crates/tasks/src/storage/postgres.rs
Normal file
@ -0,0 +1,400 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{convert::TryInto, marker::PhantomData, ops::Add, time::Duration};
|
||||
|
||||
use apalis_core::{
|
||||
error::JobStreamError,
|
||||
job::{Job, JobId, JobStreamResult},
|
||||
request::JobRequest,
|
||||
storage::{StorageError, StorageResult, StorageWorkerPulse},
|
||||
utils::Timer,
|
||||
worker::WorkerId,
|
||||
};
|
||||
use async_stream::try_stream;
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures_lite::{Stream, StreamExt};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use super::SqlJobRequest;
|
||||
|
||||
pub struct StorageFactory {
|
||||
pool: PgPool,
|
||||
sender: tokio::sync::watch::Sender<()>,
|
||||
receiver: tokio::sync::watch::Receiver<()>,
|
||||
}
|
||||
|
||||
impl StorageFactory {
|
||||
pub fn new(pool: Pool<Postgres>) -> Self {
|
||||
let (sender, receiver) = tokio::sync::watch::channel(());
|
||||
StorageFactory {
|
||||
pool,
|
||||
sender,
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn listen(self) -> Result<JoinHandle<()>, sqlx::Error> {
|
||||
let mut listener = PgListener::connect_with(&self.pool).await?;
|
||||
listener.listen("apalis::job").await?;
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
let notification = listener.recv().await.expect("Failed to poll notification");
|
||||
self.sender.send(()).expect("Failed to send notification");
|
||||
tracing::debug!(?notification, "Received notification");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
pub fn build<T>(&self) -> Storage<T> {
|
||||
Storage {
|
||||
pool: self.pool.clone(),
|
||||
notifier: self.receiver.clone(),
|
||||
job_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres
|
||||
#[derive(Debug)]
|
||||
pub struct Storage<T> {
|
||||
pool: PgPool,
|
||||
notifier: tokio::sync::watch::Receiver<()>,
|
||||
job_type: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Clone for Storage<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Storage {
|
||||
pool: self.pool.clone(),
|
||||
notifier: self.notifier.clone(),
|
||||
job_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: DeserializeOwned + Send + Unpin + Job> Storage<T> {
|
||||
fn stream_jobs(
|
||||
&self,
|
||||
worker_id: &WorkerId,
|
||||
interval: Duration,
|
||||
buffer_size: usize,
|
||||
) -> impl Stream<Item = Result<JobRequest<T>, JobStreamError>> {
|
||||
let pool = self.pool.clone();
|
||||
let mut notifier = self.notifier.clone();
|
||||
let sleeper = apalis_core::utils::timer::TokioTimer;
|
||||
let worker_id = worker_id.clone();
|
||||
try_stream! {
|
||||
loop {
|
||||
// Wait for a notification or a timeout
|
||||
let interval = sleeper.sleep(interval);
|
||||
let res = tokio::select! {
|
||||
biased; changed = notifier.changed() => changed.map_err(|e| JobStreamError::BrokenPipe(Box::from(e))),
|
||||
_ = interval => Ok(()),
|
||||
};
|
||||
res?;
|
||||
|
||||
let tx = pool.clone();
|
||||
let job_type = T::NAME;
|
||||
let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);";
|
||||
let jobs: Vec<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
|
||||
.bind(worker_id.to_string())
|
||||
.bind(job_type)
|
||||
// https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html
|
||||
.bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?)
|
||||
.fetch_all(&tx)
|
||||
.await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?;
|
||||
for job in jobs {
|
||||
yield job.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn keep_alive_at<Service>(
|
||||
&mut self,
|
||||
worker_id: &WorkerId,
|
||||
last_seen: DateTime<Utc>,
|
||||
) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let worker_type = T::NAME;
|
||||
let storage_name = std::any::type_name::<Self>();
|
||||
let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (id) DO
|
||||
UPDATE SET last_seen = EXCLUDED.last_seen";
|
||||
sqlx::query(query)
|
||||
.bind(worker_id.to_string())
|
||||
.bind(worker_type)
|
||||
.bind(storage_name)
|
||||
.bind(std::any::type_name::<Service>())
|
||||
.bind(last_seen)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T> apalis_core::storage::Storage for Storage<T>
|
||||
where
|
||||
T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
|
||||
{
|
||||
type Output = T;
|
||||
|
||||
/// Push a job to Postgres [Storage]
|
||||
///
|
||||
/// # SQL Example
|
||||
///
|
||||
/// ```sql
|
||||
/// SELECT apalis.push_job(job_type::text, job::json);
|
||||
/// ```
|
||||
async fn push(&mut self, job: Self::Output) -> StorageResult<JobId> {
|
||||
let id = JobId::new();
|
||||
let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)";
|
||||
let pool = self.pool.clone();
|
||||
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
|
||||
let job_type = T::NAME;
|
||||
sqlx::query(query)
|
||||
.bind(job)
|
||||
.bind(id.to_string())
|
||||
.bind(job_type)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn schedule(
|
||||
&mut self,
|
||||
job: Self::Output,
|
||||
on: chrono::DateTime<Utc>,
|
||||
) -> StorageResult<JobId> {
|
||||
let query =
|
||||
"INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)";
|
||||
|
||||
let mut conn = self
|
||||
.pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(Box::from(e)))?;
|
||||
|
||||
let id = JobId::new();
|
||||
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
|
||||
let job_type = T::NAME;
|
||||
sqlx::query(query)
|
||||
.bind(job)
|
||||
.bind(id.to_string())
|
||||
.bind(job_type)
|
||||
.bind(on)
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult<Option<JobRequest<Self::Output>>> {
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1";
|
||||
let res: Option<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
|
||||
.bind(job_id.to_string())
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(res.map(Into::into))
|
||||
}
|
||||
|
||||
async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult<bool> {
|
||||
let pool = self.pool.clone();
|
||||
match pulse {
|
||||
StorageWorkerPulse::EnqueueScheduled { count: _ } => {
|
||||
// Ideally jobs are queue via run_at. So this is not necessary
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
// Worker not seen in 5 minutes yet has running jobs
|
||||
StorageWorkerPulse::ReenqueueOrphaned { count } => {
|
||||
let job_type = T::NAME;
|
||||
let mut tx = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
let query = "UPDATE apalis.jobs
|
||||
SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned'
|
||||
WHERE id in
|
||||
(SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id
|
||||
WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes'
|
||||
AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);";
|
||||
sqlx::query(query)
|
||||
.bind(job_type)
|
||||
.bind(count)
|
||||
.execute(&mut tx)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let mut tx = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(worker_id.to_string())
|
||||
.execute(&mut tx)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Puts the job instantly back into the queue
|
||||
/// Another [Worker] may consume
|
||||
async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let mut tx = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(worker_id.to_string())
|
||||
.execute(&mut tx)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn consume(
|
||||
&mut self,
|
||||
worker_id: &WorkerId,
|
||||
interval: Duration,
|
||||
buffer_size: usize,
|
||||
) -> JobStreamResult<T> {
|
||||
Box::pin(
|
||||
self.stream_jobs(worker_id, interval, buffer_size)
|
||||
.map(|r| r.map(Some)),
|
||||
)
|
||||
}
|
||||
async fn len(&self) -> StorageResult<i64> {
|
||||
let pool = self.pool.clone();
|
||||
let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'";
|
||||
let record = sqlx::query(query)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(record
|
||||
.try_get("count")
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?)
|
||||
}
|
||||
async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(worker_id.to_string())
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn reschedule(&mut self, job: &JobRequest<T>, wait: Duration) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
let job_id = job.id();
|
||||
|
||||
let wait: i64 = wait
|
||||
.as_secs()
|
||||
.try_into()
|
||||
.map_err(|e| StorageError::Database(Box::new(e)))?;
|
||||
let wait = chrono::Duration::seconds(wait);
|
||||
// TODO: should we use a clock here?
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let run_at = Utc::now().add(wait);
|
||||
|
||||
let mut tx = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(run_at)
|
||||
.execute(&mut tx)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_by_id(
|
||||
&self,
|
||||
job_id: &JobId,
|
||||
job: &JobRequest<Self::Output>,
|
||||
) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
let status = job.status().as_ref();
|
||||
let attempts = job.attempts();
|
||||
let done_at = *job.done_at();
|
||||
let lock_by = job.lock_by().clone();
|
||||
let lock_at = *job.lock_at();
|
||||
let last_error = job.last_error().clone();
|
||||
|
||||
let mut tx = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7";
|
||||
sqlx::query(query)
|
||||
.bind(status.to_owned())
|
||||
.bind(attempts)
|
||||
.bind(done_at)
|
||||
.bind(lock_by.as_ref().map(WorkerId::name))
|
||||
.bind(lock_at)
|
||||
.bind(last_error)
|
||||
.bind(job_id.to_string())
|
||||
.execute(&mut tx)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn keep_alive<Service>(&mut self, worker_id: &WorkerId) -> StorageResult<()> {
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let now = Utc::now();
|
||||
|
||||
self.keep_alive_at::<Service>(worker_id, now).await
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user