1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Split the storage trait from the implementation

This commit is contained in:
Quentin Gliech
2023-01-18 09:53:42 +01:00
parent b33a330b5f
commit 73a921cc30
95 changed files with 6294 additions and 5741 deletions

21
Cargo.lock generated
View File

@ -2705,6 +2705,7 @@ dependencies = [
"mas-router",
"mas-spa",
"mas-storage",
"mas-storage-pg",
"mas-tasks",
"mas-templates",
"oauth2-types",
@ -2803,6 +2804,7 @@ dependencies = [
"chrono",
"mas-data-model",
"mas-storage",
"mas-storage-pg",
"oauth2-types",
"serde",
"sqlx",
@ -2843,6 +2845,7 @@ dependencies = [
"mas-policy",
"mas-router",
"mas-storage",
"mas-storage-pg",
"mas-templates",
"mime",
"oauth2-types",
@ -3103,6 +3106,23 @@ dependencies = [
"mas-jose",
"oauth2-types",
"rand 0.8.5",
"thiserror",
"ulid",
"url",
]
[[package]]
name = "mas-storage-pg"
version = "0.1.0"
dependencies = [
"async-trait",
"chrono",
"mas-data-model",
"mas-iana",
"mas-jose",
"mas-storage",
"oauth2-types",
"rand 0.8.5",
"rand_chacha 0.3.1",
"serde",
"serde_json",
@ -3121,6 +3141,7 @@ dependencies = [
"async-trait",
"futures-util",
"mas-storage",
"mas-storage-pg",
"sqlx",
"tokio",
"tokio-stream",

View File

@ -50,6 +50,7 @@ mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-spa = { path = "../spa" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
mas-tasks = { path = "../tasks" }
mas-templates = { path = "../templates" }
oauth2-types = { path = "../oauth2-types" }

View File

@ -15,7 +15,7 @@
use anyhow::Context;
use clap::Parser;
use mas_config::DatabaseConfig;
use mas_storage::MIGRATOR;
use mas_storage_pg::MIGRATOR;
use tracing::{info_span, Instrument};
use crate::util::database_from_config;

View File

@ -21,8 +21,9 @@ use mas_storage::{
oauth2::OAuth2ClientRepository,
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Clock, PgRepository, Repository,
Clock, Repository,
};
use mas_storage_pg::PgRepository;
use oauth2_types::scope::Scope;
use rand::SeedableRng;
use tracing::{info, info_span, warn};

View File

@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ use mas_config::RootConfig;
use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR;
use mas_storage_pg::MIGRATOR;
use mas_tasks::TaskQueue;
use tokio::signal::unix::SignalKind;
use tracing::{info, info_span, warn, Instrument};

View File

@ -11,8 +11,8 @@ thiserror = "1.0.38"
serde = "1.0.152"
url = { version = "2.3.1", features = ["serde"] }
crc = "3.0.0"
ulid = { version = "1.0.0", features = ["serde"] }
rand = "0.8.5"
ulid = "1.0.0"
rand_chacha = "0.3.1"
mas-iana = { path = "../iana" }

View File

@ -19,6 +19,7 @@ url = "2.3.1"
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
[[bin]]
name = "schema"

View File

@ -34,8 +34,9 @@ use mas_storage::{
oauth2::OAuth2ClientRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
user::{BrowserSessionRepository, UserEmailRepository},
Pagination, PgRepository, Repository,
Pagination, Repository,
};
use mas_storage_pg::PgRepository;
use model::CreationEvent;
use sqlx::PgPool;

View File

@ -15,9 +15,8 @@
use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{
compat::CompatSessionRepository, user::UserRepository, PgRepository, Repository,
};
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use url::Url;

View File

@ -14,9 +14,8 @@
use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use mas_storage::{
oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, PgRepository, Repository,
};
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository};
use mas_storage_pg::PgRepository;
use oauth2_types::scope::Scope;
use sqlx::PgPool;
use ulid::Ulid;

View File

@ -16,9 +16,9 @@ use anyhow::Context as _;
use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, PgRepository,
Repository,
upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository,
};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use super::{NodeType, User};

View File

@ -22,8 +22,9 @@ use mas_storage::{
oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionRepository, UserEmailRepository},
Pagination, PgRepository, Repository,
Pagination, Repository,
};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use super::{

View File

@ -68,6 +68,7 @@ mas-oidc-client = { path = "../oidc-client" }
mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
mas-templates = { path = "../templates" }
oauth2-types = { path = "../oauth2-types" }

View File

@ -22,8 +22,9 @@ use mas_storage::{
CompatSsoLoginRepository,
},
user::{UserPasswordRepository, UserRepository},
Clock, PgRepository, Repository,
Clock, Repository,
};
use mas_storage_pg::PgRepository;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use sqlx::PgPool;
@ -154,7 +155,7 @@ pub enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {

View File

@ -31,8 +31,9 @@ use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;

View File

@ -19,7 +19,8 @@ use axum::{
};
use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::{compat::CompatSsoLoginRepository, PgRepository, Repository};
use mas_storage::{compat::CompatSsoLoginRepository, Repository};
use mas_storage_pg::PgRepository;
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use serde_with::serde;
@ -49,7 +50,7 @@ pub enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {

View File

@ -18,8 +18,9 @@ use hyper::StatusCode;
use mas_data_model::TokenType;
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
Clock, PgRepository, Repository,
Clock, Repository,
};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use thiserror::Error;
@ -42,7 +43,7 @@ pub enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {

View File

@ -18,8 +18,9 @@ use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds};
use sqlx::PgPool;
@ -70,7 +71,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self {

View File

@ -28,7 +28,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema;
use mas_keystore::Encrypter;
use mas_storage::PgRepository;
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use tracing::{info_span, Instrument};

View File

@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -36,7 +36,7 @@ mod tests {
use super::*;
#[sqlx::test(migrator = "mas_storage::MIGRATOR")]
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> {
let state = crate::test_state(pool).await?;
let app = crate::healthcheck_router().with_state(state);

View File

@ -27,8 +27,9 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
use sqlx::PgPool;
@ -70,7 +71,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
@ -149,7 +150,7 @@ pub enum GrantCompletionError {
}
impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError);
impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError);

View File

@ -27,8 +27,9 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::Templates;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
@ -91,7 +92,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(self::callback::CallbackDestinationError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);

View File

@ -30,8 +30,9 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
use sqlx::PgPool;
use thiserror::Error;
@ -62,7 +63,7 @@ pub enum RouteError {
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);

View File

@ -25,8 +25,9 @@ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
user::{BrowserSessionRepository, UserRepository},
Clock, PgRepository, Repository,
Clock, Repository,
};
use mas_storage_pg::PgRepository;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
requests::{IntrospectionRequest, IntrospectionResponse},
@ -97,7 +98,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl From<TokenFormatError> for RouteError {
fn from(_e: TokenFormatError) -> Self {

View File

@ -19,7 +19,8 @@ use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, PgRepository, Repository};
use mas_storage::{oauth2::OAuth2ClientRepository, Repository};
use mas_storage_pg::PgRepository;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::{
@ -49,7 +50,7 @@ pub(crate) enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstanciateError);
impl_from_error_for_route!(mas_policy::EvaluationError);

View File

@ -37,8 +37,9 @@ use mas_storage::{
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
user::BrowserSessionRepository,
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
pkce::CodeChallengeError,
@ -151,7 +152,7 @@ impl IntoResponse for RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::claims::ClaimError);
impl_from_error_for_route!(mas_jose::claims::TokenHashError);

View File

@ -31,8 +31,9 @@ use mas_router::UrlBuilder;
use mas_storage::{
oauth2::OAuth2ClientRepository,
user::{BrowserSessionRepository, UserEmailRepository},
DatabaseError, PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use oauth2_types::scope;
use serde::Serialize;
use serde_with::skip_serializing_none;
@ -64,7 +65,9 @@ pub enum RouteError {
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("failed to authenticate")]
AuthorizationVerificationError(#[from] AuthorizationVerificationError<DatabaseError>),
AuthorizationVerificationError(
#[from] AuthorizationVerificationError<mas_storage_pg::DatabaseError>,
),
#[error("no suitable key found for signing")]
InvalidSigningKey,
@ -77,7 +80,7 @@ pub enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);

View File

@ -24,8 +24,9 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder;
use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
use thiserror::Error;
use ulid::Ulid;
@ -46,7 +47,7 @@ impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {

View File

@ -30,8 +30,9 @@ use mas_storage::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use oauth2_types::errors::ClientErrorCode;
use serde::Deserialize;
use sqlx::PgPool;
@ -99,7 +100,7 @@ pub(crate) enum RouteError {
Internal(Box<dyn std::error::Error>),
}
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);

View File

@ -27,8 +27,9 @@ use mas_keystore::Encrypter;
use mas_storage::{
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{
EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
@ -73,7 +74,7 @@ impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage_pg::DatabaseError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {

View File

@ -24,7 +24,8 @@ use mas_axum_utils::{
use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, PgRepository, Repository};
use mas_storage::{user::UserEmailRepository, Repository};
use mas_storage_pg::PgRepository;
use mas_templates::{EmailAddContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;

View File

@ -28,7 +28,8 @@ use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository};
use mas_storage::{user::UserEmailRepository, Clock, Repository};
use mas_storage_pg::PgRepository;
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, Rng};
use serde::Deserialize;

View File

@ -24,7 +24,8 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository};
use mas_storage::{user::UserEmailRepository, Clock, Repository};
use mas_storage_pg::PgRepository;
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;

View File

@ -25,8 +25,9 @@ use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{AccountContext, TemplateContext, Templates};
use sqlx::PgPool;

View File

@ -27,8 +27,9 @@ use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository},
Clock, PgRepository, Repository,
Clock, Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{EmptyContext, TemplateContext, Templates};
use rand::Rng;
use serde::Deserialize;

View File

@ -20,7 +20,7 @@ use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_keystore::Encrypter;
use mas_router::UrlBuilder;
use mas_storage::PgRepository;
use mas_storage_pg::PgRepository;
use mas_templates::{IndexContext, TemplateContext, Templates};
use sqlx::PgPool;

View File

@ -26,8 +26,9 @@ use mas_keystore::Encrypter;
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
Clock, PgRepository, Repository,
Clock, Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
};

View File

@ -23,7 +23,8 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route};
use mas_storage::{user::BrowserSessionRepository, Clock, PgRepository, Repository};
use mas_storage::{user::BrowserSessionRepository, Clock, Repository};
use mas_storage_pg::PgRepository;
use sqlx::PgPool;
pub(crate) async fn post(

View File

@ -26,8 +26,9 @@ use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{ReauthContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;

View File

@ -33,8 +33,9 @@ use mas_policy::PolicyFactory;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
PgRepository, Repository,
Repository,
};
use mas_storage_pg::PgRepository;
use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
TemplateContext, Templates, ToFormState,

View File

@ -0,0 +1,27 @@
[package]
name = "mas-storage-pg"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
async-trait = "0.1.60"
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] }
chrono = { version = "0.4.23", features = ["serde"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91"
thiserror = "1.0.38"
tracing = "0.1.37"
rand = "0.8.5"
rand_chacha = "0.3.1"
url = { version = "2.3.1", features = ["serde"] }
uuid = "1.2.2"
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
oauth2-types = { path = "../oauth2-types" }
mas-storage = { path = "../storage" }
mas-data-model = { path = "../data-model" }
mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.

View File

@ -1336,6 +1336,24 @@
},
"query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n "
},
"8a79c7c392dd930628caadec80c9b2645501475ab4feacbac59ca1bc52b16c3f": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Bool",
"Bool",
"Text",
"Jsonb",
"Text"
]
}
},
"query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n "
},
"8b7297c263336d70c2b647212b16f7ae39bc5cb1572e3a2dcfcd67f196a1fa39": {
"describe": {
"columns": [
@ -1821,24 +1839,6 @@
},
"query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n "
},
"c0b4996085f6f2127e1e8cfdf18b9029c22096fadfe6de59dce01c789791edb5": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Bool",
"Bool",
"Text",
"Jsonb",
"Text"
]
}
},
"query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n "
},
"c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523": {
"describe": {
"columns": [],

View File

@ -0,0 +1,216 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{CompatAccessToken, CompatSession};
use mas_storage::{compat::CompatAccessTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
pub struct PgCompatAccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatAccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_access_token.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.add",
skip_all,
fields(
db.statement,
compat_access_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatAccessToken {
id,
session_id: compat_session.id,
token,
created_at,
expires_at,
})
}
#[tracing::instrument(
name = "db.compat_access_token.expire",
skip_all,
fields(
db.statement,
%compat_access_token.id,
compat_session.id = %compat_access_token.session_id,
),
err,
)]
async fn expire(
&mut self,
clock: &Clock,
mut compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error> {
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
Uuid::from(compat_access_token.id),
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
compat_access_token.expires_at = Some(expires_at);
Ok(compat_access_token)
}
}

View File

@ -0,0 +1,322 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod access_token;
mod refresh_token;
mod session;
mod sso_login;
pub use self::{
access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::Device;
use mas_storage::{
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
},
user::UserRepository,
Clock, Repository,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_session_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let device_str = device.as_str().to_owned();
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
assert_eq!(session.user_id, user.id);
assert_eq!(session.device.as_str(), device_str);
assert!(session.is_valid());
assert!(!session.is_finished());
// Lookup the session and check it didn't change
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user_id, user.id);
assert_eq!(session_lookup.device.as_str(), device_str);
assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished());
// Finish the session
let session = repo.compat_session().finish(&clock, session).await.unwrap();
assert!(!session.is_valid());
assert!(session.is_finished());
// Reload the session and check again
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert!(!session_lookup.is_valid());
assert!(session_lookup.is_finished());
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_access_token_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
// Add an access token to that session
let token = repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, FIRST_TOKEN);
// Commit the txn and grab a new transaction, to test a conflict
repo.save().await.unwrap();
{
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Adding the same token a second time should conflict
assert!(repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.is_err());
repo.cancel().await.unwrap();
}
// Grab a new repo
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Looking up via ID works
let token_lookup = repo
.compat_access_token()
.lookup(token.id)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Looking up via the token value works
let token_lookup = repo
.compat_access_token()
.find_by_token(FIRST_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Token is currently valid
assert!(token.is_valid(clock.now()));
clock.advance(Duration::minutes(1));
// Token should have expired
assert!(!token.is_valid(clock.now()));
// Add a second access token, this time without expiration
let token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, SECOND_TOKEN);
// Token is currently valid
assert!(token.is_valid(clock.now()));
// Make it expire
repo.compat_access_token()
.expire(&clock, token)
.await
.unwrap();
// Reload it
let token = repo
.compat_access_token()
.find_by_token(SECOND_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
// Token is not valid anymore
assert!(!token.is_valid(clock.now()));
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_refresh_token_repository(pool: PgPool) {
const ACCESS_TOKEN: &str = "access_token";
const REFRESH_TOKEN: &str = "refresh_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
// Add an access token to that session
let access_token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
.await
.unwrap();
let refresh_token = repo
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
REFRESH_TOKEN.to_owned(),
)
.await
.unwrap();
assert_eq!(refresh_token.session_id, session.id);
assert_eq!(refresh_token.access_token_id, access_token.id);
assert_eq!(refresh_token.token, REFRESH_TOKEN);
assert!(refresh_token.is_valid());
assert!(!refresh_token.is_consumed());
// Look it up by ID and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.lookup(refresh_token.id)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Look it up by token and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Consume it
let refresh_token = repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());
assert!(refresh_token.is_consumed());
// Reload it and check again
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert!(!refresh_token_lookup.is_valid());
assert!(refresh_token_lookup.is_consumed());
// Consuming it again should not work
assert!(repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.is_err());
repo.save().await.unwrap();
}
}

View File

@ -0,0 +1,230 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
};
use mas_storage::{compat::CompatRefreshTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
pub struct PgCompatRefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatRefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
fn from(value: CompatRefreshTokenLookup) -> Self {
let state = match value.consumed_at {
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
None => CompatRefreshTokenState::Valid,
};
Self {
id: value.compat_refresh_token_id.into(),
state,
session_id: value.compat_session_id.into(),
token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.compat_access_token_id.into(),
}
}
}
#[async_trait]
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_refresh_token.lookup",
skip_all,
fields(
db.statement,
compat_refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.add",
skip_all,
fields(
db.statement,
compat_refresh_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
Uuid::from(compat_access_token.id),
token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: compat_session.id,
access_token_id: compat_access_token.id,
token,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_refresh_token.consume",
skip_all,
fields(
db.statement,
%compat_refresh_token.id,
compat_session.id = %compat_refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(compat_refresh_token.id),
consumed_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_refresh_token = compat_refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_refresh_token)
}
}

View File

@ -0,0 +1,195 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
use mas_storage::{compat::CompatSessionRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
pub struct PgCompatSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatSessionLookup {
compat_session_id: Uuid,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<CompatSessionLookup> for CompatSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
let id = value.compat_session_id.into();
let device = Device::try_from(value.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: value.user_id.into(),
device,
created_at: value.created_at,
};
Ok(session)
}
}
#[async_trait]
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_session.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
let res = sqlx::query_as!(
CompatSessionLookup,
r#"
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_session.add",
skip_all,
fields(
db.statement,
compat_session.id,
%user.id,
%user.username,
compat_session.device.id = device.as_str(),
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
device.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSession {
id,
state: CompatSessionState::default(),
user_id: user.id,
device,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_session.finish",
skip_all,
fields(
db.statement,
%compat_session.id,
user.id = %compat_session.user_id,
compat_session.device.id = compat_session.device.as_str(),
),
err,
)]
async fn finish(
&mut self,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
}

View File

@ -0,0 +1,342 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
};
pub struct PgCompatSsoLoginRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSsoLoginRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid,
login_token: String,
redirect_uri: String,
created_at: DateTime<Utc>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
let id = res.compat_sso_login_id.into();
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
DatabaseInconsistencyError::on("compat_sso_logins")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
};
Ok(CompatSsoLogin {
id,
login_token: res.login_token,
redirect_uri,
created_at: res.created_at,
state,
})
}
}
#[async_trait]
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_sso_login.lookup",
skip_all,
fields(
db.statement,
compat_sso_login.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE compat_sso_login_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
login_token: &str,
) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE login_token = $1
"#,
login_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.add",
skip_all,
fields(
db.statement,
compat_sso_login.id,
compat_sso_login.redirect_uri = %redirect_uri,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSsoLogin {
id,
login_token,
redirect_uri,
created_at,
state: CompatSsoLoginState::default(),
})
}
#[tracing::instrument(
name = "db.compat_sso_login.fulfill",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
%compat_session.id,
compat_session.device.id = compat_session.device.as_str(),
user.id = %compat_session.user_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error> {
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, compat_session)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
compat_session_id = $2,
fulfilled_at = $3
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
Uuid::from(compat_session.id),
fulfilled_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.exchange",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error> {
let exchanged_at = clock.now();
let compat_sso_login = compat_sso_login
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = $2
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
exchanged_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<CompatSsoLogin>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
, cl.login_token
, cl.redirect_uri
, cl.created_at
, cl.fulfilled_at
, cl.exchanged_at
, cl.compat_session_id
FROM compat_sso_logins cl
INNER JOIN compat_sessions ON compat_session_id
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", pagination);
let edges: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination
.process(edges)
.try_map(CompatSsoLogin::try_from)?;
Ok(page)
}
}

View File

@ -0,0 +1,170 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Interactions with the database
#![forbid(unsafe_code)]
#![deny(
clippy::all,
clippy::str_to_string,
clippy::future_not_send,
rustdoc::broken_intra_doc_links
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions
)]
use sqlx::{migrate::Migrator, postgres::PgQueryResult};
use thiserror::Error;
use ulid::Ulid;
trait LookupResultExt {
type Output;
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
/// into [`None`]
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
}
impl<T> LookupResultExt for Result<T, sqlx::Error> {
type Output = T;
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error> {
match self {
Ok(v) => Ok(Some(v)),
Err(sqlx::Error::RowNotFound) => Ok(None),
Err(e) => Err(e),
}
}
}
/// Generic error when interacting with the database
#[derive(Debug, Error)]
#[error(transparent)]
pub enum DatabaseError {
/// An error which came from the database itself
Driver(#[from] sqlx::Error),
/// An error which occured while converting the data from the database
Inconsistency(#[from] DatabaseInconsistencyError),
/// An error which happened because the requested database operation is
/// invalid
#[error("Invalid database operation")]
InvalidOperation {
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
/// An error which happens when an operation affects not enough or too many
/// rows
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
RowsAffected { expected: u64, actual: u64 },
}
impl DatabaseError {
pub(crate) fn ensure_affected_rows(
result: &PgQueryResult,
expected: u64,
) -> Result<(), DatabaseError> {
let actual = result.rows_affected();
if actual == expected {
Ok(())
} else {
Err(DatabaseError::RowsAffected { expected, actual })
}
}
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
Self::InvalidOperation {
source: Some(Box::new(e)),
}
}
pub(crate) const fn invalid_operation() -> Self {
Self::InvalidOperation { source: None }
}
}
#[derive(Debug, Error)]
pub struct DatabaseInconsistencyError {
table: &'static str,
column: Option<&'static str>,
row: Option<Ulid>,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
}
impl std::fmt::Display for DatabaseInconsistencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Database inconsistency on table {}", self.table)?;
if let Some(column) = self.column {
write!(f, " column {column}")?;
}
if let Some(row) = self.row {
write!(f, " row {row}")?;
}
Ok(())
}
}
impl DatabaseInconsistencyError {
#[must_use]
pub(crate) const fn on(table: &'static str) -> Self {
Self {
table,
column: None,
row: None,
source: None,
}
}
#[must_use]
pub(crate) const fn column(mut self, column: &'static str) -> Self {
self.column = Some(column);
self
}
#[must_use]
pub(crate) const fn row(mut self, row: Ulid) -> Self {
self.row = Some(row);
self
}
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
mut self,
source: E,
) -> Self {
self.source = Some(Box::new(source));
self
}
}
pub mod compat;
pub mod oauth2;
pub(crate) mod pagination;
pub(crate) mod repository;
pub(crate) mod tracing;
pub mod upstream_oauth2;
pub mod user;
pub use self::repository::PgRepository;
/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@ -0,0 +1,223 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, AccessTokenState, Session};
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
pub struct PgOAuth2AccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid,
oauth2_session_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
}
impl From<OAuth2AccessTokenLookup> for AccessToken {
fn from(value: OAuth2AccessTokenLookup) -> Self {
let state = match value.revoked_at {
None => AccessTokenState::Valid,
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
};
Self {
id: value.oauth2_access_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
access_token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
type Error = DatabaseError;
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
access_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, Self::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_access_tokens
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
&access_token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(AccessToken {
id,
state: AccessTokenState::default(),
access_token,
session_id: session.id,
created_at,
expires_at,
})
}
async fn revoke(
&mut self,
clock: &Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token.id),
revoked_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error> {
// Cleanup token which expired more than 15 minutes ago
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE expires_at < $1
"#,
threshold,
)
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
}

View File

@ -0,0 +1,510 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::num::NonZeroU32;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
};
use mas_iana::oauth::PkceCodeChallengeMethod;
use mas_storage::{oauth2::OAuth2AuthorizationGrantRepository, Clock};
use oauth2_types::{requests::ResponseMode, scope::Scope};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[allow(clippy::struct_excessive_bools)]
struct GrantLookup {
oauth2_authorization_grant_id: Uuid,
created_at: DateTime<Utc>,
cancelled_at: Option<DateTime<Utc>>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
scope: String,
state: Option<String>,
nonce: Option<String>,
redirect_uri: String,
response_mode: String,
max_age: Option<i32>,
response_type_code: bool,
response_type_id_token: bool,
authorization_code: Option<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
requires_consent: bool,
oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>,
}
impl TryFrom<GrantLookup> for AuthorizationGrant {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)]
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
let id = value.oauth2_authorization_grant_id.into();
let scope: Scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("scope")
.row(id)
.source(e)
})?;
let stage = match (
value.fulfilled_at,
value.exchanged_at,
value.cancelled_at,
value.oauth2_session_id,
) {
(None, None, None, None) => AuthorizationGrantStage::Pending,
(Some(fulfilled_at), None, None, Some(session_id)) => {
AuthorizationGrantStage::Fulfilled {
session_id: session_id.into(),
fulfilled_at,
}
}
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
AuthorizationGrantStage::Exchanged {
session_id: session_id.into(),
fulfilled_at,
exchanged_at,
}
}
(None, None, Some(cancelled_at), None) => {
AuthorizationGrantStage::Cancelled { cancelled_at }
}
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("stage")
.row(id),
);
}
};
let pkce = match (value.code_challenge, value.code_challenge_method) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce {
challenge_method: PkceCodeChallengeMethod::Plain,
challenge,
})
}
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
challenge_method: PkceCodeChallengeMethod::S256,
challenge,
}),
(None, None) => None,
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("code_challenge_method")
.row(id),
);
}
};
let code: Option<AuthorizationCode> =
match (value.response_type_code, value.authorization_code, pkce) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("authorization_code")
.row(id),
);
}
};
let redirect_uri = value.redirect_uri.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let response_mode = value.response_mode.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("response_mode")
.row(id)
.source(e)
})?;
let max_age = value
.max_age
.map(u32::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?
.map(NonZeroU32::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?;
Ok(AuthorizationGrant {
id,
stage,
client_id: value.oauth2_client_id.into(),
code,
scope,
state: value.state,
nonce: value.nonce,
max_age,
response_mode,
redirect_uri,
created_at: value.created_at,
response_type_id_token: value.response_type_id_token,
requires_consent: value.requires_consent,
})
}
}
#[async_trait]
impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_authorization_grant.add",
skip_all,
fields(
db.statement,
grant.id,
grant.scope = %scope,
%client.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, Self::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("grant.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_authorization_grants (
oauth2_authorization_grant_id,
oauth2_client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_mode,
code_challenge,
code_challenge_method,
response_type_code,
response_type_id_token,
authorization_code,
requires_consent,
created_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#,
Uuid::from(id),
Uuid::from(client.id),
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
max_age_i32,
response_mode.to_string(),
code_challenge,
code_challenge_method,
code.is_some(),
response_type_id_token,
code_str,
requires_consent,
created_at,
)
.execute(&mut *self.conn)
.await?;
Ok(AuthorizationGrant {
id,
stage: AuthorizationGrantStage::Pending,
code,
redirect_uri,
client_id: client.id,
scope,
state,
nonce,
max_age,
response_mode,
created_at,
response_type_id_token,
requires_consent,
})
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.lookup",
skip_all,
fields(
db.statement,
grant.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.find_by_code",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_code(
&mut self,
code: &str,
) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE authorization_code = $1
"#,
code,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.fulfill",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
%session.id,
user_session.id = %session.user_session_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
session: &Session,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let fulfilled_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET fulfilled_at = $2
, oauth2_session_id = $3
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
fulfilled_at,
Uuid::from(session.id),
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
// XXX: check affected rows & new methods
let grant = grant
.fulfill(fulfilled_at, session)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.exchange",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let exchanged_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
exchanged_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let grant = grant
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.give_consent",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn give_consent(
&mut self,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
requires_consent = 'f'
WHERE
og.oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
)
.execute(&mut *self.conn)
.await?;
grant.requires_consent = false;
Ok(grant)
}
}

View File

@ -0,0 +1,745 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::{BTreeMap, BTreeSet},
str::FromStr,
string::ToString,
};
use async_trait::async_trait;
use mas_data_model::{Client, JwksOrJwksUri, User};
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_jose::jwk::PublicJsonWebKeySet;
use mas_storage::{oauth2::OAuth2ClientRepository, Clock};
use oauth2_types::{
requests::GrantType,
scope::{Scope, ScopeToken},
};
use rand::{Rng, RngCore};
use sqlx::PgConnection;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
pub struct PgOAuth2ClientRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2ClientRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
// XXX: response_types & contacts
#[derive(Debug)]
struct OAuth2ClientLookup {
oauth2_client_id: Uuid,
encrypted_client_secret: Option<String>,
redirect_uris: Vec<String>,
// response_types: Vec<String>,
grant_type_authorization_code: bool,
grant_type_refresh_token: bool,
// contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<String>,
client_uri: Option<String>,
policy_uri: Option<String>,
tos_uri: Option<String>,
jwks_uri: Option<String>,
jwks: Option<serde_json::Value>,
id_token_signed_response_alg: Option<String>,
userinfo_signed_response_alg: Option<String>,
token_endpoint_auth_method: Option<String>,
token_endpoint_auth_signing_alg: Option<String>,
initiate_login_uri: Option<String>,
}
impl TryInto<Client> for OAuth2ClientLookup {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
fn try_into(self) -> Result<Client, Self::Error> {
let id = Ulid::from(self.oauth2_client_id);
let redirect_uris: Result<Vec<Url>, _> =
self.redirect_uris.iter().map(|s| s.parse()).collect();
let redirect_uris = redirect_uris.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("redirect_uris")
.row(id)
.source(e)
})?;
let response_types = vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
];
/* XXX
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
self.response_types.iter().map(|s| s.parse()).collect();
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
field: "response_types",
source,
})?;
*/
let mut grant_types = Vec::new();
if self.grant_type_authorization_code {
grant_types.push(GrantType::AuthorizationCode);
}
if self.grant_type_refresh_token {
grant_types.push(GrantType::RefreshToken);
}
let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("logo_uri")
.row(id)
.source(e)
})?;
let client_uri = self
.client_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("client_uri")
.row(id)
.source(e)
})?;
let policy_uri = self
.policy_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("policy_uri")
.row(id)
.source(e)
})?;
let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("tos_uri")
.row(id)
.source(e)
})?;
let id_token_signed_response_alg = self
.id_token_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("id_token_signed_response_alg")
.row(id)
.source(e)
})?;
let userinfo_signed_response_alg = self
.userinfo_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("userinfo_signed_response_alg")
.row(id)
.source(e)
})?;
let token_endpoint_auth_method = self
.token_endpoint_auth_method
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
})?;
let token_endpoint_auth_signing_alg = self
.token_endpoint_auth_signing_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("token_endpoint_auth_signing_alg")
.row(id)
.source(e)
})?;
let initiate_login_uri = self
.initiate_login_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("initiate_login_uri")
.row(id)
.source(e)
})?;
let jwks = match (self.jwks, self.jwks_uri) {
(None, None) => None,
(Some(jwks), None) => {
let jwks = serde_json::from_value(jwks).map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks")
.row(id)
.source(e)
})?;
Some(JwksOrJwksUri::Jwks(jwks))
}
(None, Some(jwks_uri)) => {
let jwks_uri = jwks_uri.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks_uri")
.row(id)
.source(e)
})?;
Some(JwksOrJwksUri::JwksUri(jwks_uri))
}
_ => {
return Err(DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks(_uri)")
.row(id))
}
};
Ok(Client {
id,
client_id: id.to_string(),
encrypted_client_secret: self.encrypted_client_secret,
redirect_uris,
response_types,
grant_types,
// contacts: self.contacts,
contacts: vec![],
client_name: self.client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
userinfo_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
}
#[async_trait]
impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_client.lookup",
skip_all,
fields(
db.statement,
oauth2_client.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT oauth2_client_id
, encrypted_client_secret
, ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!"
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
FROM oauth2_clients c
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_client.load_batch",
skip_all,
fields(
db.statement,
),
err,
)]
async fn load_batch(
&mut self,
ids: BTreeSet<Ulid>,
) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT oauth2_client_id
, encrypted_client_secret
, ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!"
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
FROM oauth2_clients c
WHERE oauth2_client_id = ANY($1::uuid[])
"#,
&ids,
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
res.into_iter()
.map(|r| {
r.try_into()
.map(|c: Client| (c.id, c))
.map_err(DatabaseError::from)
})
.collect()
}
#[tracing::instrument(
name = "db.oauth2_client.add",
skip_all,
fields(
db.statement,
client.id,
client.name = client_name
),
err,
)]
#[allow(clippy::too_many_lines)]
async fn add(
&mut self,
mut rng: &mut (dyn RngCore + Send),
clock: &Clock,
redirect_uris: Vec<Url>,
encrypted_client_secret: Option<String>,
grant_types: Vec<GrantType>,
contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<Url>,
client_uri: Option<Url>,
policy_uri: Option<Url>,
tos_uri: Option<Url>,
jwks_uri: Option<Url>,
jwks: Option<PublicJsonWebKeySet>,
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
initiate_login_uri: Option<Url>,
) -> Result<Client, Self::Error> {
let now = clock.now();
let id = Ulid::from_datetime_with_source(now.into(), rng);
tracing::Span::current().record("client.id", tracing::field::display(id));
let jwks_json = jwks
.as_ref()
.map(serde_json::to_value)
.transpose()
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"
INSERT INTO oauth2_clients
( oauth2_client_id
, encrypted_client_secret
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
"#,
Uuid::from(id),
encrypted_client_secret,
grant_types.contains(&GrantType::AuthorizationCode),
grant_types.contains(&GrantType::RefreshToken),
client_name,
logo_uri.as_ref().map(Url::as_str),
client_uri.as_ref().map(Url::as_str),
policy_uri.as_ref().map(Url::as_str),
tos_uri.as_ref().map(Url::as_str),
jwks_uri.as_ref().map(Url::as_str),
jwks_json,
id_token_signed_response_alg
.as_ref()
.map(ToString::to_string),
userinfo_signed_response_alg
.as_ref()
.map(ToString::to_string),
token_endpoint_auth_method.as_ref().map(ToString::to_string),
token_endpoint_auth_signing_alg
.as_ref()
.map(ToString::to_string),
initiate_login_uri.as_ref().map(Url::as_str),
)
.traced()
.execute(&mut *self.conn)
.await?;
{
let span = info_span!(
"db.oauth2_client.add.redirect_uris",
db.statement = tracing::field::Empty,
client.id = %id,
);
let (uri_ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
.iter()
.map(|uri| {
(
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
uri.as_str().to_owned(),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
SELECT id, $2, redirect_uri
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
"#,
&uri_ids,
Uuid::from(id),
&redirect_uris,
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
let jwks = match (jwks, jwks_uri) {
(None, None) => None,
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
_ => return Err(DatabaseError::invalid_operation()),
};
Ok(Client {
id,
client_id: id.to_string(),
encrypted_client_secret,
redirect_uris,
response_types: vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
],
grant_types,
contacts,
client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
userinfo_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
#[tracing::instrument(
name = "db.oauth2_client.add_from_config",
skip_all,
fields(
db.statement,
client.id = %client_id,
),
err,
)]
async fn add_from_config(
&mut self,
mut rng: impl Rng + Send,
clock: &Clock,
client_id: Ulid,
client_auth_method: OAuthClientAuthenticationMethod,
encrypted_client_secret: Option<String>,
jwks: Option<PublicJsonWebKeySet>,
jwks_uri: Option<Url>,
redirect_uris: Vec<Url>,
) -> Result<Client, Self::Error> {
let jwks_json = jwks
.as_ref()
.map(serde_json::to_value)
.transpose()
.map_err(DatabaseError::to_invalid_operation)?;
let client_auth_method = client_auth_method.to_string();
sqlx::query!(
r#"
INSERT INTO oauth2_clients
( oauth2_client_id
, encrypted_client_secret
, grant_type_authorization_code
, grant_type_refresh_token
, token_endpoint_auth_method
, jwks
, jwks_uri
)
VALUES
($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (oauth2_client_id)
DO
UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
, grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
, grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
, jwks = EXCLUDED.jwks
, jwks_uri = EXCLUDED.jwks_uri
"#,
Uuid::from(client_id),
encrypted_client_secret,
true,
true,
client_auth_method,
jwks_json,
jwks_uri.as_ref().map(Url::as_str),
)
.traced()
.execute(&mut *self.conn)
.await?;
{
let span = info_span!(
"db.oauth2_client.add_from_config.redirect_uris",
client.id = %client_id,
db.statement = tracing::field::Empty,
);
let now = clock.now();
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
.iter()
.map(|uri| {
(
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
uri.as_str().to_owned(),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
SELECT id, $2, redirect_uri
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
"#,
&ids,
Uuid::from(client_id),
&redirect_uris,
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
let jwks = match (jwks, jwks_uri) {
(None, None) => None,
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
_ => return Err(DatabaseError::invalid_operation()),
};
Ok(Client {
id: client_id,
client_id: client_id.to_string(),
encrypted_client_secret,
redirect_uris,
response_types: vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
],
grant_types: Vec::new(),
contacts: Vec::new(),
client_name: None,
logo_uri: None,
client_uri: None,
policy_uri: None,
tos_uri: None,
jwks,
id_token_signed_response_alg: None,
userinfo_signed_response_alg: None,
token_endpoint_auth_method: None,
token_endpoint_auth_signing_alg: None,
initiate_login_uri: None,
})
}
#[tracing::instrument(
name = "db.oauth2_client.get_consent_for_user",
skip_all,
fields(
db.statement,
%user.id,
%client.id,
),
err,
)]
async fn get_consent_for_user(
&mut self,
client: &Client,
user: &User,
) -> Result<Scope, Self::Error> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
Uuid::from(user.id),
Uuid::from(client.id),
)
.fetch_all(&mut *self.conn)
.await?;
let scope: Result<Scope, _> = scope_tokens
.into_iter()
.map(|s| ScopeToken::from_str(&s))
.collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
skip_all,
fields(
db.statement,
%user.id,
%client.id,
%scope,
),
err,
)]
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
user: &User,
scope: &Scope,
) -> Result<(), Self::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
&ids,
Uuid::from(user.id),
Uuid::from(client.id),
&tokens,
now,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
}

View File

@ -0,0 +1,25 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod access_token;
pub mod authorization_grant;
mod client;
mod refresh_token;
mod session;
pub use self::{
access_token::PgOAuth2AccessTokenRepository,
authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository,
refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository,
};

View File

@ -0,0 +1,224 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use mas_storage::{oauth2::OAuth2RefreshTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
pub struct PgOAuth2RefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2RefreshTokenLookup {
oauth2_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
oauth2_access_token_id: Option<Uuid>,
oauth2_session_id: Uuid,
}
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
fn from(value: OAuth2RefreshTokenLookup) -> Self {
let state = match value.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
RefreshToken {
id: value.oauth2_refresh_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
refresh_token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
}
}
}
#[async_trait]
impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_refresh_token.lookup",
skip_all,
fields(
db.statement,
refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
refresh_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
refresh_token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
})
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.consume",
skip_all,
fields(
db.statement,
%refresh_token.id,
session.id = %refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
}
}

View File

@ -0,0 +1,248 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User};
use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
};
pub struct PgOAuth2SessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2SessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct OAuthSessionLookup {
oauth2_session_id: Uuid,
user_session_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
#[allow(dead_code)]
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<OAuthSessionLookup> for Session {
type Error = DatabaseInconsistencyError;
fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
let id = Ulid::from(value.oauth2_session_id);
let scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => SessionState::Valid,
Some(finished_at) => SessionState::Finished { finished_at },
};
Ok(Session {
id,
state,
created_at: value.created_at,
client_id: value.oauth2_client_id.into(),
user_session_id: value.user_session_id.into(),
scope,
})
}
}
#[async_trait]
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_session.lookup",
skip_all,
fields(
db.statement,
session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
let res = sqlx::query_as!(
OAuthSessionLookup,
r#"
SELECT oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
, finished_at
FROM oauth2_sessions
WHERE oauth2_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(session) = res else { return Ok(None) };
Ok(Some(session.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_session.create_from_grant",
skip_all,
fields(
db.statement,
%user_session.id,
user.id = %user_session.user.id,
%grant.id,
client.id = %grant.client_id,
session.id,
session.scope = %grant.scope,
),
err,
)]
async fn create_from_grant(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
grant: &AuthorizationGrant,
user_session: &BrowserSession,
) -> Result<Session, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_sessions
( oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
Uuid::from(grant.client_id),
grant.scope.to_string(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(Session {
id,
state: SessionState::Valid,
created_at,
user_session_id: user_session.id,
client_id: grant.client_id,
scope: grant.scope.clone(),
})
}
#[tracing::instrument(
name = "db.oauth2_session.finish",
skip_all,
fields(
db.statement,
%session.id,
%session.scope,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
),
err,
)]
async fn finish(&mut self, clock: &Clock, session: Session) -> Result<Session, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_sessions
SET finished_at = $2
WHERE oauth2_session_id = $1
"#,
Uuid::from(session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.oauth2_session.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err,
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<Session>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
, finished_at
FROM oauth2_sessions os
"#,
);
query
.push(" WHERE us.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("oauth2_session_id", pagination);
let edges: Vec<OAuthSessionLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).try_map(Session::try_from)?;
Ok(page)
}
}

View File

@ -0,0 +1,78 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Utilities to manage paginated queries.
use mas_storage::{pagination::PaginationDirection, Pagination};
use sqlx::{Database, QueryBuilder};
use uuid::Uuid;
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
/// to a query
pub trait QueryBuilderExt {
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self;
}
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
where
DB: Database,
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
// clause
if let Some(after) = pagination.after {
self.push(" AND ")
.push(id_field)
.push(" > ")
.push_bind(Uuid::from(after));
}
// 3. If the before argument is provided, add `id < parsed_cursor` to the
// `WHERE` clause
if let Some(before) = pagination.before {
self.push(" AND ")
.push(id_field)
.push(" < ")
.push_bind(Uuid::from(before));
}
match pagination.direction {
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
// query
PaginationDirection::Forward => {
self.push(" ORDER BY ")
.push(id_field)
.push(" ASC LIMIT ")
.push_bind((pagination.count + 1) as i64);
}
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
// query
PaginationDirection::Backward => {
self.push(" ORDER BY ")
.push(id_field)
.push(" DESC LIMIT ")
.push_bind((pagination.count + 1) as i64);
}
};
self
}
}

View File

@ -0,0 +1,142 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use mas_storage::Repository;
use sqlx::{PgPool, Postgres, Transaction};
use crate::{
compat::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
PgCompatSsoLoginRepository,
},
oauth2::{
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
},
upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
PgUpstreamOAuthSessionRepository,
},
user::{
PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
PgUserRepository,
},
DatabaseError,
};
pub struct PgRepository {
txn: Transaction<'static, Postgres>,
}
impl PgRepository {
pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
let txn = pool.begin().await?;
Ok(PgRepository { txn })
}
pub async fn save(self) -> Result<(), DatabaseError> {
self.txn.commit().await?;
Ok(())
}
pub async fn cancel(self) -> Result<(), DatabaseError> {
self.txn.rollback().await?;
Ok(())
}
}
impl Repository for PgRepository {
type Error = DatabaseError;
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
type UserRepository<'c> = PgUserRepository<'c> where Self: 'c;
type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c;
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(&mut self.txn)
}
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(&mut self.txn)
}
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(&mut self.txn)
}
fn user(&mut self) -> Self::UserRepository<'_> {
PgUserRepository::new(&mut self.txn)
}
fn user_email(&mut self) -> Self::UserEmailRepository<'_> {
PgUserEmailRepository::new(&mut self.txn)
}
fn user_password(&mut self) -> Self::UserPasswordRepository<'_> {
PgUserPasswordRepository::new(&mut self.txn)
}
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
PgBrowserSessionRepository::new(&mut self.txn)
}
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
PgOAuth2ClientRepository::new(&mut self.txn)
}
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(&mut self.txn)
}
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(&mut self.txn)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(&mut self.txn)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(&mut self.txn)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(&mut self.txn)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(&mut self.txn)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(&mut self.txn)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.

View File

@ -0,0 +1,262 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
use mas_storage::{upstream_oauth2::UpstreamOAuthLinkRepository, Clock, Page, Pagination};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, LookupResultExt};
pub struct PgUpstreamOAuthLinkRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthLinkRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct LinkLookup {
upstream_oauth_link_id: Uuid,
upstream_oauth_provider_id: Uuid,
user_id: Option<Uuid>,
subject: String,
created_at: DateTime<Utc>,
}
impl From<LinkLookup> for UpstreamOAuthLink {
fn from(value: LinkLookup) -> Self {
UpstreamOAuthLink {
id: Ulid::from(value.upstream_oauth_link_id),
provider_id: Ulid::from(value.upstream_oauth_provider_id),
user_id: value.user_id.map(Ulid::from),
subject: value.subject,
created_at: value.created_at,
}
}
}
#[async_trait]
impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_link.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_link.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_link_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.find_by_subject",
skip_all,
fields(
db.statement,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn find_by_subject(
&mut self,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: &str,
) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
AND subject = $2
"#,
Uuid::from(upstream_oauth_provider.id),
subject,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.add",
skip_all,
fields(
db.statement,
upstream_oauth_link.id,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
) -> Result<UpstreamOAuthLink, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_links (
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
) VALUES ($1, $2, NULL, $3, $4)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&subject,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthLink {
id,
provider_id: upstream_oauth_provider.id,
user_id: None,
subject,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_link.associate_to_user",
skip_all,
fields(
db.statement,
%upstream_oauth_link.id,
%upstream_oauth_link.subject,
%user.id,
%user.username,
),
err,
)]
async fn associate_to_user(
&mut self,
upstream_oauth_link: &UpstreamOAuthLink,
user: &User,
) -> Result<(), Self::Error> {
sqlx::query!(
r#"
UPDATE upstream_oauth_links
SET user_id = $1
WHERE upstream_oauth_link_id = $2
"#,
Uuid::from(user.id),
Uuid::from(upstream_oauth_link.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.upstream_oauth_link.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("upstream_oauth_link_id", pagination);
let edges: Vec<LinkLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
Ok(page)
}
}

View File

@ -0,0 +1,271 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod link;
mod provider;
mod session;
pub use self::{
link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
session::PgUpstreamOAuthSessionRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_storage::{
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
user::UserRepository,
Clock, Pagination, Repository,
};
use oauth2_types::scope::{Scope, OPENID};
use rand::SeedableRng;
use sqlx::PgPool;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repository(pool: PgPool) {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// The provider list should be empty at the start
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
assert!(all_providers.is_empty());
// Let's add a provider
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
"https://example.com/".to_owned(),
Scope::from_iter([OPENID]),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
"client-id".to_owned(),
None,
)
.await
.unwrap();
// Look it up in the database
let provider = repo
.upstream_oauth_provider()
.lookup(provider.id)
.await
.unwrap()
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.client_id, "client-id");
// Start a session
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
&clock,
&provider,
"some-state".to_owned(),
None,
"some-nonce".to_owned(),
)
.await
.unwrap();
// Look it up in the database
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert_eq!(session.provider_id, provider.id);
assert_eq!(session.link_id(), None);
assert!(session.is_pending());
assert!(!session.is_completed());
assert!(!session.is_consumed());
// Create a link
let link = repo
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.await
.unwrap();
// We can look it up by its ID
repo.upstream_oauth_link()
.lookup(link.id)
.await
.unwrap()
.expect("link to be found in database");
// or by its subject
let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "a-subject")
.await
.unwrap()
.expect("link to be found in database");
assert_eq!(link.subject, "a-subject");
assert_eq!(link.provider_id, provider.id);
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.await
.unwrap();
// Reload the session
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert!(session.is_completed());
assert!(!session.is_consumed());
assert_eq!(session.link_id(), Some(link.id));
let session = repo
.upstream_oauth_session()
.consume(&clock, session)
.await
.unwrap();
// Reload the session
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert!(session.is_consumed());
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await
.unwrap();
let links = repo
.upstream_oauth_link()
.list_paginated(&user, Pagination::first(10))
.await
.unwrap();
assert!(!links.has_previous_page);
assert!(!links.has_next_page);
assert_eq!(links.edges.len(), 1);
assert_eq!(links.edges[0].id, link.id);
assert_eq!(links.edges[0].user_id, Some(user.id));
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_provider_repository_pagination(pool: PgPool) {
const ISSUER: &str = "https://example.com/";
let scope = Scope::from_iter([OPENID]);
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut ids = Vec::with_capacity(20);
// Create 20 providers
for idx in 0..20 {
let client_id = format!("client-{idx}");
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
ISSUER.to_owned(),
scope.clone(),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
client_id,
None,
)
.await
.unwrap();
ids.push(provider.id);
clock.advance(Duration::seconds(10));
}
// Lookup the first 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10))
.await
.unwrap();
// It returned the first 10 items
assert!(page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[..10]);
// Lookup the next 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10).after(ids[9]))
.await
.unwrap();
// It returned the next 10 items
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[10..]);
// Lookup the last 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::last(10))
.await
.unwrap();
// It returned the last 10 items
assert!(page.has_previous_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[10..]);
// Lookup the previous 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::last(10).before(ids[10]))
.await
.unwrap();
// It returned the previous 10 items
assert!(!page.has_previous_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[..10]);
// Lookup 10 items between two IDs
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10).after(ids[5]).before(ids[8]))
.await
.unwrap();
// It returned the items in between
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[6..8]);
}
}

View File

@ -0,0 +1,273 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::UpstreamOAuthProvider;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination};
use oauth2_types::scope::Scope;
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
};
pub struct PgUpstreamOAuthProviderRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthProviderRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct ProviderLookup {
upstream_oauth_provider_id: Uuid,
issuer: String,
scope: String,
client_id: String,
encrypted_client_secret: Option<String>,
token_endpoint_signing_alg: Option<String>,
token_endpoint_auth_method: String,
created_at: DateTime<Utc>,
}
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
type Error = DatabaseInconsistencyError;
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_provider_id.into();
let scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("scope")
.row(id)
.source(e)
})?;
let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
})?;
let token_endpoint_signing_alg = value
.token_endpoint_signing_alg
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_signing_alg")
.row(id)
.source(e)
})?;
Ok(UpstreamOAuthProvider {
id,
issuer: value.issuer,
scope,
client_id: value.client_id,
encrypted_client_secret: value.encrypted_client_secret,
token_endpoint_auth_method,
token_endpoint_signing_alg,
created_at: value.created_at,
})
}
}
#[async_trait]
impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_provider.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let res = res
.map(UpstreamOAuthProvider::try_from)
.transpose()
.map_err(DatabaseError::from)?;
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.add",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_providers (
upstream_oauth_provider_id,
issuer,
scope,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id,
encrypted_client_secret,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#,
Uuid::from(id),
&issuer,
scope.to_string(),
token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
&client_id,
encrypted_client_secret.as_deref(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.list_paginated",
skip_all,
fields(
db.statement,
),
err,
)]
async fn list_paginated(
&mut self,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE 1 = 1
"#,
);
query.generate_pagination("upstream_oauth_provider_id", pagination);
let edges: Vec<ProviderLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).try_map(TryInto::try_into)?;
Ok(page)
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.all",
skip_all,
fields(
db.statement,
),
err,
)]
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
"#,
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
Ok(res?)
}
}

View File

@ -0,0 +1,286 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
UpstreamOAuthProvider,
};
use mas_storage::{upstream_oauth2::UpstreamOAuthSessionRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
pub struct PgUpstreamOAuthSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct SessionLookup {
upstream_oauth_authorization_session_id: Uuid,
upstream_oauth_provider_id: Uuid,
upstream_oauth_link_id: Option<Uuid>,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
id_token: Option<String>,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
consumed_at: Option<DateTime<Utc>>,
}
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_authorization_session_id.into();
let state = match (
value.upstream_oauth_link_id,
value.id_token,
value.completed_at,
value.consumed_at,
) {
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(Some(link_id), id_token, Some(completed_at), None) => {
UpstreamOAuthAuthorizationSessionState::Completed {
completed_at,
link_id: link_id.into(),
id_token,
}
}
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
UpstreamOAuthAuthorizationSessionState::Consumed {
completed_at,
link_id: link_id.into(),
id_token,
consumed_at,
}
}
_ => {
return Err(
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
)
}
};
Ok(Self {
id,
provider_id: value.upstream_oauth_provider_id.into(),
state_str: value.state,
nonce: value.nonce,
code_challenge_verifier: value.code_challenge_verifier,
created_at: value.created_at,
state,
})
}
}
#[async_trait]
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
),
err,
)]
async fn lookup(
&mut self,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
id_token,
created_at,
completed_at,
consumed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.add",
skip_all,
fields(
db.statement,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state_str: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"upstream_oauth_authorization_session.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO upstream_oauth_authorization_sessions (
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state_str,
code_challenge_verifier.as_deref(),
nonce,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthAuthorizationSession {
id,
state: UpstreamOAuthAuthorizationSessionState::default(),
provider_id: upstream_oauth_provider.id,
state_str,
code_challenge_verifier,
nonce,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.complete_with_link",
skip_all,
fields(
db.statement,
%upstream_oauth_authorization_session.id,
%upstream_oauth_link.id,
),
err,
)]
async fn complete_with_link(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.complete(completed_at, upstream_oauth_link, id_token)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}
/// Mark a session as consumed
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.consume",
skip_all,
fields(
db.statement,
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET consumed_at = $1
WHERE upstream_oauth_authorization_session_id = $2
"#,
consumed_at,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}
}

View File

@ -0,0 +1,554 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState};
use mas_storage::{user::UserEmailRepository, Clock, Page, Pagination};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
};
pub struct PgUserEmailRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUserEmailRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(Debug, Clone, sqlx::FromRow)]
struct UserEmailLookup {
user_email_id: Uuid,
user_id: Uuid,
email: String,
created_at: DateTime<Utc>,
confirmed_at: Option<DateTime<Utc>>,
}
impl From<UserEmailLookup> for UserEmail {
fn from(e: UserEmailLookup) -> UserEmail {
UserEmail {
id: e.user_email_id.into(),
user_id: e.user_id.into(),
email: e.email,
created_at: e.created_at,
confirmed_at: e.confirmed_at,
}
}
}
struct UserEmailConfirmationCodeLookup {
user_email_confirmation_code_id: Uuid,
user_email_id: Uuid,
code: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
}
impl UserEmailConfirmationCodeLookup {
fn into_verification(self, clock: &Clock) -> UserEmailVerification {
let now = clock.now();
let state = if let Some(when) = self.consumed_at {
UserEmailVerificationState::AlreadyUsed { when }
} else if self.expires_at < now {
UserEmailVerificationState::Expired {
when: self.expires_at,
}
} else {
UserEmailVerificationState::Valid
};
UserEmailVerification {
id: self.user_email_confirmation_code_id.into(),
user_email_id: self.user_email_id.into(),
code: self.code,
state,
created_at: self.created_at,
}
}
}
#[async_trait]
impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.user_email.lookup",
skip_all,
fields(
db.statement,
user_email.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
WHERE user_email_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(user_email) = res else { return Ok(None) };
Ok(Some(user_email.into()))
}
#[tracing::instrument(
name = "db.user_email.find",
skip_all,
fields(
db.statement,
%user.id,
user_email.email = email,
),
err,
)]
async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
WHERE user_id = $1 AND email = $2
"#,
Uuid::from(user.id),
email,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(user_email) = res else { return Ok(None) };
Ok(Some(user_email.into()))
}
#[tracing::instrument(
name = "db.user_email.get_primary",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn get_primary(&mut self, user: &User) -> Result<Option<UserEmail>, Self::Error> {
let Some(id) = user.primary_user_email_id else { return Ok(None) };
let user_email = self.lookup(id).await?.ok_or_else(|| {
DatabaseInconsistencyError::on("users")
.column("primary_user_email_id")
.row(user.id)
})?;
Ok(Some(user_email))
}
#[tracing::instrument(
name = "db.user_email.all",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
WHERE user_id = $1
ORDER BY email ASC
"#,
Uuid::from(user.id),
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
Ok(res.into_iter().map(Into::into).collect())
}
#[tracing::instrument(
name = "db.user_email.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<UserEmail>, DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("ue.user_email_id", pagination);
let edges: Vec<UserEmailLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UserEmail::from);
Ok(page)
}
#[tracing::instrument(
name = "db.user_email.count",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn count(&mut self, user: &User) -> Result<usize, Self::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*)
FROM user_emails
WHERE user_id = $1
"#,
Uuid::from(user.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
let res = res.unwrap_or_default();
Ok(res
.try_into()
.map_err(DatabaseError::to_invalid_operation)?)
}
#[tracing::instrument(
name = "db.user_email.add",
skip_all,
fields(
db.statement,
%user.id,
user_email.id,
user_email.email = email,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
email: String,
) -> Result<UserEmail, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO user_emails (user_email_id, user_id, email, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
&email,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UserEmail {
id,
user_id: user.id,
email,
created_at,
confirmed_at: None,
})
}
#[tracing::instrument(
name = "db.user_email.remove",
skip_all,
fields(
db.statement,
user.id = %user_email.user_id,
%user_email.id,
%user_email.email,
),
err,
)]
async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
let span = info_span!(
"db.user_email.remove.codes",
db.statement = tracing::field::Empty
);
sqlx::query!(
r#"
DELETE FROM user_email_confirmation_codes
WHERE user_email_id = $1
"#,
Uuid::from(user_email.id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
let res = sqlx::query!(
r#"
DELETE FROM user_emails
WHERE user_email_id = $1
"#,
Uuid::from(user_email.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(())
}
async fn mark_as_verified(
&mut self,
clock: &Clock,
mut user_email: UserEmail,
) -> Result<UserEmail, Self::Error> {
let confirmed_at = clock.now();
sqlx::query!(
r#"
UPDATE user_emails
SET confirmed_at = $2
WHERE user_email_id = $1
"#,
Uuid::from(user_email.id),
confirmed_at,
)
.execute(&mut *self.conn)
.await?;
user_email.confirmed_at = Some(confirmed_at);
Ok(user_email)
}
async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> {
sqlx::query!(
r#"
UPDATE users
SET primary_user_email_id = user_emails.user_email_id
FROM user_emails
WHERE user_emails.user_email_id = $1
AND users.user_id = user_emails.user_id
"#,
Uuid::from(user_email.id),
)
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.user_email.add_verification_code",
skip_all,
fields(
db.statement,
%user_email.id,
%user_email.email,
user_email_verification.id,
user_email_verification.code = code,
),
err,
)]
async fn add_verification_code(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user_email: &UserEmail,
max_age: chrono::Duration,
code: String,
) -> Result<UserEmailVerification, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
let expires_at = created_at + max_age;
sqlx::query!(
r#"
INSERT INTO user_email_confirmation_codes
(user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(user_email.id),
code,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
let verification = UserEmailVerification {
id,
user_email_id: user_email.id,
code,
created_at,
state: UserEmailVerificationState::Valid,
};
Ok(verification)
}
#[tracing::instrument(
name = "db.user_email.find_verification_code",
skip_all,
fields(
db.statement,
%user_email.id,
user.id = %user_email.user_id,
),
err,
)]
async fn find_verification_code(
&mut self,
clock: &Clock,
user_email: &UserEmail,
code: &str,
) -> Result<Option<UserEmailVerification>, Self::Error> {
let res = sqlx::query_as!(
UserEmailConfirmationCodeLookup,
r#"
SELECT user_email_confirmation_code_id
, user_email_id
, code
, created_at
, expires_at
, consumed_at
FROM user_email_confirmation_codes
WHERE code = $1
AND user_email_id = $2
"#,
code,
Uuid::from(user_email.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into_verification(clock)))
}
#[tracing::instrument(
name = "db.user_email.consume_verification_code",
skip_all,
fields(
db.statement,
%user_email_verification.id,
user_email.id = %user_email_verification.user_email_id,
),
err,
)]
async fn consume_verification_code(
&mut self,
clock: &Clock,
mut user_email_verification: UserEmailVerification,
) -> Result<UserEmailVerification, Self::Error> {
if !matches!(
user_email_verification.state,
UserEmailVerificationState::Valid
) {
return Err(DatabaseError::invalid_operation());
}
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE user_email_confirmation_codes
SET consumed_at = $2
WHERE user_email_confirmation_code_id = $1
"#,
Uuid::from(user_email_verification.id),
consumed_at
)
.traced()
.execute(&mut *self.conn)
.await?;
user_email_verification.state =
UserEmailVerificationState::AlreadyUsed { when: consumed_at };
Ok(user_email_verification)
}
}

View File

@ -0,0 +1,203 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::User;
use mas_storage::{user::UserRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt};
mod email;
mod password;
mod session;
#[cfg(test)]
mod tests;
pub use self::{
email::PgUserEmailRepository, password::PgUserPasswordRepository,
session::PgBrowserSessionRepository,
};
pub struct PgUserRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUserRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(Debug, Clone)]
struct UserLookup {
user_id: Uuid,
username: String,
primary_user_email_id: Option<Uuid>,
#[allow(dead_code)]
created_at: DateTime<Utc>,
}
impl From<UserLookup> for User {
fn from(value: UserLookup) -> Self {
let id = value.user_id.into();
Self {
id,
username: value.username,
sub: id.to_string(),
primary_user_email_id: value.primary_user_email_id.map(Into::into),
}
}
}
#[async_trait]
impl<'c> UserRepository for PgUserRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.user.lookup",
skip_all,
fields(
db.statement,
user.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
let res = sqlx::query_as!(
UserLookup,
r#"
SELECT user_id
, username
, primary_user_email_id
, created_at
FROM users
WHERE user_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.user.find_by_username",
skip_all,
fields(
db.statement,
user.username = username,
),
err,
)]
async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
let res = sqlx::query_as!(
UserLookup,
r#"
SELECT user_id
, username
, primary_user_email_id
, created_at
FROM users
WHERE username = $1
"#,
username,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.user.add",
skip_all,
fields(
db.statement,
user.username = username,
user.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
username: String,
) -> Result<User, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO users (user_id, username, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
username,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(User {
id,
username,
sub: id.to_string(),
primary_user_email_id: None,
})
}
#[tracing::instrument(
name = "db.user.exists",
skip_all,
fields(
db.statement,
user.username = username,
),
err,
)]
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
let exists = sqlx::query_scalar!(
r#"
SELECT EXISTS(
SELECT 1 FROM users WHERE username = $1
) AS "exists!"
"#,
username
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
Ok(exists)
}
}

View File

@ -0,0 +1,155 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{Password, User};
use mas_storage::{user::UserPasswordRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
pub struct PgUserPasswordRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUserPasswordRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct UserPasswordLookup {
user_password_id: Uuid,
hashed_password: String,
version: i32,
upgraded_from_id: Option<Uuid>,
created_at: DateTime<Utc>,
}
#[async_trait]
impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.user_password.active",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err,
)]
async fn active(&mut self, user: &User) -> Result<Option<Password>, Self::Error> {
let res = sqlx::query_as!(
UserPasswordLookup,
r#"
SELECT up.user_password_id
, up.hashed_password
, up.version
, up.upgraded_from_id
, up.created_at
FROM user_passwords up
WHERE up.user_id = $1
ORDER BY up.created_at DESC
LIMIT 1
"#,
Uuid::from(user.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = Ulid::from(res.user_password_id);
let version = res.version.try_into().map_err(|e| {
DatabaseInconsistencyError::on("user_passwords")
.column("version")
.row(id)
.source(e)
})?;
let upgraded_from_id = res.upgraded_from_id.map(Ulid::from);
let created_at = res.created_at;
let hashed_password = res.hashed_password;
Ok(Some(Password {
id,
hashed_password,
version,
upgraded_from_id,
created_at,
}))
}
#[tracing::instrument(
name = "db.user_password.add",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
user_password.id,
user_password.version = version,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
version: u16,
hashed_password: String,
upgraded_from: Option<&Password>,
) -> Result<Password, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_password.id", tracing::field::display(id));
let upgraded_from_id = upgraded_from.map(|p| p.id);
sqlx::query!(
r#"
INSERT INTO user_passwords
(user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
Uuid::from(id),
Uuid::from(user.id),
hashed_password,
i32::from(version),
upgraded_from_id.map(Uuid::from),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(Password {
id,
hashed_password,
version,
upgraded_from_id,
created_at,
})
}
}

View File

@ -0,0 +1,375 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
};
pub struct PgBrowserSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgBrowserSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct SessionLookup {
user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
user_session_finished_at: Option<DateTime<Utc>>,
user_id: Uuid,
user_username: String,
user_primary_user_email_id: Option<Uuid>,
last_authentication_id: Option<Uuid>,
last_authd_at: Option<DateTime<Utc>>,
}
impl TryFrom<SessionLookup> for BrowserSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
let id = Ulid::from(value.user_id);
let user = User {
id,
username: value.user_username,
sub: id.to_string(),
primary_user_email_id: value.user_primary_user_email_id.map(Into::into),
};
let last_authentication = match (value.last_authentication_id, value.last_authd_at) {
(Some(id), Some(created_at)) => Some(Authentication {
id: id.into(),
created_at,
}),
(None, None) => None,
_ => {
return Err(DatabaseInconsistencyError::on(
"user_session_authentications",
))
}
};
Ok(BrowserSession {
id: value.user_session_id.into(),
user,
created_at: value.user_session_created_at,
finished_at: value.user_session_finished_at,
last_authentication,
})
}
}
#[async_trait]
impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.browser_session.lookup",
skip_all,
fields(
db.statement,
user_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT s.user_session_id
, s.created_at AS "user_session_created_at"
, s.finished_at AS "user_session_finished_at"
, u.user_id
, u.username AS "user_username"
, u.primary_user_email_id AS "user_primary_user_email_id"
, a.user_session_authentication_id AS "last_authentication_id?"
, a.created_at AS "last_authd_at?"
FROM user_sessions s
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications a
USING (user_session_id)
WHERE s.user_session_id = $1
ORDER BY a.created_at DESC
LIMIT 1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.browser_session.add",
skip_all,
fields(
db.statement,
%user.id,
user_session.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
) -> Result<BrowserSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO user_sessions (user_session_id, user_id, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(user.id),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
let session = BrowserSession {
id,
// XXX
user: user.clone(),
created_at,
finished_at: None,
last_authentication: None,
};
Ok(session)
}
#[tracing::instrument(
name = "db.browser_session.finish",
skip_all,
fields(
db.statement,
%user_session.id,
),
err,
)]
async fn finish(
&mut self,
clock: &Clock,
mut user_session: BrowserSession,
) -> Result<BrowserSession, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE user_sessions
SET finished_at = $1
WHERE user_session_id = $2
"#,
finished_at,
Uuid::from(user_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
user_session.finished_at = Some(finished_at);
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(user_session)
}
#[tracing::instrument(
name = "db.browser_session.list_active_paginated",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn list_active_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<BrowserSession>, Self::Error> {
// TODO: ordering of last authentication is wrong
let mut query = QueryBuilder::new(
r#"
SELECT DISTINCT ON (s.user_session_id)
s.user_session_id,
u.user_id,
u.username,
s.created_at,
a.user_session_authentication_id AS "last_authentication_id",
a.created_at AS "last_authd_at",
FROM user_sessions s
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications a
USING (user_session_id)
"#,
);
query
.push(" WHERE s.finished_at IS NULL AND s.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("s.user_session_id", pagination);
let edges: Vec<SessionLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination
.process(edges)
.try_map(BrowserSession::try_from)?;
Ok(page)
}
#[tracing::instrument(
name = "db.browser_session.count_active",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM user_sessions s
WHERE s.user_id = $1 AND s.finished_at IS NULL
"#,
Uuid::from(user.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
res.try_into().map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.browser_session.authenticate_with_password",
skip_all,
fields(
db.statement,
%user_session.id,
%user_password.id,
user_session_authentication.id,
),
err,
)]
async fn authenticate_with_password(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
mut user_session: BrowserSession,
user_password: &Password,
) -> Result<BrowserSession, Self::Error> {
let _user_password = user_password;
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"user_session_authentication.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
user_session.last_authentication = Some(Authentication { id, created_at });
Ok(user_session)
}
#[tracing::instrument(
name = "db.browser_session.authenticate_with_upstream",
skip_all,
fields(
db.statement,
%user_session.id,
%upstream_oauth_link.id,
user_session_authentication.id,
),
err,
)]
async fn authenticate_with_upstream(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
mut user_session: BrowserSession,
upstream_oauth_link: &UpstreamOAuthLink,
) -> Result<BrowserSession, Self::Error> {
let _upstream_oauth_link = upstream_oauth_link;
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"user_session_authentication.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
user_session.last_authentication = Some(Authentication { id, created_at });
Ok(user_session)
}
}

View File

@ -13,14 +13,15 @@
// limitations under the License.
use chrono::Duration;
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Clock, Repository,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use crate::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Clock, PgRepository, Repository,
};
use crate::PgRepository;
/// Test the user repository, by adding and looking up a user
#[sqlx::test(migrator = "crate::MIGRATOR")]
@ -88,7 +89,7 @@ async fn test_user_email_repo(pool: PgPool) {
// The user email should not exist yet
assert!(repo
.user_email()
.find(&user, &EMAIL)
.find(&user, EMAIL)
.await
.unwrap()
.is_none());
@ -109,7 +110,7 @@ async fn test_user_email_repo(pool: PgPool) {
assert!(repo
.user_email()
.find(&user, &EMAIL)
.find(&user, EMAIL)
.await
.unwrap()
.is_some());
@ -179,7 +180,7 @@ async fn test_user_email_repo(pool: PgPool) {
// Reload the user_email
let user_email = repo
.user_email()
.find(&user, &EMAIL)
.find(&user, EMAIL)
.await
.unwrap()
.expect("user email was not found");

View File

@ -7,18 +7,12 @@ license = "Apache-2.0"
[dependencies]
async-trait = "0.1.60"
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] }
chrono = { version = "0.4.23", features = ["serde"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91"
chrono = "0.4.23"
thiserror = "1.0.38"
tracing = "0.1.37"
rand = "0.8.5"
rand_chacha = "0.3.1"
url = { version = "2.3.1", features = ["serde"] }
uuid = "1.2.2"
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
url = "2.3.1"
ulid = "1.0.0"
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }

View File

@ -13,14 +13,12 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use chrono::Duration;
use mas_data_model::{CompatAccessToken, CompatSession};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
use crate::Clock;
#[async_trait]
pub trait CompatAccessTokenRepository: Send + Sync {
@ -52,195 +50,3 @@ pub trait CompatAccessTokenRepository: Send + Sync {
compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error>;
}
pub struct PgCompatAccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatAccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_access_token.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.add",
skip_all,
fields(
db.statement,
compat_access_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatAccessToken {
id,
session_id: compat_session.id,
token,
created_at,
expires_at,
})
}
#[tracing::instrument(
name = "db.compat_access_token.expire",
skip_all,
fields(
db.statement,
%compat_access_token.id,
compat_session.id = %compat_access_token.session_id,
),
err,
)]
async fn expire(
&mut self,
clock: &Clock,
mut compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error> {
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
Uuid::from(compat_access_token.id),
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
compat_access_token.expires_at = Some(expires_at);
Ok(compat_access_token)
}
}

View File

@ -18,301 +18,6 @@ mod session;
mod sso_login;
pub use self::{
access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository},
refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository},
session::{CompatSessionRepository, PgCompatSessionRepository},
sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository},
access_token::CompatAccessTokenRepository, refresh_token::CompatRefreshTokenRepository,
session::CompatSessionRepository, sso_login::CompatSsoLoginRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::Device;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use super::*;
use crate::{user::UserRepository, Clock, PgRepository, Repository};
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_session_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let device_str = device.as_str().to_owned();
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
assert_eq!(session.user_id, user.id);
assert_eq!(session.device.as_str(), device_str);
assert!(session.is_valid());
assert!(!session.is_finished());
// Lookup the session and check it didn't change
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user_id, user.id);
assert_eq!(session_lookup.device.as_str(), device_str);
assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished());
// Finish the session
let session = repo.compat_session().finish(&clock, session).await.unwrap();
assert!(!session.is_valid());
assert!(session.is_finished());
// Reload the session and check again
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert!(!session_lookup.is_valid());
assert!(session_lookup.is_finished());
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_access_token_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
// Add an access token to that session
let token = repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, FIRST_TOKEN);
// Commit the txn and grab a new transaction, to test a conflict
repo.save().await.unwrap();
{
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Adding the same token a second time should conflict
assert!(repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.is_err());
repo.cancel().await.unwrap();
}
// Grab a new repo
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Looking up via ID works
let token_lookup = repo
.compat_access_token()
.lookup(token.id)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Looking up via the token value works
let token_lookup = repo
.compat_access_token()
.find_by_token(FIRST_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Token is currently valid
assert!(token.is_valid(clock.now()));
clock.advance(Duration::minutes(1));
// Token should have expired
assert!(!token.is_valid(clock.now()));
// Add a second access token, this time without expiration
let token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, SECOND_TOKEN);
// Token is currently valid
assert!(token.is_valid(clock.now()));
// Make it expire
repo.compat_access_token()
.expire(&clock, token)
.await
.unwrap();
// Reload it
let token = repo
.compat_access_token()
.find_by_token(SECOND_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
// Token is not valid anymore
assert!(!token.is_valid(clock.now()));
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_refresh_token_repository(pool: PgPool) {
const ACCESS_TOKEN: &str = "access_token";
const REFRESH_TOKEN: &str = "refresh_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device)
.await
.unwrap();
// Add an access token to that session
let access_token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
.await
.unwrap();
let refresh_token = repo
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
REFRESH_TOKEN.to_owned(),
)
.await
.unwrap();
assert_eq!(refresh_token.session_id, session.id);
assert_eq!(refresh_token.access_token_id, access_token.id);
assert_eq!(refresh_token.token, REFRESH_TOKEN);
assert!(refresh_token.is_valid());
assert!(!refresh_token.is_consumed());
// Look it up by ID and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.lookup(refresh_token.id)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Look it up by token and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Consume it
let refresh_token = repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());
assert!(refresh_token.is_consumed());
// Reload it and check again
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert!(!refresh_token_lookup.is_valid());
assert!(refresh_token_lookup.is_consumed());
// Consuming it again should not work
assert!(repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.is_err());
repo.save().await.unwrap();
}
}

View File

@ -13,16 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
};
use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
use crate::Clock;
#[async_trait]
pub trait CompatRefreshTokenRepository: Send + Sync {
@ -54,207 +49,3 @@ pub trait CompatRefreshTokenRepository: Send + Sync {
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error>;
}
pub struct PgCompatRefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatRefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
fn from(value: CompatRefreshTokenLookup) -> Self {
let state = match value.consumed_at {
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
None => CompatRefreshTokenState::Valid,
};
Self {
id: value.compat_refresh_token_id.into(),
state,
session_id: value.compat_session_id.into(),
token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.compat_access_token_id.into(),
}
}
}
#[async_trait]
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_refresh_token.lookup",
skip_all,
fields(
db.statement,
compat_refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.add",
skip_all,
fields(
db.statement,
compat_refresh_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
Uuid::from(compat_access_token.id),
token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: compat_session.id,
access_token_id: compat_access_token.id,
token,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_refresh_token.consume",
skip_all,
fields(
db.statement,
%compat_refresh_token.id,
compat_session.id = %compat_refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(compat_refresh_token.id),
consumed_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_refresh_token = compat_refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_refresh_token)
}
}

View File

@ -13,16 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
use mas_data_model::{CompatSession, Device, User};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
use crate::Clock;
#[async_trait]
pub trait CompatSessionRepository: Send + Sync {
@ -47,174 +42,3 @@ pub trait CompatSessionRepository: Send + Sync {
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error>;
}
pub struct PgCompatSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatSessionLookup {
compat_session_id: Uuid,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<CompatSessionLookup> for CompatSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
let id = value.compat_session_id.into();
let device = Device::try_from(value.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: value.user_id.into(),
device,
created_at: value.created_at,
};
Ok(session)
}
}
#[async_trait]
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_session.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
let res = sqlx::query_as!(
CompatSessionLookup,
r#"
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_session.add",
skip_all,
fields(
db.statement,
compat_session.id,
%user.id,
%user.username,
compat_session.device.id = device.as_str(),
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
device.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSession {
id,
state: CompatSessionState::default(),
user_id: user.id,
device,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_session.finish",
skip_all,
fields(
db.statement,
%compat_session.id,
user.id = %compat_session.user_id,
compat_session.device.id = compat_session.device.as_str(),
),
err,
)]
async fn finish(
&mut self,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
}

View File

@ -13,19 +13,12 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
use mas_data_model::{CompatSession, CompatSsoLogin, User};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait CompatSsoLoginRepository: Send + Sync {
@ -71,317 +64,3 @@ pub trait CompatSsoLoginRepository: Send + Sync {
pagination: Pagination,
) -> Result<Page<CompatSsoLogin>, Self::Error>;
}
pub struct PgCompatSsoLoginRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSsoLoginRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid,
login_token: String,
redirect_uri: String,
created_at: DateTime<Utc>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
let id = res.compat_sso_login_id.into();
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
DatabaseInconsistencyError::on("compat_sso_logins")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
};
Ok(CompatSsoLogin {
id,
login_token: res.login_token,
redirect_uri,
created_at: res.created_at,
state,
})
}
}
#[async_trait]
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_sso_login.lookup",
skip_all,
fields(
db.statement,
compat_sso_login.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE compat_sso_login_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
login_token: &str,
) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE login_token = $1
"#,
login_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.add",
skip_all,
fields(
db.statement,
compat_sso_login.id,
compat_sso_login.redirect_uri = %redirect_uri,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSsoLogin {
id,
login_token,
redirect_uri,
created_at,
state: CompatSsoLoginState::default(),
})
}
#[tracing::instrument(
name = "db.compat_sso_login.fulfill",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
%compat_session.id,
compat_session.device.id = compat_session.device.as_str(),
user.id = %compat_session.user_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error> {
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, compat_session)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
compat_session_id = $2,
fulfilled_at = $3
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
Uuid::from(compat_session.id),
fulfilled_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.exchange",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error> {
let exchanged_at = clock.now();
let compat_sso_login = compat_sso_login
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = $2
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
exchanged_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<CompatSsoLogin>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
, cl.login_token
, cl.redirect_uri
, cl.created_at
, cl.fulfilled_at
, cl.exchanged_at
, cl.compat_session_id
FROM compat_sso_logins cl
INNER JOIN compat_sessions ON compat_session_id
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", pagination);
let edges: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination
.process(edges)
.try_map(CompatSsoLogin::try_from)?;
Ok(page)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -29,150 +29,19 @@
)]
use chrono::{DateTime, Utc};
use pagination::InvalidPagination;
use sqlx::{migrate::Migrator, postgres::PgQueryResult};
use thiserror::Error;
use ulid::Ulid;
trait LookupResultExt {
type Output;
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
/// into [`None`]
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
}
impl<T> LookupResultExt for Result<T, sqlx::Error> {
type Output = T;
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error> {
match self {
Ok(v) => Ok(Some(v)),
Err(sqlx::Error::RowNotFound) => Ok(None),
Err(e) => Err(e),
}
}
}
/// Generic error when interacting with the database
#[derive(Debug, Error)]
#[error(transparent)]
pub enum DatabaseError {
/// An error which came from the database itself
Driver(#[from] sqlx::Error),
/// An error which occured while converting the data from the database
Inconsistency(#[from] DatabaseInconsistencyError),
/// An error which occured while generating the paginated query
Pagination(#[from] InvalidPagination),
/// An error which happened because the requested database operation is
/// invalid
#[error("Invalid database operation")]
InvalidOperation {
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
/// An error which happens when an operation affects not enough or too many
/// rows
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
RowsAffected { expected: u64, actual: u64 },
}
impl DatabaseError {
pub(crate) fn ensure_affected_rows(
result: &PgQueryResult,
expected: u64,
) -> Result<(), DatabaseError> {
let actual = result.rows_affected();
if actual == expected {
Ok(())
} else {
Err(DatabaseError::RowsAffected { expected, actual })
}
}
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
Self::InvalidOperation {
source: Some(Box::new(e)),
}
}
pub(crate) const fn invalid_operation() -> Self {
Self::InvalidOperation { source: None }
}
}
#[derive(Debug, Error)]
pub struct DatabaseInconsistencyError {
table: &'static str,
column: Option<&'static str>,
row: Option<Ulid>,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
}
impl std::fmt::Display for DatabaseInconsistencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Database inconsistency on table {}", self.table)?;
if let Some(column) = self.column {
write!(f, " column {column}")?;
}
if let Some(row) = self.row {
write!(f, " row {row}")?;
}
Ok(())
}
}
impl DatabaseInconsistencyError {
#[must_use]
pub(crate) const fn on(table: &'static str) -> Self {
Self {
table,
column: None,
row: None,
source: None,
}
}
#[must_use]
pub(crate) const fn column(mut self, column: &'static str) -> Self {
self.column = Some(column);
self
}
#[must_use]
pub(crate) const fn row(mut self, row: Ulid) -> Self {
self.row = Some(row);
self
}
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
mut self,
source: E,
) -> Self {
self.source = Some(Box::new(source));
self
}
}
#[derive(Debug, Clone, Default)]
pub struct Clock {
_private: (),
#[cfg(test)]
// #[cfg(test)]
mock: Option<std::sync::Arc<std::sync::atomic::AtomicI64>>,
}
impl Clock {
#[must_use]
pub fn now(&self) -> DateTime<Utc> {
#[cfg(test)]
// #[cfg(test)]
if let Some(timestamp) = &self.mock {
let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed);
return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap();
@ -183,13 +52,14 @@ impl Clock {
Utc::now()
}
#[cfg(test)]
// #[cfg(test)]
#[must_use]
pub fn mock() -> Self {
use std::sync::{atomic::AtomicI64, Arc};
use chrono::TimeZone;
let datetime = Utc.with_ymd_and_hms(2022, 01, 16, 14, 40, 0).unwrap();
let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap();
let timestamp = datetime.timestamp();
Self {
@ -198,7 +68,7 @@ impl Clock {
}
}
#[cfg(test)]
// #[cfg(test)]
pub fn advance(&self, duration: chrono::Duration) {
let timestamp = self
.mock
@ -247,16 +117,12 @@ mod tests {
pub mod compat;
pub mod oauth2;
pub(crate) mod pagination;
pub mod pagination;
pub(crate) mod repository;
pub(crate) mod tracing;
pub mod upstream_oauth2;
pub mod user;
pub use self::{
pagination::Pagination,
repository::{PgRepository, Repository},
pagination::{Page, Pagination},
repository::Repository,
};
/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@ -13,14 +13,12 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, AccessTokenState, Session};
use chrono::Duration;
use mas_data_model::{AccessToken, Session};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
use crate::Clock;
#[async_trait]
pub trait OAuth2AccessTokenRepository: Send + Sync {
@ -55,202 +53,3 @@ pub trait OAuth2AccessTokenRepository: Send + Sync {
/// Cleanup expired access tokens
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error>;
}
pub struct PgOAuth2AccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid,
oauth2_session_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
}
impl From<OAuth2AccessTokenLookup> for AccessToken {
fn from(value: OAuth2AccessTokenLookup) -> Self {
let state = match value.revoked_at {
None => AccessTokenState::Valid,
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
};
Self {
id: value.oauth2_access_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
access_token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
type Error = DatabaseError;
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
access_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, Self::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_access_tokens
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
&access_token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(AccessToken {
id,
state: AccessTokenState::default(),
access_token,
session_id: session.id,
created_at,
expires_at,
})
}
async fn revoke(
&mut self,
clock: &Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token.id),
revoked_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error> {
// Cleanup token which expired more than 15 minutes ago
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE expires_at < $1
"#,
threshold,
)
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -15,21 +15,13 @@
use std::num::NonZeroU32;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
};
use mas_iana::oauth::PkceCodeChallengeMethod;
use mas_data_model::{AuthorizationCode, AuthorizationGrant, Client, Session};
use oauth2_types::{requests::ResponseMode, scope::Scope};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
use crate::Clock;
#[async_trait]
pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
@ -75,482 +67,3 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
authorization_grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error>;
}
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[allow(clippy::struct_excessive_bools)]
struct GrantLookup {
oauth2_authorization_grant_id: Uuid,
created_at: DateTime<Utc>,
cancelled_at: Option<DateTime<Utc>>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
scope: String,
state: Option<String>,
nonce: Option<String>,
redirect_uri: String,
response_mode: String,
max_age: Option<i32>,
response_type_code: bool,
response_type_id_token: bool,
authorization_code: Option<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
requires_consent: bool,
oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>,
}
impl TryFrom<GrantLookup> for AuthorizationGrant {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)]
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
let id = value.oauth2_authorization_grant_id.into();
let scope: Scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("scope")
.row(id)
.source(e)
})?;
let stage = match (
value.fulfilled_at,
value.exchanged_at,
value.cancelled_at,
value.oauth2_session_id,
) {
(None, None, None, None) => AuthorizationGrantStage::Pending,
(Some(fulfilled_at), None, None, Some(session_id)) => {
AuthorizationGrantStage::Fulfilled {
session_id: session_id.into(),
fulfilled_at,
}
}
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
AuthorizationGrantStage::Exchanged {
session_id: session_id.into(),
fulfilled_at,
exchanged_at,
}
}
(None, None, Some(cancelled_at), None) => {
AuthorizationGrantStage::Cancelled { cancelled_at }
}
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("stage")
.row(id),
);
}
};
let pkce = match (value.code_challenge, value.code_challenge_method) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce {
challenge_method: PkceCodeChallengeMethod::Plain,
challenge,
})
}
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
challenge_method: PkceCodeChallengeMethod::S256,
challenge,
}),
(None, None) => None,
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("code_challenge_method")
.row(id),
);
}
};
let code: Option<AuthorizationCode> =
match (value.response_type_code, value.authorization_code, pkce) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("authorization_code")
.row(id),
);
}
};
let redirect_uri = value.redirect_uri.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let response_mode = value.response_mode.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("response_mode")
.row(id)
.source(e)
})?;
let max_age = value
.max_age
.map(u32::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?
.map(NonZeroU32::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?;
Ok(AuthorizationGrant {
id,
stage,
client_id: value.oauth2_client_id.into(),
code,
scope,
state: value.state,
nonce: value.nonce,
max_age,
response_mode,
redirect_uri,
created_at: value.created_at,
response_type_id_token: value.response_type_id_token,
requires_consent: value.requires_consent,
})
}
}
#[async_trait]
impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_authorization_grant.add",
skip_all,
fields(
db.statement,
grant.id,
grant.scope = %scope,
%client.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, Self::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("grant.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_authorization_grants (
oauth2_authorization_grant_id,
oauth2_client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_mode,
code_challenge,
code_challenge_method,
response_type_code,
response_type_id_token,
authorization_code,
requires_consent,
created_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#,
Uuid::from(id),
Uuid::from(client.id),
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
max_age_i32,
response_mode.to_string(),
code_challenge,
code_challenge_method,
code.is_some(),
response_type_id_token,
code_str,
requires_consent,
created_at,
)
.execute(&mut *self.conn)
.await?;
Ok(AuthorizationGrant {
id,
stage: AuthorizationGrantStage::Pending,
code,
redirect_uri,
client_id: client.id,
scope,
state,
nonce,
max_age,
response_mode,
created_at,
response_type_id_token,
requires_consent,
})
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.lookup",
skip_all,
fields(
db.statement,
grant.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.find_by_code",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_code(
&mut self,
code: &str,
) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE authorization_code = $1
"#,
code,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.fulfill",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
%session.id,
user_session.id = %session.user_session_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
session: &Session,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let fulfilled_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET fulfilled_at = $2
, oauth2_session_id = $3
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
fulfilled_at,
Uuid::from(session.id),
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
// XXX: check affected rows & new methods
let grant = grant
.fulfill(fulfilled_at, session)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.exchange",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let exchanged_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
exchanged_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let grant = grant
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.give_consent",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn give_consent(
&mut self,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
requires_consent = 'f'
WHERE
og.oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
)
.execute(&mut *self.conn)
.await?;
grant.requires_consent = false;
Ok(grant)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,33 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::{BTreeMap, BTreeSet},
str::FromStr,
string::ToString,
};
use std::collections::{BTreeMap, BTreeSet};
use async_trait::async_trait;
use mas_data_model::{Client, JwksOrJwksUri, User};
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_data_model::{Client, User};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::jwk::PublicJsonWebKeySet;
use oauth2_types::{
requests::GrantType,
scope::{Scope, ScopeToken},
};
use oauth2_types::{requests::GrantType, scope::Scope};
use rand::{Rng, RngCore};
use sqlx::PgConnection;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
use crate::Clock;
#[async_trait]
pub trait OAuth2ClientRepository: Send + Sync {
@ -107,708 +92,3 @@ pub trait OAuth2ClientRepository: Send + Sync {
scope: &Scope,
) -> Result<(), Self::Error>;
}
pub struct PgOAuth2ClientRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2ClientRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
// XXX: response_types & contacts
#[derive(Debug)]
struct OAuth2ClientLookup {
oauth2_client_id: Uuid,
encrypted_client_secret: Option<String>,
redirect_uris: Vec<String>,
// response_types: Vec<String>,
grant_type_authorization_code: bool,
grant_type_refresh_token: bool,
// contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<String>,
client_uri: Option<String>,
policy_uri: Option<String>,
tos_uri: Option<String>,
jwks_uri: Option<String>,
jwks: Option<serde_json::Value>,
id_token_signed_response_alg: Option<String>,
userinfo_signed_response_alg: Option<String>,
token_endpoint_auth_method: Option<String>,
token_endpoint_auth_signing_alg: Option<String>,
initiate_login_uri: Option<String>,
}
impl TryInto<Client> for OAuth2ClientLookup {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
fn try_into(self) -> Result<Client, Self::Error> {
let id = Ulid::from(self.oauth2_client_id);
let redirect_uris: Result<Vec<Url>, _> =
self.redirect_uris.iter().map(|s| s.parse()).collect();
let redirect_uris = redirect_uris.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("redirect_uris")
.row(id)
.source(e)
})?;
let response_types = vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
];
/* XXX
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
self.response_types.iter().map(|s| s.parse()).collect();
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
field: "response_types",
source,
})?;
*/
let mut grant_types = Vec::new();
if self.grant_type_authorization_code {
grant_types.push(GrantType::AuthorizationCode);
}
if self.grant_type_refresh_token {
grant_types.push(GrantType::RefreshToken);
}
let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("logo_uri")
.row(id)
.source(e)
})?;
let client_uri = self
.client_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("client_uri")
.row(id)
.source(e)
})?;
let policy_uri = self
.policy_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("policy_uri")
.row(id)
.source(e)
})?;
let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("tos_uri")
.row(id)
.source(e)
})?;
let id_token_signed_response_alg = self
.id_token_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("id_token_signed_response_alg")
.row(id)
.source(e)
})?;
let userinfo_signed_response_alg = self
.userinfo_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("userinfo_signed_response_alg")
.row(id)
.source(e)
})?;
let token_endpoint_auth_method = self
.token_endpoint_auth_method
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
})?;
let token_endpoint_auth_signing_alg = self
.token_endpoint_auth_signing_alg
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("token_endpoint_auth_signing_alg")
.row(id)
.source(e)
})?;
let initiate_login_uri = self
.initiate_login_uri
.map(|s| s.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("initiate_login_uri")
.row(id)
.source(e)
})?;
let jwks = match (self.jwks, self.jwks_uri) {
(None, None) => None,
(Some(jwks), None) => {
let jwks = serde_json::from_value(jwks).map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks")
.row(id)
.source(e)
})?;
Some(JwksOrJwksUri::Jwks(jwks))
}
(None, Some(jwks_uri)) => {
let jwks_uri = jwks_uri.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks_uri")
.row(id)
.source(e)
})?;
Some(JwksOrJwksUri::JwksUri(jwks_uri))
}
_ => {
return Err(DatabaseInconsistencyError::on("oauth2_clients")
.column("jwks(_uri)")
.row(id))
}
};
Ok(Client {
id,
client_id: id.to_string(),
encrypted_client_secret: self.encrypted_client_secret,
redirect_uris,
response_types,
grant_types,
// contacts: self.contacts,
contacts: vec![],
client_name: self.client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
userinfo_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
}
#[async_trait]
impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_client.lookup",
skip_all,
fields(
db.statement,
oauth2_client.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT oauth2_client_id
, encrypted_client_secret
, ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!"
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
FROM oauth2_clients c
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_client.load_batch",
skip_all,
fields(
db.statement,
),
err,
)]
async fn load_batch(
&mut self,
ids: BTreeSet<Ulid>,
) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT oauth2_client_id
, encrypted_client_secret
, ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!"
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
FROM oauth2_clients c
WHERE oauth2_client_id = ANY($1::uuid[])
"#,
&ids,
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
res.into_iter()
.map(|r| {
r.try_into()
.map(|c: Client| (c.id, c))
.map_err(DatabaseError::from)
})
.collect()
}
#[tracing::instrument(
name = "db.oauth2_client.add",
skip_all,
fields(
db.statement,
client.id,
client.name = client_name
),
err,
)]
#[allow(clippy::too_many_lines)]
async fn add(
&mut self,
mut rng: &mut (dyn RngCore + Send),
clock: &Clock,
redirect_uris: Vec<Url>,
encrypted_client_secret: Option<String>,
grant_types: Vec<GrantType>,
contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<Url>,
client_uri: Option<Url>,
policy_uri: Option<Url>,
tos_uri: Option<Url>,
jwks_uri: Option<Url>,
jwks: Option<PublicJsonWebKeySet>,
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
initiate_login_uri: Option<Url>,
) -> Result<Client, Self::Error> {
let now = clock.now();
let id = Ulid::from_datetime_with_source(now.into(), rng);
tracing::Span::current().record("client.id", tracing::field::display(id));
let jwks_json = jwks
.as_ref()
.map(serde_json::to_value)
.transpose()
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"
INSERT INTO oauth2_clients
( oauth2_client_id
, encrypted_client_secret
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
"#,
Uuid::from(id),
encrypted_client_secret,
grant_types.contains(&GrantType::AuthorizationCode),
grant_types.contains(&GrantType::RefreshToken),
client_name,
logo_uri.as_ref().map(Url::as_str),
client_uri.as_ref().map(Url::as_str),
policy_uri.as_ref().map(Url::as_str),
tos_uri.as_ref().map(Url::as_str),
jwks_uri.as_ref().map(Url::as_str),
jwks_json,
id_token_signed_response_alg
.as_ref()
.map(ToString::to_string),
userinfo_signed_response_alg
.as_ref()
.map(ToString::to_string),
token_endpoint_auth_method.as_ref().map(ToString::to_string),
token_endpoint_auth_signing_alg
.as_ref()
.map(ToString::to_string),
initiate_login_uri.as_ref().map(Url::as_str),
)
.traced()
.execute(&mut *self.conn)
.await?;
{
let span = info_span!(
"db.oauth2_client.add.redirect_uris",
db.statement = tracing::field::Empty,
client.id = %id,
);
let (uri_ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
.iter()
.map(|uri| {
(
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
uri.as_str().to_owned(),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
SELECT id, $2, redirect_uri
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
"#,
&uri_ids,
Uuid::from(id),
&redirect_uris,
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
let jwks = match (jwks, jwks_uri) {
(None, None) => None,
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
_ => return Err(DatabaseError::invalid_operation()),
};
Ok(Client {
id,
client_id: id.to_string(),
encrypted_client_secret,
redirect_uris,
response_types: vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
],
grant_types,
contacts,
client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
userinfo_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
#[tracing::instrument(
name = "db.oauth2_client.add_from_config",
skip_all,
fields(
db.statement,
client.id = %client_id,
),
err,
)]
async fn add_from_config(
&mut self,
mut rng: impl Rng + Send,
clock: &Clock,
client_id: Ulid,
client_auth_method: OAuthClientAuthenticationMethod,
encrypted_client_secret: Option<String>,
jwks: Option<PublicJsonWebKeySet>,
jwks_uri: Option<Url>,
redirect_uris: Vec<Url>,
) -> Result<Client, Self::Error> {
let jwks_json = jwks
.as_ref()
.map(serde_json::to_value)
.transpose()
.map_err(DatabaseError::to_invalid_operation)?;
let client_auth_method = client_auth_method.to_string();
sqlx::query!(
r#"
INSERT INTO oauth2_clients
( oauth2_client_id
, encrypted_client_secret
, grant_type_authorization_code
, grant_type_refresh_token
, token_endpoint_auth_method
, jwks
, jwks_uri
)
VALUES
($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (oauth2_client_id)
DO
UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
, grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
, grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
, jwks = EXCLUDED.jwks
, jwks_uri = EXCLUDED.jwks_uri
"#,
Uuid::from(client_id),
encrypted_client_secret,
true,
true,
client_auth_method,
jwks_json,
jwks_uri.as_ref().map(Url::as_str),
)
.traced()
.execute(&mut *self.conn)
.await?;
{
let span = info_span!(
"db.oauth2_client.add_from_config.redirect_uris",
client.id = %client_id,
db.statement = tracing::field::Empty,
);
let now = clock.now();
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
.iter()
.map(|uri| {
(
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
uri.as_str().to_owned(),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
SELECT id, $2, redirect_uri
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
"#,
&ids,
Uuid::from(client_id),
&redirect_uris,
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
let jwks = match (jwks, jwks_uri) {
(None, None) => None,
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
_ => return Err(DatabaseError::invalid_operation()),
};
Ok(Client {
id: client_id,
client_id: client_id.to_string(),
encrypted_client_secret,
redirect_uris,
response_types: vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
],
grant_types: Vec::new(),
contacts: Vec::new(),
client_name: None,
logo_uri: None,
client_uri: None,
policy_uri: None,
tos_uri: None,
jwks,
id_token_signed_response_alg: None,
userinfo_signed_response_alg: None,
token_endpoint_auth_method: None,
token_endpoint_auth_signing_alg: None,
initiate_login_uri: None,
})
}
#[tracing::instrument(
name = "db.oauth2_client.get_consent_for_user",
skip_all,
fields(
db.statement,
%user.id,
%client.id,
),
err,
)]
async fn get_consent_for_user(
&mut self,
client: &Client,
user: &User,
) -> Result<Scope, Self::Error> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
Uuid::from(user.id),
Uuid::from(client.id),
)
.fetch_all(&mut *self.conn)
.await?;
let scope: Result<Scope, _> = scope_tokens
.into_iter()
.map(|s| ScopeToken::from_str(&s))
.collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
skip_all,
fields(
db.statement,
%user.id,
%client.id,
%scope,
),
err,
)]
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
user: &User,
scope: &Scope,
) -> Result<(), Self::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
&ids,
Uuid::from(user.id),
Uuid::from(client.id),
&tokens,
now,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -19,11 +19,7 @@ mod refresh_token;
mod session;
pub use self::{
access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository},
authorization_grant::{
OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository,
},
client::{OAuth2ClientRepository, PgOAuth2ClientRepository},
refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository},
session::{OAuth2SessionRepository, PgOAuth2SessionRepository},
access_token::OAuth2AccessTokenRepository,
authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository,
refresh_token::OAuth2RefreshTokenRepository, session::OAuth2SessionRepository,
};

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,14 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use mas_data_model::{AccessToken, RefreshToken, Session};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
use crate::Clock;
#[async_trait]
pub trait OAuth2RefreshTokenRepository: Send + Sync {
@ -52,203 +49,3 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error>;
}
pub struct PgOAuth2RefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2RefreshTokenLookup {
oauth2_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
oauth2_access_token_id: Option<Uuid>,
oauth2_session_id: Uuid,
}
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
fn from(value: OAuth2RefreshTokenLookup) -> Self {
let state = match value.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
RefreshToken {
id: value.oauth2_refresh_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
refresh_token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
}
}
}
#[async_trait]
impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_refresh_token.lookup",
skip_all,
fields(
db.statement,
refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
refresh_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
refresh_token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
})
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.consume",
skip_all,
fields(
db.statement,
%refresh_token.id,
session.id = %refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,18 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User};
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait OAuth2SessionRepository: Send + Sync {
@ -48,224 +41,3 @@ pub trait OAuth2SessionRepository: Send + Sync {
pagination: Pagination,
) -> Result<Page<Session>, Self::Error>;
}
pub struct PgOAuth2SessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2SessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct OAuthSessionLookup {
oauth2_session_id: Uuid,
user_session_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
#[allow(dead_code)]
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<OAuthSessionLookup> for Session {
type Error = DatabaseInconsistencyError;
fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
let id = Ulid::from(value.oauth2_session_id);
let scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => SessionState::Valid,
Some(finished_at) => SessionState::Finished { finished_at },
};
Ok(Session {
id,
state,
created_at: value.created_at,
client_id: value.oauth2_client_id.into(),
user_session_id: value.user_session_id.into(),
scope,
})
}
}
#[async_trait]
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_session.lookup",
skip_all,
fields(
db.statement,
session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
let res = sqlx::query_as!(
OAuthSessionLookup,
r#"
SELECT oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
, finished_at
FROM oauth2_sessions
WHERE oauth2_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(session) = res else { return Ok(None) };
Ok(Some(session.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_session.create_from_grant",
skip_all,
fields(
db.statement,
%user_session.id,
user.id = %user_session.user.id,
%grant.id,
client.id = %grant.client_id,
session.id,
session.scope = %grant.scope,
),
err,
)]
async fn create_from_grant(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
grant: &AuthorizationGrant,
user_session: &BrowserSession,
) -> Result<Session, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_sessions
( oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
Uuid::from(grant.client_id),
grant.scope.to_string(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(Session {
id,
state: SessionState::Valid,
created_at,
user_session_id: user_session.id,
client_id: grant.client_id,
scope: grant.scope.clone(),
})
}
#[tracing::instrument(
name = "db.oauth2_session.finish",
skip_all,
fields(
db.statement,
%session.id,
%session.scope,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
),
err,
)]
async fn finish(&mut self, clock: &Clock, session: Session) -> Result<Session, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_sessions
SET finished_at = $2
WHERE oauth2_session_id = $1
"#,
Uuid::from(session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.oauth2_session.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err,
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<Session>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
, finished_at
FROM oauth2_sessions os
"#,
);
query
.push(" WHERE us.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("oauth2_session_id", pagination);
let edges: Vec<OAuthSessionLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).try_map(Session::try_from)?;
Ok(page)
}
}

View File

@ -14,10 +14,8 @@
//! Utilities to manage paginated queries.
use sqlx::{Database, QueryBuilder};
use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
/// An error returned when invalid pagination parameters are provided
#[derive(Debug, Error)]
@ -26,14 +24,14 @@ pub struct InvalidPagination;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Pagination {
before: Option<Ulid>,
after: Option<Ulid>,
count: usize,
direction: PaginationDirection,
pub before: Option<Ulid>,
pub after: Option<Ulid>,
pub count: usize,
pub direction: PaginationDirection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PaginationDirection {
pub enum PaginationDirection {
Forward,
Backward,
}
@ -101,60 +99,8 @@ impl Pagination {
self
}
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str)
where
DB: Database,
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
// clause
if let Some(after) = self.after {
query
.push(" AND ")
.push(id_field)
.push(" > ")
.push_bind(Uuid::from(after));
}
// 3. If the before argument is provided, add `id < parsed_cursor` to the
// `WHERE` clause
if let Some(before) = self.before {
query
.push(" AND ")
.push(id_field)
.push(" < ")
.push_bind(Uuid::from(before));
}
match self.direction {
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
// query
PaginationDirection::Forward => {
query
.push(" ORDER BY ")
.push(id_field)
.push(" ASC LIMIT ")
.push_bind((self.count + 1) as i64);
}
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
// query
PaginationDirection::Backward => {
query
.push(" ORDER BY ")
.push(id_field)
.push(" DESC LIMIT ")
.push_bind((self.count + 1) as i64);
}
};
}
/// Process a page returned by a paginated query
#[must_use]
pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> {
let is_full = edges.len() == (self.count + 1);
if is_full {
@ -198,7 +144,6 @@ impl<T> Page<T> {
}
}
#[must_use]
pub fn try_map<F, E, T2>(self, f: F) -> Result<Page<T2>, E>
where
F: FnMut(T) -> Result<T2, E>,
@ -211,23 +156,3 @@ impl<T> Page<T> {
})
}
}
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
/// to a query
pub trait QueryBuilderExt {
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self;
}
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
where
DB: Database,
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
pagination.generate_pagination(self, id_field);
self
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,31 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::{PgPool, Postgres, Transaction};
use crate::{
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository,
PgCompatSessionRepository, PgCompatSsoLoginRepository,
CompatSsoLoginRepository,
},
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository,
PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository,
PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository,
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository,
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
user::{
BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository,
PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository,
UserRepository,
},
DatabaseError,
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
};
pub trait Repository: Send {
@ -126,109 +115,3 @@ pub trait Repository: Send {
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>;
}
pub struct PgRepository {
txn: Transaction<'static, Postgres>,
}
impl PgRepository {
pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
let txn = pool.begin().await?;
Ok(PgRepository { txn })
}
pub async fn save(self) -> Result<(), DatabaseError> {
self.txn.commit().await?;
Ok(())
}
pub async fn cancel(self) -> Result<(), DatabaseError> {
self.txn.rollback().await?;
Ok(())
}
}
impl Repository for PgRepository {
type Error = DatabaseError;
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
type UserRepository<'c> = PgUserRepository<'c> where Self: 'c;
type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c;
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(&mut self.txn)
}
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(&mut self.txn)
}
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(&mut self.txn)
}
fn user(&mut self) -> Self::UserRepository<'_> {
PgUserRepository::new(&mut self.txn)
}
fn user_email(&mut self) -> Self::UserEmailRepository<'_> {
PgUserEmailRepository::new(&mut self.txn)
}
fn user_password(&mut self) -> Self::UserPasswordRepository<'_> {
PgUserPasswordRepository::new(&mut self.txn)
}
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
PgBrowserSessionRepository::new(&mut self.txn)
}
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
PgOAuth2ClientRepository::new(&mut self.txn)
}
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(&mut self.txn)
}
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(&mut self.txn)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(&mut self.txn)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(&mut self.txn)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(&mut self.txn)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(&mut self.txn)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(&mut self.txn)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,18 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait UpstreamOAuthLinkRepository: Send + Sync {
@ -63,241 +56,3 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
}
pub struct PgUpstreamOAuthLinkRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthLinkRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct LinkLookup {
upstream_oauth_link_id: Uuid,
upstream_oauth_provider_id: Uuid,
user_id: Option<Uuid>,
subject: String,
created_at: DateTime<Utc>,
}
impl From<LinkLookup> for UpstreamOAuthLink {
fn from(value: LinkLookup) -> Self {
UpstreamOAuthLink {
id: Ulid::from(value.upstream_oauth_link_id),
provider_id: Ulid::from(value.upstream_oauth_provider_id),
user_id: value.user_id.map(Ulid::from),
subject: value.subject,
created_at: value.created_at,
}
}
}
#[async_trait]
impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_link.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_link.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_link_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.find_by_subject",
skip_all,
fields(
db.statement,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn find_by_subject(
&mut self,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: &str,
) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
AND subject = $2
"#,
Uuid::from(upstream_oauth_provider.id),
subject,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?
.map(Into::into);
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.add",
skip_all,
fields(
db.statement,
upstream_oauth_link.id,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
) -> Result<UpstreamOAuthLink, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_links (
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
) VALUES ($1, $2, NULL, $3, $4)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&subject,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthLink {
id,
provider_id: upstream_oauth_provider.id,
user_id: None,
subject,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_link.associate_to_user",
skip_all,
fields(
db.statement,
%upstream_oauth_link.id,
%upstream_oauth_link.subject,
%user.id,
%user.username,
),
err,
)]
async fn associate_to_user(
&mut self,
upstream_oauth_link: &UpstreamOAuthLink,
user: &User,
) -> Result<(), Self::Error> {
sqlx::query!(
r#"
UPDATE upstream_oauth_links
SET user_id = $1
WHERE upstream_oauth_link_id = $2
"#,
Uuid::from(user.id),
Uuid::from(upstream_oauth_link.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.upstream_oauth_link.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("upstream_oauth_link_id", pagination);
let edges: Vec<LinkLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
Ok(page)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -17,249 +17,6 @@ mod provider;
mod session;
pub use self::{
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository},
link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository,
session::UpstreamOAuthSessionRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use oauth2_types::scope::{Scope, OPENID};
use rand::SeedableRng;
use sqlx::PgPool;
use super::*;
use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository};
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repository(pool: PgPool) {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// The provider list should be empty at the start
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
assert!(all_providers.is_empty());
// Let's add a provider
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
"https://example.com/".to_owned(),
Scope::from_iter([OPENID]),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
"client-id".to_owned(),
None,
)
.await
.unwrap();
// Look it up in the database
let provider = repo
.upstream_oauth_provider()
.lookup(provider.id)
.await
.unwrap()
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.client_id, "client-id");
// Start a session
let session = repo
.upstream_oauth_session()
.add(
&mut rng,
&clock,
&provider,
"some-state".to_owned(),
None,
"some-nonce".to_owned(),
)
.await
.unwrap();
// Look it up in the database
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert_eq!(session.provider_id, provider.id);
assert_eq!(session.link_id(), None);
assert!(session.is_pending());
assert!(!session.is_completed());
assert!(!session.is_consumed());
// Create a link
let link = repo
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.await
.unwrap();
// We can look it up by its ID
repo.upstream_oauth_link()
.lookup(link.id)
.await
.unwrap()
.expect("link to be found in database");
// or by its subject
let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "a-subject")
.await
.unwrap()
.expect("link to be found in database");
assert_eq!(link.subject, "a-subject");
assert_eq!(link.provider_id, provider.id);
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.await
.unwrap();
// Reload the session
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert!(session.is_completed());
assert!(!session.is_consumed());
assert_eq!(session.link_id(), Some(link.id));
let session = repo
.upstream_oauth_session()
.consume(&clock, session)
.await
.unwrap();
// Reload the session
let session = repo
.upstream_oauth_session()
.lookup(session.id)
.await
.unwrap()
.expect("session to be found in the database");
assert!(session.is_consumed());
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await
.unwrap();
let links = repo
.upstream_oauth_link()
.list_paginated(&user, Pagination::first(10))
.await
.unwrap();
assert!(!links.has_previous_page);
assert!(!links.has_next_page);
assert_eq!(links.edges.len(), 1);
assert_eq!(links.edges[0].id, link.id);
assert_eq!(links.edges[0].user_id, Some(user.id));
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_provider_repository_pagination(pool: PgPool) {
const ISSUER: &str = "https://example.com/";
let scope = Scope::from_iter([OPENID]);
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::mock();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut ids = Vec::with_capacity(20);
// Create 20 providers
for idx in 0..20 {
let client_id = format!("client-{idx}");
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
ISSUER.to_owned(),
scope.clone(),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
client_id,
None,
)
.await
.unwrap();
ids.push(provider.id);
clock.advance(Duration::seconds(10));
}
// Lookup the first 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10))
.await
.unwrap();
// It returned the first 10 items
assert!(page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[..10]);
// Lookup the next 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10).after(ids[9]))
.await
.unwrap();
// It returned the next 10 items
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[10..]);
// Lookup the last 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::last(10))
.await
.unwrap();
// It returned the last 10 items
assert!(page.has_previous_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[10..]);
// Lookup the previous 10 items
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::last(10).before(ids[10]))
.await
.unwrap();
// It returned the previous 10 items
assert!(!page.has_previous_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[..10]);
// Lookup 10 items between two IDs
let page = repo
.upstream_oauth_provider()
.list_paginated(Pagination::first(10).after(ids[5]).before(ids[8]))
.await
.unwrap();
// It returned the items in between
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[6..8]);
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,20 +13,13 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::UpstreamOAuthProvider;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait UpstreamOAuthProviderRepository: Send + Sync {
@ -58,247 +51,3 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
/// Get all upstream OAuth providers
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
}
pub struct PgUpstreamOAuthProviderRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthProviderRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct ProviderLookup {
upstream_oauth_provider_id: Uuid,
issuer: String,
scope: String,
client_id: String,
encrypted_client_secret: Option<String>,
token_endpoint_signing_alg: Option<String>,
token_endpoint_auth_method: String,
created_at: DateTime<Utc>,
}
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
type Error = DatabaseInconsistencyError;
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_provider_id.into();
let scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("scope")
.row(id)
.source(e)
})?;
let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
})?;
let token_endpoint_signing_alg = value
.token_endpoint_signing_alg
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_signing_alg")
.row(id)
.source(e)
})?;
Ok(UpstreamOAuthProvider {
id,
issuer: value.issuer,
scope,
client_id: value.client_id,
encrypted_client_secret: value.encrypted_client_secret,
token_endpoint_auth_method,
token_endpoint_signing_alg,
created_at: value.created_at,
})
}
}
#[async_trait]
impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_provider.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let res = res
.map(UpstreamOAuthProvider::try_from)
.transpose()
.map_err(DatabaseError::from)?;
Ok(res)
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.add",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_providers (
upstream_oauth_provider_id,
issuer,
scope,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id,
encrypted_client_secret,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#,
Uuid::from(id),
&issuer,
scope.to_string(),
token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
&client_id,
encrypted_client_secret.as_deref(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.list_paginated",
skip_all,
fields(
db.statement,
),
err,
)]
async fn list_paginated(
&mut self,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE 1 = 1
"#,
);
query.generate_pagination("upstream_oauth_provider_id", pagination);
let edges: Vec<ProviderLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).try_map(TryInto::try_into)?;
Ok(page)
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.all",
skip_all,
fields(
db.statement,
),
err,
)]
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
"#,
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
Ok(res?)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,19 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
UpstreamOAuthProvider,
};
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
use crate::Clock;
#[async_trait]
pub trait UpstreamOAuthSessionRepository: Send + Sync {
@ -64,262 +56,3 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
}
pub struct PgUpstreamOAuthSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct SessionLookup {
upstream_oauth_authorization_session_id: Uuid,
upstream_oauth_provider_id: Uuid,
upstream_oauth_link_id: Option<Uuid>,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
id_token: Option<String>,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
consumed_at: Option<DateTime<Utc>>,
}
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_authorization_session_id.into();
let state = match (
value.upstream_oauth_link_id,
value.id_token,
value.completed_at,
value.consumed_at,
) {
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(Some(link_id), id_token, Some(completed_at), None) => {
UpstreamOAuthAuthorizationSessionState::Completed {
completed_at,
link_id: link_id.into(),
id_token,
}
}
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
UpstreamOAuthAuthorizationSessionState::Consumed {
completed_at,
link_id: link_id.into(),
id_token,
consumed_at,
}
}
_ => {
return Err(
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
)
}
};
Ok(Self {
id,
provider_id: value.upstream_oauth_provider_id.into(),
state_str: value.state,
nonce: value.nonce,
code_challenge_verifier: value.code_challenge_verifier,
created_at: value.created_at,
state,
})
}
}
#[async_trait]
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.lookup",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
),
err,
)]
async fn lookup(
&mut self,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
id_token,
created_at,
completed_at,
consumed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.add",
skip_all,
fields(
db.statement,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state_str: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"upstream_oauth_authorization_session.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO upstream_oauth_authorization_sessions (
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state_str,
code_challenge_verifier.as_deref(),
nonce,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthAuthorizationSession {
id,
state: UpstreamOAuthAuthorizationSessionState::default(),
provider_id: upstream_oauth_provider.id,
state_str,
code_challenge_verifier,
nonce,
created_at,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.complete_with_link",
skip_all,
fields(
db.statement,
%upstream_oauth_authorization_session.id,
%upstream_oauth_link.id,
),
err,
)]
async fn complete_with_link(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.complete(completed_at, upstream_oauth_link, id_token)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}
/// Mark a session as consumed
#[tracing::instrument(
name = "db.upstream_oauth_authorization_session.consume",
skip_all,
fields(
db.statement,
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET consumed_at = $1
WHERE upstream_oauth_authorization_session_id = $2
"#,
consumed_at,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,19 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState};
use mas_data_model::{User, UserEmail, UserEmailVerification};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait UserEmailRepository: Send + Sync {
@ -82,529 +74,3 @@ pub trait UserEmailRepository: Send + Sync {
verification: UserEmailVerification,
) -> Result<UserEmailVerification, Self::Error>;
}
pub struct PgUserEmailRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUserEmailRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(Debug, Clone, sqlx::FromRow)]
struct UserEmailLookup {
user_email_id: Uuid,
user_id: Uuid,
email: String,
created_at: DateTime<Utc>,
confirmed_at: Option<DateTime<Utc>>,
}
impl From<UserEmailLookup> for UserEmail {
fn from(e: UserEmailLookup) -> UserEmail {
UserEmail {
id: e.user_email_id.into(),
user_id: e.user_id.into(),
email: e.email,
created_at: e.created_at,
confirmed_at: e.confirmed_at,
}
}
}
struct UserEmailConfirmationCodeLookup {
user_email_confirmation_code_id: Uuid,
user_email_id: Uuid,
code: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
}
impl UserEmailConfirmationCodeLookup {
fn into_verification(self, clock: &Clock) -> UserEmailVerification {
let now = clock.now();
let state = if let Some(when) = self.consumed_at {
UserEmailVerificationState::AlreadyUsed { when }
} else if self.expires_at < now {
UserEmailVerificationState::Expired {
when: self.expires_at,
}
} else {
UserEmailVerificationState::Valid
};
UserEmailVerification {
id: self.user_email_confirmation_code_id.into(),
user_email_id: self.user_email_id.into(),
code: self.code,
state,
created_at: self.created_at,
}
}
}
#[async_trait]
impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.user_email.lookup",
skip_all,
fields(
db.statement,
user_email.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
WHERE user_email_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(user_email) = res else { return Ok(None) };
Ok(Some(user_email.into()))
}
#[tracing::instrument(
name = "db.user_email.find",
skip_all,
fields(
db.statement,
%user.id,
user_email.email = email,
),
err,
)]
async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
WHERE user_id = $1 AND email = $2
"#,
Uuid::from(user.id),
email,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(user_email) = res else { return Ok(None) };
Ok(Some(user_email.into()))
}
#[tracing::instrument(
name = "db.user_email.get_primary",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn get_primary(&mut self, user: &User) -> Result<Option<UserEmail>, Self::Error> {
let Some(id) = user.primary_user_email_id else { return Ok(None) };
let user_email = self.lookup(id).await?.ok_or_else(|| {
DatabaseInconsistencyError::on("users")
.column("primary_user_email_id")
.row(user.id)
})?;
Ok(Some(user_email))
}
#[tracing::instrument(
name = "db.user_email.all",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
WHERE user_id = $1
ORDER BY email ASC
"#,
Uuid::from(user.id),
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
Ok(res.into_iter().map(Into::into).collect())
}
#[tracing::instrument(
name = "db.user_email.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<UserEmail>, DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("ue.user_email_id", pagination);
let edges: Vec<UserEmailLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UserEmail::from);
Ok(page)
}
#[tracing::instrument(
name = "db.user_email.count",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn count(&mut self, user: &User) -> Result<usize, Self::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*)
FROM user_emails
WHERE user_id = $1
"#,
Uuid::from(user.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
let res = res.unwrap_or_default();
Ok(res
.try_into()
.map_err(DatabaseError::to_invalid_operation)?)
}
#[tracing::instrument(
name = "db.user_email.add",
skip_all,
fields(
db.statement,
%user.id,
user_email.id,
user_email.email = email,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
email: String,
) -> Result<UserEmail, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO user_emails (user_email_id, user_id, email, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
&email,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(UserEmail {
id,
user_id: user.id,
email,
created_at,
confirmed_at: None,
})
}
#[tracing::instrument(
name = "db.user_email.remove",
skip_all,
fields(
db.statement,
user.id = %user_email.user_id,
%user_email.id,
%user_email.email,
),
err,
)]
async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
let span = info_span!(
"db.user_email.remove.codes",
db.statement = tracing::field::Empty
);
sqlx::query!(
r#"
DELETE FROM user_email_confirmation_codes
WHERE user_email_id = $1
"#,
Uuid::from(user_email.id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
let res = sqlx::query!(
r#"
DELETE FROM user_emails
WHERE user_email_id = $1
"#,
Uuid::from(user_email.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(())
}
async fn mark_as_verified(
&mut self,
clock: &Clock,
mut user_email: UserEmail,
) -> Result<UserEmail, Self::Error> {
let confirmed_at = clock.now();
sqlx::query!(
r#"
UPDATE user_emails
SET confirmed_at = $2
WHERE user_email_id = $1
"#,
Uuid::from(user_email.id),
confirmed_at,
)
.execute(&mut *self.conn)
.await?;
user_email.confirmed_at = Some(confirmed_at);
Ok(user_email)
}
async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> {
sqlx::query!(
r#"
UPDATE users
SET primary_user_email_id = user_emails.user_email_id
FROM user_emails
WHERE user_emails.user_email_id = $1
AND users.user_id = user_emails.user_id
"#,
Uuid::from(user_email.id),
)
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.user_email.add_verification_code",
skip_all,
fields(
db.statement,
%user_email.id,
%user_email.email,
user_email_verification.id,
user_email_verification.code = code,
),
err,
)]
async fn add_verification_code(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user_email: &UserEmail,
max_age: chrono::Duration,
code: String,
) -> Result<UserEmailVerification, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
let expires_at = created_at + max_age;
sqlx::query!(
r#"
INSERT INTO user_email_confirmation_codes
(user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(user_email.id),
code,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
let verification = UserEmailVerification {
id,
user_email_id: user_email.id,
code,
created_at,
state: UserEmailVerificationState::Valid,
};
Ok(verification)
}
#[tracing::instrument(
name = "db.user_email.find_verification_code",
skip_all,
fields(
db.statement,
%user_email.id,
user.id = %user_email.user_id,
),
err,
)]
async fn find_verification_code(
&mut self,
clock: &Clock,
user_email: &UserEmail,
code: &str,
) -> Result<Option<UserEmailVerification>, Self::Error> {
let res = sqlx::query_as!(
UserEmailConfirmationCodeLookup,
r#"
SELECT user_email_confirmation_code_id
, user_email_id
, code
, created_at
, expires_at
, consumed_at
FROM user_email_confirmation_codes
WHERE code = $1
AND user_email_id = $2
"#,
code,
Uuid::from(user_email.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into_verification(clock)))
}
#[tracing::instrument(
name = "db.user_email.consume_verification_code",
skip_all,
fields(
db.statement,
%user_email_verification.id,
user_email.id = %user_email_verification.user_email_id,
),
err,
)]
async fn consume_verification_code(
&mut self,
clock: &Clock,
mut user_email_verification: UserEmailVerification,
) -> Result<UserEmailVerification, Self::Error> {
if !matches!(
user_email_verification.state,
UserEmailVerificationState::Valid
) {
return Err(DatabaseError::invalid_operation());
}
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE user_email_confirmation_codes
SET consumed_at = $2
WHERE user_email_confirmation_code_id = $1
"#,
Uuid::from(user_email_verification.id),
consumed_at
)
.traced()
.execute(&mut *self.conn)
.await?;
user_email_verification.state =
UserEmailVerificationState::AlreadyUsed { when: consumed_at };
Ok(user_email_verification)
}
}

View File

@ -13,26 +13,18 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::User;
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
use crate::Clock;
mod email;
mod password;
mod session;
#[cfg(test)]
mod tests;
pub use self::{
email::{PgUserEmailRepository, UserEmailRepository},
password::{PgUserPasswordRepository, UserPasswordRepository},
session::{BrowserSessionRepository, PgBrowserSessionRepository},
email::UserEmailRepository, password::UserPasswordRepository, session::BrowserSessionRepository,
};
#[async_trait]
@ -49,170 +41,3 @@ pub trait UserRepository: Send + Sync {
) -> Result<User, Self::Error>;
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;
}
pub struct PgUserRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUserRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(Debug, Clone)]
struct UserLookup {
user_id: Uuid,
username: String,
primary_user_email_id: Option<Uuid>,
#[allow(dead_code)]
created_at: DateTime<Utc>,
}
impl From<UserLookup> for User {
fn from(value: UserLookup) -> Self {
let id = value.user_id.into();
Self {
id,
username: value.username,
sub: id.to_string(),
primary_user_email_id: value.primary_user_email_id.map(Into::into),
}
}
}
#[async_trait]
impl<'c> UserRepository for PgUserRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.user.lookup",
skip_all,
fields(
db.statement,
user.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
let res = sqlx::query_as!(
UserLookup,
r#"
SELECT user_id
, username
, primary_user_email_id
, created_at
FROM users
WHERE user_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.user.find_by_username",
skip_all,
fields(
db.statement,
user.username = username,
),
err,
)]
async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
let res = sqlx::query_as!(
UserLookup,
r#"
SELECT user_id
, username
, primary_user_email_id
, created_at
FROM users
WHERE username = $1
"#,
username,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.user.add",
skip_all,
fields(
db.statement,
user.username = username,
user.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
username: String,
) -> Result<User, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO users (user_id, username, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
username,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(User {
id,
username,
sub: id.to_string(),
primary_user_email_id: None,
})
}
#[tracing::instrument(
name = "db.user.exists",
skip_all,
fields(
db.statement,
user.username = username,
),
err,
)]
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
let exists = sqlx::query_scalar!(
r#"
SELECT EXISTS(
SELECT 1 FROM users WHERE username = $1
) AS "exists!"
"#,
username
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
Ok(exists)
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,16 +13,10 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{Password, User};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
use crate::Clock;
#[async_trait]
pub trait UserPasswordRepository: Send + Sync {
@ -39,134 +33,3 @@ pub trait UserPasswordRepository: Send + Sync {
upgraded_from: Option<&Password>,
) -> Result<Password, Self::Error>;
}
pub struct PgUserPasswordRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUserPasswordRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct UserPasswordLookup {
user_password_id: Uuid,
hashed_password: String,
version: i32,
upgraded_from_id: Option<Uuid>,
created_at: DateTime<Utc>,
}
#[async_trait]
impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.user_password.active",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err,
)]
async fn active(&mut self, user: &User) -> Result<Option<Password>, Self::Error> {
let res = sqlx::query_as!(
UserPasswordLookup,
r#"
SELECT up.user_password_id
, up.hashed_password
, up.version
, up.upgraded_from_id
, up.created_at
FROM user_passwords up
WHERE up.user_id = $1
ORDER BY up.created_at DESC
LIMIT 1
"#,
Uuid::from(user.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = Ulid::from(res.user_password_id);
let version = res.version.try_into().map_err(|e| {
DatabaseInconsistencyError::on("user_passwords")
.column("version")
.row(id)
.source(e)
})?;
let upgraded_from_id = res.upgraded_from_id.map(Ulid::from);
let created_at = res.created_at;
let hashed_password = res.hashed_password;
Ok(Some(Password {
id,
hashed_password,
version,
upgraded_from_id,
created_at,
}))
}
#[tracing::instrument(
name = "db.user_password.add",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
user_password.id,
user_password.version = version,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
version: u16,
hashed_password: String,
upgraded_from: Option<&Password>,
) -> Result<Password, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_password.id", tracing::field::display(id));
let upgraded_from_id = upgraded_from.map(|p| p.id);
sqlx::query!(
r#"
INSERT INTO user_passwords
(user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
Uuid::from(id),
Uuid::from(user.id),
hashed_password,
i32::from(version),
upgraded_from_id.map(Uuid::from),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(Password {
id,
hashed_password,
version,
upgraded_from_id,
created_at,
})
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,18 +13,11 @@
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
};
use crate::{pagination::Page, Clock, Pagination};
#[async_trait]
pub trait BrowserSessionRepository: Send + Sync {
@ -65,351 +58,3 @@ pub trait BrowserSessionRepository: Send + Sync {
upstream_oauth_link: &UpstreamOAuthLink,
) -> Result<BrowserSession, Self::Error>;
}
pub struct PgBrowserSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgBrowserSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct SessionLookup {
user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
user_session_finished_at: Option<DateTime<Utc>>,
user_id: Uuid,
user_username: String,
user_primary_user_email_id: Option<Uuid>,
last_authentication_id: Option<Uuid>,
last_authd_at: Option<DateTime<Utc>>,
}
impl TryFrom<SessionLookup> for BrowserSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
let id = Ulid::from(value.user_id);
let user = User {
id,
username: value.user_username,
sub: id.to_string(),
primary_user_email_id: value.user_primary_user_email_id.map(Into::into),
};
let last_authentication = match (value.last_authentication_id, value.last_authd_at) {
(Some(id), Some(created_at)) => Some(Authentication {
id: id.into(),
created_at,
}),
(None, None) => None,
_ => {
return Err(DatabaseInconsistencyError::on(
"user_session_authentications",
))
}
};
Ok(BrowserSession {
id: value.user_session_id.into(),
user,
created_at: value.user_session_created_at,
finished_at: value.user_session_finished_at,
last_authentication,
})
}
}
#[async_trait]
impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.browser_session.lookup",
skip_all,
fields(
db.statement,
user_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT s.user_session_id
, s.created_at AS "user_session_created_at"
, s.finished_at AS "user_session_finished_at"
, u.user_id
, u.username AS "user_username"
, u.primary_user_email_id AS "user_primary_user_email_id"
, a.user_session_authentication_id AS "last_authentication_id?"
, a.created_at AS "last_authd_at?"
FROM user_sessions s
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications a
USING (user_session_id)
WHERE s.user_session_id = $1
ORDER BY a.created_at DESC
LIMIT 1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.browser_session.add",
skip_all,
fields(
db.statement,
%user.id,
user_session.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
) -> Result<BrowserSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO user_sessions (user_session_id, user_id, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(user.id),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
let session = BrowserSession {
id,
// XXX
user: user.clone(),
created_at,
finished_at: None,
last_authentication: None,
};
Ok(session)
}
#[tracing::instrument(
name = "db.browser_session.finish",
skip_all,
fields(
db.statement,
%user_session.id,
),
err,
)]
async fn finish(
&mut self,
clock: &Clock,
mut user_session: BrowserSession,
) -> Result<BrowserSession, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE user_sessions
SET finished_at = $1
WHERE user_session_id = $2
"#,
finished_at,
Uuid::from(user_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
user_session.finished_at = Some(finished_at);
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(user_session)
}
#[tracing::instrument(
name = "db.browser_session.list_active_paginated",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn list_active_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<BrowserSession>, Self::Error> {
// TODO: ordering of last authentication is wrong
let mut query = QueryBuilder::new(
r#"
SELECT DISTINCT ON (s.user_session_id)
s.user_session_id,
u.user_id,
u.username,
s.created_at,
a.user_session_authentication_id AS "last_authentication_id",
a.created_at AS "last_authd_at",
FROM user_sessions s
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications a
USING (user_session_id)
"#,
);
query
.push(" WHERE s.finished_at IS NULL AND s.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("s.user_session_id", pagination);
let edges: Vec<SessionLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination
.process(edges)
.try_map(BrowserSession::try_from)?;
Ok(page)
}
#[tracing::instrument(
name = "db.browser_session.count_active",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM user_sessions s
WHERE s.user_id = $1 AND s.finished_at IS NULL
"#,
Uuid::from(user.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
res.try_into().map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.browser_session.authenticate_with_password",
skip_all,
fields(
db.statement,
%user_session.id,
%user_password.id,
user_session_authentication.id,
),
err,
)]
async fn authenticate_with_password(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
mut user_session: BrowserSession,
user_password: &Password,
) -> Result<BrowserSession, Self::Error> {
let _user_password = user_password;
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"user_session_authentication.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
user_session.last_authentication = Some(Authentication { id, created_at });
Ok(user_session)
}
#[tracing::instrument(
name = "db.browser_session.authenticate_with_upstream",
skip_all,
fields(
db.statement,
%user_session.id,
%upstream_oauth_link.id,
user_session_authentication.id,
),
err,
)]
async fn authenticate_with_upstream(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
mut user_session: BrowserSession,
upstream_oauth_link: &UpstreamOAuthLink,
) -> Result<BrowserSession, Self::Error> {
let _upstream_oauth_link = upstream_oauth_link;
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"user_session_authentication.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
user_session.last_authentication = Some(Authentication { id, created_at });
Ok(user_session)
}
}

View File

@ -14,3 +14,4 @@ tracing = "0.1.37"
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }

View File

@ -14,7 +14,8 @@
//! Database-related tasks
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, PgRepository, Repository};
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository};
use mas_storage_pg::PgRepository;
use sqlx::{Pool, Postgres};
use tracing::{debug, error, info};