1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-28 11:02:02 +03:00

storage: repository pattern for upstream oauth2 providers

This commit is contained in:
Quentin Gliech
2022-12-30 10:55:37 +01:00
parent 5969b574e2
commit 0faf08fce2
11 changed files with 380 additions and 309 deletions

View File

@ -19,10 +19,11 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, oauth2::client::{insert_client_from_config, lookup_client, truncate_clients},
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{ user::{
add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified,
}, },
Clock, Clock, Repository,
}; };
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use rand::SeedableRng; use rand::SeedableRng;
@ -329,18 +330,19 @@ impl Options {
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
.transpose()?; .transpose()?;
let provider = mas_storage::upstream_oauth2::add_provider( let provider = conn
&mut conn, .upstream_oauth_provider()
&mut rng, .add(
&clock, &mut rng,
issuer.clone(), &clock,
scope.clone(), issuer.clone(),
token_endpoint_auth_method, scope.clone(),
token_endpoint_signing_alg, token_endpoint_auth_method,
client_id.clone(), token_endpoint_signing_alg,
encrypted_client_secret, client_id.clone(),
) encrypted_client_secret,
.await?; )
.await?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id); let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
let auth_uri = url_builder.upstream_oauth_authorize(provider.id); let auth_uri = url_builder.upstream_oauth_authorize(provider.id);

View File

@ -30,7 +30,9 @@ use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor}, connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, EmptyMutation, EmptySubscription, ID, Context, Description, EmptyMutation, EmptySubscription, ID,
}; };
use mas_storage::{Repository, UpstreamOAuthLinkRepository}; use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository,
};
use model::CreationEvent; use model::CreationEvent;
use sqlx::PgPool; use sqlx::PgPool;
@ -190,7 +192,7 @@ impl RootQuery {
let database = ctx.data::<PgPool>()?; let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?; let mut conn = database.acquire().await?;
let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?; let provider = conn.upstream_oauth_provider().lookup(id).await?;
Ok(provider.map(UpstreamOAuth2Provider::new)) Ok(provider.map(UpstreamOAuth2Provider::new))
} }
@ -227,14 +229,13 @@ impl RootQuery {
}) })
.transpose()?; .transpose()?;
let (has_previous_page, has_next_page, edges) = let page = conn
mas_storage::upstream_oauth2::get_paginated_providers( .upstream_oauth_provider()
&mut conn, before_id, after_id, first, last, .list_paginated(before_id, after_id, first, last)
)
.await?; .await?;
let mut connection = Connection::new(has_previous_page, has_next_page); let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
connection.edges.extend(edges.into_iter().map(|p| { connection.edges.extend(page.edges.into_iter().map(|p| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)),
UpstreamOAuth2Provider::new(p), UpstreamOAuth2Provider::new(p),

View File

@ -15,6 +15,7 @@
use anyhow::Context as _; use anyhow::Context as _;
use async_graphql::{Context, Object, ID}; use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository};
use sqlx::PgPool; use sqlx::PgPool;
use super::{NodeType, User}; use super::{NodeType, User};
@ -101,7 +102,8 @@ impl UpstreamOAuth2Link {
// Fetch on-the-fly // Fetch on-the-fly
let database = ctx.data::<PgPool>()?; let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?; let mut conn = database.acquire().await?;
mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id) conn.upstream_oauth_provider()
.lookup(self.link.provider_id)
.await? .await?
.context("Upstream OAuth 2.0 provider not found")? .context("Upstream OAuth 2.0 provider not found")?
}; };

View File

@ -22,7 +22,7 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::upstream_oauth2::lookup_provider; use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository};
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
@ -66,7 +66,9 @@ pub(crate) async fn get(
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let provider = lookup_provider(&mut txn, provider_id) let provider = txn
.upstream_oauth_provider()
.lookup(provider_id)
.await? .await?
.ok_or(RouteError::ProviderNotFound)?; .ok_or(RouteError::ProviderNotFound)?;

View File

@ -24,11 +24,12 @@ use mas_axum_utils::{
use mas_data_model::BrowserSession; use mas_data_model::BrowserSession;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{ user::{
add_user_password, authenticate_session_with_password, lookup_user_by_username, add_user_password, authenticate_session_with_password, lookup_user_by_username,
lookup_user_password, start_session, lookup_user_password, start_session,
}, },
Clock, Clock, Repository,
}; };
use mas_templates::{ use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
@ -69,7 +70,7 @@ pub(crate) async fn get(
let reply = query.go_next(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} else { } else {
let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; let providers = conn.upstream_oauth_provider().all().await?;
let content = render( let content = render(
LoginContext::default().with_upstrem_providers(providers), LoginContext::default().with_upstrem_providers(providers),
query, query,
@ -114,7 +115,7 @@ pub(crate) async fn post(
}; };
if !state.is_valid() { if !state.is_valid() {
let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; let providers = conn.upstream_oauth_provider().all().await?;
let content = render( let content = render(
LoginContext::default() LoginContext::default()
.with_form_state(state) .with_form_state(state)

View File

@ -15,8 +15,8 @@
use anyhow::Context; use anyhow::Context;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, Repository, compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
UpstreamOAuthLinkRepository, upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository,
}; };
use mas_templates::{PostAuthContext, PostAuthContextInner}; use mas_templates::{PostAuthContext, PostAuthContextInner};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -70,10 +70,11 @@ impl OptionalPostAuthAction {
.await? .await?
.context("Failed to load upstream OAuth 2.0 link")?; .context("Failed to load upstream OAuth 2.0 link")?;
let provider = let provider = conn
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) .upstream_oauth_provider()
.await? .lookup(link.provider_id)
.context("Failed to load upstream OAuth 2.0 provider")?; .await?
.context("Failed to load upstream OAuth 2.0 provider")?;
let provider = Box::new(provider); let provider = Box::new(provider);
let link = Box::new(link); let link = Box::new(link);

View File

@ -116,68 +116,6 @@
}, },
"query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n "
}, },
"0af182315b36766eca8e232280986bade0202d1b1d64160a99cd14eadcbfc25b": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n "
},
"0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": { "0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": {
"describe": { "describe": {
"columns": [ "columns": [
@ -241,6 +179,66 @@
}, },
"query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n "
}, },
"154e2e4488ff87e09163698750b56a43127cee4e1392785416a586d40a4d9b21": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false
],
"parameters": {
"Left": []
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n "
},
"1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -2089,6 +2087,68 @@
}, },
"query": "\n SELECT COUNT(*)\n FROM user_emails ue\n WHERE ue.user_id = $1\n " "query": "\n SELECT COUNT(*)\n FROM user_emails ue\n WHERE ue.user_id = $1\n "
}, },
"8f7a9fb1f24c24f8dbc3c193df2a742c9ac730ab958587b67297de2d4b843863": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n "
},
"99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -2586,66 +2646,6 @@
}, },
"query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n "
}, },
"cf00e0ad529bcb5c0640adcfe0880a3560d9739f355b90ca3ba88dd1eaf26565": {
"describe": {
"columns": [
{
"name": "upstream_oauth_provider_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "issuer",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "scope",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "client_id",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "encrypted_client_secret",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "token_endpoint_signing_alg",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
false,
false
],
"parameters": {
"Left": []
}
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n "
},
"d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": { "d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": {
"describe": { "describe": {
"columns": [], "columns": [],

View File

@ -14,28 +14,43 @@
use sqlx::{PgConnection, Postgres, Transaction}; use sqlx::{PgConnection, Postgres, Transaction};
use crate::upstream_oauth2::PgUpstreamOAuthLinkRepository; use crate::upstream_oauth2::{PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository};
pub trait Repository { pub trait Repository {
type UpstreamOAuthLinkRepository<'c> type UpstreamOAuthLinkRepository<'c>
where where
Self: 'c; Self: 'c;
type UpstreamOAuthProviderRepository<'c>
where
Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
} }
impl Repository for PgConnection { impl Repository for PgConnection {
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self) PgUpstreamOAuthLinkRepository::new(self)
} }
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self)
}
} }
impl<'t> Repository for Transaction<'t, Postgres> { impl<'t> Repository for Transaction<'t, Postgres> {
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self) PgUpstreamOAuthLinkRepository::new(self)
} }
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self)
}
} }

View File

@ -56,7 +56,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
user: &User, user: &User,
) -> Result<(), Self::Error>; ) -> Result<(), Self::Error>;
/// Get a paginated list of upstream OAuth links /// Get a paginated list of upstream OAuth links on a user
async fn list_paginated( async fn list_paginated(
&mut self, &mut self,
user: &User, user: &User,

View File

@ -18,7 +18,7 @@ mod session;
pub use self::{ pub use self::{
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
session::{ session::{
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
}, },

View File

@ -12,21 +12,66 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::UpstreamOAuthProvider; use mas_data_model::UpstreamOAuthProvider;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use rand::Rng; use rand::RngCore;
use sqlx::{PgExecutor, QueryBuilder}; use sqlx::{PgConnection, QueryBuilder};
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
pagination::{process_page, QueryBuilderExt}, pagination::{process_page, Page, QueryBuilderExt},
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
}; };
#[async_trait]
pub trait UpstreamOAuthProviderRepository: Send + Sync {
type Error;
/// Lookup an upstream OAuth provider by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
/// Add a new upstream OAuth provider
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Get a paginated list of upstream OAuth providers
async fn list_paginated(
&mut self,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
/// Get all upstream OAuth providers
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
}
pub struct PgUpstreamOAuthProviderRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthProviderRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)] #[derive(sqlx::FromRow)]
struct ProviderLookup { struct ProviderLookup {
upstream_oauth_provider_id: Uuid, upstream_oauth_provider_id: Uuid,
@ -79,71 +124,72 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
} }
} }
#[tracing::instrument( #[async_trait]
skip_all, impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
fields(upstream_oauth_provider.id = %id), type Error = DatabaseError;
err,
)]
pub async fn lookup_provider(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<UpstreamOAuthProvider>, DatabaseError> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.to_option()?;
let res = res #[tracing::instrument(
.map(UpstreamOAuthProvider::try_from) skip_all,
.transpose() fields(upstream_oauth_provider.id = %id),
.map_err(DatabaseError::from)?; err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
Ok(res) let res = res
} .map(UpstreamOAuthProvider::try_from)
.transpose()
.map_err(DatabaseError::from)?;
#[tracing::instrument( Ok(res)
skip_all, }
fields(
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
pub async fn add_provider(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
sqlx::query!( #[tracing::instrument(
r#" skip_all,
fields(
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_providers ( INSERT INTO upstream_oauth_providers (
upstream_oauth_provider_id, upstream_oauth_provider_id,
issuer, issuer,
@ -155,94 +201,95 @@ pub async fn add_provider(
created_at created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#, "#,
Uuid::from(id), Uuid::from(id),
&issuer, &issuer,
scope.to_string(), scope.to_string(),
token_endpoint_auth_method.to_string(), token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string), token_endpoint_signing_alg.as_ref().map(ToString::to_string),
&client_id, &client_id,
encrypted_client_secret.as_deref(), encrypted_client_secret.as_deref(),
created_at, created_at,
) )
.execute(executor) .execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
})
}
#[tracing::instrument(skip_all, err)]
pub async fn get_paginated_providers(
executor: impl PgExecutor<'_>,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<UpstreamOAuthProvider>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE 1 = 1
"#,
);
query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?;
let span = info_span!(
"Fetch paginated upstream OAuth 2.0 providers",
db.statement = query.sql()
);
let page: Vec<ProviderLookup> = query
.build_query_as()
.fetch_all(executor)
.instrument(span)
.await?; .await?;
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
})
}
let page: Result<Vec<_>, _> = page.into_iter().map(TryInto::try_into).collect(); async fn list_paginated(
Ok((has_previous_page, has_next_page, page?)) &mut self,
} before: Option<Ulid>,
after: Option<Ulid>,
#[tracing::instrument(skip_all, err)] first: Option<usize>,
pub async fn get_providers( last: Option<usize>,
executor: impl PgExecutor<'_>, ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
) -> Result<Vec<UpstreamOAuthProvider>, DatabaseError> { let mut query = QueryBuilder::new(
let res = sqlx::query_as!( r#"
ProviderLookup, SELECT
r#" upstream_oauth_provider_id,
SELECT issuer,
upstream_oauth_provider_id, scope,
issuer, client_id,
scope, encrypted_client_secret,
client_id, token_endpoint_signing_alg,
encrypted_client_secret, token_endpoint_auth_method,
token_endpoint_signing_alg, created_at
token_endpoint_auth_method, FROM upstream_oauth_providers
created_at WHERE 1 = 1
FROM upstream_oauth_providers "#,
"#, );
)
.fetch_all(executor) query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?;
.await?;
let span = info_span!(
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect(); "Fetch paginated upstream OAuth 2.0 providers",
Ok(res?) db.statement = query.sql()
);
let page: Vec<ProviderLookup> = query
.build_query_as()
.fetch_all(&mut *self.conn)
.instrument(span)
.await?;
let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?;
let edges: Result<Vec<_>, _> = edges.into_iter().map(TryInto::try_into).collect();
Ok(Page {
has_next_page,
has_previous_page,
edges: edges?,
})
}
#[tracing::instrument(skip_all, err)]
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
"#,
)
.fetch_all(&mut *self.conn)
.await?;
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
Ok(res?)
}
} }