1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

storage{,-pg}: better documentation of both crates

This commit is contained in:
Quentin Gliech
2023-01-26 17:58:03 +01:00
parent 0bf1a1998e
commit 6ad8b82a35
7 changed files with 544 additions and 264 deletions

View File

@@ -0,0 +1,144 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::postgres::PgQueryResult;
use thiserror::Error;
use ulid::Ulid;
/// Generic error when interacting with the database
#[derive(Debug, Error)]
#[error(transparent)]
pub enum DatabaseError {
/// An error which came from the database itself
Driver {
/// The underlying error from the database driver
#[from]
source: sqlx::Error,
},
/// An error which occured while converting the data from the database
Inconsistency(#[from] DatabaseInconsistencyError),
/// An error which happened because the requested database operation is
/// invalid
#[error("Invalid database operation")]
InvalidOperation {
/// The source of the error, if any
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
/// An error which happens when an operation affects not enough or too many
/// rows
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
RowsAffected {
/// How many rows were expected to be affected
expected: u64,
/// How many rows were actually affected
actual: u64,
},
}
impl DatabaseError {
pub(crate) fn ensure_affected_rows(
result: &PgQueryResult,
expected: u64,
) -> Result<(), DatabaseError> {
let actual = result.rows_affected();
if actual == expected {
Ok(())
} else {
Err(DatabaseError::RowsAffected { expected, actual })
}
}
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
Self::InvalidOperation {
source: Some(Box::new(e)),
}
}
pub(crate) const fn invalid_operation() -> Self {
Self::InvalidOperation { source: None }
}
}
/// An error which occured while converting the data from the database
#[derive(Debug, Error)]
pub struct DatabaseInconsistencyError {
/// The table which was being queried
table: &'static str,
/// The column which was being queried
column: Option<&'static str>,
/// The row which was being queried
row: Option<Ulid>,
/// The source of the error
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
}
impl std::fmt::Display for DatabaseInconsistencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Database inconsistency on table {}", self.table)?;
if let Some(column) = self.column {
write!(f, " column {column}")?;
}
if let Some(row) = self.row {
write!(f, " row {row}")?;
}
Ok(())
}
}
impl DatabaseInconsistencyError {
/// Create a new [`DatabaseInconsistencyError`] for the given table
#[must_use]
pub(crate) const fn on(table: &'static str) -> Self {
Self {
table,
column: None,
row: None,
source: None,
}
}
/// Set the column which was being queried
#[must_use]
pub(crate) const fn column(mut self, column: &'static str) -> Self {
self.column = Some(column);
self
}
/// Set the row which was being queried
#[must_use]
pub(crate) const fn row(mut self, row: Ulid) -> Self {
self.row = Some(row);
self
}
/// Give the source of the error
#[must_use]
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
mut self,
source: E,
) -> Self {
self.source = Some(Box::new(source));
self
}
}

View File

@@ -18,6 +18,152 @@
//! type-checked, using introspection data recorded in the `sqlx-data.json` //! type-checked, using introspection data recorded in the `sqlx-data.json`
//! file. This file is generated by the `sqlx` CLI tool, and should be updated //! file. This file is generated by the `sqlx` CLI tool, and should be updated
//! whenever the database schema changes, or new queries are added. //! whenever the database schema changes, or new queries are added.
//!
//! # Implementing a new repository
//!
//! When a new repository is defined in [`mas_storage`], it should be
//! implemented here, with the PostgreSQL backend.
//!
//! A typical implementation will look like this:
//!
//! ```rust
//! # use async_trait::async_trait;
//! # use ulid::Ulid;
//! # use rand::RngCore;
//! # use mas_storage::Clock;
//! # use mas_storage_pg::{DatabaseError, ExecuteExt, LookupResultExt};
//! # use sqlx::PgConnection;
//! # use uuid::Uuid;
//! #
//! # // A fake data structure, usually defined in mas-data-model
//! # #[derive(sqlx::FromRow)]
//! # struct FakeData {
//! # id: Ulid,
//! # }
//! #
//! # // A fake repository trait, usually defined in mas-storage
//! # #[async_trait]
//! # pub trait FakeDataRepository: Send + Sync {
//! # type Error;
//! # async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
//! # async fn add(
//! # &mut self,
//! # rng: &mut (dyn RngCore + Send),
//! # clock: &dyn Clock,
//! # ) -> Result<FakeData, Self::Error>;
//! # }
//! #
//! /// An implementation of [`FakeDataRepository`] for a PostgreSQL connection
//! pub struct PgFakeDataRepository<'c> {
//! conn: &'c mut PgConnection,
//! }
//!
//! impl<'c> PgFakeDataRepository<'c> {
//! /// Create a new [`FakeDataRepository`] from an active PostgreSQL connection
//! pub fn new(conn: &'c mut PgConnection) -> Self {
//! Self { conn }
//! }
//! }
//!
//! #[derive(sqlx::FromRow)]
//! struct FakeDataLookup {
//! fake_data_id: Uuid,
//! }
//!
//! impl From<FakeDataLookup> for FakeData {
//! fn from(value: FakeDataLookup) -> Self {
//! Self {
//! id: value.fake_data_id.into(),
//! }
//! }
//! }
//!
//! #[async_trait]
//! impl<'c> FakeDataRepository for PgFakeDataRepository<'c> {
//! type Error = DatabaseError;
//!
//! #[tracing::instrument(
//! name = "db.fake_data.lookup",
//! skip_all,
//! fields(
//! db.statement,
//! fake_data.id = %id,
//! ),
//! err,
//! )]
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error> {
//! // Note: here we would use the macro version instead, but it's not possible here in
//! // this documentation example
//! let res: Option<FakeDataLookup> = sqlx::query_as(
//! r#"
//! SELECT fake_data_id
//! FROM fake_data
//! WHERE fake_data_id = $1
//! "#,
//! )
//! .bind(Uuid::from(id))
//! .traced()
//! .fetch_one(&mut *self.conn)
//! .await
//! .to_option()?;
//!
//! let Some(res) = res else { return Ok(None) };
//!
//! Ok(Some(res.into()))
//! }
//!
//! #[tracing::instrument(
//! name = "db.fake_data.add",
//! skip_all,
//! fields(
//! db.statement,
//! fake_data.id,
//! ),
//! err,
//! )]
//! async fn add(
//! &mut self,
//! rng: &mut (dyn RngCore + Send),
//! clock: &dyn Clock,
//! ) -> Result<FakeData, Self::Error> {
//! let created_at = clock.now();
//! let id = Ulid::from_datetime_with_source(created_at.into(), rng);
//! tracing::Span::current().record("fake_data.id", tracing::field::display(id));
//!
//! // Note: here we would use the macro version instead, but it's not possible here in
//! // this documentation example
//! sqlx::query(
//! r#"
//! INSERT INTO fake_data (id)
//! VALUES ($1)
//! "#,
//! )
//! .bind(Uuid::from(id))
//! .traced()
//! .execute(&mut *self.conn)
//! .await?;
//!
//! Ok(FakeData {
//! id,
//! })
//! }
//! }
//! ```
//!
//! A few things to note with the implementation:
//!
//! - All methods are traced, with an explicit, somewhat consistent name.
//! - The SQL statement is included as attribute, by declaring a `db.statement`
//! attribute on the tracing span, and then calling [`ExecuteExt::traced`].
//! - The IDs are all [`Ulid`], and generated from the clock and the random
//! number generated passed as parameters. The generated IDs are recorded in
//! the span.
//! - The IDs are stored as [`Uuid`] in PostgreSQL, so conversions are required
//! - "Not found" errors are handled by returning `Ok(None)` instead of an
//! error. The [`LookupResultExt::to_option`] method helps to do that.
//!
//! [`Ulid`]: ulid::Ulid
//! [`Uuid`]: uuid::Uuid
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
#![deny( #![deny(
@@ -30,17 +176,23 @@
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)] #![allow(clippy::module_name_repetitions)]
use sqlx::{migrate::Migrator, postgres::PgQueryResult}; use sqlx::migrate::Migrator;
use thiserror::Error;
use ulid::Ulid;
/// An extension trait for [`Result`] which adds a [`to_option`] method, useful /// An extension trait for [`Result`] which adds a [`to_option`] method, useful
/// for handling "not found" errors from [`sqlx`] /// for handling "not found" errors from [`sqlx`]
trait LookupResultExt { ///
/// [`to_option`]: LookupResultExt::to_option
pub trait LookupResultExt {
/// The output type
type Output; type Output;
/// Transform a [`Result`] from a sqlx query to transform "not found" errors /// Transform a [`Result`] from a sqlx query to transform "not found" errors
/// into [`None`] /// into [`None`]
///
/// # Errors
///
/// Returns the original error if the error was not a
/// [`sqlx::Error::RowNotFound`] error
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>; fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
} }
@@ -56,143 +208,18 @@ impl<T> LookupResultExt for Result<T, sqlx::Error> {
} }
} }
/// Generic error when interacting with the database
#[derive(Debug, Error)]
#[error(transparent)]
pub enum DatabaseError {
/// An error which came from the database itself
Driver {
/// The underlying error from the database driver
#[from]
source: sqlx::Error,
},
/// An error which occured while converting the data from the database
Inconsistency(#[from] DatabaseInconsistencyError),
/// An error which happened because the requested database operation is
/// invalid
#[error("Invalid database operation")]
InvalidOperation {
/// The source of the error, if any
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
/// An error which happens when an operation affects not enough or too many
/// rows
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
RowsAffected {
/// How many rows were expected to be affected
expected: u64,
/// How many rows were actually affected
actual: u64,
},
}
impl DatabaseError {
pub(crate) fn ensure_affected_rows(
result: &PgQueryResult,
expected: u64,
) -> Result<(), DatabaseError> {
let actual = result.rows_affected();
if actual == expected {
Ok(())
} else {
Err(DatabaseError::RowsAffected { expected, actual })
}
}
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
Self::InvalidOperation {
source: Some(Box::new(e)),
}
}
pub(crate) const fn invalid_operation() -> Self {
Self::InvalidOperation { source: None }
}
}
/// An error which occured while converting the data from the database
#[derive(Debug, Error)]
pub struct DatabaseInconsistencyError {
/// The table which was being queried
table: &'static str,
/// The column which was being queried
column: Option<&'static str>,
/// The row which was being queried
row: Option<Ulid>,
/// The source of the error
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
}
impl std::fmt::Display for DatabaseInconsistencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Database inconsistency on table {}", self.table)?;
if let Some(column) = self.column {
write!(f, " column {column}")?;
}
if let Some(row) = self.row {
write!(f, " row {row}")?;
}
Ok(())
}
}
impl DatabaseInconsistencyError {
/// Create a new [`DatabaseInconsistencyError`] for the given table
#[must_use]
pub(crate) const fn on(table: &'static str) -> Self {
Self {
table,
column: None,
row: None,
source: None,
}
}
/// Set the column which was being queried
#[must_use]
pub(crate) const fn column(mut self, column: &'static str) -> Self {
self.column = Some(column);
self
}
/// Set the row which was being queried
#[must_use]
pub(crate) const fn row(mut self, row: Ulid) -> Self {
self.row = Some(row);
self
}
/// Give the source of the error
#[must_use]
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
mut self,
source: E,
) -> Self {
self.source = Some(Box::new(source));
self
}
}
pub mod compat; pub mod compat;
pub mod oauth2; pub mod oauth2;
pub mod upstream_oauth2; pub mod upstream_oauth2;
pub mod user; pub mod user;
mod errors;
pub(crate) mod pagination; pub(crate) mod pagination;
pub(crate) mod repository; pub(crate) mod repository;
pub(crate) mod tracing; pub(crate) mod tracing;
pub use self::repository::PgRepository; pub(crate) use self::errors::DatabaseInconsistencyError;
pub use self::{errors::DatabaseError, repository::PgRepository, tracing::ExecuteExt};
/// Embedded migrations, allowing them to run on startup /// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!(); pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@@ -18,11 +18,13 @@ use tracing::Span;
/// `db.statement` in a tracing span /// `db.statement` in a tracing span
pub trait ExecuteExt<'q, DB>: Sized { pub trait ExecuteExt<'q, DB>: Sized {
/// Records the statement as `db.statement` in the current span /// Records the statement as `db.statement` in the current span
#[must_use]
fn traced(self) -> Self { fn traced(self) -> Self {
self.record(&Span::current()) self.record(&Span::current())
} }
/// Records the statement as `db.statement` in the given span /// Records the statement as `db.statement` in the given span
#[must_use]
fn record(self, span: &Span) -> Self; fn record(self, span: &Span) -> Self;
} }

View File

@@ -13,6 +13,125 @@
// limitations under the License. // limitations under the License.
//! Interactions with the storage backend //! Interactions with the storage backend
//!
//! This crate provides a set of traits that can be implemented to interact with
//! the storage backend. Those traits are called repositories and are grouped by
//! the type of data they manage.
//!
//! Each of those reposotories can be accessed via the [`RepositoryAccess`]
//! trait. This trait can be wrapped in a [`BoxRepository`] to allow using it
//! without caring about the underlying storage backend, and without carrying
//! around the generic type parameter.
//!
//! This crate also defines a [`Clock`] trait that can be used to abstract the
//! way the current time is retrieved. It has two implementation:
//! [`SystemClock`] that uses the system time and [`MockClock`] which is useful
//! for testing.
//!
//! [`MockClock`]: crate::clock::MockClock
//!
//! # Defining a new repository
//!
//! To define a new repository, you have to:
//! 1. Define a new (async) repository trait, with the methods you need
//! 2. Write an implementation of this trait for each storage backend you want
//! (currently only for [`mas-storage-pg`])
//! 3. Make it accessible via the [`RepositoryAccess`] trait
//!
//! The repository trait definition should look like this:
//!
//! ```rust
//! # use async_trait::async_trait;
//! # use ulid::Ulid;
//! # use rand_core::RngCore;
//! # use mas_storage::Clock;
//! #
//! # // A fake data structure, usually defined in mas-data-model
//! # struct FakeData {
//! # id: Ulid,
//! # }
//! #
//! # // A fake empty macro, to replace `mas_storage::repository_impl`
//! # macro_rules! repository_impl { ($($tok:tt)*) => {} }
//!
//! #[async_trait]
//! pub trait FakeDataRepository: Send + Sync {
//! /// The error type returned by the repository
//! type Error;
//!
//! /// Lookup a [`FakeData`] by its ID
//! ///
//! /// Returns `None` if no [`FakeData`] was found
//! ///
//! /// # Parameters
//! ///
//! /// * `id`: The ID of the [`FakeData`] to lookup
//! ///
//! /// # Errors
//! ///
//! /// Returns [`Self::Error`] if the underlying repository fails
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
//!
//! /// Create a new [`FakeData`]
//! ///
//! /// Returns the newly-created [`FakeData`].
//! ///
//! /// # Parameters
//! ///
//! /// * `rng`: The random number generator to use
//! /// * `clock`: The clock used to generate timestamps
//! ///
//! /// # Errors
//! ///
//! /// Returns [`Self::Error`] if the underlying repository fails
//! async fn add(
//! &mut self,
//! rng: &mut (dyn RngCore + Send),
//! clock: &dyn Clock,
//! ) -> Result<FakeData, Self::Error>;
//! }
//!
//! repository_impl!(FakeDataRepository:
//! async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
//! async fn add(
//! &mut self,
//! rng: &mut (dyn RngCore + Send),
//! clock: &dyn Clock,
//! ) -> Result<FakeData, Self::Error>;
//! );
//! ```
//!
//! Four things to note with the implementation:
//!
//! 1. It defined an assocated error type, and all functions are faillible,
//! and use that error type
//! 2. Lookups return an `Result<Option<T>, Self::Error>`, because 'not found'
//! errors are usually cases that are handled differently
//! 3. Operations that need to record the current type use a [`Clock`]
//! parameter. Operations that need to generate new IDs also use a random
//! number generator.
//! 4. All the methods use an `&mut self`. This is ensures only one operation
//! is done at a time on a single repository instance.
//!
//! Then update the [`RepositoryAccess`] trait to make the new repository
//! available:
//!
//! ```rust
//! # trait FakeDataRepository {
//! # type Error;
//! # }
//!
//! /// Access the various repositories the backend implements.
//! pub trait RepositoryAccess: Send {
//! /// The backend-specific error type used by each repository.
//! type Error: std::error::Error + Send + Sync + 'static;
//!
//! // ...other repositories...
//!
//! /// Get a [`FakeDataRepository`]
//! fn fake_data<'c>(&'c mut self) -> Box<dyn FakeDataRepository<Error = Self::Error> + 'c>;
//! }
//! ```
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
#![deny( #![deny(
@@ -25,11 +144,10 @@
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)] #![allow(clippy::module_name_repetitions)]
use rand_core::CryptoRngCore;
pub mod clock; pub mod clock;
pub mod pagination; pub mod pagination;
pub(crate) mod repository; pub(crate) mod repository;
mod utils;
pub mod compat; pub mod compat;
pub mod oauth2; pub mod oauth2;
@@ -42,66 +160,5 @@ pub use self::{
repository::{ repository::{
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
}, },
utils::{BoxClock, BoxRng, MapErr},
}; };
/// A wrapper which is used to map the error type of a repository to another
pub struct MapErr<R, F> {
inner: R,
mapper: F,
}
impl<R, F> MapErr<R, F> {
fn new(inner: R, mapper: F) -> Self {
Self { inner, mapper }
}
}
/// A macro to implement a repository trait for the [`MapErr`] wrapper and for
/// [`Box<R>`]
#[macro_export]
macro_rules! repository_impl {
($repo_trait:ident:
$(
async fn $method:ident (
&mut self
$(, $arg:ident: $arg_ty:ty )*
$(,)?
) -> Result<$ret_ty:ty, Self::Error>;
)*
) => {
#[::async_trait::async_trait]
impl<R: ?Sized> $repo_trait for ::std::boxed::Box<R>
where
R: $repo_trait,
{
type Error = <R as $repo_trait>::Error;
$(
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
(**self).$method ( $($arg),* ).await
}
)*
}
#[::async_trait::async_trait]
impl<R, F, E> $repo_trait for $crate::MapErr<R, F>
where
R: $repo_trait,
F: FnMut(<R as $repo_trait>::Error) -> E + ::std::marker::Send + ::std::marker::Sync,
{
type Error = E;
$(
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper)
}
)*
}
};
}
/// A boxed [`Clock`]
pub type BoxClock = Box<dyn Clock + Send>;
/// A boxed random number generator
pub type BoxRng = Box<dyn CryptoRngCore + Send>;

View File

@@ -101,6 +101,20 @@ pub trait RepositoryTransaction {
} }
/// Access the various repositories the backend implements. /// Access the various repositories the backend implements.
///
/// All the methods return a boxed trait object, which can be used to access a
/// particular repository. The lifetime of the returned object is bound to the
/// lifetime of the whole repository, so that only one mutable reference to the
/// repository is used at a time.
///
/// When adding a new repository, you should add a new method to this trait, and
/// update the implementations for [`MapErr`] and [`Box<R>`] below.
///
/// Note: this used to have generic associated types to avoid boxing all the
/// repository traits, but that was removed because it made almost impossible to
/// box the trait object. This might be a shortcoming of the initial
/// implementation of generic associated types, and might be fixed in the
/// future.
pub trait RepositoryAccess: Send { pub trait RepositoryAccess: Send {
/// The backend-specific error type used by each repository. /// The backend-specific error type used by each repository.
type Error: std::error::Error + Send + Sync + 'static; type Error: std::error::Error + Send + Sync + 'static;

View File

@@ -0,0 +1,86 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Wrappers and useful type aliases
use rand_core::CryptoRngCore;
use crate::Clock;
/// A wrapper which is used to map the error type of a repository to another
pub struct MapErr<R, F> {
pub(crate) inner: R,
pub(crate) mapper: F,
_private: (),
}
impl<R, F> MapErr<R, F> {
pub(crate) fn new(inner: R, mapper: F) -> Self {
Self {
inner,
mapper,
_private: (),
}
}
}
/// A boxed [`Clock`]
pub type BoxClock = Box<dyn Clock + Send>;
/// A boxed random number generator
pub type BoxRng = Box<dyn CryptoRngCore + Send>;
/// A macro to implement a repository trait for the [`MapErr`] wrapper and for
/// [`Box<R>`]
#[macro_export]
macro_rules! repository_impl {
($repo_trait:ident:
$(
async fn $method:ident (
&mut self
$(, $arg:ident: $arg_ty:ty )*
$(,)?
) -> Result<$ret_ty:ty, Self::Error>;
)*
) => {
#[::async_trait::async_trait]
impl<R: ?Sized> $repo_trait for ::std::boxed::Box<R>
where
R: $repo_trait,
{
type Error = <R as $repo_trait>::Error;
$(
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
(**self).$method ( $($arg),* ).await
}
)*
}
#[::async_trait::async_trait]
impl<R, F, E> $repo_trait for $crate::MapErr<R, F>
where
R: $repo_trait,
F: FnMut(<R as $repo_trait>::Error) -> E + ::std::marker::Send + ::std::marker::Sync,
{
type Error = E;
$(
async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> {
self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper)
}
)*
}
};
}

View File

@@ -3,6 +3,21 @@
Interactions with the database goes through `sqlx`. Interactions with the database goes through `sqlx`.
It provides async database operations with connection pooling, migrations support and compile-time check of queries through macros. It provides async database operations with connection pooling, migrations support and compile-time check of queries through macros.
## Writing database interactions
All database interactions are done through repositoriy traits. Each repository trait usually manages one type of data, defined in the `mas-data-model` crate.
Defining a new data type and associated repository looks like this:
- Define new structs in `mas-data-model` crate
- Define the repository trait in `mas-storage` crate
- Make that repository trait available via the `RepositoryAccess` trait in `mas-storage` crate
- Setup the database schema by writing a migration file in `mas-storage-pg` crate
- Implement the new repository trait in `mas-storage-pg` crate
- Write tests for the PostgreSQL implementation in `mas-storage-pg` crate
Some of those steps are documented in more details in the `mas-storage` and `mas-storage-pg` crates.
## Compile-time check of queries ## Compile-time check of queries
To be able to check queries, `sqlx` has to introspect the live database. To be able to check queries, `sqlx` has to introspect the live database.
@@ -14,7 +29,7 @@ Preparing this flat file is done through `sqlx-cli`, and should be done everytim
# Install the CLI # Install the CLI
cargo install sqlx-cli --no-default-features --features postgres cargo install sqlx-cli --no-default-features --features postgres
cd crates/storage/ # Must be in the mas-storage crate folder cd crates/storage-pg/ # Must be in the mas-storage-pg crate folder
export DATABASE_URL=postgresql:///matrix_auth export DATABASE_URL=postgresql:///matrix_auth
cargo sqlx prepare cargo sqlx prepare
``` ```
@@ -24,75 +39,10 @@ cargo sqlx prepare
Migration files live in the `migrations` folder in the `mas-core` crate. Migration files live in the `migrations` folder in the `mas-core` crate.
```sh ```sh
cd crates/storage/ # Again, in the mas-storage crate folder cd crates/storage-pg/ # Again, in the mas-storage-pg crate folder
export DATABASE_URL=postgresql:///matrix_auth export DATABASE_URL=postgresql:///matrix_auth
cargo sqlx migrate run # Run pending migrations cargo sqlx migrate run # Run pending migrations
cargo sqlx migrate revert # Revert the last migration cargo sqlx migrate add [description] # Add new migration files
cargo sqlx migrate add -r [description] # Add new migration files
``` ```
Note that migrations are embedded in the final binary and can be run from the service CLI tool. Note that migrations are embedded in the final binary and can be run from the service CLI tool.
## Writing database interactions
**TODO**: *This section is outdated.*
A typical interaction with the database look like this:
```rust
pub async fn lookup_session(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<SessionInfo> {
sqlx::query_as!(
SessionInfo, // Struct that will be filled with the result
r#"
SELECT
s.id,
u.id as user_id,
u.username,
s.active,
s.created_at,
a.created_at as "last_authd_at?"
FROM user_sessions s
INNER JOIN users u
ON s.user_id = u.id
LEFT JOIN user_session_authentications a
ON a.session_id = s.id
WHERE s.id = $1
ORDER BY a.created_at DESC
LIMIT 1
"#,
id, // Query parameter
)
.fetch_one(executor)
.await
// Providing some context when there is an error
.context("could not fetch session")
}
```
Note that we pass an `impl Executor` as parameter here.
This allows us to use this function from either a simple connection or from an active transaction.
The caveat here is that the `executor` can be used only once, so if an interaction needs to do multiple queries, it should probably take an `impl Acquire` to then acquire a transaction and do multiple interactions.
```rust
pub async fn login(
conn: impl Acquire<'_, Database = Postgres>,
username: &str,
password: String,
) -> Result<SessionInfo, LoginError> {
let mut txn = conn.begin().await.context("could not start transaction")?;
// First interaction
let user = lookup_user_by_username(&mut txn, username)?;
// Second interaction
let mut session = start_session(&mut txn, user).await?;
// Third interaction
session.last_authd_at =
Some(authenticate_session(&mut txn, session.id, password).await?);
// Commit the transaction once everything went fine
txn.commit().await.context("could not commit transaction")?;
Ok(session)
}
```