You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
storage{,-pg}: better documentation of both crates
This commit is contained in:
144
crates/storage-pg/src/errors.rs
Normal file
144
crates/storage-pg/src/errors.rs
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use sqlx::postgres::PgQueryResult;
|
||||||
|
use thiserror::Error;
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
/// Generic error when interacting with the database
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error(transparent)]
|
||||||
|
pub enum DatabaseError {
|
||||||
|
/// An error which came from the database itself
|
||||||
|
Driver {
|
||||||
|
/// The underlying error from the database driver
|
||||||
|
#[from]
|
||||||
|
source: sqlx::Error,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// An error which occured while converting the data from the database
|
||||||
|
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||||
|
|
||||||
|
/// An error which happened because the requested database operation is
|
||||||
|
/// invalid
|
||||||
|
#[error("Invalid database operation")]
|
||||||
|
InvalidOperation {
|
||||||
|
/// The source of the error, if any
|
||||||
|
#[source]
|
||||||
|
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// An error which happens when an operation affects not enough or too many
|
||||||
|
/// rows
|
||||||
|
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
|
||||||
|
RowsAffected {
|
||||||
|
/// How many rows were expected to be affected
|
||||||
|
expected: u64,
|
||||||
|
|
||||||
|
/// How many rows were actually affected
|
||||||
|
actual: u64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DatabaseError {
|
||||||
|
pub(crate) fn ensure_affected_rows(
|
||||||
|
result: &PgQueryResult,
|
||||||
|
expected: u64,
|
||||||
|
) -> Result<(), DatabaseError> {
|
||||||
|
let actual = result.rows_affected();
|
||||||
|
if actual == expected {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(DatabaseError::RowsAffected { expected, actual })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
|
||||||
|
Self::InvalidOperation {
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) const fn invalid_operation() -> Self {
|
||||||
|
Self::InvalidOperation { source: None }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error which occured while converting the data from the database
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub struct DatabaseInconsistencyError {
|
||||||
|
/// The table which was being queried
|
||||||
|
table: &'static str,
|
||||||
|
|
||||||
|
/// The column which was being queried
|
||||||
|
column: Option<&'static str>,
|
||||||
|
|
||||||
|
/// The row which was being queried
|
||||||
|
row: Option<Ulid>,
|
||||||
|
|
||||||
|
/// The source of the error
|
||||||
|
#[source]
|
||||||
|
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for DatabaseInconsistencyError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Database inconsistency on table {}", self.table)?;
|
||||||
|
if let Some(column) = self.column {
|
||||||
|
write!(f, " column {column}")?;
|
||||||
|
}
|
||||||
|
if let Some(row) = self.row {
|
||||||
|
write!(f, " row {row}")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DatabaseInconsistencyError {
|
||||||
|
/// Create a new [`DatabaseInconsistencyError`] for the given table
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) const fn on(table: &'static str) -> Self {
|
||||||
|
Self {
|
||||||
|
table,
|
||||||
|
column: None,
|
||||||
|
row: None,
|
||||||
|
source: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the column which was being queried
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) const fn column(mut self, column: &'static str) -> Self {
|
||||||
|
self.column = Some(column);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the row which was being queried
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) const fn row(mut self, row: Ulid) -> Self {
|
||||||
|
self.row = Some(row);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Give the source of the error
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
|
||||||
|
mut self,
|
||||||
|
source: E,
|
||||||
|
) -> Self {
|
||||||
|
self.source = Some(Box::new(source));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
@@ -18,6 +18,152 @@
|
|||||||
//! type-checked, using introspection data recorded in the `sqlx-data.json`
|
//! type-checked, using introspection data recorded in the `sqlx-data.json`
|
||||||
//! file. This file is generated by the `sqlx` CLI tool, and should be updated
|
//! file. This file is generated by the `sqlx` CLI tool, and should be updated
|
||||||
//! whenever the database schema changes, or new queries are added.
|
//! whenever the database schema changes, or new queries are added.
|
||||||
|
//!
|
||||||
|
//! # Implementing a new repository
|
||||||
|
//!
|
||||||
|
//! When a new repository is defined in [`mas_storage`], it should be
|
||||||
|
//! implemented here, with the PostgreSQL backend.
|
||||||
|
//!
|
||||||
|
//! A typical implementation will look like this:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! # use async_trait::async_trait;
|
||||||
|
//! # use ulid::Ulid;
|
||||||
|
//! # use rand::RngCore;
|
||||||
|
//! # use mas_storage::Clock;
|
||||||
|
//! # use mas_storage_pg::{DatabaseError, ExecuteExt, LookupResultExt};
|
||||||
|
//! # use sqlx::PgConnection;
|
||||||
|
//! # use uuid::Uuid;
|
||||||
|
//! #
|
||||||
|
//! # // A fake data structure, usually defined in mas-data-model
|
||||||
|
//! # #[derive(sqlx::FromRow)]
|
||||||
|
//! # struct FakeData {
|
||||||
|
//! # id: Ulid,
|
||||||
|
//! # }
|
||||||
|
//! #
|
||||||
|
//! # // A fake repository trait, usually defined in mas-storage
|
||||||
|
//! # #[async_trait]
|
||||||
|
//! # pub trait FakeDataRepository: Send + Sync {
|
||||||
|
//! # type Error;
|
||||||
|
//! # async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
|
||||||
|
//! # async fn add(
|
||||||
|
//! # &mut self,
|
||||||
|
//! # rng: &mut (dyn RngCore + Send),
|
||||||
|
//! # clock: &dyn Clock,
|
||||||
|
//! # ) -> Result<FakeData, Self::Error>;
|
||||||
|
//! # }
|
||||||
|
//! #
|
||||||
|
//! /// An implementation of [`FakeDataRepository`] for a PostgreSQL connection
|
||||||
|
//! pub struct PgFakeDataRepository<'c> {
|
||||||
|
//! conn: &'c mut PgConnection,
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! impl<'c> PgFakeDataRepository<'c> {
|
||||||
|
//! /// Create a new [`FakeDataRepository`] from an active PostgreSQL connection
|
||||||
|
//! pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||||
|
//! Self { conn }
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! #[derive(sqlx::FromRow)]
|
||||||
|
//! struct FakeDataLookup {
|
||||||
|
//! fake_data_id: Uuid,
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! impl From<FakeDataLookup> for FakeData {
|
||||||
|
//! fn from(value: FakeDataLookup) -> Self {
|
||||||
|
//! Self {
|
||||||
|
//! id: value.fake_data_id.into(),
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! #[async_trait]
|
||||||
|
//! impl<'c> FakeDataRepository for PgFakeDataRepository<'c> {
|
||||||
|
//! type Error = DatabaseError;
|
||||||
|
//!
|
||||||
|
//! #[tracing::instrument(
|
||||||
|
//! name = "db.fake_data.lookup",
|
||||||
|
//! skip_all,
|
||||||
|
//! fields(
|
||||||
|
//! db.statement,
|
||||||
|
//! fake_data.id = %id,
|
||||||
|
//! ),
|
||||||
|
//! err,
|
||||||
|
//! )]
|
||||||
|
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error> {
|
||||||
|
//! // Note: here we would use the macro version instead, but it's not possible here in
|
||||||
|
//! // this documentation example
|
||||||
|
//! let res: Option<FakeDataLookup> = sqlx::query_as(
|
||||||
|
//! r#"
|
||||||
|
//! SELECT fake_data_id
|
||||||
|
//! FROM fake_data
|
||||||
|
//! WHERE fake_data_id = $1
|
||||||
|
//! "#,
|
||||||
|
//! )
|
||||||
|
//! .bind(Uuid::from(id))
|
||||||
|
//! .traced()
|
||||||
|
//! .fetch_one(&mut *self.conn)
|
||||||
|
//! .await
|
||||||
|
//! .to_option()?;
|
||||||
|
//!
|
||||||
|
//! let Some(res) = res else { return Ok(None) };
|
||||||
|
//!
|
||||||
|
//! Ok(Some(res.into()))
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! #[tracing::instrument(
|
||||||
|
//! name = "db.fake_data.add",
|
||||||
|
//! skip_all,
|
||||||
|
//! fields(
|
||||||
|
//! db.statement,
|
||||||
|
//! fake_data.id,
|
||||||
|
//! ),
|
||||||
|
//! err,
|
||||||
|
//! )]
|
||||||
|
//! async fn add(
|
||||||
|
//! &mut self,
|
||||||
|
//! rng: &mut (dyn RngCore + Send),
|
||||||
|
//! clock: &dyn Clock,
|
||||||
|
//! ) -> Result<FakeData, Self::Error> {
|
||||||
|
//! let created_at = clock.now();
|
||||||
|
//! let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||||
|
//! tracing::Span::current().record("fake_data.id", tracing::field::display(id));
|
||||||
|
//!
|
||||||
|
//! // Note: here we would use the macro version instead, but it's not possible here in
|
||||||
|
//! // this documentation example
|
||||||
|
//! sqlx::query(
|
||||||
|
//! r#"
|
||||||
|
//! INSERT INTO fake_data (id)
|
||||||
|
//! VALUES ($1)
|
||||||
|
//! "#,
|
||||||
|
//! )
|
||||||
|
//! .bind(Uuid::from(id))
|
||||||
|
//! .traced()
|
||||||
|
//! .execute(&mut *self.conn)
|
||||||
|
//! .await?;
|
||||||
|
//!
|
||||||
|
//! Ok(FakeData {
|
||||||
|
//! id,
|
||||||
|
//! })
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! A few things to note with the implementation:
|
||||||
|
//!
|
||||||
|
//! - All methods are traced, with an explicit, somewhat consistent name.
|
||||||
|
//! - The SQL statement is included as attribute, by declaring a `db.statement`
|
||||||
|
//! attribute on the tracing span, and then calling [`ExecuteExt::traced`].
|
||||||
|
//! - The IDs are all [`Ulid`], and generated from the clock and the random
|
||||||
|
//! number generated passed as parameters. The generated IDs are recorded in
|
||||||
|
//! the span.
|
||||||
|
//! - The IDs are stored as [`Uuid`] in PostgreSQL, so conversions are required
|
||||||
|
//! - "Not found" errors are handled by returning `Ok(None)` instead of an
|
||||||
|
//! error. The [`LookupResultExt::to_option`] method helps to do that.
|
||||||
|
//!
|
||||||
|
//! [`Ulid`]: ulid::Ulid
|
||||||
|
//! [`Uuid`]: uuid::Uuid
|
||||||
|
|
||||||
#![forbid(unsafe_code)]
|
#![forbid(unsafe_code)]
|
||||||
#![deny(
|
#![deny(
|
||||||
@@ -30,17 +176,23 @@
|
|||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(clippy::module_name_repetitions)]
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
use sqlx::{migrate::Migrator, postgres::PgQueryResult};
|
use sqlx::migrate::Migrator;
|
||||||
use thiserror::Error;
|
|
||||||
use ulid::Ulid;
|
|
||||||
|
|
||||||
/// An extension trait for [`Result`] which adds a [`to_option`] method, useful
|
/// An extension trait for [`Result`] which adds a [`to_option`] method, useful
|
||||||
/// for handling "not found" errors from [`sqlx`]
|
/// for handling "not found" errors from [`sqlx`]
|
||||||
trait LookupResultExt {
|
///
|
||||||
|
/// [`to_option`]: LookupResultExt::to_option
|
||||||
|
pub trait LookupResultExt {
|
||||||
|
/// The output type
|
||||||
type Output;
|
type Output;
|
||||||
|
|
||||||
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
|
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
|
||||||
/// into [`None`]
|
/// into [`None`]
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns the original error if the error was not a
|
||||||
|
/// [`sqlx::Error::RowNotFound`] error
|
||||||
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
|
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,143 +208,18 @@ impl<T> LookupResultExt for Result<T, sqlx::Error> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic error when interacting with the database
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
#[error(transparent)]
|
|
||||||
pub enum DatabaseError {
|
|
||||||
/// An error which came from the database itself
|
|
||||||
Driver {
|
|
||||||
/// The underlying error from the database driver
|
|
||||||
#[from]
|
|
||||||
source: sqlx::Error,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// An error which occured while converting the data from the database
|
|
||||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
|
||||||
|
|
||||||
/// An error which happened because the requested database operation is
|
|
||||||
/// invalid
|
|
||||||
#[error("Invalid database operation")]
|
|
||||||
InvalidOperation {
|
|
||||||
/// The source of the error, if any
|
|
||||||
#[source]
|
|
||||||
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// An error which happens when an operation affects not enough or too many
|
|
||||||
/// rows
|
|
||||||
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
|
|
||||||
RowsAffected {
|
|
||||||
/// How many rows were expected to be affected
|
|
||||||
expected: u64,
|
|
||||||
|
|
||||||
/// How many rows were actually affected
|
|
||||||
actual: u64,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DatabaseError {
|
|
||||||
pub(crate) fn ensure_affected_rows(
|
|
||||||
result: &PgQueryResult,
|
|
||||||
expected: u64,
|
|
||||||
) -> Result<(), DatabaseError> {
|
|
||||||
let actual = result.rows_affected();
|
|
||||||
if actual == expected {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(DatabaseError::RowsAffected { expected, actual })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
|
|
||||||
Self::InvalidOperation {
|
|
||||||
source: Some(Box::new(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) const fn invalid_operation() -> Self {
|
|
||||||
Self::InvalidOperation { source: None }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An error which occured while converting the data from the database
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub struct DatabaseInconsistencyError {
|
|
||||||
/// The table which was being queried
|
|
||||||
table: &'static str,
|
|
||||||
|
|
||||||
/// The column which was being queried
|
|
||||||
column: Option<&'static str>,
|
|
||||||
|
|
||||||
/// The row which was being queried
|
|
||||||
row: Option<Ulid>,
|
|
||||||
|
|
||||||
/// The source of the error
|
|
||||||
#[source]
|
|
||||||
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for DatabaseInconsistencyError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "Database inconsistency on table {}", self.table)?;
|
|
||||||
if let Some(column) = self.column {
|
|
||||||
write!(f, " column {column}")?;
|
|
||||||
}
|
|
||||||
if let Some(row) = self.row {
|
|
||||||
write!(f, " row {row}")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DatabaseInconsistencyError {
|
|
||||||
/// Create a new [`DatabaseInconsistencyError`] for the given table
|
|
||||||
#[must_use]
|
|
||||||
pub(crate) const fn on(table: &'static str) -> Self {
|
|
||||||
Self {
|
|
||||||
table,
|
|
||||||
column: None,
|
|
||||||
row: None,
|
|
||||||
source: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the column which was being queried
|
|
||||||
#[must_use]
|
|
||||||
pub(crate) const fn column(mut self, column: &'static str) -> Self {
|
|
||||||
self.column = Some(column);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the row which was being queried
|
|
||||||
#[must_use]
|
|
||||||
pub(crate) const fn row(mut self, row: Ulid) -> Self {
|
|
||||||
self.row = Some(row);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Give the source of the error
|
|
||||||
#[must_use]
|
|
||||||
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
|
|
||||||
mut self,
|
|
||||||
source: E,
|
|
||||||
) -> Self {
|
|
||||||
self.source = Some(Box::new(source));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod compat;
|
pub mod compat;
|
||||||
pub mod oauth2;
|
pub mod oauth2;
|
||||||
pub mod upstream_oauth2;
|
pub mod upstream_oauth2;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
|
|
||||||
|
mod errors;
|
||||||
pub(crate) mod pagination;
|
pub(crate) mod pagination;
|
||||||
pub(crate) mod repository;
|
pub(crate) mod repository;
|
||||||
pub(crate) mod tracing;
|
pub(crate) mod tracing;
|
||||||
|
|
||||||
pub use self::repository::PgRepository;
|
pub(crate) use self::errors::DatabaseInconsistencyError;
|
||||||
|
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
|
||||||
|
|
||||||
/// Embedded migrations, allowing them to run on startup
|
/// Embedded migrations, allowing them to run on startup
|
||||||
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
||||||
|
@@ -18,11 +18,13 @@ use tracing::Span;
|
|||||||
/// `db.statement` in a tracing span
|
/// `db.statement` in a tracing span
|
||||||
pub trait ExecuteExt<'q, DB>: Sized {
|
pub trait ExecuteExt<'q, DB>: Sized {
|
||||||
/// Records the statement as `db.statement` in the current span
|
/// Records the statement as `db.statement` in the current span
|
||||||
|
#[must_use]
|
||||||
fn traced(self) -> Self {
|
fn traced(self) -> Self {
|
||||||
self.record(&Span::current())
|
self.record(&Span::current())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Records the statement as `db.statement` in the given span
|
/// Records the statement as `db.statement` in the given span
|
||||||
|
#[must_use]
|
||||||
fn record(self, span: &Span) -> Self;
|
fn record(self, span: &Span) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -13,6 +13,125 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
//! Interactions with the storage backend
|
//! Interactions with the storage backend
|
||||||
|
//!
|
||||||
|
//! This crate provides a set of traits that can be implemented to interact with
|
||||||
|
//! the storage backend. Those traits are called repositories and are grouped by
|
||||||
|
//! the type of data they manage.
|
||||||
|
//!
|
||||||
|
//! Each of those reposotories can be accessed via the [`RepositoryAccess`]
|
||||||
|
//! trait. This trait can be wrapped in a [`BoxRepository`] to allow using it
|
||||||
|
//! without caring about the underlying storage backend, and without carrying
|
||||||
|
//! around the generic type parameter.
|
||||||
|
//!
|
||||||
|
//! This crate also defines a [`Clock`] trait that can be used to abstract the
|
||||||
|
//! way the current time is retrieved. It has two implementation:
|
||||||
|
//! [`SystemClock`] that uses the system time and [`MockClock`] which is useful
|
||||||
|
//! for testing.
|
||||||
|
//!
|
||||||
|
//! [`MockClock`]: crate::clock::MockClock
|
||||||
|
//!
|
||||||
|
//! # Defining a new repository
|
||||||
|
//!
|
||||||
|
//! To define a new repository, you have to:
|
||||||
|
//! 1. Define a new (async) repository trait, with the methods you need
|
||||||
|
//! 2. Write an implementation of this trait for each storage backend you want
|
||||||
|
//! (currently only for [`mas-storage-pg`])
|
||||||
|
//! 3. Make it accessible via the [`RepositoryAccess`] trait
|
||||||
|
//!
|
||||||
|
//! The repository trait definition should look like this:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! # use async_trait::async_trait;
|
||||||
|
//! # use ulid::Ulid;
|
||||||
|
//! # use rand_core::RngCore;
|
||||||
|
//! # use mas_storage::Clock;
|
||||||
|
//! #
|
||||||
|
//! # // A fake data structure, usually defined in mas-data-model
|
||||||
|
//! # struct FakeData {
|
||||||
|
//! # id: Ulid,
|
||||||
|
//! # }
|
||||||
|
//! #
|
||||||
|
//! # // A fake empty macro, to replace `mas_storage::repository_impl`
|
||||||
|
//! # macro_rules! repository_impl { ($($tok:tt)*) => {} }
|
||||||
|
//!
|
||||||
|
//! #[async_trait]
|
||||||
|
//! pub trait FakeDataRepository: Send + Sync {
|
||||||
|
//! /// The error type returned by the repository
|
||||||
|
//! type Error;
|
||||||
|
//!
|
||||||
|
//! /// Lookup a [`FakeData`] by its ID
|
||||||
|
//! ///
|
||||||
|
//! /// Returns `None` if no [`FakeData`] was found
|
||||||
|
//! ///
|
||||||
|
//! /// # Parameters
|
||||||
|
//! ///
|
||||||
|
//! /// * `id`: The ID of the [`FakeData`] to lookup
|
||||||
|
//! ///
|
||||||
|
//! /// # Errors
|
||||||
|
//! ///
|
||||||
|
//! /// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
|
||||||
|
//!
|
||||||
|
//! /// Create a new [`FakeData`]
|
||||||
|
//! ///
|
||||||
|
//! /// Returns the newly-created [`FakeData`].
|
||||||
|
//! ///
|
||||||
|
//! /// # Parameters
|
||||||
|
//! ///
|
||||||
|
//! /// * `rng`: The random number generator to use
|
||||||
|
//! /// * `clock`: The clock used to generate timestamps
|
||||||
|
//! ///
|
||||||
|
//! /// # Errors
|
||||||
|
//! ///
|
||||||
|
//! /// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
//! async fn add(
|
||||||
|
//! &mut self,
|
||||||
|
//! rng: &mut (dyn RngCore + Send),
|
||||||
|
//! clock: &dyn Clock,
|
||||||
|
//! ) -> Result<FakeData, Self::Error>;
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! repository_impl!(FakeDataRepository:
|
||||||
|
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
|
||||||
|
//! async fn add(
|
||||||
|
//! &mut self,
|
||||||
|
//! rng: &mut (dyn RngCore + Send),
|
||||||
|
//! clock: &dyn Clock,
|
||||||
|
//! ) -> Result<FakeData, Self::Error>;
|
||||||
|
//! );
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Four things to note with the implementation:
|
||||||
|
//!
|
||||||
|
//! 1. It defined an assocated error type, and all functions are faillible,
|
||||||
|
//! and use that error type
|
||||||
|
//! 2. Lookups return an `Result<Option<T>, Self::Error>`, because 'not found'
|
||||||
|
//! errors are usually cases that are handled differently
|
||||||
|
//! 3. Operations that need to record the current type use a [`Clock`]
|
||||||
|
//! parameter. Operations that need to generate new IDs also use a random
|
||||||
|
//! number generator.
|
||||||
|
//! 4. All the methods use an `&mut self`. This is ensures only one operation
|
||||||
|
//! is done at a time on a single repository instance.
|
||||||
|
//!
|
||||||
|
//! Then update the [`RepositoryAccess`] trait to make the new repository
|
||||||
|
//! available:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! # trait FakeDataRepository {
|
||||||
|
//! # type Error;
|
||||||
|
//! # }
|
||||||
|
//!
|
||||||
|
//! /// Access the various repositories the backend implements.
|
||||||
|
//! pub trait RepositoryAccess: Send {
|
||||||
|
//! /// The backend-specific error type used by each repository.
|
||||||
|
//! type Error: std::error::Error + Send + Sync + 'static;
|
||||||
|
//!
|
||||||
|
//! // ...other repositories...
|
||||||
|
//!
|
||||||
|
//! /// Get a [`FakeDataRepository`]
|
||||||
|
//! fn fake_data<'c>(&'c mut self) -> Box<dyn FakeDataRepository<Error = Self::Error> + 'c>;
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
#![forbid(unsafe_code)]
|
#![forbid(unsafe_code)]
|
||||||
#![deny(
|
#![deny(
|
||||||
@@ -25,11 +144,10 @@
|
|||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(clippy::module_name_repetitions)]
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
use rand_core::CryptoRngCore;
|
|
||||||
|
|
||||||
pub mod clock;
|
pub mod clock;
|
||||||
pub mod pagination;
|
pub mod pagination;
|
||||||
pub(crate) mod repository;
|
pub(crate) mod repository;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
pub mod compat;
|
pub mod compat;
|
||||||
pub mod oauth2;
|
pub mod oauth2;
|
||||||
@@ -42,66 +160,5 @@ pub use self::{
|
|||||||
repository::{
|
repository::{
|
||||||
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
|
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
|
||||||
},
|
},
|
||||||
|
utils::{BoxClock, BoxRng, MapErr},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A wrapper which is used to map the error type of a repository to another
|
|
||||||
pub struct MapErr<R, F> {
|
|
||||||
inner: R,
|
|
||||||
mapper: F,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R, F> MapErr<R, F> {
|
|
||||||
fn new(inner: R, mapper: F) -> Self {
|
|
||||||
Self { inner, mapper }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A macro to implement a repository trait for the [`MapErr`] wrapper and for
|
|
||||||
/// [`Box<R>`]
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! repository_impl {
|
|
||||||
($repo_trait:ident:
|
|
||||||
$(
|
|
||||||
async fn $method:ident (
|
|
||||||
&mut self
|
|
||||||
$(, $arg:ident: $arg_ty:ty )*
|
|
||||||
$(,)?
|
|
||||||
) -> Result<$ret_ty:ty, Self::Error>;
|
|
||||||
)*
|
|
||||||
) => {
|
|
||||||
#[::async_trait::async_trait]
|
|
||||||
impl<R: ?Sized> $repo_trait for ::std::boxed::Box<R>
|
|
||||||
where
|
|
||||||
R: $repo_trait,
|
|
||||||
{
|
|
||||||
type Error = <R as $repo_trait>::Error;
|
|
||||||
|
|
||||||
$(
|
|
||||||
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
|
|
||||||
(**self).$method ( $($arg),* ).await
|
|
||||||
}
|
|
||||||
)*
|
|
||||||
}
|
|
||||||
|
|
||||||
#[::async_trait::async_trait]
|
|
||||||
impl<R, F, E> $repo_trait for $crate::MapErr<R, F>
|
|
||||||
where
|
|
||||||
R: $repo_trait,
|
|
||||||
F: FnMut(<R as $repo_trait>::Error) -> E + ::std::marker::Send + ::std::marker::Sync,
|
|
||||||
{
|
|
||||||
type Error = E;
|
|
||||||
|
|
||||||
$(
|
|
||||||
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
|
|
||||||
self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper)
|
|
||||||
}
|
|
||||||
)*
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A boxed [`Clock`]
|
|
||||||
pub type BoxClock = Box<dyn Clock + Send>;
|
|
||||||
|
|
||||||
/// A boxed random number generator
|
|
||||||
pub type BoxRng = Box<dyn CryptoRngCore + Send>;
|
|
||||||
|
@@ -101,6 +101,20 @@ pub trait RepositoryTransaction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Access the various repositories the backend implements.
|
/// Access the various repositories the backend implements.
|
||||||
|
///
|
||||||
|
/// All the methods return a boxed trait object, which can be used to access a
|
||||||
|
/// particular repository. The lifetime of the returned object is bound to the
|
||||||
|
/// lifetime of the whole repository, so that only one mutable reference to the
|
||||||
|
/// repository is used at a time.
|
||||||
|
///
|
||||||
|
/// When adding a new repository, you should add a new method to this trait, and
|
||||||
|
/// update the implementations for [`MapErr`] and [`Box<R>`] below.
|
||||||
|
///
|
||||||
|
/// Note: this used to have generic associated types to avoid boxing all the
|
||||||
|
/// repository traits, but that was removed because it made almost impossible to
|
||||||
|
/// box the trait object. This might be a shortcoming of the initial
|
||||||
|
/// implementation of generic associated types, and might be fixed in the
|
||||||
|
/// future.
|
||||||
pub trait RepositoryAccess: Send {
|
pub trait RepositoryAccess: Send {
|
||||||
/// The backend-specific error type used by each repository.
|
/// The backend-specific error type used by each repository.
|
||||||
type Error: std::error::Error + Send + Sync + 'static;
|
type Error: std::error::Error + Send + Sync + 'static;
|
||||||
|
86
crates/storage/src/utils.rs
Normal file
86
crates/storage/src/utils.rs
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! Wrappers and useful type aliases
|
||||||
|
|
||||||
|
use rand_core::CryptoRngCore;
|
||||||
|
|
||||||
|
use crate::Clock;
|
||||||
|
|
||||||
|
/// A wrapper which is used to map the error type of a repository to another
|
||||||
|
pub struct MapErr<R, F> {
|
||||||
|
pub(crate) inner: R,
|
||||||
|
pub(crate) mapper: F,
|
||||||
|
_private: (),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, F> MapErr<R, F> {
|
||||||
|
pub(crate) fn new(inner: R, mapper: F) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
mapper,
|
||||||
|
_private: (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A boxed [`Clock`]
|
||||||
|
pub type BoxClock = Box<dyn Clock + Send>;
|
||||||
|
|
||||||
|
/// A boxed random number generator
|
||||||
|
pub type BoxRng = Box<dyn CryptoRngCore + Send>;
|
||||||
|
|
||||||
|
/// A macro to implement a repository trait for the [`MapErr`] wrapper and for
|
||||||
|
/// [`Box<R>`]
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! repository_impl {
|
||||||
|
($repo_trait:ident:
|
||||||
|
$(
|
||||||
|
async fn $method:ident (
|
||||||
|
&mut self
|
||||||
|
$(, $arg:ident: $arg_ty:ty )*
|
||||||
|
$(,)?
|
||||||
|
) -> Result<$ret_ty:ty, Self::Error>;
|
||||||
|
)*
|
||||||
|
) => {
|
||||||
|
#[::async_trait::async_trait]
|
||||||
|
impl<R: ?Sized> $repo_trait for ::std::boxed::Box<R>
|
||||||
|
where
|
||||||
|
R: $repo_trait,
|
||||||
|
{
|
||||||
|
type Error = <R as $repo_trait>::Error;
|
||||||
|
|
||||||
|
$(
|
||||||
|
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
|
||||||
|
(**self).$method ( $($arg),* ).await
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
|
||||||
|
#[::async_trait::async_trait]
|
||||||
|
impl<R, F, E> $repo_trait for $crate::MapErr<R, F>
|
||||||
|
where
|
||||||
|
R: $repo_trait,
|
||||||
|
F: FnMut(<R as $repo_trait>::Error) -> E + ::std::marker::Send + ::std::marker::Sync,
|
||||||
|
{
|
||||||
|
type Error = E;
|
||||||
|
|
||||||
|
$(
|
||||||
|
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
|
||||||
|
self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper)
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
@@ -3,6 +3,21 @@
|
|||||||
Interactions with the database goes through `sqlx`.
|
Interactions with the database goes through `sqlx`.
|
||||||
It provides async database operations with connection pooling, migrations support and compile-time check of queries through macros.
|
It provides async database operations with connection pooling, migrations support and compile-time check of queries through macros.
|
||||||
|
|
||||||
|
## Writing database interactions
|
||||||
|
|
||||||
|
All database interactions are done through repositoriy traits. Each repository trait usually manages one type of data, defined in the `mas-data-model` crate.
|
||||||
|
|
||||||
|
Defining a new data type and associated repository looks like this:
|
||||||
|
|
||||||
|
- Define new structs in `mas-data-model` crate
|
||||||
|
- Define the repository trait in `mas-storage` crate
|
||||||
|
- Make that repository trait available via the `RepositoryAccess` trait in `mas-storage` crate
|
||||||
|
- Setup the database schema by writing a migration file in `mas-storage-pg` crate
|
||||||
|
- Implement the new repository trait in `mas-storage-pg` crate
|
||||||
|
- Write tests for the PostgreSQL implementation in `mas-storage-pg` crate
|
||||||
|
|
||||||
|
Some of those steps are documented in more details in the `mas-storage` and `mas-storage-pg` crates.
|
||||||
|
|
||||||
## Compile-time check of queries
|
## Compile-time check of queries
|
||||||
|
|
||||||
To be able to check queries, `sqlx` has to introspect the live database.
|
To be able to check queries, `sqlx` has to introspect the live database.
|
||||||
@@ -14,7 +29,7 @@ Preparing this flat file is done through `sqlx-cli`, and should be done everytim
|
|||||||
# Install the CLI
|
# Install the CLI
|
||||||
cargo install sqlx-cli --no-default-features --features postgres
|
cargo install sqlx-cli --no-default-features --features postgres
|
||||||
|
|
||||||
cd crates/storage/ # Must be in the mas-storage crate folder
|
cd crates/storage-pg/ # Must be in the mas-storage-pg crate folder
|
||||||
export DATABASE_URL=postgresql:///matrix_auth
|
export DATABASE_URL=postgresql:///matrix_auth
|
||||||
cargo sqlx prepare
|
cargo sqlx prepare
|
||||||
```
|
```
|
||||||
@@ -24,75 +39,10 @@ cargo sqlx prepare
|
|||||||
Migration files live in the `migrations` folder in the `mas-core` crate.
|
Migration files live in the `migrations` folder in the `mas-core` crate.
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
cd crates/storage/ # Again, in the mas-storage crate folder
|
cd crates/storage-pg/ # Again, in the mas-storage-pg crate folder
|
||||||
export DATABASE_URL=postgresql:///matrix_auth
|
export DATABASE_URL=postgresql:///matrix_auth
|
||||||
cargo sqlx migrate run # Run pending migrations
|
cargo sqlx migrate run # Run pending migrations
|
||||||
cargo sqlx migrate revert # Revert the last migration
|
cargo sqlx migrate add [description] # Add new migration files
|
||||||
cargo sqlx migrate add -r [description] # Add new migration files
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that migrations are embedded in the final binary and can be run from the service CLI tool.
|
Note that migrations are embedded in the final binary and can be run from the service CLI tool.
|
||||||
|
|
||||||
## Writing database interactions
|
|
||||||
|
|
||||||
**TODO**: *This section is outdated.*
|
|
||||||
|
|
||||||
A typical interaction with the database look like this:
|
|
||||||
|
|
||||||
```rust
|
|
||||||
pub async fn lookup_session(
|
|
||||||
executor: impl Executor<'_, Database = Postgres>,
|
|
||||||
id: i64,
|
|
||||||
) -> anyhow::Result<SessionInfo> {
|
|
||||||
sqlx::query_as!(
|
|
||||||
SessionInfo, // Struct that will be filled with the result
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
s.id,
|
|
||||||
u.id as user_id,
|
|
||||||
u.username,
|
|
||||||
s.active,
|
|
||||||
s.created_at,
|
|
||||||
a.created_at as "last_authd_at?"
|
|
||||||
FROM user_sessions s
|
|
||||||
INNER JOIN users u
|
|
||||||
ON s.user_id = u.id
|
|
||||||
LEFT JOIN user_session_authentications a
|
|
||||||
ON a.session_id = s.id
|
|
||||||
WHERE s.id = $1
|
|
||||||
ORDER BY a.created_at DESC
|
|
||||||
LIMIT 1
|
|
||||||
"#,
|
|
||||||
id, // Query parameter
|
|
||||||
)
|
|
||||||
.fetch_one(executor)
|
|
||||||
.await
|
|
||||||
// Providing some context when there is an error
|
|
||||||
.context("could not fetch session")
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that we pass an `impl Executor` as parameter here.
|
|
||||||
This allows us to use this function from either a simple connection or from an active transaction.
|
|
||||||
|
|
||||||
The caveat here is that the `executor` can be used only once, so if an interaction needs to do multiple queries, it should probably take an `impl Acquire` to then acquire a transaction and do multiple interactions.
|
|
||||||
|
|
||||||
```rust
|
|
||||||
pub async fn login(
|
|
||||||
conn: impl Acquire<'_, Database = Postgres>,
|
|
||||||
username: &str,
|
|
||||||
password: String,
|
|
||||||
) -> Result<SessionInfo, LoginError> {
|
|
||||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
|
||||||
// First interaction
|
|
||||||
let user = lookup_user_by_username(&mut txn, username)?;
|
|
||||||
// Second interaction
|
|
||||||
let mut session = start_session(&mut txn, user).await?;
|
|
||||||
// Third interaction
|
|
||||||
session.last_authd_at =
|
|
||||||
Some(authenticate_session(&mut txn, session.id, password).await?);
|
|
||||||
// Commit the transaction once everything went fine
|
|
||||||
txn.commit().await.context("could not commit transaction")?;
|
|
||||||
Ok(session)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
Reference in New Issue
Block a user