diff --git a/Cargo.lock b/Cargo.lock index b780f20f..6ecc512d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2705,6 +2705,7 @@ dependencies = [ "mas-router", "mas-spa", "mas-storage", + "mas-storage-pg", "mas-tasks", "mas-templates", "oauth2-types", @@ -2803,6 +2804,7 @@ dependencies = [ "chrono", "mas-data-model", "mas-storage", + "mas-storage-pg", "oauth2-types", "serde", "sqlx", @@ -2843,6 +2845,7 @@ dependencies = [ "mas-policy", "mas-router", "mas-storage", + "mas-storage-pg", "mas-templates", "mime", "oauth2-types", @@ -3103,6 +3106,23 @@ dependencies = [ "mas-jose", "oauth2-types", "rand 0.8.5", + "thiserror", + "ulid", + "url", +] + +[[package]] +name = "mas-storage-pg" +version = "0.1.0" +dependencies = [ + "async-trait", + "chrono", + "mas-data-model", + "mas-iana", + "mas-jose", + "mas-storage", + "oauth2-types", + "rand 0.8.5", "rand_chacha 0.3.1", "serde", "serde_json", @@ -3121,6 +3141,7 @@ dependencies = [ "async-trait", "futures-util", "mas-storage", + "mas-storage-pg", "sqlx", "tokio", "tokio-stream", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index a3bc4ec3..451e8d32 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -50,6 +50,7 @@ mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-spa = { path = "../spa" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } mas-tasks = { path = "../tasks" } mas-templates = { path = "../templates" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/cli/src/commands/database.rs b/crates/cli/src/commands/database.rs index ca59ce1d..0e4d68af 100644 --- a/crates/cli/src/commands/database.rs +++ b/crates/cli/src/commands/database.rs @@ -15,7 +15,7 @@ use anyhow::Context; use clap::Parser; use mas_config::DatabaseConfig; -use mas_storage::MIGRATOR; +use mas_storage_pg::MIGRATOR; use tracing::{info_span, Instrument}; use crate::util::database_from_config; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 15378940..c608db83 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -21,8 +21,9 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; use rand::SeedableRng; use tracing::{info, info_span, warn}; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index fb2a3f16..1a7e39e6 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ use mas_config::RootConfig; use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_router::UrlBuilder; -use mas_storage::MIGRATOR; +use mas_storage_pg::MIGRATOR; use mas_tasks::TaskQueue; use tokio::signal::unix::SignalKind; use tracing::{info, info_span, warn, Instrument}; diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 661e9bbc..b96d7b83 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -11,8 +11,8 @@ thiserror = "1.0.38" serde = "1.0.152" url = { version = "2.3.1", features = ["serde"] } crc = "3.0.0" +ulid = { version = "1.0.0", features = ["serde"] } rand = "0.8.5" -ulid = "1.0.0" rand_chacha = "0.3.1" mas-iana = { path = "../iana" } diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index ff3159b0..16f7e5b5 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -19,6 +19,7 @@ url = "2.3.1" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } [[bin]] name = "schema" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index d2de0c24..159387ae 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -34,8 +34,9 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, PgRepository, Repository, + Pagination, Repository, }; +use mas_storage_pg::PgRepository; use model::CreationEvent; use sqlx::PgPool; diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index a2196e36..e5cd66bc 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,9 +15,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{ - compat::CompatSessionRepository, user::UserRepository, PgRepository, Repository, -}; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use url::Url; diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 171c800f..90a0c6b7 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,9 +14,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{ - oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, PgRepository, Repository, -}; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; +use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 4a4c223b..5767f8d4 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -16,9 +16,9 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, PgRepository, - Repository, + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use super::{NodeType, User}; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 68daff1b..3f587eb0 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -22,8 +22,9 @@ use mas_storage::{ oauth2::OAuth2SessionRepository, upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, PgRepository, Repository, + Pagination, Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use super::{ diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 47dd2775..b8fbb5af 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -68,6 +68,7 @@ mas-oidc-client = { path = "../oidc-client" } mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } mas-templates = { path = "../templates" } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index f344f7e0..bfd36d8a 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,8 +22,9 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use sqlx::PgPool; @@ -154,7 +155,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 7ca61ab2..1fea922e 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,8 +31,9 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index befd3e32..38aa0894 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,7 +19,8 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::{compat::CompatSsoLoginRepository, PgRepository, Repository}; +use mas_storage::{compat::CompatSsoLoginRepository, Repository}; +use mas_storage_pg::PgRepository; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -49,7 +50,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 762f77b2..21229fe7 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,8 +18,9 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use thiserror::Error; @@ -42,7 +43,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index ea6d5d23..e1601395 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,8 +18,9 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; use sqlx::PgPool; @@ -70,7 +71,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index d3a610b6..fcc6aa3c 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -28,7 +28,7 @@ use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; -use mas_storage::PgRepository; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use tracing::{info_span, Instrument}; diff --git a/crates/handlers/src/health.rs b/crates/handlers/src/health.rs index 6322dffd..10638497 100644 --- a/crates/handlers/src/health.rs +++ b/crates/handlers/src/health.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ mod tests { use super::*; - #[sqlx::test(migrator = "mas_storage::MIGRATOR")] + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> { let state = crate::test_state(pool).await?; let app = crate::healthcheck_router().with_state(state); diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index c983e79c..05554e12 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -27,8 +27,9 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use sqlx::PgPool; @@ -70,7 +71,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -149,7 +150,7 @@ pub enum GrantCompletionError { } impl_from_error_for_route!(GrantCompletionError: sqlx::Error); -impl_from_error_for_route!(GrantCompletionError: mas_storage::DatabaseError); +impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 155f72f7..43bda928 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -27,8 +27,9 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -91,7 +92,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index f3d4fd46..b0f752f7 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,8 +30,9 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use sqlx::PgPool; use thiserror::Error; @@ -62,7 +63,7 @@ pub enum RouteError { impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 2837928f..e8f9941f 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,8 +25,9 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, requests::{IntrospectionRequest, IntrospectionResponse}, @@ -97,7 +98,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index d6180f9a..8e9489e8 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,8 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::OAuth2ClientRepository, PgRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; +use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -49,7 +50,7 @@ pub(crate) enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 6365a0ad..67ecb498 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,8 +37,9 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, pkce::CodeChallengeError, @@ -151,7 +152,7 @@ impl IntoResponse for RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::claims::ClaimError); impl_from_error_for_route!(mas_jose::claims::TokenHashError); diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index a125c5dd..2f560037 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,8 +31,9 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - DatabaseError, PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; @@ -64,7 +65,9 @@ pub enum RouteError { Internal(Box), #[error("failed to authenticate")] - AuthorizationVerificationError(#[from] AuthorizationVerificationError), + AuthorizationVerificationError( + #[from] AuthorizationVerificationError, + ), #[error("no suitable key found for signing")] InvalidSigningKey, @@ -77,7 +80,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index bdd19b7b..fcf5a7d1 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,8 +24,9 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -46,7 +47,7 @@ impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 521efd7b..d243666d 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -30,8 +30,9 @@ use mas_storage::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; use sqlx::PgPool; @@ -99,7 +100,7 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 18849be8..8709ff21 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -27,8 +27,9 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, @@ -73,7 +74,7 @@ impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); -impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_storage_pg::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index e0cc063d..e99c8e4d 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,7 +24,8 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, PgRepository, Repository}; +use mas_storage::{user::UserEmailRepository, Repository}; +use mas_storage_pg::PgRepository; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 3fda398a..4d70ab33 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,7 +28,8 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 085b9a33..2b398b42 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,7 +24,8 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 5017db00..8d2eb3e2 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,8 +25,9 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; use sqlx::PgPool; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 8d496432..089093f6 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -27,8 +27,9 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; use serde::Deserialize; diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 49668dae..cab7c743 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,7 +20,7 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; -use mas_storage::PgRepository; +use mas_storage_pg::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 76ffa455..87ba9e84 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,8 +26,9 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, + Clock, Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, }; diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 156e6afb..373264d0 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,7 +23,8 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, Clock, PgRepository, Repository}; +use mas_storage::{user::BrowserSessionRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use sqlx::PgPool; pub(crate) async fn post( diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index aac51abd..49249f3c 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,8 +26,9 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index a014eb9d..58db6ec1 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,8 +33,9 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - PgRepository, Repository, + Repository, }; +use mas_storage_pg::PgRepository; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates, ToFormState, diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml new file mode 100644 index 00000000..fad6e30e --- /dev/null +++ b/crates/storage-pg/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "mas-storage-pg" +version = "0.1.0" +authors = ["Quentin Gliech "] +edition = "2021" +license = "Apache-2.0" + +[dependencies] +async-trait = "0.1.60" +sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } +chrono = { version = "0.4.23", features = ["serde"] } +serde = { version = "1.0.152", features = ["derive"] } +serde_json = "1.0.91" +thiserror = "1.0.38" +tracing = "0.1.37" + +rand = "0.8.5" +rand_chacha = "0.3.1" +url = { version = "2.3.1", features = ["serde"] } +uuid = "1.2.2" +ulid = { version = "1.0.0", features = ["uuid", "serde"] } + +oauth2-types = { path = "../oauth2-types" } +mas-storage = { path = "../storage" } +mas-data-model = { path = "../data-model" } +mas-iana = { path = "../iana" } +mas-jose = { path = "../jose" } diff --git a/crates/storage/build.rs b/crates/storage-pg/build.rs similarity index 92% rename from crates/storage/build.rs rename to crates/storage-pg/build.rs index dd5b1142..dca71bd6 100644 --- a/crates/storage/build.rs +++ b/crates/storage-pg/build.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/crates/storage/migrations/20221018142001_init.sql b/crates/storage-pg/migrations/20221018142001_init.sql similarity index 100% rename from crates/storage/migrations/20221018142001_init.sql rename to crates/storage-pg/migrations/20221018142001_init.sql diff --git a/crates/storage/migrations/20221121151402_upstream_oauth.sql b/crates/storage-pg/migrations/20221121151402_upstream_oauth.sql similarity index 100% rename from crates/storage/migrations/20221121151402_upstream_oauth.sql rename to crates/storage-pg/migrations/20221121151402_upstream_oauth.sql diff --git a/crates/storage/migrations/20221213145242_password_schemes.sql b/crates/storage-pg/migrations/20221213145242_password_schemes.sql similarity index 100% rename from crates/storage/migrations/20221213145242_password_schemes.sql rename to crates/storage-pg/migrations/20221213145242_password_schemes.sql diff --git a/crates/storage/sqlx-data.json b/crates/storage-pg/sqlx-data.json similarity index 98% rename from crates/storage/sqlx-data.json rename to crates/storage-pg/sqlx-data.json index 8148f796..94527512 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage-pg/sqlx-data.json @@ -1336,6 +1336,24 @@ }, "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n " }, + "8a79c7c392dd930628caadec80c9b2645501475ab4feacbac59ca1bc52b16c3f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Jsonb", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " + }, "8b7297c263336d70c2b647212b16f7ae39bc5cb1572e3a2dcfcd67f196a1fa39": { "describe": { "columns": [ @@ -1821,24 +1839,6 @@ }, "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " }, - "c0b4996085f6f2127e1e8cfdf18b9029c22096fadfe6de59dce01c789791edb5": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Jsonb", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " - }, "c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523": { "describe": { "columns": [], diff --git a/crates/storage-pg/src/compat/access_token.rs b/crates/storage-pg/src/compat/access_token.rs new file mode 100644 index 00000000..5f73ed9e --- /dev/null +++ b/crates/storage-pg/src/compat/access_token.rs @@ -0,0 +1,216 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{CompatAccessToken, CompatSession}; +use mas_storage::{compat::CompatAccessTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgCompatAccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatAccessTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatAccessTokenLookup { + compat_access_token_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: Option>, + compat_session_id: Uuid, +} + +impl From for CompatAccessToken { + fn from(value: CompatAccessTokenLookup) -> Self { + Self { + id: value.compat_access_token_id.into(), + session_id: value.compat_session_id.into(), + token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_access_token.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE compat_access_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.add", + skip_all, + fields( + db.statement, + compat_access_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); + + let expires_at = expires_after.map(|expires_after| created_at + expires_after); + + sqlx::query!( + r#" + INSERT INTO compat_access_tokens + (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatAccessToken { + id, + session_id: compat_session.id, + token, + created_at, + expires_at, + }) + } + + #[tracing::instrument( + name = "db.compat_access_token.expire", + skip_all, + fields( + db.statement, + %compat_access_token.id, + compat_session.id = %compat_access_token.session_id, + ), + err, + )] + async fn expire( + &mut self, + clock: &Clock, + mut compat_access_token: CompatAccessToken, + ) -> Result { + let expires_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_access_tokens + SET expires_at = $2 + WHERE compat_access_token_id = $1 + "#, + Uuid::from(compat_access_token.id), + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + compat_access_token.expires_at = Some(expires_at); + Ok(compat_access_token) + } +} diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs new file mode 100644 index 00000000..732ce3aa --- /dev/null +++ b/crates/storage-pg/src/compat/mod.rs @@ -0,0 +1,322 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod access_token; +mod refresh_token; +mod session; +mod sso_login; + +pub use self::{ + access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository, + session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository, +}; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_data_model::Device; + use mas_storage::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + }, + user::UserRepository, + Clock, Repository, + }; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_session_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let device_str = device.as_str().to_owned(); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + assert_eq!(session.user_id, user.id); + assert_eq!(session.device.as_str(), device_str); + assert!(session.is_valid()); + assert!(!session.is_finished()); + + // Lookup the session and check it didn't change + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user_id, user.id); + assert_eq!(session_lookup.device.as_str(), device_str); + assert!(session_lookup.is_valid()); + assert!(!session_lookup.is_finished()); + + // Finish the session + let session = repo.compat_session().finish(&clock, session).await.unwrap(); + assert!(!session.is_valid()); + assert!(session.is_finished()); + + // Reload the session and check again + let session_lookup = repo + .compat_session() + .lookup(session.id) + .await + .unwrap() + .expect("compat session not found"); + assert!(!session_lookup.is_valid()); + assert!(session_lookup.is_finished()); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_access_token_repository(pool: PgPool) { + const FIRST_TOKEN: &str = "first_access_token"; + const SECOND_TOKEN: &str = "second_access_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let token = repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, FIRST_TOKEN); + + // Commit the txn and grab a new transaction, to test a conflict + repo.save().await.unwrap(); + + { + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + // Adding the same token a second time should conflict + assert!(repo + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + FIRST_TOKEN.to_owned(), + Some(Duration::minutes(1)), + ) + .await + .is_err()); + repo.cancel().await.unwrap(); + } + + // Grab a new repo + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Looking up via ID works + let token_lookup = repo + .compat_access_token() + .lookup(token.id) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Looking up via the token value works + let token_lookup = repo + .compat_access_token() + .find_by_token(FIRST_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + assert_eq!(token.id, token_lookup.id); + assert_eq!(token_lookup.session_id, session.id); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + clock.advance(Duration::minutes(1)); + // Token should have expired + assert!(!token.is_valid(clock.now())); + + // Add a second access token, this time without expiration + let token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None) + .await + .unwrap(); + assert_eq!(token.session_id, session.id); + assert_eq!(token.token, SECOND_TOKEN); + + // Token is currently valid + assert!(token.is_valid(clock.now())); + + // Make it expire + repo.compat_access_token() + .expire(&clock, token) + .await + .unwrap(); + + // Reload it + let token = repo + .compat_access_token() + .find_by_token(SECOND_TOKEN) + .await + .unwrap() + .expect("compat access token not found"); + + // Token is not valid anymore + assert!(!token.is_valid(clock.now())); + + repo.save().await.unwrap(); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_refresh_token_repository(pool: PgPool) { + const ACCESS_TOKEN: &str = "access_token"; + const REFRESH_TOKEN: &str = "refresh_token"; + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let session = repo + .compat_session() + .add(&mut rng, &clock, &user, device) + .await + .unwrap(); + + // Add an access token to that session + let access_token = repo + .compat_access_token() + .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None) + .await + .unwrap(); + + let refresh_token = repo + .compat_refresh_token() + .add( + &mut rng, + &clock, + &session, + &access_token, + REFRESH_TOKEN.to_owned(), + ) + .await + .unwrap(); + assert_eq!(refresh_token.session_id, session.id); + assert_eq!(refresh_token.access_token_id, access_token.id); + assert_eq!(refresh_token.token, REFRESH_TOKEN); + assert!(refresh_token.is_valid()); + assert!(!refresh_token.is_consumed()); + + // Look it up by ID and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .lookup(refresh_token.id) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Look it up by token and check everything matches + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert_eq!(refresh_token_lookup.id, refresh_token.id); + assert_eq!(refresh_token_lookup.session_id, session.id); + assert_eq!(refresh_token_lookup.access_token_id, access_token.id); + assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); + assert!(refresh_token_lookup.is_valid()); + assert!(!refresh_token_lookup.is_consumed()); + + // Consume it + let refresh_token = repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .unwrap(); + assert!(!refresh_token.is_valid()); + assert!(refresh_token.is_consumed()); + + // Reload it and check again + let refresh_token_lookup = repo + .compat_refresh_token() + .find_by_token(REFRESH_TOKEN) + .await + .unwrap() + .expect("refresh token not found"); + assert!(!refresh_token_lookup.is_valid()); + assert!(refresh_token_lookup.is_consumed()); + + // Consuming it again should not work + assert!(repo + .compat_refresh_token() + .consume(&clock, refresh_token) + .await + .is_err()); + + repo.save().await.unwrap(); + } +} diff --git a/crates/storage-pg/src/compat/refresh_token.rs b/crates/storage-pg/src/compat/refresh_token.rs new file mode 100644 index 00000000..314e8147 --- /dev/null +++ b/crates/storage-pg/src/compat/refresh_token.rs @@ -0,0 +1,230 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, +}; +use mas_storage::{compat::CompatRefreshTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgCompatRefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatRefreshTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatRefreshTokenLookup { + compat_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + compat_access_token_id: Uuid, + compat_session_id: Uuid, +} + +impl From for CompatRefreshToken { + fn from(value: CompatRefreshTokenLookup) -> Self { + let state = match value.consumed_at { + Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, + None => CompatRefreshTokenState::Valid, + }; + + Self { + id: value.compat_refresh_token_id.into(), + state, + session_id: value.compat_session_id.into(), + token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.compat_access_token_id.into(), + } + } +} + +#[async_trait] +impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_refresh_token.lookup", + skip_all, + fields( + db.statement, + compat_refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.add", + skip_all, + fields( + db.statement, + compat_refresh_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_refresh_tokens + (compat_refresh_token_id, compat_session_id, + compat_access_token_id, refresh_token, created_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + Uuid::from(compat_access_token.id), + token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatRefreshToken { + id, + state: CompatRefreshTokenState::default(), + session_id: compat_session.id, + access_token_id: compat_access_token.id, + token, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.consume", + skip_all, + fields( + db.statement, + %compat_refresh_token.id, + compat_session.id = %compat_refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_refresh_tokens + SET consumed_at = $2 + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(compat_refresh_token.id), + consumed_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_refresh_token = compat_refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_refresh_token) + } +} diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs new file mode 100644 index 00000000..a6e65f9a --- /dev/null +++ b/crates/storage-pg/src/compat/session.rs @@ -0,0 +1,195 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSessionState, Device, User}; +use mas_storage::{compat::CompatSessionRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgCompatSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatSessionLookup { + compat_session_id: Uuid, + device_id: String, + user_id: Uuid, + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for CompatSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: CompatSessionLookup) -> Result { + let id = value.compat_session_id.into(); + let device = Device::try_from(value.device_id).map_err(|e| { + DatabaseInconsistencyError::on("compat_sessions") + .column("device_id") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + + let session = CompatSession { + id, + state, + user_id: value.user_id.into(), + device, + created_at: value.created_at, + }; + + Ok(session) + } +} + +#[async_trait] +impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_session.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSessionLookup, + r#" + SELECT compat_session_id + , device_id + , user_id + , created_at + , finished_at + FROM compat_sessions + WHERE compat_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_session.add", + skip_all, + fields( + db.statement, + compat_session.id, + %user.id, + %user.username, + compat_session.device.id = device.as_str(), + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + device: Device, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + device.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSession { + id, + state: CompatSessionState::default(), + user_id: user.id, + device, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_session.finish", + skip_all, + fields( + db.statement, + %compat_session.id, + user.id = %compat_session.user_id, + compat_session.device.id = compat_session.device.as_str(), + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + compat_session: CompatSession, + ) -> Result { + let finished_at = clock.now(); + + let res = sqlx::query!( + r#" + UPDATE compat_sessions cs + SET finished_at = $2 + WHERE compat_session_id = $1 + "#, + Uuid::from(compat_session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_session = compat_session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_session) + } +} diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs new file mode 100644 index 00000000..a2eeb926 --- /dev/null +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -0,0 +1,342 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; +use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgCompatSsoLoginRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSsoLoginRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct CompatSsoLoginLookup { + compat_sso_login_id: Uuid, + login_token: String, + redirect_uri: String, + created_at: DateTime, + fulfilled_at: Option>, + exchanged_at: Option>, + compat_session_id: Option, +} + +impl TryFrom for CompatSsoLogin { + type Error = DatabaseInconsistencyError; + + fn try_from(res: CompatSsoLoginLookup) -> Result { + let id = res.compat_sso_login_id.into(); + let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| { + DatabaseInconsistencyError::on("compat_sso_logins") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { + (None, None, None) => CompatSsoLoginState::Pending, + (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { + fulfilled_at, + session_id: session_id.into(), + }, + (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { + CompatSsoLoginState::Exchanged { + fulfilled_at, + exchanged_at, + session_id: session_id.into(), + } + } + _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), + }; + + Ok(CompatSsoLogin { + id, + login_token: res.login_token, + redirect_uri, + created_at: res.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_sso_login.lookup", + skip_all, + fields( + db.statement, + compat_sso_login.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE compat_sso_login_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE login_token = $1 + "#, + login_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.add", + skip_all, + fields( + db.statement, + compat_sso_login.id, + compat_sso_login.redirect_uri = %redirect_uri, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + login_token: String, + redirect_uri: Url, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sso_logins + (compat_sso_login_id, login_token, redirect_uri, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + &login_token, + redirect_uri.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSsoLogin { + id, + login_token, + redirect_uri, + created_at, + state: CompatSsoLoginState::default(), + }) + } + + #[tracing::instrument( + name = "db.compat_sso_login.fulfill", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + %compat_session.id, + compat_session.device.id = compat_session.device.as_str(), + user.id = %compat_session.user_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result { + let fulfilled_at = clock.now(); + let compat_sso_login = compat_sso_login + .fulfill(fulfilled_at, compat_session) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + compat_session_id = $2, + fulfilled_at = $3 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + Uuid::from(compat_session.id), + fulfilled_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.exchange", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result { + let exchanged_at = clock.now(); + let compat_sso_login = compat_sso_login + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + exchanged_at = $2 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + exchanged_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT cl.compat_sso_login_id + , cl.login_token + , cl.redirect_uri + , cl.created_at + , cl.fulfilled_at + , cl.exchanged_at + , cl.compat_session_id + + FROM compat_sso_logins cl + INNER JOIN compat_sessions ON compat_session_id + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("cl.compat_sso_login_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination + .process(edges) + .try_map(CompatSsoLogin::try_from)?; + Ok(page) + } +} diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs new file mode 100644 index 00000000..459c8c3b --- /dev/null +++ b/crates/storage-pg/src/lib.rs @@ -0,0 +1,170 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Interactions with the database + +#![forbid(unsafe_code)] +#![deny( + clippy::all, + clippy::str_to_string, + clippy::future_not_send, + rustdoc::broken_intra_doc_links +)] +#![warn(clippy::pedantic)] +#![allow( + clippy::missing_errors_doc, + clippy::missing_panics_doc, + clippy::module_name_repetitions +)] + +use sqlx::{migrate::Migrator, postgres::PgQueryResult}; +use thiserror::Error; +use ulid::Ulid; + +trait LookupResultExt { + type Output; + + /// Transform a [`Result`] from a sqlx query to transform "not found" errors + /// into [`None`] + fn to_option(self) -> Result, sqlx::Error>; +} + +impl LookupResultExt for Result { + type Output = T; + + fn to_option(self) -> Result, sqlx::Error> { + match self { + Ok(v) => Ok(Some(v)), + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(e), + } + } +} + +/// Generic error when interacting with the database +#[derive(Debug, Error)] +#[error(transparent)] +pub enum DatabaseError { + /// An error which came from the database itself + Driver(#[from] sqlx::Error), + + /// An error which occured while converting the data from the database + Inconsistency(#[from] DatabaseInconsistencyError), + + /// An error which happened because the requested database operation is + /// invalid + #[error("Invalid database operation")] + InvalidOperation { + #[source] + source: Option>, + }, + + /// An error which happens when an operation affects not enough or too many + /// rows + #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] + RowsAffected { expected: u64, actual: u64 }, +} + +impl DatabaseError { + pub(crate) fn ensure_affected_rows( + result: &PgQueryResult, + expected: u64, + ) -> Result<(), DatabaseError> { + let actual = result.rows_affected(); + if actual == expected { + Ok(()) + } else { + Err(DatabaseError::RowsAffected { expected, actual }) + } + } + + pub(crate) fn to_invalid_operation(e: E) -> Self { + Self::InvalidOperation { + source: Some(Box::new(e)), + } + } + + pub(crate) const fn invalid_operation() -> Self { + Self::InvalidOperation { source: None } + } +} + +#[derive(Debug, Error)] +pub struct DatabaseInconsistencyError { + table: &'static str, + column: Option<&'static str>, + row: Option, + + #[source] + source: Option>, +} + +impl std::fmt::Display for DatabaseInconsistencyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Database inconsistency on table {}", self.table)?; + if let Some(column) = self.column { + write!(f, " column {column}")?; + } + if let Some(row) = self.row { + write!(f, " row {row}")?; + } + + Ok(()) + } +} + +impl DatabaseInconsistencyError { + #[must_use] + pub(crate) const fn on(table: &'static str) -> Self { + Self { + table, + column: None, + row: None, + source: None, + } + } + + #[must_use] + pub(crate) const fn column(mut self, column: &'static str) -> Self { + self.column = Some(column); + self + } + + #[must_use] + pub(crate) const fn row(mut self, row: Ulid) -> Self { + self.row = Some(row); + self + } + + pub(crate) fn source( + mut self, + source: E, + ) -> Self { + self.source = Some(Box::new(source)); + self + } +} + +pub mod compat; +pub mod oauth2; +pub(crate) mod pagination; +pub(crate) mod repository; +pub(crate) mod tracing; +pub mod upstream_oauth2; +pub mod user; + +pub use self::repository::PgRepository; + +/// Embedded migrations, allowing them to run on startup +pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage-pg/src/oauth2/access_token.rs b/crates/storage-pg/src/oauth2/access_token.rs new file mode 100644 index 00000000..33d95242 --- /dev/null +++ b/crates/storage-pg/src/oauth2/access_token.rs @@ -0,0 +1,223 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{AccessToken, AccessTokenState, Session}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgOAuth2AccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AccessTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2AccessTokenLookup { + oauth2_access_token_id: Uuid, + oauth2_session_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: DateTime, + revoked_at: Option>, +} + +impl From for AccessToken { + fn from(value: OAuth2AccessTokenLookup) -> Self { + let state = match value.revoked_at { + None => AccessTokenState::Valid, + Some(revoked_at) => AccessTokenState::Revoked { revoked_at }, + }; + + Self { + id: value.oauth2_access_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + access_token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { + type Error = DatabaseError; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" + SELECT oauth2_access_token_id + , access_token + , created_at + , expires_at + , revoked_at + , oauth2_session_id + + FROM oauth2_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_access_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + access_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result { + let created_at = clock.now(); + let expires_at = created_at + expires_after; + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + + tracing::Span::current().record("access_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_access_tokens + (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + &access_token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(AccessToken { + id, + state: AccessTokenState::default(), + access_token, + session_id: session.id, + created_at, + expires_at, + }) + } + + async fn revoke( + &mut self, + clock: &Clock, + access_token: AccessToken, + ) -> Result { + let revoked_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_access_tokens + SET revoked_at = $2 + WHERE oauth2_access_token_id = $1 + "#, + Uuid::from(access_token.id), + revoked_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + access_token + .revoke(revoked_at) + .map_err(DatabaseError::to_invalid_operation) + } + + async fn cleanup_expired(&mut self, clock: &Clock) -> Result { + // Cleanup token which expired more than 15 minutes ago + let threshold = clock.now() - Duration::minutes(15); + let res = sqlx::query!( + r#" + DELETE FROM oauth2_access_tokens + WHERE expires_at < $1 + "#, + threshold, + ) + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } +} diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs new file mode 100644 index 00000000..027111a7 --- /dev/null +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -0,0 +1,510 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::num::NonZeroU32; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, +}; +use mas_iana::oauth::PkceCodeChallengeMethod; +use mas_storage::{oauth2::OAuth2AuthorizationGrantRepository, Clock}; +use oauth2_types::{requests::ResponseMode, scope::Scope}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgOAuth2AuthorizationGrantRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2AuthorizationGrantRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[allow(clippy::struct_excessive_bools)] +struct GrantLookup { + oauth2_authorization_grant_id: Uuid, + created_at: DateTime, + cancelled_at: Option>, + fulfilled_at: Option>, + exchanged_at: Option>, + scope: String, + state: Option, + nonce: Option, + redirect_uri: String, + response_mode: String, + max_age: Option, + response_type_code: bool, + response_type_id_token: bool, + authorization_code: Option, + code_challenge: Option, + code_challenge_method: Option, + requires_consent: bool, + oauth2_client_id: Uuid, + oauth2_session_id: Option, +} + +impl TryFrom for AuthorizationGrant { + type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] + fn try_from(value: GrantLookup) -> Result { + let id = value.oauth2_authorization_grant_id.into(); + let scope: Scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("scope") + .row(id) + .source(e) + })?; + + let stage = match ( + value.fulfilled_at, + value.exchanged_at, + value.cancelled_at, + value.oauth2_session_id, + ) { + (None, None, None, None) => AuthorizationGrantStage::Pending, + (Some(fulfilled_at), None, None, Some(session_id)) => { + AuthorizationGrantStage::Fulfilled { + session_id: session_id.into(), + fulfilled_at, + } + } + (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => { + AuthorizationGrantStage::Exchanged { + session_id: session_id.into(), + fulfilled_at, + exchanged_at, + } + } + (None, None, Some(cancelled_at), None) => { + AuthorizationGrantStage::Cancelled { cancelled_at } + } + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("stage") + .row(id), + ); + } + }; + + let pkce = match (value.code_challenge, value.code_challenge_method) { + (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { + Some(Pkce { + challenge_method: PkceCodeChallengeMethod::Plain, + challenge, + }) + } + (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { + challenge_method: PkceCodeChallengeMethod::S256, + challenge, + }), + (None, None) => None, + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("code_challenge_method") + .row(id), + ); + } + }; + + let code: Option = + match (value.response_type_code, value.authorization_code, pkce) { + (false, None, None) => None, + (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), + _ => { + return Err( + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("authorization_code") + .row(id), + ); + } + }; + + let redirect_uri = value.redirect_uri.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let response_mode = value.response_mode.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("response_mode") + .row(id) + .source(e) + })?; + + let max_age = value + .max_age + .map(u32::try_from) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("max_age") + .row(id) + .source(e) + })? + .map(NonZeroU32::try_from) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("max_age") + .row(id) + .source(e) + })?; + + Ok(AuthorizationGrant { + id, + stage, + client_id: value.oauth2_client_id.into(), + code, + scope, + state: value.state, + nonce: value.nonce, + max_age, + response_mode, + redirect_uri, + created_at: value.created_at, + response_type_id_token: value.response_type_id_token, + requires_consent: value.requires_consent, + }) + } +} + +#[async_trait] +impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.add", + skip_all, + fields( + db.statement, + grant.id, + grant.scope = %scope, + %client.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result { + let code_challenge = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| &p.challenge); + let code_challenge_method = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| p.challenge_method.to_string()); + // TODO: this conversion is a bit ugly + let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); + let code_str = code.as_ref().map(|c| &c.code); + + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("grant.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_authorization_grants ( + oauth2_authorization_grant_id, + oauth2_client_id, + redirect_uri, + scope, + state, + nonce, + max_age, + response_mode, + code_challenge, + code_challenge_method, + response_type_code, + response_type_id_token, + authorization_code, + requires_consent, + created_at + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + "#, + Uuid::from(id), + Uuid::from(client.id), + redirect_uri.to_string(), + scope.to_string(), + state, + nonce, + max_age_i32, + response_mode.to_string(), + code_challenge, + code_challenge_method, + code.is_some(), + response_type_id_token, + code_str, + requires_consent, + created_at, + ) + .execute(&mut *self.conn) + .await?; + + Ok(AuthorizationGrant { + id, + stage: AuthorizationGrantStage::Pending, + code, + redirect_uri, + client_id: client.id, + scope, + state, + nonce, + max_age, + response_mode, + created_at, + response_type_id_token, + requires_consent, + }) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.lookup", + skip_all, + fields( + db.statement, + grant.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.find_by_code", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_code( + &mut self, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT oauth2_authorization_grant_id + , created_at + , cancelled_at + , fulfilled_at + , exchanged_at + , scope + , state + , redirect_uri + , response_mode + , nonce + , max_age + , oauth2_client_id + , authorization_code + , response_type_code + , response_type_id_token + , code_challenge + , code_challenge_method + , requires_consent + , oauth2_session_id + FROM + oauth2_authorization_grants + + WHERE authorization_code = $1 + "#, + code, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.fulfill", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + %session.id, + user_session.id = %session.user_session_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &Clock, + session: &Session, + grant: AuthorizationGrant, + ) -> Result { + let fulfilled_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET fulfilled_at = $2 + , oauth2_session_id = $3 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + fulfilled_at, + Uuid::from(session.id), + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + // XXX: check affected rows & new methods + let grant = grant + .fulfill(fulfilled_at, session) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.exchange", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &Clock, + grant: AuthorizationGrant, + ) -> Result { + let exchanged_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_authorization_grants + SET exchanged_at = $2 + WHERE oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + exchanged_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let grant = grant + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(grant) + } + + #[tracing::instrument( + name = "db.oauth2_authorization_grant.give_consent", + skip_all, + fields( + db.statement, + %grant.id, + client.id = %grant.client_id, + ), + err, + )] + async fn give_consent( + &mut self, + mut grant: AuthorizationGrant, + ) -> Result { + sqlx::query!( + r#" + UPDATE oauth2_authorization_grants AS og + SET + requires_consent = 'f' + WHERE + og.oauth2_authorization_grant_id = $1 + "#, + Uuid::from(grant.id), + ) + .execute(&mut *self.conn) + .await?; + + grant.requires_consent = false; + + Ok(grant) + } +} diff --git a/crates/storage-pg/src/oauth2/client.rs b/crates/storage-pg/src/oauth2/client.rs new file mode 100644 index 00000000..4430c669 --- /dev/null +++ b/crates/storage-pg/src/oauth2/client.rs @@ -0,0 +1,745 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::{BTreeMap, BTreeSet}, + str::FromStr, + string::ToString, +}; + +use async_trait::async_trait; +use mas_data_model::{Client, JwksOrJwksUri, User}; +use mas_iana::{ + jose::JsonWebSignatureAlg, + oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, +}; +use mas_jose::jwk::PublicJsonWebKeySet; +use mas_storage::{oauth2::OAuth2ClientRepository, Clock}; +use oauth2_types::{ + requests::GrantType, + scope::{Scope, ScopeToken}, +}; +use rand::{Rng, RngCore}; +use sqlx::PgConnection; +use tracing::{info_span, Instrument}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgOAuth2ClientRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2ClientRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +// XXX: response_types & contacts +#[derive(Debug)] +struct OAuth2ClientLookup { + oauth2_client_id: Uuid, + encrypted_client_secret: Option, + redirect_uris: Vec, + // response_types: Vec, + grant_type_authorization_code: bool, + grant_type_refresh_token: bool, + // contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, +} + +impl TryInto for OAuth2ClientLookup { + type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing + fn try_into(self) -> Result { + let id = Ulid::from(self.oauth2_client_id); + + let redirect_uris: Result, _> = + self.redirect_uris.iter().map(|s| s.parse()).collect(); + let redirect_uris = redirect_uris.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("redirect_uris") + .row(id) + .source(e) + })?; + + let response_types = vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ]; + /* XXX + let response_types: Result, _> = + self.response_types.iter().map(|s| s.parse()).collect(); + let response_types = response_types.map_err(|source| ClientFetchError::ParseField { + field: "response_types", + source, + })?; + */ + + let mut grant_types = Vec::new(); + if self.grant_type_authorization_code { + grant_types.push(GrantType::AuthorizationCode); + } + if self.grant_type_refresh_token { + grant_types.push(GrantType::RefreshToken); + } + + let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("logo_uri") + .row(id) + .source(e) + })?; + + let client_uri = self + .client_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("client_uri") + .row(id) + .source(e) + })?; + + let policy_uri = self + .policy_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("policy_uri") + .row(id) + .source(e) + })?; + + let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("tos_uri") + .row(id) + .source(e) + })?; + + let id_token_signed_response_alg = self + .id_token_signed_response_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("id_token_signed_response_alg") + .row(id) + .source(e) + })?; + + let userinfo_signed_response_alg = self + .userinfo_signed_response_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("userinfo_signed_response_alg") + .row(id) + .source(e) + })?; + + let token_endpoint_auth_method = self + .token_endpoint_auth_method + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + })?; + + let token_endpoint_auth_signing_alg = self + .token_endpoint_auth_signing_alg + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("token_endpoint_auth_signing_alg") + .row(id) + .source(e) + })?; + + let initiate_login_uri = self + .initiate_login_uri + .map(|s| s.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("initiate_login_uri") + .row(id) + .source(e) + })?; + + let jwks = match (self.jwks, self.jwks_uri) { + (None, None) => None, + (Some(jwks), None) => { + let jwks = serde_json::from_value(jwks).map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks") + .row(id) + .source(e) + })?; + Some(JwksOrJwksUri::Jwks(jwks)) + } + (None, Some(jwks_uri)) => { + let jwks_uri = jwks_uri.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks_uri") + .row(id) + .source(e) + })?; + + Some(JwksOrJwksUri::JwksUri(jwks_uri)) + } + _ => { + return Err(DatabaseInconsistencyError::on("oauth2_clients") + .column("jwks(_uri)") + .row(id)) + } + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret: self.encrypted_client_secret, + redirect_uris, + response_types, + grant_types, + // contacts: self.contacts, + contacts: vec![], + client_name: self.client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, + }) + } +} + +#[async_trait] +impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_client.lookup", + skip_all, + fields( + db.statement, + oauth2_client.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c + + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_client.load_batch", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error> { + let ids: Vec = ids.into_iter().map(Uuid::from).collect(); + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c + + WHERE oauth2_client_id = ANY($1::uuid[]) + "#, + &ids, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + res.into_iter() + .map(|r| { + r.try_into() + .map(|c: Client| (c.id, c)) + .map_err(DatabaseError::from) + }) + .collect() + } + + #[tracing::instrument( + name = "db.oauth2_client.add", + skip_all, + fields( + db.statement, + client.id, + client.name = client_name + ), + err, + )] + #[allow(clippy::too_many_lines)] + async fn add( + &mut self, + mut rng: &mut (dyn RngCore + Send), + clock: &Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result { + let now = clock.now(); + let id = Ulid::from_datetime_with_source(now.into(), rng); + tracing::Span::current().record("client.id", tracing::field::display(id)); + + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; + + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + "#, + Uuid::from(id), + encrypted_client_secret, + grant_types.contains(&GrantType::AuthorizationCode), + grant_types.contains(&GrantType::RefreshToken), + client_name, + logo_uri.as_ref().map(Url::as_str), + client_uri.as_ref().map(Url::as_str), + policy_uri.as_ref().map(Url::as_str), + tos_uri.as_ref().map(Url::as_str), + jwks_uri.as_ref().map(Url::as_str), + jwks_json, + id_token_signed_response_alg + .as_ref() + .map(ToString::to_string), + userinfo_signed_response_alg + .as_ref() + .map(ToString::to_string), + token_endpoint_auth_method.as_ref().map(ToString::to_string), + token_endpoint_auth_signing_alg + .as_ref() + .map(ToString::to_string), + initiate_login_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add.redirect_uris", + db.statement = tracing::field::Empty, + client.id = %id, + ); + + let (uri_ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &uri_ids, + Uuid::from(id), + &redirect_uris, + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types, + contacts, + client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, + }) + } + + #[tracing::instrument( + name = "db.oauth2_client.add_from_config", + skip_all, + fields( + db.statement, + client.id = %client_id, + ), + err, + )] + async fn add_from_config( + &mut self, + mut rng: impl Rng + Send, + clock: &Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result { + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; + + let client_auth_method = client_auth_method.to_string(); + + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , token_endpoint_auth_method + , jwks + , jwks_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (oauth2_client_id) + DO + UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret + , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code + , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token + , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method + , jwks = EXCLUDED.jwks + , jwks_uri = EXCLUDED.jwks_uri + "#, + Uuid::from(client_id), + encrypted_client_secret, + true, + true, + client_auth_method, + jwks_json, + jwks_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add_from_config.redirect_uris", + client.id = %client_id, + db.statement = tracing::field::Empty, + ); + + let now = clock.now(); + let (ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &ids, + Uuid::from(client_id), + &redirect_uris, + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id: client_id, + client_id: client_id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types: Vec::new(), + contacts: Vec::new(), + client_name: None, + logo_uri: None, + client_uri: None, + policy_uri: None, + tos_uri: None, + jwks, + id_token_signed_response_alg: None, + userinfo_signed_response_alg: None, + token_endpoint_auth_method: None, + token_endpoint_auth_signing_alg: None, + initiate_login_uri: None, + }) + } + + #[tracing::instrument( + name = "db.oauth2_client.get_consent_for_user", + skip_all, + fields( + db.statement, + %user.id, + %client.id, + ), + err, + )] + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result { + let scope_tokens: Vec = sqlx::query_scalar!( + r#" + SELECT scope_token + FROM oauth2_consents + WHERE user_id = $1 AND oauth2_client_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(client.id), + ) + .fetch_all(&mut *self.conn) + .await?; + + let scope: Result = scope_tokens + .into_iter() + .map(|s| ScopeToken::from_str(&s)) + .collect(); + + let scope = scope.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_consents") + .column("scope_token") + .source(e) + })?; + + Ok(scope) + } + + #[tracing::instrument( + skip_all, + fields( + db.statement, + %user.id, + %client.id, + %scope, + ), + err, + )] + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error> { + let now = clock.now(); + let (tokens, ids): (Vec, Vec) = scope + .iter() + .map(|token| { + ( + token.to_string(), + Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_consents + (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) + SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) + ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 + "#, + &ids, + Uuid::from(user.id), + Uuid::from(client.id), + &tokens, + now, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } +} diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs new file mode 100644 index 00000000..edad2beb --- /dev/null +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -0,0 +1,25 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod access_token; +pub mod authorization_grant; +mod client; +mod refresh_token; +mod session; + +pub use self::{ + access_token::PgOAuth2AccessTokenRepository, + authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository, + refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository, +}; diff --git a/crates/storage-pg/src/oauth2/refresh_token.rs b/crates/storage-pg/src/oauth2/refresh_token.rs new file mode 100644 index 00000000..47281d93 --- /dev/null +++ b/crates/storage-pg/src/oauth2/refresh_token.rs @@ -0,0 +1,224 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; +use mas_storage::{oauth2::OAuth2RefreshTokenRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgOAuth2RefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2RefreshTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2RefreshTokenLookup { + oauth2_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + oauth2_access_token_id: Option, + oauth2_session_id: Uuid, +} + +impl From for RefreshToken { + fn from(value: OAuth2RefreshTokenLookup) -> Self { + let state = match value.consumed_at { + None => RefreshTokenState::Valid, + Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, + }; + + RefreshToken { + id: value.oauth2_refresh_token_id.into(), + state, + session_id: value.oauth2_session_id.into(), + refresh_token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.oauth2_access_token_id.map(Ulid::from), + } + } +} + +#[async_trait] +impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_refresh_token.lookup", + skip_all, + fields( + db.statement, + refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens + + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2RefreshTokenLookup, + r#" + SELECT oauth2_refresh_token_id + , refresh_token + , created_at + , consumed_at + , oauth2_access_token_id + , oauth2_session_id + FROM oauth2_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.add", + skip_all, + fields( + db.statement, + %session.id, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + refresh_token.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_refresh_tokens + (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, + refresh_token, created_at) + VALUES + ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(session.id), + Uuid::from(access_token.id), + refresh_token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(RefreshToken { + id, + state: RefreshTokenState::default(), + session_id: session.id, + refresh_token, + access_token_id: Some(access_token.id), + created_at, + }) + } + + #[tracing::instrument( + name = "db.oauth2_refresh_token.consume", + skip_all, + fields( + db.statement, + %refresh_token.id, + session.id = %refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + refresh_token: RefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_refresh_tokens + SET consumed_at = $2 + WHERE oauth2_refresh_token_id = $1 + "#, + Uuid::from(refresh_token.id), + consumed_at, + ) + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation) + } +} diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs new file mode 100644 index 00000000..96f798e6 --- /dev/null +++ b/crates/storage-pg/src/oauth2/session.rs @@ -0,0 +1,248 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; +use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgOAuth2SessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2SessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct OAuthSessionLookup { + oauth2_session_id: Uuid, + user_session_id: Uuid, + oauth2_client_id: Uuid, + scope: String, + #[allow(dead_code)] + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for Session { + type Error = DatabaseInconsistencyError; + + fn try_from(value: OAuthSessionLookup) -> Result { + let id = Ulid::from(value.oauth2_session_id); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_sessions") + .column("scope") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => SessionState::Valid, + Some(finished_at) => SessionState::Finished { finished_at }, + }; + + Ok(Session { + id, + state, + created_at: value.created_at, + client_id: value.oauth2_client_id.into(), + user_session_id: value.user_session_id.into(), + scope, + }) + } +} + +#[async_trait] +impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_session.lookup", + skip_all, + fields( + db.statement, + session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuthSessionLookup, + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions + + WHERE oauth2_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(session) = res else { return Ok(None) }; + + Ok(Some(session.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_session.create_from_grant", + skip_all, + fields( + db.statement, + %user_session.id, + user.id = %user_session.user.id, + %grant.id, + client.id = %grant.client_id, + session.id, + session.scope = %grant.scope, + ), + err, + )] + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO oauth2_sessions + ( oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + ) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + Uuid::from(grant.client_id), + grant.scope.to_string(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Session { + id, + state: SessionState::Valid, + created_at, + user_session_id: user_session.id, + client_id: grant.client_id, + scope: grant.scope.clone(), + }) + } + + #[tracing::instrument( + name = "db.oauth2_session.finish", + skip_all, + fields( + db.statement, + %session.id, + %session.scope, + user_session.id = %session.user_session_id, + client.id = %session.client_id, + ), + err, + )] + async fn finish(&mut self, clock: &Clock, session: Session) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE oauth2_sessions + SET finished_at = $2 + WHERE oauth2_session_id = $1 + "#, + Uuid::from(session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation) + } + + #[tracing::instrument( + name = "db.oauth2_session.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT oauth2_session_id + , user_session_id + , oauth2_client_id + , scope + , created_at + , finished_at + FROM oauth2_sessions os + "#, + ); + + query + .push(" WHERE us.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("oauth2_session_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).try_map(Session::try_from)?; + Ok(page) + } +} diff --git a/crates/storage-pg/src/pagination.rs b/crates/storage-pg/src/pagination.rs new file mode 100644 index 00000000..97e5220f --- /dev/null +++ b/crates/storage-pg/src/pagination.rs @@ -0,0 +1,78 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utilities to manage paginated queries. + +use mas_storage::{pagination::PaginationDirection, Pagination}; +use sqlx::{Database, QueryBuilder}; +use uuid::Uuid; + +/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination +/// to a query +pub trait QueryBuilderExt { + /// Add cursor-based pagination to a query, as used in paginated GraphQL + /// connections + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self; +} + +impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> +where + DB: Database, + Uuid: sqlx::Type + sqlx::Encode<'a, DB>, + i64: sqlx::Type + sqlx::Encode<'a, DB>, +{ + fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self { + // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 + // 1. Start from the greedy query: SELECT * FROM table + + // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` + // clause + if let Some(after) = pagination.after { + self.push(" AND ") + .push(id_field) + .push(" > ") + .push_bind(Uuid::from(after)); + } + + // 3. If the before argument is provided, add `id < parsed_cursor` to the + // `WHERE` clause + if let Some(before) = pagination.before { + self.push(" AND ") + .push(id_field) + .push(" < ") + .push_bind(Uuid::from(before)); + } + + match pagination.direction { + // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the + // query + PaginationDirection::Forward => { + self.push(" ORDER BY ") + .push(id_field) + .push(" ASC LIMIT ") + .push_bind((pagination.count + 1) as i64); + } + // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the + // query + PaginationDirection::Backward => { + self.push(" ORDER BY ") + .push(id_field) + .push(" DESC LIMIT ") + .push_bind((pagination.count + 1) as i64); + } + }; + + self + } +} diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs new file mode 100644 index 00000000..288181a6 --- /dev/null +++ b/crates/storage-pg/src/repository.rs @@ -0,0 +1,142 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use mas_storage::Repository; +use sqlx::{PgPool, Postgres, Transaction}; + +use crate::{ + compat::{ + PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, + PgCompatSsoLoginRepository, + }, + oauth2::{ + PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, + PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + }, + upstream_oauth2::{ + PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, + PgUpstreamOAuthSessionRepository, + }, + user::{ + PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, + PgUserRepository, + }, + DatabaseError, +}; + +pub struct PgRepository { + txn: Transaction<'static, Postgres>, +} + +impl PgRepository { + pub async fn from_pool(pool: &PgPool) -> Result { + let txn = pool.begin().await?; + Ok(PgRepository { txn }) + } + + pub async fn save(self) -> Result<(), DatabaseError> { + self.txn.commit().await?; + Ok(()) + } + + pub async fn cancel(self) -> Result<(), DatabaseError> { + self.txn.rollback().await?; + Ok(()) + } +} + +impl Repository for PgRepository { + type Error = DatabaseError; + + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; + type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; + type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; + type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; + type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; + type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; + type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; + type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; + type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; + type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; + type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; + type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; + type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; + type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { + PgUpstreamOAuthLinkRepository::new(&mut self.txn) + } + + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { + PgUpstreamOAuthProviderRepository::new(&mut self.txn) + } + + fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { + PgUpstreamOAuthSessionRepository::new(&mut self.txn) + } + + fn user(&mut self) -> Self::UserRepository<'_> { + PgUserRepository::new(&mut self.txn) + } + + fn user_email(&mut self) -> Self::UserEmailRepository<'_> { + PgUserEmailRepository::new(&mut self.txn) + } + + fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { + PgUserPasswordRepository::new(&mut self.txn) + } + + fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { + PgBrowserSessionRepository::new(&mut self.txn) + } + + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { + PgOAuth2ClientRepository::new(&mut self.txn) + } + + fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { + PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) + } + + fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { + PgOAuth2SessionRepository::new(&mut self.txn) + } + + fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { + PgOAuth2AccessTokenRepository::new(&mut self.txn) + } + + fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { + PgOAuth2RefreshTokenRepository::new(&mut self.txn) + } + + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { + PgCompatSessionRepository::new(&mut self.txn) + } + + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { + PgCompatSsoLoginRepository::new(&mut self.txn) + } + + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { + PgCompatAccessTokenRepository::new(&mut self.txn) + } + + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { + PgCompatRefreshTokenRepository::new(&mut self.txn) + } +} diff --git a/crates/storage/src/tracing.rs b/crates/storage-pg/src/tracing.rs similarity index 95% rename from crates/storage/src/tracing.rs rename to crates/storage-pg/src/tracing.rs index 08c62e46..1210816c 100644 --- a/crates/storage/src/tracing.rs +++ b/crates/storage-pg/src/tracing.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/crates/storage-pg/src/upstream_oauth2/link.rs b/crates/storage-pg/src/upstream_oauth2/link.rs new file mode 100644 index 00000000..4087e2c7 --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/link.rs @@ -0,0 +1,262 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; +use mas_storage::{upstream_oauth2::UpstreamOAuthLinkRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +pub struct PgUpstreamOAuthLinkRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthLinkRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct LinkLookup { + upstream_oauth_link_id: Uuid, + upstream_oauth_provider_id: Uuid, + user_id: Option, + subject: String, + created_at: DateTime, +} + +impl From for UpstreamOAuthLink { + fn from(value: LinkLookup) -> Self { + UpstreamOAuthLink { + id: Ulid::from(value.upstream_oauth_link_id), + provider_id: Ulid::from(value.upstream_oauth_provider_id), + user_id: value.user_id.map(Ulid::from), + subject: value.subject, + created_at: value.created_at, + } + } +} + +#[async_trait] +impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_link.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_link.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_link_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.find_by_subject", + skip_all, + fields( + db.statement, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + AND subject = $2 + "#, + Uuid::from(upstream_oauth_provider.id), + subject, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.add", + skip_all, + fields( + db.statement, + upstream_oauth_link.id, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_links ( + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + ) VALUES ($1, $2, NULL, $3, $4) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &subject, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthLink { + id, + provider_id: upstream_oauth_provider.id, + user_id: None, + subject, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.associate_to_user", + skip_all, + fields( + db.statement, + %upstream_oauth_link.id, + %upstream_oauth_link.subject, + %user.id, + %user.username, + ), + err, + )] + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE upstream_oauth_links + SET user_id = $1 + WHERE upstream_oauth_link_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(upstream_oauth_link.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.upstream_oauth_link.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("upstream_oauth_link_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).map(UpstreamOAuthLink::from); + Ok(page) + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs new file mode 100644 index 00000000..e77daba2 --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -0,0 +1,271 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod link; +mod provider; +mod session; + +pub use self::{ + link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository, + session::PgUpstreamOAuthSessionRepository, +}; + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_storage::{ + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::UserRepository, + Clock, Pagination, Repository, + }; + use oauth2_types::scope::{Scope, OPENID}; + use rand::SeedableRng; + use sqlx::PgPool; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_repository(pool: PgPool) { + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // The provider list should be empty at the start + let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); + assert!(all_providers.is_empty()); + + // Let's add a provider + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + "https://example.com/".to_owned(), + Scope::from_iter([OPENID]), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + "client-id".to_owned(), + None, + ) + .await + .unwrap(); + + // Look it up in the database + let provider = repo + .upstream_oauth_provider() + .lookup(provider.id) + .await + .unwrap() + .expect("provider to be found in the database"); + assert_eq!(provider.issuer, "https://example.com/"); + assert_eq!(provider.client_id, "client-id"); + + // Start a session + let session = repo + .upstream_oauth_session() + .add( + &mut rng, + &clock, + &provider, + "some-state".to_owned(), + None, + "some-nonce".to_owned(), + ) + .await + .unwrap(); + + // Look it up in the database + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert_eq!(session.provider_id, provider.id); + assert_eq!(session.link_id(), None); + assert!(session.is_pending()); + assert!(!session.is_completed()); + assert!(!session.is_consumed()); + + // Create a link + let link = repo + .upstream_oauth_link() + .add(&mut rng, &clock, &provider, "a-subject".to_owned()) + .await + .unwrap(); + + // We can look it up by its ID + repo.upstream_oauth_link() + .lookup(link.id) + .await + .unwrap() + .expect("link to be found in database"); + + // or by its subject + let link = repo + .upstream_oauth_link() + .find_by_subject(&provider, "a-subject") + .await + .unwrap() + .expect("link to be found in database"); + assert_eq!(link.subject, "a-subject"); + assert_eq!(link.provider_id, provider.id); + + let session = repo + .upstream_oauth_session() + .complete_with_link(&clock, session, &link, None) + .await + .unwrap(); + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert!(session.is_completed()); + assert!(!session.is_consumed()); + assert_eq!(session.link_id(), Some(link.id)); + + let session = repo + .upstream_oauth_session() + .consume(&clock, session) + .await + .unwrap(); + // Reload the session + let session = repo + .upstream_oauth_session() + .lookup(session.id) + .await + .unwrap() + .expect("session to be found in the database"); + assert!(session.is_consumed()); + + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await + .unwrap(); + + let links = repo + .upstream_oauth_link() + .list_paginated(&user, Pagination::first(10)) + .await + .unwrap(); + assert!(!links.has_previous_page); + assert!(!links.has_next_page); + assert_eq!(links.edges.len(), 1); + assert_eq!(links.edges[0].id, link.id); + assert_eq!(links.edges[0].user_id, Some(user.id)); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_provider_repository_pagination(pool: PgPool) { + const ISSUER: &str = "https://example.com/"; + let scope = Scope::from_iter([OPENID]); + + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + let mut ids = Vec::with_capacity(20); + // Create 20 providers + for idx in 0..20 { + let client_id = format!("client-{idx}"); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + ISSUER.to_owned(), + scope.clone(), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + client_id, + None, + ) + .await + .unwrap(); + ids.push(provider.id); + clock.advance(Duration::seconds(10)); + } + + // Lookup the first 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10)) + .await + .unwrap(); + + // It returned the first 10 items + assert!(page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup the next 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10).after(ids[9])) + .await + .unwrap(); + + // It returned the next 10 items + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the last 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::last(10)) + .await + .unwrap(); + + // It returned the last 10 items + assert!(page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the previous 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::last(10).before(ids[10])) + .await + .unwrap(); + + // It returned the previous 10 items + assert!(!page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup 10 items between two IDs + let page = repo + .upstream_oauth_provider() + .list_paginated(Pagination::first(10).after(ids[5]).before(ids[8])) + .await + .unwrap(); + + // It returned the items in between + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[6..8]); + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs new file mode 100644 index 00000000..480249ee --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -0,0 +1,273 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::UpstreamOAuthProvider; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination}; +use oauth2_types::scope::Scope; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgUpstreamOAuthProviderRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthProviderRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct ProviderLookup { + upstream_oauth_provider_id: Uuid, + issuer: String, + scope: String, + client_id: String, + encrypted_client_secret: Option, + token_endpoint_signing_alg: Option, + token_endpoint_auth_method: String, + created_at: DateTime, +} + +impl TryFrom for UpstreamOAuthProvider { + type Error = DatabaseInconsistencyError; + fn try_from(value: ProviderLookup) -> Result { + let id = value.upstream_oauth_provider_id.into(); + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("scope") + .row(id) + .source(e) + })?; + let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + })?; + let token_endpoint_signing_alg = value + .token_endpoint_signing_alg + .map(|x| x.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("token_endpoint_signing_alg") + .row(id) + .source(e) + })?; + + Ok(UpstreamOAuthProvider { + id, + issuer: value.issuer, + scope, + client_id: value.client_id, + encrypted_client_secret: value.encrypted_client_secret, + token_endpoint_auth_method, + token_endpoint_signing_alg, + created_at: value.created_at, + }) + } +} + +#[async_trait] +impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_provider.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let res = res + .map(UpstreamOAuthProvider::try_from) + .transpose() + .map_err(DatabaseError::from)?; + + Ok(res) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.add", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id, + upstream_oauth_provider.issuer = %issuer, + upstream_oauth_provider.client_id = %client_id, + ), + err, + )] + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_providers ( + upstream_oauth_provider_id, + issuer, + scope, + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id, + encrypted_client_secret, + created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + "#, + Uuid::from(id), + &issuer, + scope.to_string(), + token_endpoint_auth_method.to_string(), + token_endpoint_signing_alg.as_ref().map(ToString::to_string), + &client_id, + encrypted_client_secret.as_deref(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthProvider { + id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.list_paginated", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn list_paginated( + &mut self, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE 1 = 1 + "#, + ); + + query.generate_pagination("upstream_oauth_provider_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).try_map(TryInto::try_into)?; + Ok(page) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.all", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn all(&mut self) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + "#, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); + Ok(res?) + } +} diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs new file mode 100644 index 00000000..699a463f --- /dev/null +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -0,0 +1,286 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, + UpstreamOAuthProvider, +}; +use mas_storage::{upstream_oauth2::UpstreamOAuthSessionRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgUpstreamOAuthSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct SessionLookup { + upstream_oauth_authorization_session_id: Uuid, + upstream_oauth_provider_id: Uuid, + upstream_oauth_link_id: Option, + state: String, + code_challenge_verifier: Option, + nonce: String, + id_token: Option, + created_at: DateTime, + completed_at: Option>, + consumed_at: Option>, +} + +impl TryFrom for UpstreamOAuthAuthorizationSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = value.upstream_oauth_authorization_session_id.into(); + let state = match ( + value.upstream_oauth_link_id, + value.id_token, + value.completed_at, + value.consumed_at, + ) { + (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, + (Some(link_id), id_token, Some(completed_at), None) => { + UpstreamOAuthAuthorizationSessionState::Completed { + completed_at, + link_id: link_id.into(), + id_token, + } + } + (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { + UpstreamOAuthAuthorizationSessionState::Consumed { + completed_at, + link_id: link_id.into(), + id_token, + consumed_at, + } + } + _ => { + return Err( + DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), + ) + } + }; + + Ok(Self { + id, + provider_id: value.upstream_oauth_provider_id.into(), + state_str: value.state, + nonce: value.nonce, + code_challenge_verifier: value.code_challenge_verifier, + created_at: value.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.lookup", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + upstream_oauth_link_id, + state, + code_challenge_verifier, + nonce, + id_token, + created_at, + completed_at, + consumed_at + FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_authorization_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.add", + skip_all, + fields( + db.statement, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + upstream_oauth_authorization_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state_str: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "upstream_oauth_authorization_session.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_authorization_sessions ( + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at, + consumed_at, + id_token + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &state_str, + code_challenge_verifier.as_deref(), + nonce, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthAuthorizationSession { + id, + state: UpstreamOAuthAuthorizationSessionState::default(), + provider_id: upstream_oauth_provider.id, + state_str, + code_challenge_verifier, + nonce, + created_at, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.complete_with_link", + skip_all, + fields( + db.statement, + %upstream_oauth_authorization_session.id, + %upstream_oauth_link.id, + ), + err, + )] + async fn complete_with_link( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + let completed_at = clock.now(); + + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET upstream_oauth_link_id = $1, + completed_at = $2, + id_token = $3 + WHERE upstream_oauth_authorization_session_id = $4 + "#, + Uuid::from(upstream_oauth_link.id), + completed_at, + id_token, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .complete(completed_at, upstream_oauth_link, id_token) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(upstream_oauth_authorization_session) + } + + /// Mark a session as consumed + #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.consume", + skip_all, + fields( + db.statement, + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result { + let consumed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET consumed_at = $1 + WHERE upstream_oauth_authorization_session_id = $2 + "#, + consumed_at, + Uuid::from(upstream_oauth_authorization_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(upstream_oauth_authorization_session) + } +} diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs new file mode 100644 index 00000000..936b0649 --- /dev/null +++ b/crates/storage-pg/src/user/email.rs @@ -0,0 +1,554 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; +use mas_storage::{user::UserEmailRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use tracing::{info_span, Instrument}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgUserEmailRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserEmailRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(Debug, Clone, sqlx::FromRow)] +struct UserEmailLookup { + user_email_id: Uuid, + user_id: Uuid, + email: String, + created_at: DateTime, + confirmed_at: Option>, +} + +impl From for UserEmail { + fn from(e: UserEmailLookup) -> UserEmail { + UserEmail { + id: e.user_email_id.into(), + user_id: e.user_id.into(), + email: e.email, + created_at: e.created_at, + confirmed_at: e.confirmed_at, + } + } +} + +struct UserEmailConfirmationCodeLookup { + user_email_confirmation_code_id: Uuid, + user_email_id: Uuid, + code: String, + created_at: DateTime, + expires_at: DateTime, + consumed_at: Option>, +} + +impl UserEmailConfirmationCodeLookup { + fn into_verification(self, clock: &Clock) -> UserEmailVerification { + let now = clock.now(); + let state = if let Some(when) = self.consumed_at { + UserEmailVerificationState::AlreadyUsed { when } + } else if self.expires_at < now { + UserEmailVerificationState::Expired { + when: self.expires_at, + } + } else { + UserEmailVerificationState::Valid + }; + + UserEmailVerification { + id: self.user_email_confirmation_code_id.into(), + user_email_id: self.user_email_id.into(), + code: self.code, + state, + created_at: self.created_at, + } + } +} + +#[async_trait] +impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_email.lookup", + skip_all, + fields( + db.statement, + user_email.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_email_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + name = "db.user_email.find", + skip_all, + fields( + db.statement, + %user.id, + user_email.email = email, + ), + err, + )] + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 AND email = $2 + "#, + Uuid::from(user.id), + email, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(user_email) = res else { return Ok(None) }; + + Ok(Some(user_email.into())) + } + + #[tracing::instrument( + name = "db.user_email.get_primary", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { + let Some(id) = user.primary_user_email_id else { return Ok(None) }; + + let user_email = self.lookup(id).await?.ok_or_else(|| { + DatabaseInconsistencyError::on("users") + .column("primary_user_email_id") + .row(user.id) + })?; + + Ok(Some(user_email)) + } + + #[tracing::instrument( + name = "db.user_email.all", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn all(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailLookup, + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + + WHERE user_id = $1 + + ORDER BY email ASC + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + Ok(res.into_iter().map(Into::into).collect()) + } + + #[tracing::instrument( + name = "db.user_email.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, DatabaseError> { + let mut query = QueryBuilder::new( + r#" + SELECT user_email_id + , user_id + , email + , created_at + , confirmed_at + FROM user_emails + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("ue.user_email_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).map(UserEmail::from); + Ok(page) + } + + #[tracing::instrument( + name = "db.user_email.count", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn count(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) + FROM user_emails + WHERE user_id = $1 + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + let res = res.unwrap_or_default(); + + Ok(res + .try_into() + .map_err(DatabaseError::to_invalid_operation)?) + } + + #[tracing::instrument( + name = "db.user_email.add", + skip_all, + fields( + db.statement, + %user.id, + user_email.id, + user_email.email = email, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + email: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_emails (user_email_id, user_id, email, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + &email, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(UserEmail { + id, + user_id: user.id, + email, + created_at, + confirmed_at: None, + }) + } + + #[tracing::instrument( + name = "db.user_email.remove", + skip_all, + fields( + db.statement, + user.id = %user_email.user_id, + %user_email.id, + %user_email.email, + ), + err, + )] + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { + let span = info_span!( + "db.user_email.remove.codes", + db.statement = tracing::field::Empty + ); + sqlx::query!( + r#" + DELETE FROM user_email_confirmation_codes + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + + let res = sqlx::query!( + r#" + DELETE FROM user_emails + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } + + async fn mark_as_verified( + &mut self, + clock: &Clock, + mut user_email: UserEmail, + ) -> Result { + let confirmed_at = clock.now(); + sqlx::query!( + r#" + UPDATE user_emails + SET confirmed_at = $2 + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + confirmed_at, + ) + .execute(&mut *self.conn) + .await?; + + user_email.confirmed_at = Some(confirmed_at); + Ok(user_email) + } + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE users + SET primary_user_email_id = user_emails.user_email_id + FROM user_emails + WHERE user_emails.user_email_id = $1 + AND users.user_id = user_emails.user_id + "#, + Uuid::from(user_email.id), + ) + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + name = "db.user_email.add_verification_code", + skip_all, + fields( + db.statement, + %user_email.id, + %user_email.email, + user_email_verification.id, + user_email_verification.code = code, + ), + err, + )] + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); + let expires_at = created_at + max_age; + + sqlx::query!( + r#" + INSERT INTO user_email_confirmation_codes + (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(user_email.id), + code, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let verification = UserEmailVerification { + id, + user_email_id: user_email.id, + code, + created_at, + state: UserEmailVerificationState::Valid, + }; + + Ok(verification) + } + + #[tracing::instrument( + name = "db.user_email.find_verification_code", + skip_all, + fields( + db.statement, + %user_email.id, + user.id = %user_email.user_id, + ), + err, + )] + async fn find_verification_code( + &mut self, + clock: &Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserEmailConfirmationCodeLookup, + r#" + SELECT user_email_confirmation_code_id + , user_email_id + , code + , created_at + , expires_at + , consumed_at + FROM user_email_confirmation_codes + WHERE code = $1 + AND user_email_id = $2 + "#, + code, + Uuid::from(user_email.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into_verification(clock))) + } + + #[tracing::instrument( + name = "db.user_email.consume_verification_code", + skip_all, + fields( + db.statement, + %user_email_verification.id, + user_email.id = %user_email_verification.user_email_id, + ), + err, + )] + async fn consume_verification_code( + &mut self, + clock: &Clock, + mut user_email_verification: UserEmailVerification, + ) -> Result { + if !matches!( + user_email_verification.state, + UserEmailVerificationState::Valid + ) { + return Err(DatabaseError::invalid_operation()); + } + + let consumed_at = clock.now(); + + sqlx::query!( + r#" + UPDATE user_email_confirmation_codes + SET consumed_at = $2 + WHERE user_email_confirmation_code_id = $1 + "#, + Uuid::from(user_email_verification.id), + consumed_at + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_email_verification.state = + UserEmailVerificationState::AlreadyUsed { when: consumed_at }; + + Ok(user_email_verification) + } +} diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs new file mode 100644 index 00000000..d7320261 --- /dev/null +++ b/crates/storage-pg/src/user/mod.rs @@ -0,0 +1,203 @@ +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::User; +use mas_storage::{user::UserRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, LookupResultExt}; + +mod email; +mod password; +mod session; + +#[cfg(test)] +mod tests; + +pub use self::{ + email::PgUserEmailRepository, password::PgUserPasswordRepository, + session::PgBrowserSessionRepository, +}; + +pub struct PgUserRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(Debug, Clone)] +struct UserLookup { + user_id: Uuid, + username: String, + primary_user_email_id: Option, + + #[allow(dead_code)] + created_at: DateTime, +} + +impl From for User { + fn from(value: UserLookup) -> Self { + let id = value.user_id.into(); + Self { + id, + username: value.username, + sub: id.to_string(), + primary_user_email_id: value.primary_user_email_id.map(Into::into), + } + } +} + +#[async_trait] +impl<'c> UserRepository for PgUserRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user.lookup", + skip_all, + fields( + db.statement, + user.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE user_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.user.find_by_username", + skip_all, + fields( + db.statement, + user.username = username, + ), + err, + )] + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT user_id + , username + , primary_user_email_id + , created_at + FROM users + WHERE username = $1 + "#, + username, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.user.add", + skip_all, + fields( + db.statement, + user.username = username, + user.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + username: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO users (user_id, username, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + username, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(User { + id, + username, + sub: id.to_string(), + primary_user_email_id: None, + }) + } + + #[tracing::instrument( + name = "db.user.exists", + skip_all, + fields( + db.statement, + user.username = username, + ), + err, + )] + async fn exists(&mut self, username: &str) -> Result { + let exists = sqlx::query_scalar!( + r#" + SELECT EXISTS( + SELECT 1 FROM users WHERE username = $1 + ) AS "exists!" + "#, + username + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + Ok(exists) + } +} diff --git a/crates/storage-pg/src/user/password.rs b/crates/storage-pg/src/user/password.rs new file mode 100644 index 00000000..997b1227 --- /dev/null +++ b/crates/storage-pg/src/user/password.rs @@ -0,0 +1,155 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{Password, User}; +use mas_storage::{user::UserPasswordRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; + +pub struct PgUserPasswordRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserPasswordRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct UserPasswordLookup { + user_password_id: Uuid, + hashed_password: String, + version: i32, + upgraded_from_id: Option, + created_at: DateTime, +} + +#[async_trait] +impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_password.active", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err, + )] + async fn active(&mut self, user: &User) -> Result, Self::Error> { + let res = sqlx::query_as!( + UserPasswordLookup, + r#" + SELECT up.user_password_id + , up.hashed_password + , up.version + , up.upgraded_from_id + , up.created_at + FROM user_passwords up + WHERE up.user_id = $1 + ORDER BY up.created_at DESC + LIMIT 1 + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + let id = Ulid::from(res.user_password_id); + + let version = res.version.try_into().map_err(|e| { + DatabaseInconsistencyError::on("user_passwords") + .column("version") + .row(id) + .source(e) + })?; + + let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); + let created_at = res.created_at; + let hashed_password = res.hashed_password; + + Ok(Some(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + })) + } + + #[tracing::instrument( + name = "db.user_password.add", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + user_password.id, + user_password.version = version, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_password.id", tracing::field::display(id)); + + let upgraded_from_id = upgraded_from.map(|p| p.id); + + sqlx::query!( + r#" + INSERT INTO user_passwords + (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + Uuid::from(id), + Uuid::from(user.id), + hashed_password, + i32::from(version), + upgraded_from_id.map(Uuid::from), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(Password { + id, + hashed_password, + version, + upgraded_from_id, + created_at, + }) + } +} diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs new file mode 100644 index 00000000..d216c067 --- /dev/null +++ b/crates/storage-pg/src/user/session.rs @@ -0,0 +1,375 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + LookupResultExt, +}; + +pub struct PgBrowserSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgBrowserSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct SessionLookup { + user_session_id: Uuid, + user_session_created_at: DateTime, + user_session_finished_at: Option>, + user_id: Uuid, + user_username: String, + user_primary_user_email_id: Option, + last_authentication_id: Option, + last_authd_at: Option>, +} + +impl TryFrom for BrowserSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = Ulid::from(value.user_id); + let user = User { + id, + username: value.user_username, + sub: id.to_string(), + primary_user_email_id: value.user_primary_user_email_id.map(Into::into), + }; + + let last_authentication = match (value.last_authentication_id, value.last_authd_at) { + (Some(id), Some(created_at)) => Some(Authentication { + id: id.into(), + created_at, + }), + (None, None) => None, + _ => { + return Err(DatabaseInconsistencyError::on( + "user_session_authentications", + )) + } + }; + + Ok(BrowserSession { + id: value.user_session_id.into(), + user, + created_at: value.user_session_created_at, + finished_at: value.user_session_finished_at, + last_authentication, + }) + } +} + +#[async_trait] +impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.browser_session.lookup", + skip_all, + fields( + db.statement, + user_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT s.user_session_id + , s.created_at AS "user_session_created_at" + , s.finished_at AS "user_session_finished_at" + , u.user_id + , u.username AS "user_username" + , u.primary_user_email_id AS "user_primary_user_email_id" + , a.user_session_authentication_id AS "last_authentication_id?" + , a.created_at AS "last_authd_at?" + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + WHERE s.user_session_id = $1 + ORDER BY a.created_at DESC + LIMIT 1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.browser_session.add", + skip_all, + fields( + db.statement, + %user.id, + user_session.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO user_sessions (user_session_id, user_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let session = BrowserSession { + id, + // XXX + user: user.clone(), + created_at, + finished_at: None, + last_authentication: None, + }; + + Ok(session) + } + + #[tracing::instrument( + name = "db.browser_session.finish", + skip_all, + fields( + db.statement, + %user_session.id, + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + mut user_session: BrowserSession, + ) -> Result { + let finished_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE user_sessions + SET finished_at = $1 + WHERE user_session_id = $2 + "#, + finished_at, + Uuid::from(user_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.finished_at = Some(finished_at); + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.list_active_paginated", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn list_active_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error> { + // TODO: ordering of last authentication is wrong + let mut query = QueryBuilder::new( + r#" + SELECT DISTINCT ON (s.user_session_id) + s.user_session_id, + u.user_id, + u.username, + s.created_at, + a.user_session_authentication_id AS "last_authentication_id", + a.created_at AS "last_authd_at", + FROM user_sessions s + INNER JOIN users u + USING (user_id) + LEFT JOIN user_session_authentications a + USING (user_session_id) + "#, + ); + + query + .push(" WHERE s.finished_at IS NULL AND s.user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("s.user_session_id", pagination); + + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination + .process(edges) + .try_map(BrowserSession::try_from)?; + Ok(page) + } + + #[tracing::instrument( + name = "db.browser_session.count_active", + skip_all, + fields( + db.statement, + %user.id, + ), + err, + )] + async fn count_active(&mut self, user: &User) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) as "count!" + FROM user_sessions s + WHERE s.user_id = $1 AND s.finished_at IS NULL + "#, + Uuid::from(user.id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + res.try_into().map_err(DatabaseError::to_invalid_operation) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_password", + skip_all, + fields( + db.statement, + %user_session.id, + %user_password.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + mut user_session: BrowserSession, + user_password: &Password, + ) -> Result { + let _user_password = user_password; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } + + #[tracing::instrument( + name = "db.browser_session.authenticate_with_upstream", + skip_all, + fields( + db.statement, + %user_session.id, + %upstream_oauth_link.id, + user_session_authentication.id, + ), + err, + )] + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + mut user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result { + let _upstream_oauth_link = upstream_oauth_link; + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO user_session_authentications + (user_session_authentication_id, user_session_id, created_at) + VALUES ($1, $2, $3) + "#, + Uuid::from(id), + Uuid::from(user_session.id), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_session.last_authentication = Some(Authentication { id, created_at }); + + Ok(user_session) + } +} diff --git a/crates/storage/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs similarity index 98% rename from crates/storage/src/user/tests.rs rename to crates/storage-pg/src/user/tests.rs index fca35ce0..f0a071b0 100644 --- a/crates/storage/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -13,14 +13,15 @@ // limitations under the License. use chrono::Duration; +use mas_storage::{ + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Clock, Repository, +}; use rand::SeedableRng; use rand_chacha::ChaChaRng; use sqlx::PgPool; -use crate::{ - user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, PgRepository, Repository, -}; +use crate::PgRepository; /// Test the user repository, by adding and looking up a user #[sqlx::test(migrator = "crate::MIGRATOR")] @@ -88,7 +89,7 @@ async fn test_user_email_repo(pool: PgPool) { // The user email should not exist yet assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_none()); @@ -109,7 +110,7 @@ async fn test_user_email_repo(pool: PgPool) { assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_some()); @@ -179,7 +180,7 @@ async fn test_user_email_repo(pool: PgPool) { // Reload the user_email let user_email = repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .expect("user email was not found"); diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index fb6c0fdc..97089e95 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -7,18 +7,12 @@ license = "Apache-2.0" [dependencies] async-trait = "0.1.60" -sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } -chrono = { version = "0.4.23", features = ["serde"] } -serde = { version = "1.0.152", features = ["derive"] } -serde_json = "1.0.91" +chrono = "0.4.23" thiserror = "1.0.38" -tracing = "0.1.37" rand = "0.8.5" -rand_chacha = "0.3.1" -url = { version = "2.3.1", features = ["serde"] } -uuid = "1.2.2" -ulid = { version = "1.0.0", features = ["uuid", "serde"] } +url = "2.3.1" +ulid = "1.0.0" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs index 86d2dd19..46ff6e3f 100644 --- a/crates/storage/src/compat/access_token.rs +++ b/crates/storage/src/compat/access_token.rs @@ -13,14 +13,12 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Duration, Utc}; +use chrono::Duration; use mas_data_model::{CompatAccessToken, CompatSession}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait CompatAccessTokenRepository: Send + Sync { @@ -52,195 +50,3 @@ pub trait CompatAccessTokenRepository: Send + Sync { compat_access_token: CompatAccessToken, ) -> Result; } - -pub struct PgCompatAccessTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatAccessTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct CompatAccessTokenLookup { - compat_access_token_id: Uuid, - access_token: String, - created_at: DateTime, - expires_at: Option>, - compat_session_id: Uuid, -} - -impl From for CompatAccessToken { - fn from(value: CompatAccessTokenLookup) -> Self { - Self { - id: value.compat_access_token_id.into(), - session_id: value.compat_session_id.into(), - token: value.access_token, - created_at: value.created_at, - expires_at: value.expires_at, - } - } -} - -#[async_trait] -impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_access_token.lookup", - skip_all, - fields( - db.statement, - compat_session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE compat_access_token_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_access_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - access_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE access_token = $1 - "#, - access_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_access_token.add", - skip_all, - fields( - db.statement, - compat_access_token.id, - %compat_session.id, - user.id = %compat_session.user_id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - compat_session: &CompatSession, - token: String, - expires_after: Option, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); - - let expires_at = expires_after.map(|expires_after| created_at + expires_after); - - sqlx::query!( - r#" - INSERT INTO compat_access_tokens - (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(compat_session.id), - token, - created_at, - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatAccessToken { - id, - session_id: compat_session.id, - token, - created_at, - expires_at, - }) - } - - #[tracing::instrument( - name = "db.compat_access_token.expire", - skip_all, - fields( - db.statement, - %compat_access_token.id, - compat_session.id = %compat_access_token.session_id, - ), - err, - )] - async fn expire( - &mut self, - clock: &Clock, - mut compat_access_token: CompatAccessToken, - ) -> Result { - let expires_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_access_tokens - SET expires_at = $2 - WHERE compat_access_token_id = $1 - "#, - Uuid::from(compat_access_token.id), - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - compat_access_token.expires_at = Some(expires_at); - Ok(compat_access_token) - } -} diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs index c37081b8..634c04a7 100644 --- a/crates/storage/src/compat/mod.rs +++ b/crates/storage/src/compat/mod.rs @@ -18,301 +18,6 @@ mod session; mod sso_login; pub use self::{ - access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository}, - refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository}, - session::{CompatSessionRepository, PgCompatSessionRepository}, - sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository}, + access_token::CompatAccessTokenRepository, refresh_token::CompatRefreshTokenRepository, + session::CompatSessionRepository, sso_login::CompatSsoLoginRepository, }; - -#[cfg(test)] -mod tests { - use chrono::Duration; - use mas_data_model::Device; - use rand::SeedableRng; - use rand_chacha::ChaChaRng; - use sqlx::PgPool; - - use super::*; - use crate::{user::UserRepository, Clock, PgRepository, Repository}; - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_session_repository(pool: PgPool) { - const FIRST_TOKEN: &str = "first_access_token"; - const SECOND_TOKEN: &str = "second_access_token"; - let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Create a user - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - - // Start a compat session for that user - let device = Device::generate(&mut rng); - let device_str = device.as_str().to_owned(); - let session = repo - .compat_session() - .add(&mut rng, &clock, &user, device) - .await - .unwrap(); - assert_eq!(session.user_id, user.id); - assert_eq!(session.device.as_str(), device_str); - assert!(session.is_valid()); - assert!(!session.is_finished()); - - // Lookup the session and check it didn't change - let session_lookup = repo - .compat_session() - .lookup(session.id) - .await - .unwrap() - .expect("compat session not found"); - assert_eq!(session_lookup.id, session.id); - assert_eq!(session_lookup.user_id, user.id); - assert_eq!(session_lookup.device.as_str(), device_str); - assert!(session_lookup.is_valid()); - assert!(!session_lookup.is_finished()); - - // Finish the session - let session = repo.compat_session().finish(&clock, session).await.unwrap(); - assert!(!session.is_valid()); - assert!(session.is_finished()); - - // Reload the session and check again - let session_lookup = repo - .compat_session() - .lookup(session.id) - .await - .unwrap() - .expect("compat session not found"); - assert!(!session_lookup.is_valid()); - assert!(session_lookup.is_finished()); - } - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_access_token_repository(pool: PgPool) { - const FIRST_TOKEN: &str = "first_access_token"; - const SECOND_TOKEN: &str = "second_access_token"; - let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Create a user - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - - // Start a compat session for that user - let device = Device::generate(&mut rng); - let session = repo - .compat_session() - .add(&mut rng, &clock, &user, device) - .await - .unwrap(); - - // Add an access token to that session - let token = repo - .compat_access_token() - .add( - &mut rng, - &clock, - &session, - FIRST_TOKEN.to_owned(), - Some(Duration::minutes(1)), - ) - .await - .unwrap(); - assert_eq!(token.session_id, session.id); - assert_eq!(token.token, FIRST_TOKEN); - - // Commit the txn and grab a new transaction, to test a conflict - repo.save().await.unwrap(); - - { - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - // Adding the same token a second time should conflict - assert!(repo - .compat_access_token() - .add( - &mut rng, - &clock, - &session, - FIRST_TOKEN.to_owned(), - Some(Duration::minutes(1)), - ) - .await - .is_err()); - repo.cancel().await.unwrap(); - } - - // Grab a new repo - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Looking up via ID works - let token_lookup = repo - .compat_access_token() - .lookup(token.id) - .await - .unwrap() - .expect("compat access token not found"); - assert_eq!(token.id, token_lookup.id); - assert_eq!(token_lookup.session_id, session.id); - - // Looking up via the token value works - let token_lookup = repo - .compat_access_token() - .find_by_token(FIRST_TOKEN) - .await - .unwrap() - .expect("compat access token not found"); - assert_eq!(token.id, token_lookup.id); - assert_eq!(token_lookup.session_id, session.id); - - // Token is currently valid - assert!(token.is_valid(clock.now())); - - clock.advance(Duration::minutes(1)); - // Token should have expired - assert!(!token.is_valid(clock.now())); - - // Add a second access token, this time without expiration - let token = repo - .compat_access_token() - .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None) - .await - .unwrap(); - assert_eq!(token.session_id, session.id); - assert_eq!(token.token, SECOND_TOKEN); - - // Token is currently valid - assert!(token.is_valid(clock.now())); - - // Make it expire - repo.compat_access_token() - .expire(&clock, token) - .await - .unwrap(); - - // Reload it - let token = repo - .compat_access_token() - .find_by_token(SECOND_TOKEN) - .await - .unwrap() - .expect("compat access token not found"); - - // Token is not valid anymore - assert!(!token.is_valid(clock.now())); - - repo.save().await.unwrap(); - } - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_refresh_token_repository(pool: PgPool) { - const ACCESS_TOKEN: &str = "access_token"; - const REFRESH_TOKEN: &str = "refresh_token"; - let mut rng = ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // Create a user - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - - // Start a compat session for that user - let device = Device::generate(&mut rng); - let session = repo - .compat_session() - .add(&mut rng, &clock, &user, device) - .await - .unwrap(); - - // Add an access token to that session - let access_token = repo - .compat_access_token() - .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None) - .await - .unwrap(); - - let refresh_token = repo - .compat_refresh_token() - .add( - &mut rng, - &clock, - &session, - &access_token, - REFRESH_TOKEN.to_owned(), - ) - .await - .unwrap(); - assert_eq!(refresh_token.session_id, session.id); - assert_eq!(refresh_token.access_token_id, access_token.id); - assert_eq!(refresh_token.token, REFRESH_TOKEN); - assert!(refresh_token.is_valid()); - assert!(!refresh_token.is_consumed()); - - // Look it up by ID and check everything matches - let refresh_token_lookup = repo - .compat_refresh_token() - .lookup(refresh_token.id) - .await - .unwrap() - .expect("refresh token not found"); - assert_eq!(refresh_token_lookup.id, refresh_token.id); - assert_eq!(refresh_token_lookup.session_id, session.id); - assert_eq!(refresh_token_lookup.access_token_id, access_token.id); - assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); - assert!(refresh_token_lookup.is_valid()); - assert!(!refresh_token_lookup.is_consumed()); - - // Look it up by token and check everything matches - let refresh_token_lookup = repo - .compat_refresh_token() - .find_by_token(REFRESH_TOKEN) - .await - .unwrap() - .expect("refresh token not found"); - assert_eq!(refresh_token_lookup.id, refresh_token.id); - assert_eq!(refresh_token_lookup.session_id, session.id); - assert_eq!(refresh_token_lookup.access_token_id, access_token.id); - assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN); - assert!(refresh_token_lookup.is_valid()); - assert!(!refresh_token_lookup.is_consumed()); - - // Consume it - let refresh_token = repo - .compat_refresh_token() - .consume(&clock, refresh_token) - .await - .unwrap(); - assert!(!refresh_token.is_valid()); - assert!(refresh_token.is_consumed()); - - // Reload it and check again - let refresh_token_lookup = repo - .compat_refresh_token() - .find_by_token(REFRESH_TOKEN) - .await - .unwrap() - .expect("refresh token not found"); - assert!(!refresh_token_lookup.is_valid()); - assert!(refresh_token_lookup.is_consumed()); - - // Consuming it again should not work - assert!(repo - .compat_refresh_token() - .consume(&clock, refresh_token) - .await - .is_err()); - - repo.save().await.unwrap(); - } -} diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs index 30054622..7a1057ff 100644 --- a/crates/storage/src/compat/refresh_token.rs +++ b/crates/storage/src/compat/refresh_token.rs @@ -13,16 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, -}; +use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait CompatRefreshTokenRepository: Send + Sync { @@ -54,207 +49,3 @@ pub trait CompatRefreshTokenRepository: Send + Sync { compat_refresh_token: CompatRefreshToken, ) -> Result; } - -pub struct PgCompatRefreshTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatRefreshTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct CompatRefreshTokenLookup { - compat_refresh_token_id: Uuid, - refresh_token: String, - created_at: DateTime, - consumed_at: Option>, - compat_access_token_id: Uuid, - compat_session_id: Uuid, -} - -impl From for CompatRefreshToken { - fn from(value: CompatRefreshTokenLookup) -> Self { - let state = match value.consumed_at { - Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, - None => CompatRefreshTokenState::Valid, - }; - - Self { - id: value.compat_refresh_token_id.into(), - state, - session_id: value.compat_session_id.into(), - token: value.refresh_token, - created_at: value.created_at, - access_token_id: value.compat_access_token_id.into(), - } - } -} - -#[async_trait] -impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_refresh_token.lookup", - skip_all, - fields( - db.statement, - compat_refresh_token.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatRefreshTokenLookup, - r#" - SELECT compat_refresh_token_id - , refresh_token - , created_at - , consumed_at - , compat_session_id - , compat_access_token_id - - FROM compat_refresh_tokens - - WHERE compat_refresh_token_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_refresh_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - refresh_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatRefreshTokenLookup, - r#" - SELECT compat_refresh_token_id - , refresh_token - , created_at - , consumed_at - , compat_session_id - , compat_access_token_id - - FROM compat_refresh_tokens - - WHERE refresh_token = $1 - "#, - refresh_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.compat_refresh_token.add", - skip_all, - fields( - db.statement, - compat_refresh_token.id, - %compat_session.id, - user.id = %compat_session.user_id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - compat_session: &CompatSession, - compat_access_token: &CompatAccessToken, - token: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_refresh_tokens - (compat_refresh_token_id, compat_session_id, - compat_access_token_id, refresh_token, created_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(compat_session.id), - Uuid::from(compat_access_token.id), - token, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatRefreshToken { - id, - state: CompatRefreshTokenState::default(), - session_id: compat_session.id, - access_token_id: compat_access_token.id, - token, - created_at, - }) - } - - #[tracing::instrument( - name = "db.compat_refresh_token.consume", - skip_all, - fields( - db.statement, - %compat_refresh_token.id, - compat_session.id = %compat_refresh_token.session_id, - ), - err, - )] - async fn consume( - &mut self, - clock: &Clock, - compat_refresh_token: CompatRefreshToken, - ) -> Result { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_refresh_tokens - SET consumed_at = $2 - WHERE compat_refresh_token_id = $1 - "#, - Uuid::from(compat_refresh_token.id), - consumed_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let compat_refresh_token = compat_refresh_token - .consume(consumed_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(compat_refresh_token) - } -} diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 3068be73..34bc6838 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -13,16 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{CompatSession, CompatSessionState, Device, User}; +use mas_data_model::{CompatSession, Device, User}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait CompatSessionRepository: Send + Sync { @@ -47,174 +42,3 @@ pub trait CompatSessionRepository: Send + Sync { compat_session: CompatSession, ) -> Result; } - -pub struct PgCompatSessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatSessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct CompatSessionLookup { - compat_session_id: Uuid, - device_id: String, - user_id: Uuid, - created_at: DateTime, - finished_at: Option>, -} - -impl TryFrom for CompatSession { - type Error = DatabaseInconsistencyError; - - fn try_from(value: CompatSessionLookup) -> Result { - let id = value.compat_session_id.into(); - let device = Device::try_from(value.device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device_id") - .row(id) - .source(e) - })?; - - let state = match value.finished_at { - None => CompatSessionState::Valid, - Some(finished_at) => CompatSessionState::Finished { finished_at }, - }; - - let session = CompatSession { - id, - state, - user_id: value.user_id.into(), - device, - created_at: value.created_at, - }; - - Ok(session) - } -} - -#[async_trait] -impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_session.lookup", - skip_all, - fields( - db.statement, - compat_session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatSessionLookup, - r#" - SELECT compat_session_id - , device_id - , user_id - , created_at - , finished_at - FROM compat_sessions - WHERE compat_session_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.compat_session.add", - skip_all, - fields( - db.statement, - compat_session.id, - %user.id, - %user.username, - compat_session.device.id = device.as_str(), - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - device: Device, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - device.as_str(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatSession { - id, - state: CompatSessionState::default(), - user_id: user.id, - device, - created_at, - }) - } - - #[tracing::instrument( - name = "db.compat_session.finish", - skip_all, - fields( - db.statement, - %compat_session.id, - user.id = %compat_session.user_id, - compat_session.device.id = compat_session.device.as_str(), - ), - err, - )] - async fn finish( - &mut self, - clock: &Clock, - compat_session: CompatSession, - ) -> Result { - let finished_at = clock.now(); - - let res = sqlx::query!( - r#" - UPDATE compat_sessions cs - SET finished_at = $2 - WHERE compat_session_id = $1 - "#, - Uuid::from(compat_session.id), - finished_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let compat_session = compat_session - .finish(finished_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(compat_session) - } -} diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 76cf1ede..348e0ac5 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -13,19 +13,12 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; +use mas_data_model::{CompatSession, CompatSsoLogin, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait CompatSsoLoginRepository: Send + Sync { @@ -71,317 +64,3 @@ pub trait CompatSsoLoginRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } - -pub struct PgCompatSsoLoginRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgCompatSsoLoginRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct CompatSsoLoginLookup { - compat_sso_login_id: Uuid, - login_token: String, - redirect_uri: String, - created_at: DateTime, - fulfilled_at: Option>, - exchanged_at: Option>, - compat_session_id: Option, -} - -impl TryFrom for CompatSsoLogin { - type Error = DatabaseInconsistencyError; - - fn try_from(res: CompatSsoLoginLookup) -> Result { - let id = res.compat_sso_login_id.into(); - let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| { - DatabaseInconsistencyError::on("compat_sso_logins") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { - (None, None, None) => CompatSsoLoginState::Pending, - (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { - fulfilled_at, - session_id: session_id.into(), - }, - (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { - CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session_id: session_id.into(), - } - } - _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), - }; - - Ok(CompatSsoLogin { - id, - login_token: res.login_token, - redirect_uri, - created_at: res.created_at, - state, - }) - } -} - -#[async_trait] -impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.compat_sso_login.lookup", - skip_all, - fields( - db.statement, - compat_sso_login.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT compat_sso_login_id - , login_token - , redirect_uri - , created_at - , fulfilled_at - , exchanged_at - , compat_session_id - - FROM compat_sso_logins - WHERE compat_sso_login_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.compat_sso_login.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - login_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT compat_sso_login_id - , login_token - , redirect_uri - , created_at - , fulfilled_at - , exchanged_at - , compat_session_id - - FROM compat_sso_logins - WHERE login_token = $1 - "#, - login_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.compat_sso_login.add", - skip_all, - fields( - db.statement, - compat_sso_login.id, - compat_sso_login.redirect_uri = %redirect_uri, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - login_token: String, - redirect_uri: Url, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sso_logins - (compat_sso_login_id, login_token, redirect_uri, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - &login_token, - redirect_uri.as_str(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(CompatSsoLogin { - id, - login_token, - redirect_uri, - created_at, - state: CompatSsoLoginState::default(), - }) - } - - #[tracing::instrument( - name = "db.compat_sso_login.fulfill", - skip_all, - fields( - db.statement, - %compat_sso_login.id, - %compat_session.id, - compat_session.device.id = compat_session.device.as_str(), - user.id = %compat_session.user_id, - ), - err, - )] - async fn fulfill( - &mut self, - clock: &Clock, - compat_sso_login: CompatSsoLogin, - compat_session: &CompatSession, - ) -> Result { - let fulfilled_at = clock.now(); - let compat_sso_login = compat_sso_login - .fulfill(fulfilled_at, compat_session) - .map_err(DatabaseError::to_invalid_operation)?; - - let res = sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - compat_session_id = $2, - fulfilled_at = $3 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - Uuid::from(compat_session.id), - fulfilled_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(compat_sso_login) - } - - #[tracing::instrument( - name = "db.compat_sso_login.exchange", - skip_all, - fields( - db.statement, - %compat_sso_login.id, - ), - err, - )] - async fn exchange( - &mut self, - clock: &Clock, - compat_sso_login: CompatSsoLogin, - ) -> Result { - let exchanged_at = clock.now(); - let compat_sso_login = compat_sso_login - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - let res = sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - exchanged_at = $2 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - exchanged_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(compat_sso_login) - } - - #[tracing::instrument( - name = "db.compat_sso_login.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT cl.compat_sso_login_id - , cl.login_token - , cl.redirect_uri - , cl.created_at - , cl.fulfilled_at - , cl.exchanged_at - , cl.compat_session_id - - FROM compat_sso_logins cl - INNER JOIN compat_sessions ON compat_session_id - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination - .process(edges) - .try_map(CompatSsoLogin::try_from)?; - Ok(page) - } -} diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index e92d37fe..a65c806c 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,150 +29,19 @@ )] use chrono::{DateTime, Utc}; -use pagination::InvalidPagination; -use sqlx::{migrate::Migrator, postgres::PgQueryResult}; -use thiserror::Error; -use ulid::Ulid; - -trait LookupResultExt { - type Output; - - /// Transform a [`Result`] from a sqlx query to transform "not found" errors - /// into [`None`] - fn to_option(self) -> Result, sqlx::Error>; -} - -impl LookupResultExt for Result { - type Output = T; - - fn to_option(self) -> Result, sqlx::Error> { - match self { - Ok(v) => Ok(Some(v)), - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(e), - } - } -} - -/// Generic error when interacting with the database -#[derive(Debug, Error)] -#[error(transparent)] -pub enum DatabaseError { - /// An error which came from the database itself - Driver(#[from] sqlx::Error), - - /// An error which occured while converting the data from the database - Inconsistency(#[from] DatabaseInconsistencyError), - - /// An error which occured while generating the paginated query - Pagination(#[from] InvalidPagination), - - /// An error which happened because the requested database operation is - /// invalid - #[error("Invalid database operation")] - InvalidOperation { - #[source] - source: Option>, - }, - - /// An error which happens when an operation affects not enough or too many - /// rows - #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] - RowsAffected { expected: u64, actual: u64 }, -} - -impl DatabaseError { - pub(crate) fn ensure_affected_rows( - result: &PgQueryResult, - expected: u64, - ) -> Result<(), DatabaseError> { - let actual = result.rows_affected(); - if actual == expected { - Ok(()) - } else { - Err(DatabaseError::RowsAffected { expected, actual }) - } - } - - pub(crate) fn to_invalid_operation(e: E) -> Self { - Self::InvalidOperation { - source: Some(Box::new(e)), - } - } - - pub(crate) const fn invalid_operation() -> Self { - Self::InvalidOperation { source: None } - } -} - -#[derive(Debug, Error)] -pub struct DatabaseInconsistencyError { - table: &'static str, - column: Option<&'static str>, - row: Option, - - #[source] - source: Option>, -} - -impl std::fmt::Display for DatabaseInconsistencyError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Database inconsistency on table {}", self.table)?; - if let Some(column) = self.column { - write!(f, " column {column}")?; - } - if let Some(row) = self.row { - write!(f, " row {row}")?; - } - - Ok(()) - } -} - -impl DatabaseInconsistencyError { - #[must_use] - pub(crate) const fn on(table: &'static str) -> Self { - Self { - table, - column: None, - row: None, - source: None, - } - } - - #[must_use] - pub(crate) const fn column(mut self, column: &'static str) -> Self { - self.column = Some(column); - self - } - - #[must_use] - pub(crate) const fn row(mut self, row: Ulid) -> Self { - self.row = Some(row); - self - } - - pub(crate) fn source( - mut self, - source: E, - ) -> Self { - self.source = Some(Box::new(source)); - self - } -} #[derive(Debug, Clone, Default)] pub struct Clock { _private: (), - #[cfg(test)] + // #[cfg(test)] mock: Option>, } impl Clock { #[must_use] pub fn now(&self) -> DateTime { - #[cfg(test)] + // #[cfg(test)] if let Some(timestamp) = &self.mock { let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed); return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap(); @@ -183,13 +52,14 @@ impl Clock { Utc::now() } - #[cfg(test)] + // #[cfg(test)] + #[must_use] pub fn mock() -> Self { use std::sync::{atomic::AtomicI64, Arc}; use chrono::TimeZone; - let datetime = Utc.with_ymd_and_hms(2022, 01, 16, 14, 40, 0).unwrap(); + let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap(); let timestamp = datetime.timestamp(); Self { @@ -198,7 +68,7 @@ impl Clock { } } - #[cfg(test)] + // #[cfg(test)] pub fn advance(&self, duration: chrono::Duration) { let timestamp = self .mock @@ -247,16 +117,12 @@ mod tests { pub mod compat; pub mod oauth2; -pub(crate) mod pagination; +pub mod pagination; pub(crate) mod repository; -pub(crate) mod tracing; pub mod upstream_oauth2; pub mod user; pub use self::{ - pagination::Pagination, - repository::{PgRepository, Repository}, + pagination::{Page, Pagination}, + repository::Repository, }; - -/// Embedded migrations, allowing them to run on startup -pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index db10ed72..a0406e44 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,14 +13,12 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, AccessTokenState, Session}; +use chrono::Duration; +use mas_data_model::{AccessToken, Session}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait OAuth2AccessTokenRepository: Send + Sync { @@ -55,202 +53,3 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { /// Cleanup expired access tokens async fn cleanup_expired(&mut self, clock: &Clock) -> Result; } - -pub struct PgOAuth2AccessTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2AccessTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct OAuth2AccessTokenLookup { - oauth2_access_token_id: Uuid, - oauth2_session_id: Uuid, - access_token: String, - created_at: DateTime, - expires_at: DateTime, - revoked_at: Option>, -} - -impl From for AccessToken { - fn from(value: OAuth2AccessTokenLookup) -> Self { - let state = match value.revoked_at { - None => AccessTokenState::Valid, - Some(revoked_at) => AccessTokenState::Revoked { revoked_at }, - }; - - Self { - id: value.oauth2_access_token_id.into(), - state, - session_id: value.oauth2_session_id.into(), - access_token: value.access_token, - created_at: value.created_at, - expires_at: value.expires_at, - } - } -} - -#[async_trait] -impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> { - type Error = DatabaseError; - - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" - SELECT oauth2_access_token_id - , access_token - , created_at - , expires_at - , revoked_at - , oauth2_session_id - - FROM oauth2_access_tokens - - WHERE oauth2_access_token_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_access_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - access_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" - SELECT oauth2_access_token_id - , access_token - , created_at - , expires_at - , revoked_at - , oauth2_session_id - - FROM oauth2_access_tokens - - WHERE access_token = $1 - "#, - access_token, - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_access_token.add", - skip_all, - fields( - db.statement, - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - access_token.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - session: &Session, - access_token: String, - expires_after: Duration, - ) -> Result { - let created_at = clock.now(); - let expires_at = created_at + expires_after; - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - - tracing::Span::current().record("access_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_access_tokens - (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - &access_token, - created_at, - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(AccessToken { - id, - state: AccessTokenState::default(), - access_token, - session_id: session.id, - created_at, - expires_at, - }) - } - - async fn revoke( - &mut self, - clock: &Clock, - access_token: AccessToken, - ) -> Result { - let revoked_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_access_tokens - SET revoked_at = $2 - WHERE oauth2_access_token_id = $1 - "#, - Uuid::from(access_token.id), - revoked_at, - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - access_token - .revoke(revoked_at) - .map_err(DatabaseError::to_invalid_operation) - } - - async fn cleanup_expired(&mut self, clock: &Clock) -> Result { - // Cleanup token which expired more than 15 minutes ago - let threshold = clock.now() - Duration::minutes(15); - let res = sqlx::query!( - r#" - DELETE FROM oauth2_access_tokens - WHERE expires_at < $1 - "#, - threshold, - ) - .execute(&mut *self.conn) - .await?; - - Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) - } -} diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index c57c5dcd..ce1a716f 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,21 +15,13 @@ use std::num::NonZeroU32; use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, -}; -use mas_iana::oauth::PkceCodeChallengeMethod; +use mas_data_model::{AuthorizationCode, AuthorizationGrant, Client, Session}; use oauth2_types::{requests::ResponseMode, scope::Scope}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait OAuth2AuthorizationGrantRepository: Send + Sync { @@ -75,482 +67,3 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { authorization_grant: AuthorizationGrant, ) -> Result; } - -pub struct PgOAuth2AuthorizationGrantRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2AuthorizationGrantRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[allow(clippy::struct_excessive_bools)] -struct GrantLookup { - oauth2_authorization_grant_id: Uuid, - created_at: DateTime, - cancelled_at: Option>, - fulfilled_at: Option>, - exchanged_at: Option>, - scope: String, - state: Option, - nonce: Option, - redirect_uri: String, - response_mode: String, - max_age: Option, - response_type_code: bool, - response_type_id_token: bool, - authorization_code: Option, - code_challenge: Option, - code_challenge_method: Option, - requires_consent: bool, - oauth2_client_id: Uuid, - oauth2_session_id: Option, -} - -impl TryFrom for AuthorizationGrant { - type Error = DatabaseInconsistencyError; - - #[allow(clippy::too_many_lines)] - fn try_from(value: GrantLookup) -> Result { - let id = value.oauth2_authorization_grant_id.into(); - let scope: Scope = value.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("scope") - .row(id) - .source(e) - })?; - - let stage = match ( - value.fulfilled_at, - value.exchanged_at, - value.cancelled_at, - value.oauth2_session_id, - ) { - (None, None, None, None) => AuthorizationGrantStage::Pending, - (Some(fulfilled_at), None, None, Some(session_id)) => { - AuthorizationGrantStage::Fulfilled { - session_id: session_id.into(), - fulfilled_at, - } - } - (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => { - AuthorizationGrantStage::Exchanged { - session_id: session_id.into(), - fulfilled_at, - exchanged_at, - } - } - (None, None, Some(cancelled_at), None) => { - AuthorizationGrantStage::Cancelled { cancelled_at } - } - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("stage") - .row(id), - ); - } - }; - - let pkce = match (value.code_challenge, value.code_challenge_method) { - (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { - Some(Pkce { - challenge_method: PkceCodeChallengeMethod::Plain, - challenge, - }) - } - (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { - challenge_method: PkceCodeChallengeMethod::S256, - challenge, - }), - (None, None) => None, - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("code_challenge_method") - .row(id), - ); - } - }; - - let code: Option = - match (value.response_type_code, value.authorization_code, pkce) { - (false, None, None) => None, - (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), - _ => { - return Err( - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("authorization_code") - .row(id), - ); - } - }; - - let redirect_uri = value.redirect_uri.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let response_mode = value.response_mode.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("response_mode") - .row(id) - .source(e) - })?; - - let max_age = value - .max_age - .map(u32::try_from) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("max_age") - .row(id) - .source(e) - })? - .map(NonZeroU32::try_from) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("max_age") - .row(id) - .source(e) - })?; - - Ok(AuthorizationGrant { - id, - stage, - client_id: value.oauth2_client_id.into(), - code, - scope, - state: value.state, - nonce: value.nonce, - max_age, - response_mode, - redirect_uri, - created_at: value.created_at, - response_type_id_token: value.response_type_id_token, - requires_consent: value.requires_consent, - }) - } -} - -#[async_trait] -impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.add", - skip_all, - fields( - db.statement, - grant.id, - grant.scope = %scope, - %client.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - client: &Client, - redirect_uri: Url, - scope: Scope, - code: Option, - state: Option, - nonce: Option, - max_age: Option, - response_mode: ResponseMode, - response_type_id_token: bool, - requires_consent: bool, - ) -> Result { - let code_challenge = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| &p.challenge); - let code_challenge_method = code - .as_ref() - .and_then(|c| c.pkce.as_ref()) - .map(|p| p.challenge_method.to_string()); - // TODO: this conversion is a bit ugly - let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); - let code_str = code.as_ref().map(|c| &c.code); - - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("grant.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_authorization_grants ( - oauth2_authorization_grant_id, - oauth2_client_id, - redirect_uri, - scope, - state, - nonce, - max_age, - response_mode, - code_challenge, - code_challenge_method, - response_type_code, - response_type_id_token, - authorization_code, - requires_consent, - created_at - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) - "#, - Uuid::from(id), - Uuid::from(client.id), - redirect_uri.to_string(), - scope.to_string(), - state, - nonce, - max_age_i32, - response_mode.to_string(), - code_challenge, - code_challenge_method, - code.is_some(), - response_type_id_token, - code_str, - requires_consent, - created_at, - ) - .execute(&mut *self.conn) - .await?; - - Ok(AuthorizationGrant { - id, - stage: AuthorizationGrantStage::Pending, - code, - redirect_uri, - client_id: client.id, - scope, - state, - nonce, - max_age, - response_mode, - created_at, - response_type_id_token, - requires_consent, - }) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.lookup", - skip_all, - fields( - db.statement, - grant.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT oauth2_authorization_grant_id - , created_at - , cancelled_at - , fulfilled_at - , exchanged_at - , scope - , state - , redirect_uri - , response_mode - , nonce - , max_age - , oauth2_client_id - , authorization_code - , response_type_code - , response_type_id_token - , code_challenge - , code_challenge_method - , requires_consent - , oauth2_session_id - FROM - oauth2_authorization_grants - - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.find_by_code", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_code( - &mut self, - code: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - GrantLookup, - r#" - SELECT oauth2_authorization_grant_id - , created_at - , cancelled_at - , fulfilled_at - , exchanged_at - , scope - , state - , redirect_uri - , response_mode - , nonce - , max_age - , oauth2_client_id - , authorization_code - , response_type_code - , response_type_id_token - , code_challenge - , code_challenge_method - , requires_consent - , oauth2_session_id - FROM - oauth2_authorization_grants - - WHERE authorization_code = $1 - "#, - code, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.fulfill", - skip_all, - fields( - db.statement, - %grant.id, - client.id = %grant.client_id, - %session.id, - user_session.id = %session.user_session_id, - ), - err, - )] - async fn fulfill( - &mut self, - clock: &Clock, - session: &Session, - grant: AuthorizationGrant, - ) -> Result { - let fulfilled_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_authorization_grants - SET fulfilled_at = $2 - , oauth2_session_id = $3 - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - fulfilled_at, - Uuid::from(session.id), - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - // XXX: check affected rows & new methods - let grant = grant - .fulfill(fulfilled_at, session) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.exchange", - skip_all, - fields( - db.statement, - %grant.id, - client.id = %grant.client_id, - ), - err, - )] - async fn exchange( - &mut self, - clock: &Clock, - grant: AuthorizationGrant, - ) -> Result { - let exchanged_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_authorization_grants - SET exchanged_at = $2 - WHERE oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - exchanged_at, - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let grant = grant - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(grant) - } - - #[tracing::instrument( - name = "db.oauth2_authorization_grant.give_consent", - skip_all, - fields( - db.statement, - %grant.id, - client.id = %grant.client_id, - ), - err, - )] - async fn give_consent( - &mut self, - mut grant: AuthorizationGrant, - ) -> Result { - sqlx::query!( - r#" - UPDATE oauth2_authorization_grants AS og - SET - requires_consent = 'f' - WHERE - og.oauth2_authorization_grant_id = $1 - "#, - Uuid::from(grant.id), - ) - .execute(&mut *self.conn) - .await?; - - grant.requires_consent = false; - - Ok(grant) - } -} diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 756017b8..093369a4 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,33 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - collections::{BTreeMap, BTreeSet}, - str::FromStr, - string::ToString, -}; +use std::collections::{BTreeMap, BTreeSet}; use async_trait::async_trait; -use mas_data_model::{Client, JwksOrJwksUri, User}; -use mas_iana::{ - jose::JsonWebSignatureAlg, - oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, -}; +use mas_data_model::{Client, User}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_jose::jwk::PublicJsonWebKeySet; -use oauth2_types::{ - requests::GrantType, - scope::{Scope, ScopeToken}, -}; +use oauth2_types::{requests::GrantType, scope::Scope}; use rand::{Rng, RngCore}; -use sqlx::PgConnection; -use tracing::{info_span, Instrument}; use ulid::Ulid; use url::Url; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait OAuth2ClientRepository: Send + Sync { @@ -107,708 +92,3 @@ pub trait OAuth2ClientRepository: Send + Sync { scope: &Scope, ) -> Result<(), Self::Error>; } - -pub struct PgOAuth2ClientRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2ClientRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -// XXX: response_types & contacts -#[derive(Debug)] -struct OAuth2ClientLookup { - oauth2_client_id: Uuid, - encrypted_client_secret: Option, - redirect_uris: Vec, - // response_types: Vec, - grant_type_authorization_code: bool, - grant_type_refresh_token: bool, - // contacts: Vec, - client_name: Option, - logo_uri: Option, - client_uri: Option, - policy_uri: Option, - tos_uri: Option, - jwks_uri: Option, - jwks: Option, - id_token_signed_response_alg: Option, - userinfo_signed_response_alg: Option, - token_endpoint_auth_method: Option, - token_endpoint_auth_signing_alg: Option, - initiate_login_uri: Option, -} - -impl TryInto for OAuth2ClientLookup { - type Error = DatabaseInconsistencyError; - - #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing - fn try_into(self) -> Result { - let id = Ulid::from(self.oauth2_client_id); - - let redirect_uris: Result, _> = - self.redirect_uris.iter().map(|s| s.parse()).collect(); - let redirect_uris = redirect_uris.map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("redirect_uris") - .row(id) - .source(e) - })?; - - let response_types = vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::None, - ]; - /* XXX - let response_types: Result, _> = - self.response_types.iter().map(|s| s.parse()).collect(); - let response_types = response_types.map_err(|source| ClientFetchError::ParseField { - field: "response_types", - source, - })?; - */ - - let mut grant_types = Vec::new(); - if self.grant_type_authorization_code { - grant_types.push(GrantType::AuthorizationCode); - } - if self.grant_type_refresh_token { - grant_types.push(GrantType::RefreshToken); - } - - let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("logo_uri") - .row(id) - .source(e) - })?; - - let client_uri = self - .client_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("client_uri") - .row(id) - .source(e) - })?; - - let policy_uri = self - .policy_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("policy_uri") - .row(id) - .source(e) - })?; - - let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("tos_uri") - .row(id) - .source(e) - })?; - - let id_token_signed_response_alg = self - .id_token_signed_response_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("id_token_signed_response_alg") - .row(id) - .source(e) - })?; - - let userinfo_signed_response_alg = self - .userinfo_signed_response_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("userinfo_signed_response_alg") - .row(id) - .source(e) - })?; - - let token_endpoint_auth_method = self - .token_endpoint_auth_method - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - })?; - - let token_endpoint_auth_signing_alg = self - .token_endpoint_auth_signing_alg - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("token_endpoint_auth_signing_alg") - .row(id) - .source(e) - })?; - - let initiate_login_uri = self - .initiate_login_uri - .map(|s| s.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("initiate_login_uri") - .row(id) - .source(e) - })?; - - let jwks = match (self.jwks, self.jwks_uri) { - (None, None) => None, - (Some(jwks), None) => { - let jwks = serde_json::from_value(jwks).map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks") - .row(id) - .source(e) - })?; - Some(JwksOrJwksUri::Jwks(jwks)) - } - (None, Some(jwks_uri)) => { - let jwks_uri = jwks_uri.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks_uri") - .row(id) - .source(e) - })?; - - Some(JwksOrJwksUri::JwksUri(jwks_uri)) - } - _ => { - return Err(DatabaseInconsistencyError::on("oauth2_clients") - .column("jwks(_uri)") - .row(id)) - } - }; - - Ok(Client { - id, - client_id: id.to_string(), - encrypted_client_secret: self.encrypted_client_secret, - redirect_uris, - response_types, - grant_types, - // contacts: self.contacts, - contacts: vec![], - client_name: self.client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - }) - } -} - -#[async_trait] -impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_client.lookup", - skip_all, - fields( - db.statement, - oauth2_client.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT oauth2_client_id - , encrypted_client_secret - , ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!" - , grant_type_authorization_code - , grant_type_refresh_token - , client_name - , logo_uri - , client_uri - , policy_uri - , tos_uri - , jwks_uri - , jwks - , id_token_signed_response_alg - , userinfo_signed_response_alg - , token_endpoint_auth_method - , token_endpoint_auth_signing_alg - , initiate_login_uri - FROM oauth2_clients c - - WHERE oauth2_client_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_client.load_batch", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn load_batch( - &mut self, - ids: BTreeSet, - ) -> Result, Self::Error> { - let ids: Vec = ids.into_iter().map(Uuid::from).collect(); - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT oauth2_client_id - , encrypted_client_secret - , ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!" - , grant_type_authorization_code - , grant_type_refresh_token - , client_name - , logo_uri - , client_uri - , policy_uri - , tos_uri - , jwks_uri - , jwks - , id_token_signed_response_alg - , userinfo_signed_response_alg - , token_endpoint_auth_method - , token_endpoint_auth_signing_alg - , initiate_login_uri - FROM oauth2_clients c - - WHERE oauth2_client_id = ANY($1::uuid[]) - "#, - &ids, - ) - .traced() - .fetch_all(&mut *self.conn) - .await?; - - res.into_iter() - .map(|r| { - r.try_into() - .map(|c: Client| (c.id, c)) - .map_err(DatabaseError::from) - }) - .collect() - } - - #[tracing::instrument( - name = "db.oauth2_client.add", - skip_all, - fields( - db.statement, - client.id, - client.name = client_name - ), - err, - )] - #[allow(clippy::too_many_lines)] - async fn add( - &mut self, - mut rng: &mut (dyn RngCore + Send), - clock: &Clock, - redirect_uris: Vec, - encrypted_client_secret: Option, - grant_types: Vec, - contacts: Vec, - client_name: Option, - logo_uri: Option, - client_uri: Option, - policy_uri: Option, - tos_uri: Option, - jwks_uri: Option, - jwks: Option, - id_token_signed_response_alg: Option, - userinfo_signed_response_alg: Option, - token_endpoint_auth_method: Option, - token_endpoint_auth_signing_alg: Option, - initiate_login_uri: Option, - ) -> Result { - let now = clock.now(); - let id = Ulid::from_datetime_with_source(now.into(), rng); - tracing::Span::current().record("client.id", tracing::field::display(id)); - - let jwks_json = jwks - .as_ref() - .map(serde_json::to_value) - .transpose() - .map_err(DatabaseError::to_invalid_operation)?; - - sqlx::query!( - r#" - INSERT INTO oauth2_clients - ( oauth2_client_id - , encrypted_client_secret - , grant_type_authorization_code - , grant_type_refresh_token - , client_name - , logo_uri - , client_uri - , policy_uri - , tos_uri - , jwks_uri - , jwks - , id_token_signed_response_alg - , userinfo_signed_response_alg - , token_endpoint_auth_method - , token_endpoint_auth_signing_alg - , initiate_login_uri - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) - "#, - Uuid::from(id), - encrypted_client_secret, - grant_types.contains(&GrantType::AuthorizationCode), - grant_types.contains(&GrantType::RefreshToken), - client_name, - logo_uri.as_ref().map(Url::as_str), - client_uri.as_ref().map(Url::as_str), - policy_uri.as_ref().map(Url::as_str), - tos_uri.as_ref().map(Url::as_str), - jwks_uri.as_ref().map(Url::as_str), - jwks_json, - id_token_signed_response_alg - .as_ref() - .map(ToString::to_string), - userinfo_signed_response_alg - .as_ref() - .map(ToString::to_string), - token_endpoint_auth_method.as_ref().map(ToString::to_string), - token_endpoint_auth_signing_alg - .as_ref() - .map(ToString::to_string), - initiate_login_uri.as_ref().map(Url::as_str), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - { - let span = info_span!( - "db.oauth2_client.add.redirect_uris", - db.statement = tracing::field::Empty, - client.id = %id, - ); - - let (uri_ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &uri_ids, - Uuid::from(id), - &redirect_uris, - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - } - - let jwks = match (jwks, jwks_uri) { - (None, None) => None, - (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), - (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), - _ => return Err(DatabaseError::invalid_operation()), - }; - - Ok(Client { - id, - client_id: id.to_string(), - encrypted_client_secret, - redirect_uris, - response_types: vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::None, - ], - grant_types, - contacts, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - }) - } - - #[tracing::instrument( - name = "db.oauth2_client.add_from_config", - skip_all, - fields( - db.statement, - client.id = %client_id, - ), - err, - )] - async fn add_from_config( - &mut self, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - client_auth_method: OAuthClientAuthenticationMethod, - encrypted_client_secret: Option, - jwks: Option, - jwks_uri: Option, - redirect_uris: Vec, - ) -> Result { - let jwks_json = jwks - .as_ref() - .map(serde_json::to_value) - .transpose() - .map_err(DatabaseError::to_invalid_operation)?; - - let client_auth_method = client_auth_method.to_string(); - - sqlx::query!( - r#" - INSERT INTO oauth2_clients - ( oauth2_client_id - , encrypted_client_secret - , grant_type_authorization_code - , grant_type_refresh_token - , token_endpoint_auth_method - , jwks - , jwks_uri - ) - VALUES - ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (oauth2_client_id) - DO - UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret - , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code - , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token - , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method - , jwks = EXCLUDED.jwks - , jwks_uri = EXCLUDED.jwks_uri - "#, - Uuid::from(client_id), - encrypted_client_secret, - true, - true, - client_auth_method, - jwks_json, - jwks_uri.as_ref().map(Url::as_str), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - { - let span = info_span!( - "db.oauth2_client.add_from_config.redirect_uris", - client.id = %client_id, - db.statement = tracing::field::Empty, - ); - - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - } - - let jwks = match (jwks, jwks_uri) { - (None, None) => None, - (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), - (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), - _ => return Err(DatabaseError::invalid_operation()), - }; - - Ok(Client { - id: client_id, - client_id: client_id.to_string(), - encrypted_client_secret, - redirect_uris, - response_types: vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::None, - ], - grant_types: Vec::new(), - contacts: Vec::new(), - client_name: None, - logo_uri: None, - client_uri: None, - policy_uri: None, - tos_uri: None, - jwks, - id_token_signed_response_alg: None, - userinfo_signed_response_alg: None, - token_endpoint_auth_method: None, - token_endpoint_auth_signing_alg: None, - initiate_login_uri: None, - }) - } - - #[tracing::instrument( - name = "db.oauth2_client.get_consent_for_user", - skip_all, - fields( - db.statement, - %user.id, - %client.id, - ), - err, - )] - async fn get_consent_for_user( - &mut self, - client: &Client, - user: &User, - ) -> Result { - let scope_tokens: Vec = sqlx::query_scalar!( - r#" - SELECT scope_token - FROM oauth2_consents - WHERE user_id = $1 AND oauth2_client_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(client.id), - ) - .fetch_all(&mut *self.conn) - .await?; - - let scope: Result = scope_tokens - .into_iter() - .map(|s| ScopeToken::from_str(&s)) - .collect(); - - let scope = scope.map_err(|e| { - DatabaseInconsistencyError::on("oauth2_consents") - .column("scope_token") - .source(e) - })?; - - Ok(scope) - } - - #[tracing::instrument( - skip_all, - fields( - db.statement, - %user.id, - %client.id, - %scope, - ), - err, - )] - async fn give_consent_for_user( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - client: &Client, - user: &User, - scope: &Scope, - ) -> Result<(), Self::Error> { - let now = clock.now(); - let (tokens, ids): (Vec, Vec) = scope - .iter() - .map(|token| { - ( - token.to_string(), - Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)), - ) - }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_consents - (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at) - SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token) - ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5 - "#, - &ids, - Uuid::from(user.id), - Uuid::from(client.id), - &tokens, - now, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(()) - } -} diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 480c4515..eaa5e317 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,11 +19,7 @@ mod refresh_token; mod session; pub use self::{ - access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository}, - authorization_grant::{ - OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository, - }, - client::{OAuth2ClientRepository, PgOAuth2ClientRepository}, - refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository}, - session::{OAuth2SessionRepository, PgOAuth2SessionRepository}, + access_token::OAuth2AccessTokenRepository, + authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository, + refresh_token::OAuth2RefreshTokenRepository, session::OAuth2SessionRepository, }; diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 5d3bb013..1e23634a 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,14 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session}; +use mas_data_model::{AccessToken, RefreshToken, Session}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; #[async_trait] pub trait OAuth2RefreshTokenRepository: Send + Sync { @@ -52,203 +49,3 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { refresh_token: RefreshToken, ) -> Result; } - -pub struct PgOAuth2RefreshTokenRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2RefreshTokenRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct OAuth2RefreshTokenLookup { - oauth2_refresh_token_id: Uuid, - refresh_token: String, - created_at: DateTime, - consumed_at: Option>, - oauth2_access_token_id: Option, - oauth2_session_id: Uuid, -} - -impl From for RefreshToken { - fn from(value: OAuth2RefreshTokenLookup) -> Self { - let state = match value.consumed_at { - None => RefreshTokenState::Valid, - Some(consumed_at) => RefreshTokenState::Consumed { consumed_at }, - }; - - RefreshToken { - id: value.oauth2_refresh_token_id.into(), - state, - session_id: value.oauth2_session_id.into(), - refresh_token: value.refresh_token, - created_at: value.created_at, - access_token_id: value.oauth2_access_token_id.map(Ulid::from), - } - } -} - -#[async_trait] -impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_refresh_token.lookup", - skip_all, - fields( - db.statement, - refresh_token.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2RefreshTokenLookup, - r#" - SELECT oauth2_refresh_token_id - , refresh_token - , created_at - , consumed_at - , oauth2_access_token_id - , oauth2_session_id - FROM oauth2_refresh_tokens - - WHERE oauth2_refresh_token_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_refresh_token.find_by_token", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn find_by_token( - &mut self, - refresh_token: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuth2RefreshTokenLookup, - r#" - SELECT oauth2_refresh_token_id - , refresh_token - , created_at - , consumed_at - , oauth2_access_token_id - , oauth2_session_id - FROM oauth2_refresh_tokens - - WHERE refresh_token = $1 - "#, - refresh_token, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.oauth2_refresh_token.add", - skip_all, - fields( - db.statement, - %session.id, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - refresh_token.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - session: &Session, - access_token: &AccessToken, - refresh_token: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_refresh_tokens - (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id, - refresh_token, created_at) - VALUES - ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - Uuid::from(access_token.id), - refresh_token, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(RefreshToken { - id, - state: RefreshTokenState::default(), - session_id: session.id, - refresh_token, - access_token_id: Some(access_token.id), - created_at, - }) - } - - #[tracing::instrument( - name = "db.oauth2_refresh_token.consume", - skip_all, - fields( - db.statement, - %refresh_token.id, - session.id = %refresh_token.session_id, - ), - err, - )] - async fn consume( - &mut self, - clock: &Clock, - refresh_token: RefreshToken, - ) -> Result { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_refresh_tokens - SET consumed_at = $2 - WHERE oauth2_refresh_token_id = $1 - "#, - Uuid::from(refresh_token.id), - consumed_at, - ) - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - refresh_token - .consume(consumed_at) - .map_err(DatabaseError::to_invalid_operation) - } -} diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index dc21fbcb..5e6498d8 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait OAuth2SessionRepository: Send + Sync { @@ -48,224 +41,3 @@ pub trait OAuth2SessionRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } - -pub struct PgOAuth2SessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgOAuth2SessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct OAuthSessionLookup { - oauth2_session_id: Uuid, - user_session_id: Uuid, - oauth2_client_id: Uuid, - scope: String, - #[allow(dead_code)] - created_at: DateTime, - finished_at: Option>, -} - -impl TryFrom for Session { - type Error = DatabaseInconsistencyError; - - fn try_from(value: OAuthSessionLookup) -> Result { - let id = Ulid::from(value.oauth2_session_id); - let scope = value.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_sessions") - .column("scope") - .row(id) - .source(e) - })?; - - let state = match value.finished_at { - None => SessionState::Valid, - Some(finished_at) => SessionState::Finished { finished_at }, - }; - - Ok(Session { - id, - state, - created_at: value.created_at, - client_id: value.oauth2_client_id.into(), - user_session_id: value.user_session_id.into(), - scope, - }) - } -} - -#[async_trait] -impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.oauth2_session.lookup", - skip_all, - fields( - db.statement, - session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - OAuthSessionLookup, - r#" - SELECT oauth2_session_id - , user_session_id - , oauth2_client_id - , scope - , created_at - , finished_at - FROM oauth2_sessions - - WHERE oauth2_session_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(session) = res else { return Ok(None) }; - - Ok(Some(session.try_into()?)) - } - - #[tracing::instrument( - name = "db.oauth2_session.create_from_grant", - skip_all, - fields( - db.statement, - %user_session.id, - user.id = %user_session.user.id, - %grant.id, - client.id = %grant.client_id, - session.id, - session.scope = %grant.scope, - ), - err, - )] - async fn create_from_grant( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - grant: &AuthorizationGrant, - user_session: &BrowserSession, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO oauth2_sessions - ( oauth2_session_id - , user_session_id - , oauth2_client_id - , scope - , created_at - ) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - Uuid::from(grant.client_id), - grant.scope.to_string(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(Session { - id, - state: SessionState::Valid, - created_at, - user_session_id: user_session.id, - client_id: grant.client_id, - scope: grant.scope.clone(), - }) - } - - #[tracing::instrument( - name = "db.oauth2_session.finish", - skip_all, - fields( - db.statement, - %session.id, - %session.scope, - user_session.id = %session.user_session_id, - client.id = %session.client_id, - ), - err, - )] - async fn finish(&mut self, clock: &Clock, session: Session) -> Result { - let finished_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE oauth2_sessions - SET finished_at = $2 - WHERE oauth2_session_id = $1 - "#, - Uuid::from(session.id), - finished_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - session - .finish(finished_at) - .map_err(DatabaseError::to_invalid_operation) - } - - #[tracing::instrument( - name = "db.oauth2_session.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err, - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT oauth2_session_id - , user_session_id - , oauth2_client_id - , scope - , created_at - , finished_at - FROM oauth2_sessions os - "#, - ); - - query - .push(" WHERE us.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("oauth2_session_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).try_map(Session::try_from)?; - Ok(page) - } -} diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 1fa74dda..6af45641 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -14,10 +14,8 @@ //! Utilities to manage paginated queries. -use sqlx::{Database, QueryBuilder}; use thiserror::Error; use ulid::Ulid; -use uuid::Uuid; /// An error returned when invalid pagination parameters are provided #[derive(Debug, Error)] @@ -26,14 +24,14 @@ pub struct InvalidPagination; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Pagination { - before: Option, - after: Option, - count: usize, - direction: PaginationDirection, + pub before: Option, + pub after: Option, + pub count: usize, + pub direction: PaginationDirection, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum PaginationDirection { +pub enum PaginationDirection { Forward, Backward, } @@ -101,60 +99,8 @@ impl Pagination { self } - /// Add cursor-based pagination to a query, as used in paginated GraphQL - /// connections - fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str) - where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, - { - // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 - // 1. Start from the greedy query: SELECT * FROM table - - // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` - // clause - if let Some(after) = self.after { - query - .push(" AND ") - .push(id_field) - .push(" > ") - .push_bind(Uuid::from(after)); - } - - // 3. If the before argument is provided, add `id < parsed_cursor` to the - // `WHERE` clause - if let Some(before) = self.before { - query - .push(" AND ") - .push(id_field) - .push(" < ") - .push_bind(Uuid::from(before)); - } - - match self.direction { - // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the - // query - PaginationDirection::Forward => { - query - .push(" ORDER BY ") - .push(id_field) - .push(" ASC LIMIT ") - .push_bind((self.count + 1) as i64); - } - // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the - // query - PaginationDirection::Backward => { - query - .push(" ORDER BY ") - .push(id_field) - .push(" DESC LIMIT ") - .push_bind((self.count + 1) as i64); - } - }; - } - /// Process a page returned by a paginated query + #[must_use] pub fn process(&self, mut edges: Vec) -> Page { let is_full = edges.len() == (self.count + 1); if is_full { @@ -198,7 +144,6 @@ impl Page { } } - #[must_use] pub fn try_map(self, f: F) -> Result, E> where F: FnMut(T) -> Result, @@ -211,23 +156,3 @@ impl Page { }) } } - -/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination -/// to a query -pub trait QueryBuilderExt { - /// Add cursor-based pagination to a query, as used in paginated GraphQL - /// connections - fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self; -} - -impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> -where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, -{ - fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self { - pagination.generate_pagination(self, id_field); - self - } -} diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 1fde4b41..55afe41b 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,31 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlx::{PgPool, Postgres, Transaction}; - use crate::{ compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, - CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, - PgCompatSessionRepository, PgCompatSsoLoginRepository, + CompatSsoLoginRepository, }, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, - OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository, - PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository, - PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, upstream_oauth2::{ - PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, - PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository, - UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, }, - user::{ - BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository, - PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository, - UserRepository, - }, - DatabaseError, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, }; pub trait Repository: Send { @@ -126,109 +115,3 @@ pub trait Repository: Send { fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; } - -pub struct PgRepository { - txn: Transaction<'static, Postgres>, -} - -impl PgRepository { - pub async fn from_pool(pool: &PgPool) -> Result { - let txn = pool.begin().await?; - Ok(PgRepository { txn }) - } - - pub async fn save(self) -> Result<(), DatabaseError> { - self.txn.commit().await?; - Ok(()) - } - - pub async fn cancel(self) -> Result<(), DatabaseError> { - self.txn.rollback().await?; - Ok(()) - } -} - -impl Repository for PgRepository { - type Error = DatabaseError; - - type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; - type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; - type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; - type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; - type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; - type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; - type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; - type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; - type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; - type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; - type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; - type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; - type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; - type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; - type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(&mut self.txn) - } - - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(&mut self.txn) - } - - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(&mut self.txn) - } - - fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(&mut self.txn) - } - - fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(&mut self.txn) - } - - fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(&mut self.txn) - } - - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(&mut self.txn) - } - - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(&mut self.txn) - } - - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) - } - - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(&mut self.txn) - } - - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(&mut self.txn) - } - - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(&mut self.txn) - } - - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(&mut self.txn) - } - - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(&mut self.txn) - } - - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(&mut self.txn) - } - - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(&mut self.txn) - } -} diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 76364afe..bc20c6ea 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthLinkRepository: Send + Sync { @@ -63,241 +56,3 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } - -pub struct PgUpstreamOAuthLinkRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUpstreamOAuthLinkRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct LinkLookup { - upstream_oauth_link_id: Uuid, - upstream_oauth_provider_id: Uuid, - user_id: Option, - subject: String, - created_at: DateTime, -} - -impl From for UpstreamOAuthLink { - fn from(value: LinkLookup) -> Self { - UpstreamOAuthLink { - id: Ulid::from(value.upstream_oauth_link_id), - provider_id: Ulid::from(value.upstream_oauth_provider_id), - user_id: value.user_id.map(Ulid::from), - subject: value.subject, - created_at: value.created_at, - } - } -} - -#[async_trait] -impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.upstream_oauth_link.lookup", - skip_all, - fields( - db.statement, - upstream_oauth_link.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_link_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()? - .map(Into::into); - - Ok(res) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.find_by_subject", - skip_all, - fields( - db.statement, - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, - )] - async fn find_by_subject( - &mut self, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_provider_id = $1 - AND subject = $2 - "#, - Uuid::from(upstream_oauth_provider.id), - subject, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()? - .map(Into::into); - - Ok(res) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.add", - skip_all, - fields( - db.statement, - upstream_oauth_link.id, - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_links ( - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - ) VALUES ($1, $2, NULL, $3, $4) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &subject, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UpstreamOAuthLink { - id, - provider_id: upstream_oauth_provider.id, - user_id: None, - subject, - created_at, - }) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.associate_to_user", - skip_all, - fields( - db.statement, - %upstream_oauth_link.id, - %upstream_oauth_link.subject, - %user.id, - %user.username, - ), - err, - )] - async fn associate_to_user( - &mut self, - upstream_oauth_link: &UpstreamOAuthLink, - user: &User, - ) -> Result<(), Self::Error> { - sqlx::query!( - r#" - UPDATE upstream_oauth_links - SET user_id = $1 - WHERE upstream_oauth_link_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(upstream_oauth_link.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(()) - } - - #[tracing::instrument( - name = "db.upstream_oauth_link.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("upstream_oauth_link_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).map(UpstreamOAuthLink::from); - Ok(page) - } -} diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d1b6809f..1648a644 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,249 +17,6 @@ mod provider; mod session; pub use self::{ - link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, - provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository}, - session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository}, + link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository, + session::UpstreamOAuthSessionRepository, }; - -#[cfg(test)] -mod tests { - use chrono::Duration; - use oauth2_types::scope::{Scope, OPENID}; - use rand::SeedableRng; - use sqlx::PgPool; - - use super::*; - use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository}; - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_repository(pool: PgPool) { - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - // The provider list should be empty at the start - let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); - assert!(all_providers.is_empty()); - - // Let's add a provider - let provider = repo - .upstream_oauth_provider() - .add( - &mut rng, - &clock, - "https://example.com/".to_owned(), - Scope::from_iter([OPENID]), - mas_iana::oauth::OAuthClientAuthenticationMethod::None, - None, - "client-id".to_owned(), - None, - ) - .await - .unwrap(); - - // Look it up in the database - let provider = repo - .upstream_oauth_provider() - .lookup(provider.id) - .await - .unwrap() - .expect("provider to be found in the database"); - assert_eq!(provider.issuer, "https://example.com/"); - assert_eq!(provider.client_id, "client-id"); - - // Start a session - let session = repo - .upstream_oauth_session() - .add( - &mut rng, - &clock, - &provider, - "some-state".to_owned(), - None, - "some-nonce".to_owned(), - ) - .await - .unwrap(); - - // Look it up in the database - let session = repo - .upstream_oauth_session() - .lookup(session.id) - .await - .unwrap() - .expect("session to be found in the database"); - assert_eq!(session.provider_id, provider.id); - assert_eq!(session.link_id(), None); - assert!(session.is_pending()); - assert!(!session.is_completed()); - assert!(!session.is_consumed()); - - // Create a link - let link = repo - .upstream_oauth_link() - .add(&mut rng, &clock, &provider, "a-subject".to_owned()) - .await - .unwrap(); - - // We can look it up by its ID - repo.upstream_oauth_link() - .lookup(link.id) - .await - .unwrap() - .expect("link to be found in database"); - - // or by its subject - let link = repo - .upstream_oauth_link() - .find_by_subject(&provider, "a-subject") - .await - .unwrap() - .expect("link to be found in database"); - assert_eq!(link.subject, "a-subject"); - assert_eq!(link.provider_id, provider.id); - - let session = repo - .upstream_oauth_session() - .complete_with_link(&clock, session, &link, None) - .await - .unwrap(); - // Reload the session - let session = repo - .upstream_oauth_session() - .lookup(session.id) - .await - .unwrap() - .expect("session to be found in the database"); - assert!(session.is_completed()); - assert!(!session.is_consumed()); - assert_eq!(session.link_id(), Some(link.id)); - - let session = repo - .upstream_oauth_session() - .consume(&clock, session) - .await - .unwrap(); - // Reload the session - let session = repo - .upstream_oauth_session() - .lookup(session.id) - .await - .unwrap() - .expect("session to be found in the database"); - assert!(session.is_consumed()); - - let user = repo - .user() - .add(&mut rng, &clock, "john".to_owned()) - .await - .unwrap(); - repo.upstream_oauth_link() - .associate_to_user(&link, &user) - .await - .unwrap(); - - let links = repo - .upstream_oauth_link() - .list_paginated(&user, Pagination::first(10)) - .await - .unwrap(); - assert!(!links.has_previous_page); - assert!(!links.has_next_page); - assert_eq!(links.edges.len(), 1); - assert_eq!(links.edges[0].id, link.id); - assert_eq!(links.edges[0].user_id, Some(user.id)); - } - - #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_provider_repository_pagination(pool: PgPool) { - const ISSUER: &str = "https://example.com/"; - let scope = Scope::from_iter([OPENID]); - - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); - let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); - - let mut ids = Vec::with_capacity(20); - // Create 20 providers - for idx in 0..20 { - let client_id = format!("client-{idx}"); - let provider = repo - .upstream_oauth_provider() - .add( - &mut rng, - &clock, - ISSUER.to_owned(), - scope.clone(), - mas_iana::oauth::OAuthClientAuthenticationMethod::None, - None, - client_id, - None, - ) - .await - .unwrap(); - ids.push(provider.id); - clock.advance(Duration::seconds(10)); - } - - // Lookup the first 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::first(10)) - .await - .unwrap(); - - // It returned the first 10 items - assert!(page.has_next_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[..10]); - - // Lookup the next 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::first(10).after(ids[9])) - .await - .unwrap(); - - // It returned the next 10 items - assert!(!page.has_next_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[10..]); - - // Lookup the last 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::last(10)) - .await - .unwrap(); - - // It returned the last 10 items - assert!(page.has_previous_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[10..]); - - // Lookup the previous 10 items - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::last(10).before(ids[10])) - .await - .unwrap(); - - // It returned the previous 10 items - assert!(!page.has_previous_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[..10]); - - // Lookup 10 items between two IDs - let page = repo - .upstream_oauth_provider() - .list_paginated(Pagination::first(10).after(ids[5]).before(ids[8])) - .await - .unwrap(); - - // It returned the items in between - assert!(!page.has_next_page); - let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); - assert_eq!(&edge_ids, &ids[6..8]); - } -} diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 14bd6547..4be8f127 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,20 +13,13 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::UpstreamOAuthProvider; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use oauth2_types::scope::Scope; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthProviderRepository: Send + Sync { @@ -58,247 +51,3 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { /// Get all upstream OAuth providers async fn all(&mut self) -> Result, Self::Error>; } - -pub struct PgUpstreamOAuthProviderRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUpstreamOAuthProviderRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct ProviderLookup { - upstream_oauth_provider_id: Uuid, - issuer: String, - scope: String, - client_id: String, - encrypted_client_secret: Option, - token_endpoint_signing_alg: Option, - token_endpoint_auth_method: String, - created_at: DateTime, -} - -impl TryFrom for UpstreamOAuthProvider { - type Error = DatabaseInconsistencyError; - fn try_from(value: ProviderLookup) -> Result { - let id = value.upstream_oauth_provider_id.into(); - let scope = value.scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("scope") - .row(id) - .source(e) - })?; - let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_auth_method") - .row(id) - .source(e) - })?; - let token_endpoint_signing_alg = value - .token_endpoint_signing_alg - .map(|x| x.parse()) - .transpose() - .map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("token_endpoint_signing_alg") - .row(id) - .source(e) - })?; - - Ok(UpstreamOAuthProvider { - id, - issuer: value.issuer, - scope, - client_id: value.client_id, - encrypted_client_secret: value.encrypted_client_secret, - token_endpoint_auth_method, - token_endpoint_signing_alg, - created_at: value.created_at, - }) - } -} - -#[async_trait] -impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.upstream_oauth_provider.lookup", - skip_all, - fields( - db.statement, - upstream_oauth_provider.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let res = res - .map(UpstreamOAuthProvider::try_from) - .transpose() - .map_err(DatabaseError::from)?; - - Ok(res) - } - - #[tracing::instrument( - name = "db.upstream_oauth_provider.add", - skip_all, - fields( - db.statement, - upstream_oauth_provider.id, - upstream_oauth_provider.issuer = %issuer, - upstream_oauth_provider.client_id = %client_id, - ), - err, - )] - #[allow(clippy::too_many_arguments)] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - issuer: String, - scope: Scope, - token_endpoint_auth_method: OAuthClientAuthenticationMethod, - token_endpoint_signing_alg: Option, - client_id: String, - encrypted_client_secret: Option, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_providers ( - upstream_oauth_provider_id, - issuer, - scope, - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id, - encrypted_client_secret, - created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - "#, - Uuid::from(id), - &issuer, - scope.to_string(), - token_endpoint_auth_method.to_string(), - token_endpoint_signing_alg.as_ref().map(ToString::to_string), - &client_id, - encrypted_client_secret.as_deref(), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UpstreamOAuthProvider { - id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at, - }) - } - - #[tracing::instrument( - name = "db.upstream_oauth_provider.list_paginated", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn list_paginated( - &mut self, - pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE 1 = 1 - "#, - ); - - query.generate_pagination("upstream_oauth_provider_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).try_map(TryInto::try_into)?; - Ok(page) - } - - #[tracing::instrument( - name = "db.upstream_oauth_provider.all", - skip_all, - fields( - db.statement, - ), - err, - )] - async fn all(&mut self) -> Result, Self::Error> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - "#, - ) - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); - Ok(res?) - } -} diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index d5da6ef8..4d41a8ec 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,19 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, - UpstreamOAuthProvider, -}; +use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -64,262 +56,3 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result; } - -pub struct PgUpstreamOAuthSessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUpstreamOAuthSessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct SessionLookup { - upstream_oauth_authorization_session_id: Uuid, - upstream_oauth_provider_id: Uuid, - upstream_oauth_link_id: Option, - state: String, - code_challenge_verifier: Option, - nonce: String, - id_token: Option, - created_at: DateTime, - completed_at: Option>, - consumed_at: Option>, -} - -impl TryFrom for UpstreamOAuthAuthorizationSession { - type Error = DatabaseInconsistencyError; - - fn try_from(value: SessionLookup) -> Result { - let id = value.upstream_oauth_authorization_session_id.into(); - let state = match ( - value.upstream_oauth_link_id, - value.id_token, - value.completed_at, - value.consumed_at, - ) { - (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, - (Some(link_id), id_token, Some(completed_at), None) => { - UpstreamOAuthAuthorizationSessionState::Completed { - completed_at, - link_id: link_id.into(), - id_token, - } - } - (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { - UpstreamOAuthAuthorizationSessionState::Consumed { - completed_at, - link_id: link_id.into(), - id_token, - consumed_at, - } - } - _ => { - return Err( - DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), - ) - } - }; - - Ok(Self { - id, - provider_id: value.upstream_oauth_provider_id.into(), - state_str: value.state, - nonce: value.nonce, - code_challenge_verifier: value.code_challenge_verifier, - created_at: value.created_at, - state, - }) - } -} - -#[async_trait] -impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.lookup", - skip_all, - fields( - db.statement, - upstream_oauth_provider.id = %id, - ), - err, - )] - async fn lookup( - &mut self, - id: Ulid, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - upstream_oauth_link_id, - state, - code_challenge_verifier, - nonce, - id_token, - created_at, - completed_at, - consumed_at - FROM upstream_oauth_authorization_sessions - WHERE upstream_oauth_authorization_session_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.add", - skip_all, - fields( - db.statement, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - upstream_oauth_authorization_session.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - state_str: String, - code_challenge_verifier: Option, - nonce: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record( - "upstream_oauth_authorization_session.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO upstream_oauth_authorization_sessions ( - upstream_oauth_authorization_session_id, - upstream_oauth_provider_id, - state, - code_challenge_verifier, - nonce, - created_at, - completed_at, - consumed_at, - id_token - ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &state_str, - code_challenge_verifier.as_deref(), - nonce, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UpstreamOAuthAuthorizationSession { - id, - state: UpstreamOAuthAuthorizationSessionState::default(), - provider_id: upstream_oauth_provider.id, - state_str, - code_challenge_verifier, - nonce, - created_at, - }) - } - - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.complete_with_link", - skip_all, - fields( - db.statement, - %upstream_oauth_authorization_session.id, - %upstream_oauth_link.id, - ), - err, - )] - async fn complete_with_link( - &mut self, - clock: &Clock, - upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, - upstream_oauth_link: &UpstreamOAuthLink, - id_token: Option, - ) -> Result { - let completed_at = clock.now(); - - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET upstream_oauth_link_id = $1, - completed_at = $2, - id_token = $3 - WHERE upstream_oauth_authorization_session_id = $4 - "#, - Uuid::from(upstream_oauth_link.id), - completed_at, - id_token, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let upstream_oauth_authorization_session = upstream_oauth_authorization_session - .complete(completed_at, upstream_oauth_link, id_token) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(upstream_oauth_authorization_session) - } - - /// Mark a session as consumed - #[tracing::instrument( - name = "db.upstream_oauth_authorization_session.consume", - skip_all, - fields( - db.statement, - %upstream_oauth_authorization_session.id, - ), - err, - )] - async fn consume( - &mut self, - clock: &Clock, - upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, - ) -> Result { - let consumed_at = clock.now(); - sqlx::query!( - r#" - UPDATE upstream_oauth_authorization_sessions - SET consumed_at = $1 - WHERE upstream_oauth_authorization_session_id = $2 - "#, - consumed_at, - Uuid::from(upstream_oauth_authorization_session.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let upstream_oauth_authorization_session = upstream_oauth_authorization_session - .consume(consumed_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(upstream_oauth_authorization_session) - } -} diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 8c8efe1b..41a7d293 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,19 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; +use mas_data_model::{User, UserEmail, UserEmailVerification}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; -use tracing::{info_span, Instrument}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait UserEmailRepository: Send + Sync { @@ -82,529 +74,3 @@ pub trait UserEmailRepository: Send + Sync { verification: UserEmailVerification, ) -> Result; } - -pub struct PgUserEmailRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUserEmailRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(Debug, Clone, sqlx::FromRow)] -struct UserEmailLookup { - user_email_id: Uuid, - user_id: Uuid, - email: String, - created_at: DateTime, - confirmed_at: Option>, -} - -impl From for UserEmail { - fn from(e: UserEmailLookup) -> UserEmail { - UserEmail { - id: e.user_email_id.into(), - user_id: e.user_id.into(), - email: e.email, - created_at: e.created_at, - confirmed_at: e.confirmed_at, - } - } -} - -struct UserEmailConfirmationCodeLookup { - user_email_confirmation_code_id: Uuid, - user_email_id: Uuid, - code: String, - created_at: DateTime, - expires_at: DateTime, - consumed_at: Option>, -} - -impl UserEmailConfirmationCodeLookup { - fn into_verification(self, clock: &Clock) -> UserEmailVerification { - let now = clock.now(); - let state = if let Some(when) = self.consumed_at { - UserEmailVerificationState::AlreadyUsed { when } - } else if self.expires_at < now { - UserEmailVerificationState::Expired { - when: self.expires_at, - } - } else { - UserEmailVerificationState::Valid - }; - - UserEmailVerification { - id: self.user_email_confirmation_code_id.into(), - user_email_id: self.user_email_id.into(), - code: self.code, - state, - created_at: self.created_at, - } - } -} - -#[async_trait] -impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.user_email.lookup", - skip_all, - fields( - db.statement, - user_email.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - - WHERE user_email_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(user_email) = res else { return Ok(None) }; - - Ok(Some(user_email.into())) - } - - #[tracing::instrument( - name = "db.user_email.find", - skip_all, - fields( - db.statement, - %user.id, - user_email.email = email, - ), - err, - )] - async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - - WHERE user_id = $1 AND email = $2 - "#, - Uuid::from(user.id), - email, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(user_email) = res else { return Ok(None) }; - - Ok(Some(user_email.into())) - } - - #[tracing::instrument( - name = "db.user_email.get_primary", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { - let Some(id) = user.primary_user_email_id else { return Ok(None) }; - - let user_email = self.lookup(id).await?.ok_or_else(|| { - DatabaseInconsistencyError::on("users") - .column("primary_user_email_id") - .row(user.id) - })?; - - Ok(Some(user_email)) - } - - #[tracing::instrument( - name = "db.user_email.all", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn all(&mut self, user: &User) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailLookup, - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - - WHERE user_id = $1 - - ORDER BY email ASC - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_all(&mut *self.conn) - .await?; - - Ok(res.into_iter().map(Into::into).collect()) - } - - #[tracing::instrument( - name = "db.user_email.list_paginated", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn list_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT user_email_id - , user_id - , email - , created_at - , confirmed_at - FROM user_emails - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("ue.user_email_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination.process(edges).map(UserEmail::from); - Ok(page) - } - - #[tracing::instrument( - name = "db.user_email.count", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn count(&mut self, user: &User) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) - FROM user_emails - WHERE user_id = $1 - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await?; - - let res = res.unwrap_or_default(); - - Ok(res - .try_into() - .map_err(DatabaseError::to_invalid_operation)?) - } - - #[tracing::instrument( - name = "db.user_email.add", - skip_all, - fields( - db.statement, - %user.id, - user_email.id, - user_email.email = email, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - email: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_email.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_emails (user_email_id, user_id, email, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - &email, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(UserEmail { - id, - user_id: user.id, - email, - created_at, - confirmed_at: None, - }) - } - - #[tracing::instrument( - name = "db.user_email.remove", - skip_all, - fields( - db.statement, - user.id = %user_email.user_id, - %user_email.id, - %user_email.email, - ), - err, - )] - async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { - let span = info_span!( - "db.user_email.remove.codes", - db.statement = tracing::field::Empty - ); - sqlx::query!( - r#" - DELETE FROM user_email_confirmation_codes - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - - let res = sqlx::query!( - r#" - DELETE FROM user_emails - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(()) - } - - async fn mark_as_verified( - &mut self, - clock: &Clock, - mut user_email: UserEmail, - ) -> Result { - let confirmed_at = clock.now(); - sqlx::query!( - r#" - UPDATE user_emails - SET confirmed_at = $2 - WHERE user_email_id = $1 - "#, - Uuid::from(user_email.id), - confirmed_at, - ) - .execute(&mut *self.conn) - .await?; - - user_email.confirmed_at = Some(confirmed_at); - Ok(user_email) - } - - async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> { - sqlx::query!( - r#" - UPDATE users - SET primary_user_email_id = user_emails.user_email_id - FROM user_emails - WHERE user_emails.user_email_id = $1 - AND users.user_id = user_emails.user_id - "#, - Uuid::from(user_email.id), - ) - .execute(&mut *self.conn) - .await?; - - Ok(()) - } - - #[tracing::instrument( - name = "db.user_email.add_verification_code", - skip_all, - fields( - db.statement, - %user_email.id, - %user_email.email, - user_email_verification.id, - user_email_verification.code = code, - ), - err, - )] - async fn add_verification_code( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user_email: &UserEmail, - max_age: chrono::Duration, - code: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); - let expires_at = created_at + max_age; - - sqlx::query!( - r#" - INSERT INTO user_email_confirmation_codes - (user_email_confirmation_code_id, user_email_id, code, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(user_email.id), - code, - created_at, - expires_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let verification = UserEmailVerification { - id, - user_email_id: user_email.id, - code, - created_at, - state: UserEmailVerificationState::Valid, - }; - - Ok(verification) - } - - #[tracing::instrument( - name = "db.user_email.find_verification_code", - skip_all, - fields( - db.statement, - %user_email.id, - user.id = %user_email.user_id, - ), - err, - )] - async fn find_verification_code( - &mut self, - clock: &Clock, - user_email: &UserEmail, - code: &str, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserEmailConfirmationCodeLookup, - r#" - SELECT user_email_confirmation_code_id - , user_email_id - , code - , created_at - , expires_at - , consumed_at - FROM user_email_confirmation_codes - WHERE code = $1 - AND user_email_id = $2 - "#, - code, - Uuid::from(user_email.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into_verification(clock))) - } - - #[tracing::instrument( - name = "db.user_email.consume_verification_code", - skip_all, - fields( - db.statement, - %user_email_verification.id, - user_email.id = %user_email_verification.user_email_id, - ), - err, - )] - async fn consume_verification_code( - &mut self, - clock: &Clock, - mut user_email_verification: UserEmailVerification, - ) -> Result { - if !matches!( - user_email_verification.state, - UserEmailVerificationState::Valid - ) { - return Err(DatabaseError::invalid_operation()); - } - - let consumed_at = clock.now(); - - sqlx::query!( - r#" - UPDATE user_email_confirmation_codes - SET consumed_at = $2 - WHERE user_email_confirmation_code_id = $1 - "#, - Uuid::from(user_email_verification.id), - consumed_at - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_email_verification.state = - UserEmailVerificationState::AlreadyUsed { when: consumed_at }; - - Ok(user_email_verification) - } -} diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 9dd3d2ca..23c2f6d1 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -13,26 +13,18 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::User; use rand::RngCore; -use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::Clock; mod email; mod password; mod session; -#[cfg(test)] -mod tests; - pub use self::{ - email::{PgUserEmailRepository, UserEmailRepository}, - password::{PgUserPasswordRepository, UserPasswordRepository}, - session::{BrowserSessionRepository, PgBrowserSessionRepository}, + email::UserEmailRepository, password::UserPasswordRepository, session::BrowserSessionRepository, }; #[async_trait] @@ -49,170 +41,3 @@ pub trait UserRepository: Send + Sync { ) -> Result; async fn exists(&mut self, username: &str) -> Result; } - -pub struct PgUserRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUserRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(Debug, Clone)] -struct UserLookup { - user_id: Uuid, - username: String, - primary_user_email_id: Option, - - #[allow(dead_code)] - created_at: DateTime, -} - -impl From for User { - fn from(value: UserLookup) -> Self { - let id = value.user_id.into(); - Self { - id, - username: value.username, - sub: id.to_string(), - primary_user_email_id: value.primary_user_email_id.map(Into::into), - } - } -} - -#[async_trait] -impl<'c> UserRepository for PgUserRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.user.lookup", - skip_all, - fields( - db.statement, - user.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT user_id - , username - , primary_user_email_id - , created_at - FROM users - WHERE user_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.user.find_by_username", - skip_all, - fields( - db.statement, - user.username = username, - ), - err, - )] - async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserLookup, - r#" - SELECT user_id - , username - , primary_user_email_id - , created_at - FROM users - WHERE username = $1 - "#, - username, - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) - } - - #[tracing::instrument( - name = "db.user.add", - skip_all, - fields( - db.statement, - user.username = username, - user.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - username: String, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO users (user_id, username, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - username, - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(User { - id, - username, - sub: id.to_string(), - primary_user_email_id: None, - }) - } - - #[tracing::instrument( - name = "db.user.exists", - skip_all, - fields( - db.statement, - user.username = username, - ), - err, - )] - async fn exists(&mut self, username: &str) -> Result { - let exists = sqlx::query_scalar!( - r#" - SELECT EXISTS( - SELECT 1 FROM users WHERE username = $1 - ) AS "exists!" - "#, - username - ) - .traced() - .fetch_one(&mut *self.conn) - .await?; - - Ok(exists) - } -} diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 56c8a439..2d2d2534 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,16 +13,10 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; use mas_data_model::{Password, User}; use rand::RngCore; -use sqlx::PgConnection; -use ulid::Ulid; -use uuid::Uuid; -use crate::{ - tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; +use crate::Clock; #[async_trait] pub trait UserPasswordRepository: Send + Sync { @@ -39,134 +33,3 @@ pub trait UserPasswordRepository: Send + Sync { upgraded_from: Option<&Password>, ) -> Result; } - -pub struct PgUserPasswordRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgUserPasswordRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -struct UserPasswordLookup { - user_password_id: Uuid, - hashed_password: String, - version: i32, - upgraded_from_id: Option, - created_at: DateTime, -} - -#[async_trait] -impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.user_password.active", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - ), - err, - )] - async fn active(&mut self, user: &User) -> Result, Self::Error> { - let res = sqlx::query_as!( - UserPasswordLookup, - r#" - SELECT up.user_password_id - , up.hashed_password - , up.version - , up.upgraded_from_id - , up.created_at - FROM user_passwords up - WHERE up.user_id = $1 - ORDER BY up.created_at DESC - LIMIT 1 - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = Ulid::from(res.user_password_id); - - let version = res.version.try_into().map_err(|e| { - DatabaseInconsistencyError::on("user_passwords") - .column("version") - .row(id) - .source(e) - })?; - - let upgraded_from_id = res.upgraded_from_id.map(Ulid::from); - let created_at = res.created_at; - let hashed_password = res.hashed_password; - - Ok(Some(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - })) - } - - #[tracing::instrument( - name = "db.user_password.add", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - user_password.id, - user_password.version = version, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - version: u16, - hashed_password: String, - upgraded_from: Option<&Password>, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_password.id", tracing::field::display(id)); - - let upgraded_from_id = upgraded_from.map(|p| p.id); - - sqlx::query!( - r#" - INSERT INTO user_passwords - (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at) - VALUES ($1, $2, $3, $4, $5, $6) - "#, - Uuid::from(id), - Uuid::from(user.id), - hashed_password, - i32::from(version), - upgraded_from_id.map(Uuid::from), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - Ok(Password { - id, - hashed_password, - version, - upgraded_from_id, - created_at, - }) - } -} diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 10b96da7..2e55f40c 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ // limitations under the License. use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; -use uuid::Uuid; -use crate::{ - pagination::{Page, QueryBuilderExt}, - tracing::ExecuteExt, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination, -}; +use crate::{pagination::Page, Clock, Pagination}; #[async_trait] pub trait BrowserSessionRepository: Send + Sync { @@ -65,351 +58,3 @@ pub trait BrowserSessionRepository: Send + Sync { upstream_oauth_link: &UpstreamOAuthLink, ) -> Result; } - -pub struct PgBrowserSessionRepository<'c> { - conn: &'c mut PgConnection, -} - -impl<'c> PgBrowserSessionRepository<'c> { - pub fn new(conn: &'c mut PgConnection) -> Self { - Self { conn } - } -} - -#[derive(sqlx::FromRow)] -struct SessionLookup { - user_session_id: Uuid, - user_session_created_at: DateTime, - user_session_finished_at: Option>, - user_id: Uuid, - user_username: String, - user_primary_user_email_id: Option, - last_authentication_id: Option, - last_authd_at: Option>, -} - -impl TryFrom for BrowserSession { - type Error = DatabaseInconsistencyError; - - fn try_from(value: SessionLookup) -> Result { - let id = Ulid::from(value.user_id); - let user = User { - id, - username: value.user_username, - sub: id.to_string(), - primary_user_email_id: value.user_primary_user_email_id.map(Into::into), - }; - - let last_authentication = match (value.last_authentication_id, value.last_authd_at) { - (Some(id), Some(created_at)) => Some(Authentication { - id: id.into(), - created_at, - }), - (None, None) => None, - _ => { - return Err(DatabaseInconsistencyError::on( - "user_session_authentications", - )) - } - }; - - Ok(BrowserSession { - id: value.user_session_id.into(), - user, - created_at: value.user_session_created_at, - finished_at: value.user_session_finished_at, - last_authentication, - }) - } -} - -#[async_trait] -impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { - type Error = DatabaseError; - - #[tracing::instrument( - name = "db.browser_session.lookup", - skip_all, - fields( - db.statement, - user_session.id = %id, - ), - err, - )] - async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { - let res = sqlx::query_as!( - SessionLookup, - r#" - SELECT s.user_session_id - , s.created_at AS "user_session_created_at" - , s.finished_at AS "user_session_finished_at" - , u.user_id - , u.username AS "user_username" - , u.primary_user_email_id AS "user_primary_user_email_id" - , a.user_session_authentication_id AS "last_authentication_id?" - , a.created_at AS "last_authd_at?" - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - WHERE s.user_session_id = $1 - ORDER BY a.created_at DESC - LIMIT 1 - "#, - Uuid::from(id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - - #[tracing::instrument( - name = "db.browser_session.add", - skip_all, - fields( - db.statement, - %user.id, - user_session.id, - ), - err, - )] - async fn add( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - user: &User, - ) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record("user_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO user_sessions (user_session_id, user_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user.id), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - let session = BrowserSession { - id, - // XXX - user: user.clone(), - created_at, - finished_at: None, - last_authentication: None, - }; - - Ok(session) - } - - #[tracing::instrument( - name = "db.browser_session.finish", - skip_all, - fields( - db.statement, - %user_session.id, - ), - err, - )] - async fn finish( - &mut self, - clock: &Clock, - mut user_session: BrowserSession, - ) -> Result { - let finished_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE user_sessions - SET finished_at = $1 - WHERE user_session_id = $2 - "#, - finished_at, - Uuid::from(user_session.id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_session.finished_at = Some(finished_at); - - DatabaseError::ensure_affected_rows(&res, 1)?; - - Ok(user_session) - } - - #[tracing::instrument( - name = "db.browser_session.list_active_paginated", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn list_active_paginated( - &mut self, - user: &User, - pagination: Pagination, - ) -> Result, Self::Error> { - // TODO: ordering of last authentication is wrong - let mut query = QueryBuilder::new( - r#" - SELECT DISTINCT ON (s.user_session_id) - s.user_session_id, - u.user_id, - u.username, - s.created_at, - a.user_session_authentication_id AS "last_authentication_id", - a.created_at AS "last_authd_at", - FROM user_sessions s - INNER JOIN users u - USING (user_id) - LEFT JOIN user_session_authentications a - USING (user_session_id) - "#, - ); - - query - .push(" WHERE s.finished_at IS NULL AND s.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("s.user_session_id", pagination); - - let edges: Vec = query - .build_query_as() - .traced() - .fetch_all(&mut *self.conn) - .await?; - - let page = pagination - .process(edges) - .try_map(BrowserSession::try_from)?; - Ok(page) - } - - #[tracing::instrument( - name = "db.browser_session.count_active", - skip_all, - fields( - db.statement, - %user.id, - ), - err, - )] - async fn count_active(&mut self, user: &User) -> Result { - let res = sqlx::query_scalar!( - r#" - SELECT COUNT(*) as "count!" - FROM user_sessions s - WHERE s.user_id = $1 AND s.finished_at IS NULL - "#, - Uuid::from(user.id), - ) - .traced() - .fetch_one(&mut *self.conn) - .await?; - - res.try_into().map_err(DatabaseError::to_invalid_operation) - } - - #[tracing::instrument( - name = "db.browser_session.authenticate_with_password", - skip_all, - fields( - db.statement, - %user_session.id, - %user_password.id, - user_session_authentication.id, - ), - err, - )] - async fn authenticate_with_password( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - mut user_session: BrowserSession, - user_password: &Password, - ) -> Result { - let _user_password = user_password; - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(user_session) - } - - #[tracing::instrument( - name = "db.browser_session.authenticate_with_upstream", - skip_all, - fields( - db.statement, - %user_session.id, - %upstream_oauth_link.id, - user_session_authentication.id, - ), - err, - )] - async fn authenticate_with_upstream( - &mut self, - rng: &mut (dyn RngCore + Send), - clock: &Clock, - mut user_session: BrowserSession, - upstream_oauth_link: &UpstreamOAuthLink, - ) -> Result { - let _upstream_oauth_link = upstream_oauth_link; - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), rng); - tracing::Span::current().record( - "user_session_authentication.id", - tracing::field::display(id), - ); - - sqlx::query!( - r#" - INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) - "#, - Uuid::from(id), - Uuid::from(user_session.id), - created_at, - ) - .traced() - .execute(&mut *self.conn) - .await?; - - user_session.last_authentication = Some(Authentication { id, created_at }); - - Ok(user_session) - } -} diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index b82e16c2..99270a72 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -14,3 +14,4 @@ tracing = "0.1.37" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } mas-storage = { path = "../storage" } +mas-storage-pg = { path = "../storage-pg" } diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 39a33b8d..66068860 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,8 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, PgRepository, Repository}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository}; +use mas_storage_pg::PgRepository; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info};