1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-23 11:02:35 +03:00

Split the storage trait from the implementation

This commit is contained in:
Quentin Gliech
2023-01-18 09:53:42 +01:00
parent b33a330b5f
commit 73a921cc30
95 changed files with 6294 additions and 5741 deletions

View File

@@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -29,150 +29,19 @@
)]
use chrono::{DateTime, Utc};
use pagination::InvalidPagination;
use sqlx::{migrate::Migrator, postgres::PgQueryResult};
use thiserror::Error;
use ulid::Ulid;
trait LookupResultExt {
type Output;
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
/// into [`None`]
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
}
impl<T> LookupResultExt for Result<T, sqlx::Error> {
type Output = T;
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error> {
match self {
Ok(v) => Ok(Some(v)),
Err(sqlx::Error::RowNotFound) => Ok(None),
Err(e) => Err(e),
}
}
}
/// Generic error when interacting with the database
#[derive(Debug, Error)]
#[error(transparent)]
pub enum DatabaseError {
/// An error which came from the database itself
Driver(#[from] sqlx::Error),
/// An error which occured while converting the data from the database
Inconsistency(#[from] DatabaseInconsistencyError),
/// An error which occured while generating the paginated query
Pagination(#[from] InvalidPagination),
/// An error which happened because the requested database operation is
/// invalid
#[error("Invalid database operation")]
InvalidOperation {
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
/// An error which happens when an operation affects not enough or too many
/// rows
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
RowsAffected { expected: u64, actual: u64 },
}
impl DatabaseError {
pub(crate) fn ensure_affected_rows(
result: &PgQueryResult,
expected: u64,
) -> Result<(), DatabaseError> {
let actual = result.rows_affected();
if actual == expected {
Ok(())
} else {
Err(DatabaseError::RowsAffected { expected, actual })
}
}
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
Self::InvalidOperation {
source: Some(Box::new(e)),
}
}
pub(crate) const fn invalid_operation() -> Self {
Self::InvalidOperation { source: None }
}
}
#[derive(Debug, Error)]
pub struct DatabaseInconsistencyError {
table: &'static str,
column: Option<&'static str>,
row: Option<Ulid>,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
}
impl std::fmt::Display for DatabaseInconsistencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Database inconsistency on table {}", self.table)?;
if let Some(column) = self.column {
write!(f, " column {column}")?;
}
if let Some(row) = self.row {
write!(f, " row {row}")?;
}
Ok(())
}
}
impl DatabaseInconsistencyError {
#[must_use]
pub(crate) const fn on(table: &'static str) -> Self {
Self {
table,
column: None,
row: None,
source: None,
}
}
#[must_use]
pub(crate) const fn column(mut self, column: &'static str) -> Self {
self.column = Some(column);
self
}
#[must_use]
pub(crate) const fn row(mut self, row: Ulid) -> Self {
self.row = Some(row);
self
}
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
mut self,
source: E,
) -> Self {
self.source = Some(Box::new(source));
self
}
}
#[derive(Debug, Clone, Default)]
pub struct Clock {
_private: (),
#[cfg(test)]
// #[cfg(test)]
mock: Option<std::sync::Arc<std::sync::atomic::AtomicI64>>,
}
impl Clock {
#[must_use]
pub fn now(&self) -> DateTime<Utc> {
#[cfg(test)]
// #[cfg(test)]
if let Some(timestamp) = &self.mock {
let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed);
return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap();
@@ -183,13 +52,14 @@ impl Clock {
Utc::now()
}
#[cfg(test)]
// #[cfg(test)]
#[must_use]
pub fn mock() -> Self {
use std::sync::{atomic::AtomicI64, Arc};
use chrono::TimeZone;
let datetime = Utc.with_ymd_and_hms(2022, 01, 16, 14, 40, 0).unwrap();
let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap();
let timestamp = datetime.timestamp();
Self {
@@ -198,7 +68,7 @@ impl Clock {
}
}
#[cfg(test)]
// #[cfg(test)]
pub fn advance(&self, duration: chrono::Duration) {
let timestamp = self
.mock
@@ -247,16 +117,12 @@ mod tests {
pub mod compat;
pub mod oauth2;
pub(crate) mod pagination;
pub mod pagination;
pub(crate) mod repository;
pub(crate) mod tracing;
pub mod upstream_oauth2;
pub mod user;
pub use self::{
pagination::Pagination,
repository::{PgRepository, Repository},
pagination::{Page, Pagination},
repository::Repository,
};
/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!();