You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
WIP: use apalis to schedule jobs
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -14,50 +14,66 @@
|
||||
|
||||
//! Database-related tasks
|
||||
|
||||
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Repository, RepositoryAccess, SystemClock};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tracing::{debug, error, info};
|
||||
use std::str::FromStr;
|
||||
|
||||
use super::Task;
|
||||
use apalis_core::{
|
||||
builder::{WorkerBuilder, WorkerFactory},
|
||||
context::JobContext,
|
||||
executor::TokioExecutor,
|
||||
job::Job,
|
||||
job_fn::job_fn,
|
||||
monitor::Monitor,
|
||||
};
|
||||
use apalis_cron::CronStream;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess};
|
||||
use tracing::{debug, info};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CleanupExpired(Pool<Postgres>, SystemClock);
|
||||
use crate::{JobContextExt, State};
|
||||
|
||||
impl std::fmt::Debug for CleanupExpired {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("CleanupExpired").finish_non_exhaustive()
|
||||
#[derive(Default, Clone)]
|
||||
pub struct CleanupExpiredTokensJob {
|
||||
scheduled: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<DateTime<Utc>> for CleanupExpiredTokensJob {
|
||||
fn from(scheduled: DateTime<Utc>) -> Self {
|
||||
Self { scheduled }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Task for CleanupExpired {
|
||||
async fn run(&self) {
|
||||
let res = async move {
|
||||
let mut repo = PgRepository::from_pool(&self.0).await?.boxed();
|
||||
let res = repo.oauth2_access_token().cleanup_expired(&self.1).await;
|
||||
repo.save().await?;
|
||||
res
|
||||
}
|
||||
.await;
|
||||
impl Job for CleanupExpiredTokensJob {
|
||||
const NAME: &'static str = "cleanup-expired-tokens";
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(0) => {
|
||||
debug!("no token to clean up");
|
||||
}
|
||||
Ok(count) => {
|
||||
info!(count, "cleaned up expired tokens");
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, "failed to cleanup expired tokens");
|
||||
}
|
||||
}
|
||||
pub async fn cleanup_expired_tokens(
|
||||
job: CleanupExpiredTokensJob,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
|
||||
debug!("cleanup expired tokens job scheduled at {}", job.scheduled);
|
||||
|
||||
let state = ctx.state();
|
||||
let clock = state.clock();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
let count = repo.oauth2_access_token().cleanup_expired(&clock).await?;
|
||||
repo.save().await?;
|
||||
|
||||
if count == 0 {
|
||||
debug!("no token to clean up");
|
||||
} else {
|
||||
info!(count, "cleaned up expired tokens");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cleanup expired tokens
|
||||
#[must_use]
|
||||
pub fn cleanup_expired(pool: &Pool<Postgres>) -> impl Task + Clone {
|
||||
// XXX: the clock should come from somewhere else
|
||||
CleanupExpired(pool.clone(), SystemClock::default())
|
||||
pub(crate) fn register(monitor: Monitor<TokioExecutor>, state: &State) -> Monitor<TokioExecutor> {
|
||||
let schedule = apalis_cron::Schedule::from_str("*/15 * * * * *").unwrap();
|
||||
let worker = WorkerBuilder::new("cleanup-expired-tokens")
|
||||
.stream(CronStream::new(schedule).to_stream())
|
||||
.layer(state.inject())
|
||||
.build(job_fn(cleanup_expired_tokens));
|
||||
|
||||
monitor.register(worker)
|
||||
}
|
||||
|
||||
111
crates/tasks/src/email.rs
Normal file
111
crates/tasks/src/email.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{
|
||||
builder::{WorkerBuilder, WorkerFactory},
|
||||
context::JobContext,
|
||||
executor::TokioExecutor,
|
||||
job::Job,
|
||||
job_fn::job_fn,
|
||||
monitor::Monitor,
|
||||
storage::builder::WithStorage,
|
||||
};
|
||||
use chrono::Duration;
|
||||
use mas_data_model::UserEmail;
|
||||
use mas_email::{Address, EmailVerificationContext, Mailbox};
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{JobContextExt, State};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct VerifyEmailJob {
|
||||
user_email_id: Ulid,
|
||||
}
|
||||
|
||||
impl VerifyEmailJob {
|
||||
#[must_use]
|
||||
pub fn new(user_email: &UserEmail) -> Self {
|
||||
Self {
|
||||
user_email_id: user_email.id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Job for VerifyEmailJob {
|
||||
const NAME: &'static str = "verify-email";
|
||||
}
|
||||
|
||||
async fn verify_email(job: VerifyEmailJob, ctx: JobContext) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let mut repo = state.repository().await?;
|
||||
let mut rng = state.rng();
|
||||
let mailer = state.mailer();
|
||||
let clock = state.clock();
|
||||
|
||||
// Lookup the user email
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(job.user_email_id)
|
||||
.await?
|
||||
.context("User email not found")?;
|
||||
|
||||
// Lookup the user associated with the email
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_email.user_id)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
// Generate a verification code
|
||||
let range = Uniform::<u32>::from(0..1_000_000);
|
||||
let code = rng.sample(range);
|
||||
let code = format!("{code:06}");
|
||||
|
||||
let address: Address = user_email.email.parse()?;
|
||||
|
||||
// Save the verification code in the database
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code)
|
||||
.await?;
|
||||
|
||||
// And send the verification email
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
|
||||
let context = EmailVerificationContext::new(user.clone(), verification.clone());
|
||||
|
||||
mailer.send_verification_email(mailbox, &context).await?;
|
||||
|
||||
info!(
|
||||
email.id = %user_email.id,
|
||||
"Verification email sent"
|
||||
);
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn register(monitor: Monitor<TokioExecutor>, state: &State) -> Monitor<TokioExecutor> {
|
||||
let storage = state.store();
|
||||
let worker = WorkerBuilder::new("verify-email")
|
||||
.layer(state.inject())
|
||||
.with_storage(storage)
|
||||
.build(job_fn(verify_email));
|
||||
monitor.register(worker)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,113 +12,96 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Generic, sequential task scheduler
|
||||
//!
|
||||
//! Tasks here are ran one after another to avoid having to unnecesarily lock
|
||||
//! resources and avoid database conflicts. Tasks are not persisted, which is
|
||||
//! considered "good enough" for now.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(
|
||||
clippy::all,
|
||||
clippy::str_to_string,
|
||||
missing_docs,
|
||||
rustdoc::broken_intra_doc_links
|
||||
)]
|
||||
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
|
||||
#![warn(clippy::pedantic)]
|
||||
|
||||
use std::{collections::VecDeque, sync::Arc, time::Duration};
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use tokio::{
|
||||
sync::{Mutex, Notify},
|
||||
time::Interval,
|
||||
};
|
||||
use tokio_stream::wrappers::IntervalStream;
|
||||
use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monitor::Monitor};
|
||||
use apalis_sql::postgres::PostgresStorage;
|
||||
use mas_email::Mailer;
|
||||
use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock};
|
||||
use mas_storage_pg::{DatabaseError, PgRepository};
|
||||
use rand::SeedableRng;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tracing::debug;
|
||||
|
||||
mod database;
|
||||
mod email;
|
||||
|
||||
pub use self::database::cleanup_expired;
|
||||
pub use self::email::VerifyEmailJob;
|
||||
|
||||
/// A [`Task`] can be executed by a [`TaskQueue`]
|
||||
#[async_trait::async_trait]
|
||||
pub trait Task: std::fmt::Debug + Send + Sync + 'static {
|
||||
/// Execute the [`Task`]
|
||||
async fn run(&self);
|
||||
#[derive(Clone)]
|
||||
struct State {
|
||||
pool: Pool<Postgres>,
|
||||
mailer: Mailer,
|
||||
clock: SystemClock,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TaskQueueInner {
|
||||
pending_tasks: Mutex<VecDeque<Box<dyn Task>>>,
|
||||
notifier: Notify,
|
||||
}
|
||||
|
||||
impl TaskQueueInner {
|
||||
async fn recuring<T: Task + Clone>(&self, interval: Interval, task: T) {
|
||||
let mut stream = IntervalStream::new(interval);
|
||||
|
||||
while (stream.next()).await.is_some() {
|
||||
self.schedule(task.clone()).await;
|
||||
impl State {
|
||||
pub fn new(pool: Pool<Postgres>, clock: SystemClock, mailer: Mailer) -> Self {
|
||||
Self {
|
||||
pool,
|
||||
mailer,
|
||||
clock,
|
||||
}
|
||||
}
|
||||
|
||||
async fn schedule<T: Task>(&self, task: T) {
|
||||
let task = Box::new(task);
|
||||
self.pending_tasks.lock().await.push_back(task);
|
||||
self.notifier.notify_one();
|
||||
pub fn inject(&self) -> Extension<Self> {
|
||||
Extension(self.clone())
|
||||
}
|
||||
|
||||
async fn tick(&self) {
|
||||
loop {
|
||||
let pending = {
|
||||
let mut tasks = self.pending_tasks.lock().await;
|
||||
tasks.pop_front()
|
||||
};
|
||||
|
||||
if let Some(pending) = pending {
|
||||
pending.run().await;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
pub fn pool(&self) -> &Pool<Postgres> {
|
||||
&self.pool
|
||||
}
|
||||
|
||||
async fn run_forever(&self) {
|
||||
loop {
|
||||
self.notifier.notified().await;
|
||||
self.tick().await;
|
||||
}
|
||||
pub fn clock(&self) -> BoxClock {
|
||||
Box::new(self.clock.clone())
|
||||
}
|
||||
|
||||
pub fn store<J>(&self) -> PostgresStorage<J> {
|
||||
PostgresStorage::new(self.pool.clone())
|
||||
}
|
||||
|
||||
pub fn mailer(&self) -> &Mailer {
|
||||
&self.mailer
|
||||
}
|
||||
|
||||
pub fn rng(&self) -> rand_chacha::ChaChaRng {
|
||||
let _ = self;
|
||||
|
||||
// This is fine.
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
rand_chacha::ChaChaRng::from_rng(rand::thread_rng()).expect("failed to seed rng")
|
||||
}
|
||||
|
||||
pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
|
||||
let repo = PgRepository::from_pool(self.pool())
|
||||
.await?
|
||||
.map_err(mas_storage::RepositoryError::from_error)
|
||||
.boxed();
|
||||
|
||||
Ok(repo)
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`TaskQueue`] executes tasks inserted in it in order
|
||||
#[derive(Default)]
|
||||
pub struct TaskQueue {
|
||||
inner: Arc<TaskQueueInner>,
|
||||
trait JobContextExt {
|
||||
fn state(&self) -> State;
|
||||
}
|
||||
|
||||
impl TaskQueue {
|
||||
/// Start the task queue to run forever
|
||||
pub fn start(&self) {
|
||||
let queue = self.inner.clone();
|
||||
tokio::task::spawn(async move {
|
||||
queue.run_forever().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn schedule<T: Task>(&self, task: T) {
|
||||
let queue = self.inner.clone();
|
||||
queue.schedule(task).await;
|
||||
}
|
||||
|
||||
/// Schedule a task in the queue at regular intervals
|
||||
pub fn recuring(&self, every: Duration, task: impl Task + Clone + std::fmt::Debug) {
|
||||
debug!(?task, period = every.as_secs(), "Scheduling recuring task");
|
||||
let queue = self.inner.clone();
|
||||
tokio::task::spawn(async move {
|
||||
queue.recuring(tokio::time::interval(every), task).await;
|
||||
});
|
||||
impl JobContextExt for apalis_core::context::JobContext {
|
||||
fn state(&self) -> State {
|
||||
self.data_opt::<State>()
|
||||
.expect("state not injected in job context")
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn init(pool: &Pool<Postgres>, mailer: &Mailer) -> Monitor<TokioExecutor> {
|
||||
let state = State::new(pool.clone(), SystemClock::default(), mailer.clone());
|
||||
let monitor = Monitor::new();
|
||||
let monitor = self::database::register(monitor, &state);
|
||||
let monitor = self::email::register(monitor, &state);
|
||||
debug!(?monitor, "workers registered");
|
||||
monitor
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user