1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Reimplementation of a postgres-backed storage with a shared PG listener

This commit is contained in:
Quentin Gliech
2023-07-17 17:36:37 +02:00
parent 3f17e2215c
commit 68db56c2a2
15 changed files with 540 additions and 95 deletions

View File

@ -63,7 +63,7 @@ impl HttpClientFactory {
let client = (
MapErrLayer::new(BoxError::from),
MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)),
BodyToBytesResponseLayer::default(),
BodyToBytesResponseLayer,
)
.layer(client);

View File

@ -110,7 +110,7 @@ impl Options {
config.matrix.secret.clone(),
http_client_factory,
);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
// TODO: grab the handle
tokio::spawn(monitor.run());
}

View File

@ -64,7 +64,7 @@ impl Options {
let worker_name = Alphanumeric.sample_string(&mut rng, 10);
info!(worker_name, "Starting task scheduler");
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
span.exit();

View File

@ -106,9 +106,7 @@ pub struct FormUrlencodedRequestLayer<T> {
impl<T> Default for FormUrlencodedRequestLayer<T> {
fn default() -> Self {
Self {
_t: PhantomData::default(),
}
Self { _t: PhantomData }
}
}

View File

@ -106,9 +106,7 @@ pub struct JsonRequestLayer<T> {
impl<T> Default for JsonRequestLayer<T> {
fn default() -> Self {
Self {
_t: PhantomData::default(),
}
Self { _t: PhantomData }
}
}

View File

@ -106,9 +106,7 @@ pub struct JsonResponseLayer<T> {
impl<T> Default for JsonResponseLayer<T> {
fn default() -> Self {
Self {
_t: PhantomData::default(),
}
Self { _t: PhantomData }
}
}

View File

@ -57,7 +57,7 @@ pub fn hyper_service() -> HttpService {
let client = ServiceBuilder::new()
.map_err(BoxError::from)
.map_request_body(Full::new)
.layer(BodyToBytesResponseLayer::default())
.layer(BodyToBytesResponseLayer)
.override_request_header(USER_AGENT, MAS_USER_AGENT.clone())
.concurrency_limit(10)
.follow_redirects()

View File

@ -7,15 +7,17 @@ license = "apache-2.0"
[dependencies]
anyhow = "1.0.71"
apalis-core = { version = "0.4.2", features = ["extensions", "tokio-comp"] }
apalis-core = { version = "0.4.2", features = ["extensions", "tokio-comp", "storage"] }
apalis-cron = "0.4.2"
apalis-sql = { version = "0.4.2", features = ["postgres", "tokio-comp"] }
async-stream = "0.3.5"
async-trait = "0.1.71"
chrono = "0.4.26"
futures-lite = "1.13.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
sqlx = { version = "0.6.3", features = ["runtime-tokio-rustls", "postgres"] }
thiserror = "1.0.41"
tokio = { version = "1.28.2", features = ["macros", "time"] }
tower = "0.4.13"
tracing = "0.1.37"
tracing-opentelemetry = "0.19.0"
@ -23,6 +25,7 @@ opentelemetry = "0.19.0"
ulid = "1.0.0"
url = "2.4.0"
serde = { version = "1.0.166", features = ["derive"] }
serde_json = "1.0.97"
mas-data-model = { path = "../data-model" }
mas-email = { path = "../email" }

View File

@ -28,6 +28,7 @@ use rand::{distributions::Uniform, Rng};
use tracing::info;
use crate::{
storage::PostgresStorageFactory,
utils::{metrics_layer, trace_layer},
JobContextExt, State,
};
@ -96,8 +97,9 @@ pub(crate) fn register(
suffix: &str,
monitor: Monitor<TokioExecutor>,
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let storage = state.store();
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = VerifyEmailJob::NAME);
let worker = WorkerBuilder::new(worker_name)
.layer(state.inject())

View File

@ -19,7 +19,6 @@
use std::sync::Arc;
use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor};
use apalis_sql::postgres::PostgresStorage;
use mas_email::Mailer;
use mas_matrix::HomeserverConnection;
use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock};
@ -28,9 +27,12 @@ use rand::SeedableRng;
use sqlx::{Pool, Postgres};
use tracing::debug;
use crate::storage::PostgresStorageFactory;
mod database;
mod email;
mod matrix;
mod storage;
mod utils;
#[derive(Clone)]
@ -68,10 +70,6 @@ impl State {
Box::new(self.clock.clone())
}
pub fn store<J>(&self) -> PostgresStorage<J> {
PostgresStorage::new(self.pool.clone())
}
pub fn mailer(&self) -> &Mailer {
&self.mailer
}
@ -108,23 +106,30 @@ impl JobContextExt for apalis_core::context::JobContext {
}
}
#[must_use]
pub fn init(
/// Initialise the workers.
///
/// # Errors
///
/// This function can fail if the database connection fails.
pub async fn init(
name: &str,
pool: &Pool<Postgres>,
mailer: &Mailer,
homeserver: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Monitor<TokioExecutor> {
) -> Result<Monitor<TokioExecutor>, sqlx::Error> {
let state = State::new(
pool.clone(),
SystemClock::default(),
mailer.clone(),
homeserver,
);
let factory = PostgresStorageFactory::new(pool.clone());
let monitor = Monitor::new().executor(TokioExecutor::new());
let monitor = self::database::register(name, monitor, &state);
let monitor = self::email::register(name, monitor, &state);
let monitor = self::matrix::register(name, monitor, &state);
let monitor = self::email::register(name, monitor, &state, &factory);
let monitor = self::matrix::register(name, monitor, &state, &factory);
// TODO: we might want to grab the join handle here
factory.listen().await?;
debug!(?monitor, "workers registered");
monitor
Ok(monitor)
}

View File

@ -30,6 +30,7 @@ use mas_storage::{
use tracing::info;
use crate::{
storage::PostgresStorageFactory,
utils::{metrics_layer, trace_layer},
JobContextExt, State,
};
@ -158,8 +159,9 @@ pub(crate) fn register(
suffix: &str,
monitor: Monitor<TokioExecutor>,
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let storage = state.store();
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = ProvisionUserJob::NAME);
let provision_user_worker = WorkerBuilder::new(worker_name)
.layer(state.inject())
@ -168,7 +170,7 @@ pub(crate) fn register(
.with_storage(storage)
.build_fn(provision_user);
let storage = state.store();
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = ProvisionDeviceJob::NAME);
let provision_device_worker = WorkerBuilder::new(worker_name)
.layer(state.inject())
@ -177,7 +179,7 @@ pub(crate) fn register(
.with_storage(storage)
.build_fn(provision_device);
let storage = state.store();
let storage = storage_factory.build();
let worker_name = format!("{job}-{suffix}", job = DeleteDeviceJob::NAME);
let delete_device_worker = WorkerBuilder::new(worker_name)
.layer(state.inject())

View File

@ -0,0 +1,78 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::str::FromStr;
use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId};
use chrono::{DateTime, Utc};
use serde_json::Value;
use sqlx::Row;
/// Wrapper for [`JobRequest`]
pub(crate) struct SqlJobRequest<T>(JobRequest<T>);
impl<T> From<SqlJobRequest<T>> for JobRequest<T> {
fn from(val: SqlJobRequest<T>) -> Self {
val.0
}
}
impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow>
for SqlJobRequest<T>
{
fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
let job: Value = row.try_get("job")?;
let id: JobId =
JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
index: "id".to_owned(),
source: Box::new(e),
})?;
let mut context = JobContext::new(id);
let run_at = row.try_get("run_at")?;
context.set_run_at(run_at);
let attempts = row.try_get("attempts").unwrap_or(0);
context.set_attempts(attempts);
let max_attempts = row.try_get("max_attempts").unwrap_or(25);
context.set_max_attempts(max_attempts);
let done_at: Option<DateTime<Utc>> = row.try_get("done_at").unwrap_or_default();
context.set_done_at(done_at);
let lock_at: Option<DateTime<Utc>> = row.try_get("lock_at").unwrap_or_default();
context.set_lock_at(lock_at);
let last_error = row.try_get("last_error").unwrap_or_default();
context.set_last_error(last_error);
let status: String = row.try_get("status")?;
context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
index: "job".to_owned(),
source: Box::new(e),
})?);
let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
context.set_lock_by(lock_by.map(WorkerId::new));
Ok(SqlJobRequest(JobRequest::new_with_context(
serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode {
index: "job".to_owned(),
source: Box::new(e),
})?,
context,
)))
}
}

View File

@ -0,0 +1,22 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a
//! shared connection for the [`PgListener`]
mod from_row;
mod postgres;
use self::from_row::SqlJobRequest;
pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory;

View File

@ -0,0 +1,400 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{convert::TryInto, marker::PhantomData, ops::Add, time::Duration};
use apalis_core::{
error::JobStreamError,
job::{Job, JobId, JobStreamResult},
request::JobRequest,
storage::{StorageError, StorageResult, StorageWorkerPulse},
utils::Timer,
worker::WorkerId,
};
use async_stream::try_stream;
use chrono::{DateTime, Utc};
use futures_lite::{Stream, StreamExt};
use serde::{de::DeserializeOwned, Serialize};
use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row};
use tokio::task::JoinHandle;
use super::SqlJobRequest;
pub struct StorageFactory {
pool: PgPool,
sender: tokio::sync::watch::Sender<()>,
receiver: tokio::sync::watch::Receiver<()>,
}
impl StorageFactory {
pub fn new(pool: Pool<Postgres>) -> Self {
let (sender, receiver) = tokio::sync::watch::channel(());
StorageFactory {
pool,
sender,
receiver,
}
}
pub async fn listen(self) -> Result<JoinHandle<()>, sqlx::Error> {
let mut listener = PgListener::connect_with(&self.pool).await?;
listener.listen("apalis::job").await?;
let handle = tokio::spawn(async move {
loop {
let notification = listener.recv().await.expect("Failed to poll notification");
self.sender.send(()).expect("Failed to send notification");
tracing::debug!(?notification, "Received notification");
}
});
Ok(handle)
}
pub fn build<T>(&self) -> Storage<T> {
Storage {
pool: self.pool.clone(),
notifier: self.receiver.clone(),
job_type: PhantomData,
}
}
}
/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres
#[derive(Debug)]
pub struct Storage<T> {
pool: PgPool,
notifier: tokio::sync::watch::Receiver<()>,
job_type: PhantomData<T>,
}
impl<T> Clone for Storage<T> {
fn clone(&self) -> Self {
Storage {
pool: self.pool.clone(),
notifier: self.notifier.clone(),
job_type: PhantomData,
}
}
}
impl<T: DeserializeOwned + Send + Unpin + Job> Storage<T> {
fn stream_jobs(
&self,
worker_id: &WorkerId,
interval: Duration,
buffer_size: usize,
) -> impl Stream<Item = Result<JobRequest<T>, JobStreamError>> {
let pool = self.pool.clone();
let mut notifier = self.notifier.clone();
let sleeper = apalis_core::utils::timer::TokioTimer;
let worker_id = worker_id.clone();
try_stream! {
loop {
// Wait for a notification or a timeout
let interval = sleeper.sleep(interval);
let res = tokio::select! {
biased; changed = notifier.changed() => changed.map_err(|e| JobStreamError::BrokenPipe(Box::from(e))),
_ = interval => Ok(()),
};
res?;
let tx = pool.clone();
let job_type = T::NAME;
let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);";
let jobs: Vec<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
.bind(worker_id.to_string())
.bind(job_type)
// https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html
.bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?)
.fetch_all(&tx)
.await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?;
for job in jobs {
yield job.into()
}
}
}
}
async fn keep_alive_at<Service>(
&mut self,
worker_id: &WorkerId,
last_seen: DateTime<Utc>,
) -> StorageResult<()> {
let pool = self.pool.clone();
let worker_type = T::NAME;
let storage_name = std::any::type_name::<Self>();
let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (id) DO
UPDATE SET last_seen = EXCLUDED.last_seen";
sqlx::query(query)
.bind(worker_id.to_string())
.bind(worker_type)
.bind(storage_name)
.bind(std::any::type_name::<Service>())
.bind(last_seen)
.execute(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
}
#[async_trait::async_trait]
impl<T> apalis_core::storage::Storage for Storage<T>
where
T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
{
type Output = T;
/// Push a job to Postgres [Storage]
///
/// # SQL Example
///
/// ```sql
/// SELECT apalis.push_job(job_type::text, job::json);
/// ```
async fn push(&mut self, job: Self::Output) -> StorageResult<JobId> {
let id = JobId::new();
let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)";
let pool = self.pool.clone();
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
let job_type = T::NAME;
sqlx::query(query)
.bind(job)
.bind(id.to_string())
.bind(job_type)
.execute(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(id)
}
async fn schedule(
&mut self,
job: Self::Output,
on: chrono::DateTime<Utc>,
) -> StorageResult<JobId> {
let query =
"INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)";
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| StorageError::Connection(Box::from(e)))?;
let id = JobId::new();
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
let job_type = T::NAME;
sqlx::query(query)
.bind(job)
.bind(id.to_string())
.bind(job_type)
.bind(on)
.execute(&mut conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(id)
}
async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult<Option<JobRequest<Self::Output>>> {
let pool = self.pool.clone();
let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1";
let res: Option<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
.bind(job_id.to_string())
.fetch_optional(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(res.map(Into::into))
}
async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult<bool> {
let pool = self.pool.clone();
match pulse {
StorageWorkerPulse::EnqueueScheduled { count: _ } => {
// Ideally jobs are queue via run_at. So this is not necessary
Ok(true)
}
// Worker not seen in 5 minutes yet has running jobs
StorageWorkerPulse::ReenqueueOrphaned { count } => {
let job_type = T::NAME;
let mut tx = pool
.acquire()
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
let query = "UPDATE apalis.jobs
SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned'
WHERE id in
(SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id
WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes'
AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);";
sqlx::query(query)
.bind(job_type)
.bind(count)
.execute(&mut tx)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(true)
}
_ => unimplemented!(),
}
}
async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
let pool = self.pool.clone();
let mut tx = pool
.acquire()
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(job_id.to_string())
.bind(worker_id.to_string())
.execute(&mut tx)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
/// Puts the job instantly back into the queue
/// Another [Worker] may consume
async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
let pool = self.pool.clone();
let mut tx = pool
.acquire()
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(job_id.to_string())
.bind(worker_id.to_string())
.execute(&mut tx)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
fn consume(
&mut self,
worker_id: &WorkerId,
interval: Duration,
buffer_size: usize,
) -> JobStreamResult<T> {
Box::pin(
self.stream_jobs(worker_id, interval, buffer_size)
.map(|r| r.map(Some)),
)
}
async fn len(&self) -> StorageResult<i64> {
let pool = self.pool.clone();
let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'";
let record = sqlx::query(query)
.fetch_one(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(record
.try_get("count")
.map_err(|e| StorageError::Database(Box::from(e)))?)
}
async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
let pool = self.pool.clone();
let query =
"UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(job_id.to_string())
.bind(worker_id.to_string())
.execute(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
async fn reschedule(&mut self, job: &JobRequest<T>, wait: Duration) -> StorageResult<()> {
let pool = self.pool.clone();
let job_id = job.id();
let wait: i64 = wait
.as_secs()
.try_into()
.map_err(|e| StorageError::Database(Box::new(e)))?;
let wait = chrono::Duration::seconds(wait);
// TODO: should we use a clock here?
#[allow(clippy::disallowed_methods)]
let run_at = Utc::now().add(wait);
let mut tx = pool
.acquire()
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
sqlx::query(query)
.bind(job_id.to_string())
.bind(run_at)
.execute(&mut tx)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
async fn update_by_id(
&self,
job_id: &JobId,
job: &JobRequest<Self::Output>,
) -> StorageResult<()> {
let pool = self.pool.clone();
let status = job.status().as_ref();
let attempts = job.attempts();
let done_at = *job.done_at();
let lock_by = job.lock_by().clone();
let lock_at = *job.lock_at();
let last_error = job.last_error().clone();
let mut tx = pool
.acquire()
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7";
sqlx::query(query)
.bind(status.to_owned())
.bind(attempts)
.bind(done_at)
.bind(lock_by.as_ref().map(WorkerId::name))
.bind(lock_at)
.bind(last_error)
.bind(job_id.to_string())
.execute(&mut tx)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
async fn keep_alive<Service>(&mut self, worker_id: &WorkerId) -> StorageResult<()> {
#[allow(clippy::disallowed_methods)]
let now = Utc::now();
self.keep_alive_at::<Service>(worker_id, now).await
}
}