diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index bda7ee2c..09a70023 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -179,6 +179,7 @@ pub mod compat; pub mod oauth2; pub(crate) mod pagination; pub(crate) mod repository; +pub(crate) mod tracing; pub mod upstream_oauth2; pub mod user; diff --git a/crates/storage/src/tracing.rs b/crates/storage/src/tracing.rs new file mode 100644 index 00000000..60eb284c --- /dev/null +++ b/crates/storage/src/tracing.rs @@ -0,0 +1,29 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub trait ExecuteExt<'q, DB> { + /// Records the statement as `db.statement` in the current span + fn traced(self) -> Self; +} + +impl<'q, DB, T> ExecuteExt<'q, DB> for T +where + T: sqlx::Execute<'q, DB>, + DB: sqlx::Database, +{ + fn traced(self) -> Self { + tracing::Span::current().record("db.statement", self.sql()); + self + } +} diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 100e9833..0d443671 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -17,12 +17,12 @@ use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; -use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt, }; @@ -103,8 +103,12 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { type Error = DatabaseError; #[tracing::instrument( + name = "db.upstream_oauth_link.lookup", skip_all, - fields(upstream_oauth_link.id = %id), + fields( + db.statement, + upstream_oauth_link.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -122,6 +126,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()? @@ -131,8 +136,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.find_by_subject", skip_all, fields( + db.statement, upstream_oauth_link.subject = subject, %upstream_oauth_provider.id, %upstream_oauth_provider.issuer, @@ -161,6 +168,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { Uuid::from(upstream_oauth_provider.id), subject, ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()? @@ -170,8 +178,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.add", skip_all, fields( + db.statement, upstream_oauth_link.id, upstream_oauth_link.subject = subject, %upstream_oauth_provider.id, @@ -206,6 +216,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { &subject, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -219,8 +230,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.associate_to_user", skip_all, fields( + db.statement, %upstream_oauth_link.id, %upstream_oauth_link.subject, %user.id, @@ -242,6 +255,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { Uuid::from(user.id), Uuid::from(upstream_oauth_link.id), ) + .traced() .execute(&mut *self.conn) .await?; @@ -249,8 +263,13 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( + name = "db.upstream_oauth_link.list_paginated", skip_all, - fields(%user.id, %user.username), + fields( + db.statement, + %user.id, + %user.username, + ), err )] async fn list_paginated( @@ -278,14 +297,10 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { .push_bind(Uuid::from(user.id)) .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 user links", - db.statement = query.sql() - ); let page: Vec = query .build_query_as() + .traced() .fetch_all(&mut *self.conn) - .instrument(span) .await?; let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 3d8ba141..a7efb6c8 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -19,12 +19,12 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod use oauth2_types::scope::Scope; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; -use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -129,8 +129,12 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' type Error = DatabaseError; #[tracing::instrument( + name = "db.upstream_oauth_provider.lookup", skip_all, - fields(upstream_oauth_provider.id = %id), + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -151,6 +155,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -164,8 +169,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' } #[tracing::instrument( + name = "db.upstream_oauth_provider.add", skip_all, fields( + db.statement, upstream_oauth_provider.id, upstream_oauth_provider.issuer = %issuer, upstream_oauth_provider.client_id = %client_id, @@ -210,6 +217,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' encrypted_client_secret.as_deref(), created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -225,6 +233,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' }) } + #[tracing::instrument( + name = "db.upstream_oauth_provider.list_paginated", + skip_all, + fields( + db.statement, + ), + err, + )] async fn list_paginated( &mut self, before: Option, @@ -250,14 +266,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 providers", - db.statement = query.sql() - ); let page: Vec = query .build_query_as() + .traced() .fetch_all(&mut *self.conn) - .instrument(span) .await?; let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; @@ -269,7 +281,15 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' edges: edges?, }) } - #[tracing::instrument(skip_all, err)] + + #[tracing::instrument( + name = "db.upstream_oauth_provider.all", + skip_all, + fields( + db.statement, + ), + err, + )] async fn all(&mut self) -> Result, Self::Error> { let res = sqlx::query_as!( ProviderLookup, @@ -286,6 +306,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' FROM upstream_oauth_providers "#, ) + .traced() .fetch_all(&mut *self.conn) .await?; diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index f8dffcf3..f13c6ec8 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -20,7 +20,7 @@ use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseError, LookupResultExt}; +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -88,8 +88,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> type Error = DatabaseError; #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.lookup", skip_all, - fields(upstream_oauth_provider.id = %id), + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), err, )] async fn lookup( @@ -115,6 +119,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -138,8 +143,10 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> } #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.add", skip_all, fields( + db.statement, %upstream_oauth_provider.id, %upstream_oauth_provider.issuer, %upstream_oauth_provider.client_id, @@ -184,6 +191,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> nonce, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -202,8 +210,10 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> } #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.complete_with_link", skip_all, fields( + db.statement, %upstream_oauth_authorization_session.id, %upstream_oauth_link.id, ), @@ -230,6 +240,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> id_token, Uuid::from(upstream_oauth_authorization_session.id), ) + .traced() .execute(&mut *self.conn) .await?; @@ -242,8 +253,10 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> /// Mark a session as consumed #[tracing::instrument( + name = "db.upstream_oauth_authorization_session.consume", skip_all, fields( + db.statement, %upstream_oauth_authorization_session.id, ), err, @@ -263,6 +276,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> consumed_at, Uuid::from(upstream_oauth_authorization_session.id), ) + .traced() .execute(&mut *self.conn) .await?; diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 83784a56..2f748611 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -22,6 +22,7 @@ use uuid::Uuid; use crate::{ pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -152,8 +153,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { type Error = DatabaseError; #[tracing::instrument( + name = "db.user_email.lookup", skip_all, - fields(user_email.id = %id), + fields( + db.statement, + user_email.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -171,6 +176,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -181,8 +187,13 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.find", skip_all, - fields(%user.id, user_email.email = email), + fields( + db.statement, + %user.id, + user_email.email = email, + ), err, )] async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error> { @@ -201,6 +212,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { Uuid::from(user.id), email, ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -211,8 +223,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.get_primary", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn get_primary(&mut self, user: &User) -> Result, Self::Error> { @@ -228,8 +244,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.all", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn all(&mut self, user: &User) -> Result, Self::Error> { @@ -249,6 +269,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(user.id), ) + .traced() .fetch_all(&mut *self.conn) .await?; @@ -256,8 +277,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.list_paginated", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn list_paginated( @@ -284,7 +309,11 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { .push_bind(Uuid::from(user.id)) .generate_pagination("ue.user_email_id", before, after, first, last)?; - let edges: Vec = query.build_query_as().fetch_all(&mut *self.conn).await?; + let edges: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?; @@ -298,8 +327,12 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.count", skip_all, - fields(%user.id), + fields( + db.statement, + %user.id, + ), err, )] async fn count(&mut self, user: &User) -> Result { @@ -311,6 +344,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(user.id), ) + .traced() .fetch_one(&mut *self.conn) .await?; @@ -322,8 +356,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.add", skip_all, fields( + db.statement, %user.id, user_email.id, user_email.email = email, @@ -351,6 +387,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { &email, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -364,8 +401,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.remove", skip_all, fields( + db.statement, user.id = %user_email.user_id, %user_email.id, %user_email.email, @@ -380,6 +419,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { "#, Uuid::from(user_email.id), ) + .traced() .execute(&mut *self.conn) .await?; @@ -426,8 +466,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.add_verification_code", skip_all, fields( + db.statement, %user_email.id, %user_email.email, user_email_verification.id, @@ -460,6 +502,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { created_at, expires_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -475,8 +518,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.find_verification_code", skip_all, fields( + db.statement, %user_email.id, user.id = %user_email.user_id, ), @@ -504,6 +549,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { code, Uuid::from(user_email.id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -514,8 +560,10 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { } #[tracing::instrument( + name = "db.user_email.consume_verification_code", skip_all, fields( + db.statement, %user_email_verification.id, user_email.id = %user_email_verification.user_email_id, ), @@ -544,6 +592,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { Uuid::from(user_email_verification.id), consumed_at ) + .traced() .execute(&mut *self.conn) .await?; diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 54b3689c..50f71752 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -23,6 +23,7 @@ use uuid::Uuid; use crate::{ pagination::{process_page, QueryBuilderExt}, + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; @@ -88,8 +89,12 @@ impl<'c> UserRepository for PgUserRepository<'c> { type Error = DatabaseError; #[tracing::instrument( + name = "db.user.lookup", skip_all, - fields(user.id = %id), + fields( + db.statement, + user.id = %id, + ), err, )] async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { @@ -105,6 +110,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { "#, Uuid::from(id), ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -115,8 +121,12 @@ impl<'c> UserRepository for PgUserRepository<'c> { } #[tracing::instrument( + name = "db.user.find_by_username", skip_all, - fields(user.username = username), + fields( + db.statement, + user.username = username, + ), err, )] async fn find_by_username(&mut self, username: &str) -> Result, Self::Error> { @@ -132,6 +142,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { "#, username, ) + .traced() .fetch_one(&mut *self.conn) .await .to_option()?; @@ -142,8 +153,10 @@ impl<'c> UserRepository for PgUserRepository<'c> { } #[tracing::instrument( + name = "db.user.add", skip_all, fields( + db.statement, user.username = username, user.id, ), @@ -168,6 +181,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { username, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -180,8 +194,12 @@ impl<'c> UserRepository for PgUserRepository<'c> { } #[tracing::instrument( + name = "db.user.exists", skip_all, - fields(user.username = username), + fields( + db.statement, + user.username = username, + ), err, )] async fn exists(&mut self, username: &str) -> Result { @@ -193,6 +211,7 @@ impl<'c> UserRepository for PgUserRepository<'c> { "#, username ) + .traced() .fetch_one(&mut *self.conn) .await?;