1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: repository pattern for upstream oauth2 providers

This commit is contained in:
Quentin Gliech
2022-12-30 10:55:37 +01:00
parent 5969b574e2
commit 0faf08fce2
11 changed files with 380 additions and 309 deletions

View File

@ -12,21 +12,66 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::UpstreamOAuthProvider;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use rand::Rng;
use sqlx::{PgExecutor, QueryBuilder};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::{process_page, QueryBuilderExt},
pagination::{process_page, Page, QueryBuilderExt},
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[async_trait]
pub trait UpstreamOAuthProviderRepository: Send + Sync {
type Error;
/// Lookup an upstream OAuth provider by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
/// Add a new upstream OAuth provider
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Get a paginated list of upstream OAuth providers
async fn list_paginated(
&mut self,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
/// Get all upstream OAuth providers
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
}
pub struct PgUpstreamOAuthProviderRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgUpstreamOAuthProviderRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct ProviderLookup {
upstream_oauth_provider_id: Uuid,
@ -79,71 +124,72 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
}
}
#[tracing::instrument(
skip_all,
fields(upstream_oauth_provider.id = %id),
err,
)]
pub async fn lookup_provider(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<UpstreamOAuthProvider>, DatabaseError> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.to_option()?;
#[async_trait]
impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
type Error = DatabaseError;
let res = res
.map(UpstreamOAuthProvider::try_from)
.transpose()
.map_err(DatabaseError::from)?;
#[tracing::instrument(
skip_all,
fields(upstream_oauth_provider.id = %id),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
Ok(res)
}
let res = res
.map(UpstreamOAuthProvider::try_from)
.transpose()
.map_err(DatabaseError::from)?;
#[tracing::instrument(
skip_all,
fields(
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
pub async fn add_provider(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
Ok(res)
}
sqlx::query!(
r#"
#[tracing::instrument(
skip_all,
fields(
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO upstream_oauth_providers (
upstream_oauth_provider_id,
issuer,
@ -155,94 +201,95 @@ pub async fn add_provider(
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#,
Uuid::from(id),
&issuer,
scope.to_string(),
token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
&client_id,
encrypted_client_secret.as_deref(),
created_at,
)
.execute(executor)
.await?;
Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
})
}
#[tracing::instrument(skip_all, err)]
pub async fn get_paginated_providers(
executor: impl PgExecutor<'_>,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<UpstreamOAuthProvider>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE 1 = 1
"#,
);
query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?;
let span = info_span!(
"Fetch paginated upstream OAuth 2.0 providers",
db.statement = query.sql()
);
let page: Vec<ProviderLookup> = query
.build_query_as()
.fetch_all(executor)
.instrument(span)
Uuid::from(id),
&issuer,
scope.to_string(),
token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
&client_id,
encrypted_client_secret.as_deref(),
created_at,
)
.execute(&mut *self.conn)
.await?;
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
})
}
let page: Result<Vec<_>, _> = page.into_iter().map(TryInto::try_into).collect();
Ok((has_previous_page, has_next_page, page?))
}
#[tracing::instrument(skip_all, err)]
pub async fn get_providers(
executor: impl PgExecutor<'_>,
) -> Result<Vec<UpstreamOAuthProvider>, DatabaseError> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
"#,
)
.fetch_all(executor)
.await?;
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
Ok(res?)
async fn list_paginated(
&mut self,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE 1 = 1
"#,
);
query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?;
let span = info_span!(
"Fetch paginated upstream OAuth 2.0 providers",
db.statement = query.sql()
);
let page: Vec<ProviderLookup> = query
.build_query_as()
.fetch_all(&mut *self.conn)
.instrument(span)
.await?;
let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?;
let edges: Result<Vec<_>, _> = edges.into_iter().map(TryInto::try_into).collect();
Ok(Page {
has_next_page,
has_previous_page,
edges: edges?,
})
}
#[tracing::instrument(skip_all, err)]
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
"#,
)
.fetch_all(&mut *self.conn)
.await?;
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
Ok(res?)
}
}