You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-12-03 22:51:11 +03:00
Better upstream OAuth provider pagination and filtering
This commit is contained in:
@@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use async_graphql::Interface;
|
use async_graphql::{Interface, Object};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
|
||||||
mod browser_sessions;
|
mod browser_sessions;
|
||||||
@@ -52,3 +52,14 @@ pub enum CreationEvent {
|
|||||||
UpstreamOAuth2Link(Box<UpstreamOAuth2Link>),
|
UpstreamOAuth2Link(Box<UpstreamOAuth2Link>),
|
||||||
OAuth2Session(Box<OAuth2Session>),
|
OAuth2Session(Box<OAuth2Session>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct PreloadedTotalCount(pub Option<usize>);
|
||||||
|
|
||||||
|
#[Object]
|
||||||
|
impl PreloadedTotalCount {
|
||||||
|
/// Identifies the total count of items in the connection.
|
||||||
|
async fn total_count(&self) -> Result<usize, async_graphql::Error> {
|
||||||
|
self.0
|
||||||
|
.ok_or_else(|| async_graphql::Error::new("total count not preloaded"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,19 +26,14 @@ use mas_storage::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session,
|
browser_sessions::BrowserSessionState,
|
||||||
UpstreamOAuth2Link,
|
compat_sessions::{CompatSessionState, CompatSessionType, CompatSsoLogin},
|
||||||
};
|
matrix::MatrixUser,
|
||||||
use crate::{
|
oauth::OAuth2SessionState,
|
||||||
model::{
|
BrowserSession, CompatSession, Cursor, NodeCursor, NodeType, OAuth2Session,
|
||||||
browser_sessions::BrowserSessionState,
|
PreloadedTotalCount, UpstreamOAuth2Link,
|
||||||
compat_sessions::{CompatSessionState, CompatSessionType},
|
|
||||||
matrix::MatrixUser,
|
|
||||||
oauth::OAuth2SessionState,
|
|
||||||
CompatSession,
|
|
||||||
},
|
|
||||||
state::ContextExt,
|
|
||||||
};
|
};
|
||||||
|
use crate::state::ContextExt;
|
||||||
|
|
||||||
#[derive(Description)]
|
#[derive(Description)]
|
||||||
/// A user is an individual's account.
|
/// A user is an individual's account.
|
||||||
@@ -511,17 +506,6 @@ impl User {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PreloadedTotalCount(Option<usize>);
|
|
||||||
|
|
||||||
#[Object]
|
|
||||||
impl PreloadedTotalCount {
|
|
||||||
/// Identifies the total count of items in the connection.
|
|
||||||
async fn total_count(&self) -> Result<usize, async_graphql::Error> {
|
|
||||||
self.0
|
|
||||||
.ok_or_else(|| async_graphql::Error::new("total count not preloaded"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A user email address
|
/// A user email address
|
||||||
#[derive(Description)]
|
#[derive(Description)]
|
||||||
pub struct UserEmail(pub mas_data_model::UserEmail);
|
pub struct UserEmail(pub mas_data_model::UserEmail);
|
||||||
|
|||||||
@@ -16,10 +16,13 @@ use async_graphql::{
|
|||||||
connection::{query, Connection, Edge, OpaqueCursor},
|
connection::{query, Connection, Edge, OpaqueCursor},
|
||||||
Context, Object, ID,
|
Context, Object, ID,
|
||||||
};
|
};
|
||||||
use mas_storage::Pagination;
|
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderFilter, Pagination, RepositoryAccess};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
model::{Cursor, NodeCursor, NodeType, UpstreamOAuth2Link, UpstreamOAuth2Provider},
|
model::{
|
||||||
|
Cursor, NodeCursor, NodeType, PreloadedTotalCount, UpstreamOAuth2Link,
|
||||||
|
UpstreamOAuth2Provider,
|
||||||
|
},
|
||||||
state::ContextExt,
|
state::ContextExt,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -78,7 +81,8 @@ impl UpstreamOAuthQuery {
|
|||||||
before: Option<String>,
|
before: Option<String>,
|
||||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||||
) -> Result<Connection<Cursor, UpstreamOAuth2Provider>, async_graphql::Error> {
|
) -> Result<Connection<Cursor, UpstreamOAuth2Provider, PreloadedTotalCount>, async_graphql::Error>
|
||||||
|
{
|
||||||
let state = ctx.state();
|
let state = ctx.state();
|
||||||
let mut repo = state.repository().await?;
|
let mut repo = state.repository().await?;
|
||||||
|
|
||||||
@@ -100,14 +104,27 @@ impl UpstreamOAuthQuery {
|
|||||||
.transpose()?;
|
.transpose()?;
|
||||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
|
let filter = UpstreamOAuthProviderFilter::new();
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.list_paginated(pagination)
|
.list(filter, pagination)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Preload the total count if requested
|
||||||
|
let count = if ctx.look_ahead().field("totalCount").exists() {
|
||||||
|
Some(repo.upstream_oauth_provider().count(filter).await?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
repo.cancel().await?;
|
repo.cancel().await?;
|
||||||
|
|
||||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
let mut connection = Connection::with_additional_fields(
|
||||||
|
page.has_previous_page,
|
||||||
|
page.has_next_page,
|
||||||
|
PreloadedTotalCount(count),
|
||||||
|
);
|
||||||
connection.edges.extend(page.edges.into_iter().map(|p| {
|
connection.edges.extend(page.edges.into_iter().map(|p| {
|
||||||
Edge::new(
|
Edge::new(
|
||||||
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)),
|
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)),
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"db_name": "PostgreSQL",
|
|
||||||
"query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n ",
|
|
||||||
"describe": {
|
|
||||||
"columns": [
|
|
||||||
{
|
|
||||||
"ordinal": 0,
|
|
||||||
"name": "count",
|
|
||||||
"type_info": "Int8"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"parameters": {
|
|
||||||
"Left": [
|
|
||||||
"Uuid"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"nullable": [
|
|
||||||
null
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"hash": "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c"
|
|
||||||
}
|
|
||||||
@@ -77,3 +77,19 @@ pub enum OAuth2Sessions {
|
|||||||
CreatedAt,
|
CreatedAt,
|
||||||
FinishedAt,
|
FinishedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(sea_query::Iden)]
|
||||||
|
#[iden = "upstream_oauth_providers"]
|
||||||
|
pub enum UpstreamOAuthProviders {
|
||||||
|
Table,
|
||||||
|
#[iden = "upstream_oauth_provider_id"]
|
||||||
|
UpstreamOAuthProviderId,
|
||||||
|
Issuer,
|
||||||
|
Scope,
|
||||||
|
ClientId,
|
||||||
|
EncryptedClientSecret,
|
||||||
|
TokenEndpointSigningAlg,
|
||||||
|
TokenEndpointAuthMethod,
|
||||||
|
CreatedAt,
|
||||||
|
ClaimsImports,
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ mod tests {
|
|||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
clock::MockClock,
|
clock::MockClock,
|
||||||
upstream_oauth2::{
|
upstream_oauth2::{
|
||||||
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
|
UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
|
||||||
UpstreamOAuthSessionRepository,
|
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository,
|
||||||
},
|
},
|
||||||
user::UserRepository,
|
user::UserRepository,
|
||||||
Pagination, RepositoryAccess,
|
Pagination, RepositoryAccess,
|
||||||
@@ -208,6 +208,14 @@ mod tests {
|
|||||||
let clock = MockClock::default();
|
let clock = MockClock::default();
|
||||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||||
|
|
||||||
|
let filter = UpstreamOAuthProviderFilter::new();
|
||||||
|
|
||||||
|
// Count the number of providers before we start
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider().count(filter).await.unwrap(),
|
||||||
|
0
|
||||||
|
);
|
||||||
|
|
||||||
let mut ids = Vec::with_capacity(20);
|
let mut ids = Vec::with_capacity(20);
|
||||||
// Create 20 providers
|
// Create 20 providers
|
||||||
for idx in 0..20 {
|
for idx in 0..20 {
|
||||||
@@ -231,10 +239,16 @@ mod tests {
|
|||||||
clock.advance(Duration::seconds(10));
|
clock.advance(Duration::seconds(10));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now we have 20 providers
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider().count(filter).await.unwrap(),
|
||||||
|
20
|
||||||
|
);
|
||||||
|
|
||||||
// Lookup the first 10 items
|
// Lookup the first 10 items
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.list_paginated(Pagination::first(10))
|
.list(filter, Pagination::first(10))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -246,7 +260,7 @@ mod tests {
|
|||||||
// Lookup the next 10 items
|
// Lookup the next 10 items
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.list_paginated(Pagination::first(10).after(ids[9]))
|
.list(filter, Pagination::first(10).after(ids[9]))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -258,7 +272,7 @@ mod tests {
|
|||||||
// Lookup the last 10 items
|
// Lookup the last 10 items
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.list_paginated(Pagination::last(10))
|
.list(filter, Pagination::last(10))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -270,7 +284,7 @@ mod tests {
|
|||||||
// Lookup the previous 10 items
|
// Lookup the previous 10 items
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.list_paginated(Pagination::last(10).before(ids[10]))
|
.list(filter, Pagination::last(10).before(ids[10]))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -282,7 +296,7 @@ mod tests {
|
|||||||
// Lookup 10 items between two IDs
|
// Lookup 10 items between two IDs
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.list_paginated(Pagination::first(10).after(ids[5]).before(ids[8]))
|
.list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -16,16 +16,21 @@ use async_trait::async_trait;
|
|||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
|
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
|
||||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination};
|
use mas_storage::{
|
||||||
|
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
|
||||||
|
Clock, Page, Pagination,
|
||||||
|
};
|
||||||
use oauth2_types::scope::Scope;
|
use oauth2_types::scope::Scope;
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
use sqlx::{types::Json, PgConnection, QueryBuilder};
|
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
|
||||||
|
use sqlx::{types::Json, PgConnection};
|
||||||
use tracing::{info_span, Instrument};
|
use tracing::{info_span, Instrument};
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
|
iden::UpstreamOAuthProviders, pagination::QueryBuilderExt, sea_query_sqlx::map_values,
|
||||||
|
tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
|
/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
|
||||||
@@ -43,6 +48,7 @@ impl<'c> PgUpstreamOAuthProviderRepository<'c> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(sqlx::FromRow)]
|
#[derive(sqlx::FromRow)]
|
||||||
|
#[enum_def]
|
||||||
struct ProviderLookup {
|
struct ProviderLookup {
|
||||||
upstream_oauth_provider_id: Uuid,
|
upstream_oauth_provider_id: Uuid,
|
||||||
issuer: String,
|
issuer: String,
|
||||||
@@ -209,6 +215,72 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.upstream_oauth_provider.delete_by_id",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
upstream_oauth_provider.id = %id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
|
||||||
|
// Delete the authorization sessions first, as they have a foreign key
|
||||||
|
// constraint on the links and the providers.
|
||||||
|
{
|
||||||
|
let span = info_span!(
|
||||||
|
"db.oauth2_client.delete_by_id.authorization_sessions",
|
||||||
|
upstream_oauth_provider.id = %id,
|
||||||
|
db.statement = tracing::field::Empty,
|
||||||
|
);
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
DELETE FROM upstream_oauth_authorization_sessions
|
||||||
|
WHERE upstream_oauth_provider_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.record(&span)
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.instrument(span)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the links next, as they have a foreign key constraint on the
|
||||||
|
// providers.
|
||||||
|
{
|
||||||
|
let span = info_span!(
|
||||||
|
"db.oauth2_client.delete_by_id.links",
|
||||||
|
upstream_oauth_provider.id = %id,
|
||||||
|
db.statement = tracing::field::Empty,
|
||||||
|
);
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
DELETE FROM upstream_oauth_links
|
||||||
|
WHERE upstream_oauth_provider_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.record(&span)
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.instrument(span)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = sqlx::query!(
|
||||||
|
r#"
|
||||||
|
DELETE FROM upstream_oauth_providers
|
||||||
|
WHERE upstream_oauth_provider_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.traced()
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
DatabaseError::ensure_affected_rows(&res, 1)
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
name = "db.upstream_oauth_provider.add",
|
name = "db.upstream_oauth_provider.add",
|
||||||
skip_all,
|
skip_all,
|
||||||
@@ -288,110 +360,139 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
name = "db.upstream_oauth_provider.delete_by_id",
|
name = "db.upstream_oauth_provider.list",
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
db.statement,
|
|
||||||
upstream_oauth_provider.id = %id,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
|
|
||||||
// Delete the authorization sessions first, as they have a foreign key
|
|
||||||
// constraint on the links and the providers.
|
|
||||||
{
|
|
||||||
let span = info_span!(
|
|
||||||
"db.oauth2_client.delete_by_id.authorization_sessions",
|
|
||||||
upstream_oauth_provider.id = %id,
|
|
||||||
db.statement = tracing::field::Empty,
|
|
||||||
);
|
|
||||||
sqlx::query!(
|
|
||||||
r#"
|
|
||||||
DELETE FROM upstream_oauth_authorization_sessions
|
|
||||||
WHERE upstream_oauth_provider_id = $1
|
|
||||||
"#,
|
|
||||||
Uuid::from(id),
|
|
||||||
)
|
|
||||||
.record(&span)
|
|
||||||
.execute(&mut *self.conn)
|
|
||||||
.instrument(span)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the links next, as they have a foreign key constraint on the
|
|
||||||
// providers.
|
|
||||||
{
|
|
||||||
let span = info_span!(
|
|
||||||
"db.oauth2_client.delete_by_id.links",
|
|
||||||
upstream_oauth_provider.id = %id,
|
|
||||||
db.statement = tracing::field::Empty,
|
|
||||||
);
|
|
||||||
sqlx::query!(
|
|
||||||
r#"
|
|
||||||
DELETE FROM upstream_oauth_links
|
|
||||||
WHERE upstream_oauth_provider_id = $1
|
|
||||||
"#,
|
|
||||||
Uuid::from(id),
|
|
||||||
)
|
|
||||||
.record(&span)
|
|
||||||
.execute(&mut *self.conn)
|
|
||||||
.instrument(span)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let res = sqlx::query!(
|
|
||||||
r#"
|
|
||||||
DELETE FROM upstream_oauth_providers
|
|
||||||
WHERE upstream_oauth_provider_id = $1
|
|
||||||
"#,
|
|
||||||
Uuid::from(id),
|
|
||||||
)
|
|
||||||
.traced()
|
|
||||||
.execute(&mut *self.conn)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
DatabaseError::ensure_affected_rows(&res, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(
|
|
||||||
name = "db.upstream_oauth_provider.list_paginated",
|
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
db.statement,
|
db.statement,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
async fn list_paginated(
|
async fn list(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
_filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
pagination: Pagination,
|
pagination: Pagination,
|
||||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
||||||
let mut query = QueryBuilder::new(
|
// XXX: the filter is currently ignored, as it does not have any fields
|
||||||
r#"
|
let (sql, values) = Query::select()
|
||||||
SELECT
|
.expr_as(
|
||||||
upstream_oauth_provider_id,
|
Expr::col((
|
||||||
issuer,
|
UpstreamOAuthProviders::Table,
|
||||||
scope,
|
UpstreamOAuthProviders::UpstreamOAuthProviderId,
|
||||||
client_id,
|
)),
|
||||||
encrypted_client_secret,
|
ProviderLookupIden::UpstreamOauthProviderId,
|
||||||
token_endpoint_signing_alg,
|
)
|
||||||
token_endpoint_auth_method,
|
.expr_as(
|
||||||
created_at,
|
Expr::col((
|
||||||
claims_imports
|
UpstreamOAuthProviders::Table,
|
||||||
FROM upstream_oauth_providers
|
UpstreamOAuthProviders::Issuer,
|
||||||
WHERE 1 = 1
|
)),
|
||||||
"#,
|
ProviderLookupIden::Issuer,
|
||||||
);
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
|
||||||
|
ProviderLookupIden::Scope,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::ClientId,
|
||||||
|
)),
|
||||||
|
ProviderLookupIden::ClientId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::EncryptedClientSecret,
|
||||||
|
)),
|
||||||
|
ProviderLookupIden::EncryptedClientSecret,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::TokenEndpointSigningAlg,
|
||||||
|
)),
|
||||||
|
ProviderLookupIden::TokenEndpointSigningAlg,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::TokenEndpointAuthMethod,
|
||||||
|
)),
|
||||||
|
ProviderLookupIden::TokenEndpointAuthMethod,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::CreatedAt,
|
||||||
|
)),
|
||||||
|
ProviderLookupIden::CreatedAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::ClaimsImports,
|
||||||
|
)),
|
||||||
|
ProviderLookupIden::ClaimsImports,
|
||||||
|
)
|
||||||
|
.from(UpstreamOAuthProviders::Table)
|
||||||
|
.generate_pagination(
|
||||||
|
(
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::UpstreamOAuthProviderId,
|
||||||
|
)
|
||||||
|
.into_column_ref(),
|
||||||
|
pagination,
|
||||||
|
)
|
||||||
|
.build(PostgresQueryBuilder);
|
||||||
|
|
||||||
query.generate_pagination("upstream_oauth_provider_id", pagination);
|
let arguments = map_values(values);
|
||||||
|
|
||||||
let edges: Vec<ProviderLookup> = query
|
let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
|
||||||
.build_query_as()
|
|
||||||
.traced()
|
.traced()
|
||||||
.fetch_all(&mut *self.conn)
|
.fetch_all(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let page = pagination.process(edges).try_map(TryInto::try_into)?;
|
let page = pagination
|
||||||
Ok(page)
|
.process(edges)
|
||||||
|
.try_map(UpstreamOAuthProvider::try_from)?;
|
||||||
|
|
||||||
|
return Ok(page);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.upstream_oauth_provider.count",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn count(
|
||||||
|
&mut self,
|
||||||
|
_filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
|
) -> Result<usize, Self::Error> {
|
||||||
|
// XXX: the filter is currently ignored, as it does not have any fields
|
||||||
|
let (sql, values) = Query::select()
|
||||||
|
.expr(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::UpstreamOAuthProviderId,
|
||||||
|
))
|
||||||
|
.count(),
|
||||||
|
)
|
||||||
|
.from(UpstreamOAuthProviders::Table)
|
||||||
|
.build(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
let arguments = map_values(values);
|
||||||
|
|
||||||
|
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||||
|
.traced()
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
count
|
||||||
|
.try_into()
|
||||||
|
.map_err(DatabaseError::to_invalid_operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ mod provider;
|
|||||||
mod session;
|
mod session;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository,
|
link::UpstreamOAuthLinkRepository,
|
||||||
|
provider::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
|
||||||
session::UpstreamOAuthSessionRepository,
|
session::UpstreamOAuthSessionRepository,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
|
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
|
||||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||||
@@ -21,6 +23,20 @@ use ulid::Ulid;
|
|||||||
|
|
||||||
use crate::{pagination::Page, repository_impl, Clock, Pagination};
|
use crate::{pagination::Page, repository_impl, Clock, Pagination};
|
||||||
|
|
||||||
|
/// Filter parameters for listing upstream OAuth 2.0 providers
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||||
|
pub struct UpstreamOAuthProviderFilter<'a> {
|
||||||
|
_lifetime: PhantomData<&'a ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> UpstreamOAuthProviderFilter<'a> {
|
||||||
|
/// Create a new [`UpstreamOAuthProviderFilter`] with default values
|
||||||
|
#[must_use]
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// An [`UpstreamOAuthProviderRepository`] helps interacting with
|
/// An [`UpstreamOAuthProviderRepository`] helps interacting with
|
||||||
/// [`UpstreamOAuthProvider`] saved in the storage backend
|
/// [`UpstreamOAuthProvider`] saved in the storage backend
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -137,20 +153,36 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
|||||||
claims_imports: UpstreamOAuthProviderClaimsImports,
|
claims_imports: UpstreamOAuthProviderClaimsImports,
|
||||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||||
|
|
||||||
/// Get a paginated list of upstream OAuth providers
|
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
|
||||||
///
|
///
|
||||||
/// # Parameters
|
/// # Parameters
|
||||||
///
|
///
|
||||||
|
/// * `filter`: The filter to apply
|
||||||
/// * `pagination`: The pagination parameters
|
/// * `pagination`: The pagination parameters
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns [`Self::Error`] if the underlying repository fails
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
async fn list_paginated(
|
async fn list(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
pagination: Pagination,
|
pagination: Pagination,
|
||||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
|
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
|
||||||
|
|
||||||
|
/// Count the number of [`UpstreamOAuthProvider`] with the given filter
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// * `filter`: The filter to apply
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
async fn count(
|
||||||
|
&mut self,
|
||||||
|
filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
|
) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
/// Get all upstream OAuth providers
|
/// Get all upstream OAuth providers
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
@@ -192,10 +224,16 @@ repository_impl!(UpstreamOAuthProviderRepository:
|
|||||||
|
|
||||||
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
|
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
|
||||||
|
|
||||||
async fn list_paginated(
|
async fn list(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
pagination: Pagination
|
pagination: Pagination
|
||||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
|
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
|
||||||
|
|
||||||
|
async fn count(
|
||||||
|
&mut self,
|
||||||
|
filter: UpstreamOAuthProviderFilter<'_>
|
||||||
|
) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -880,6 +880,10 @@ type UpstreamOAuth2ProviderConnection {
|
|||||||
A list of nodes.
|
A list of nodes.
|
||||||
"""
|
"""
|
||||||
nodes: [UpstreamOAuth2Provider!]!
|
nodes: [UpstreamOAuth2Provider!]!
|
||||||
|
"""
|
||||||
|
Identifies the total count of items in the connection.
|
||||||
|
"""
|
||||||
|
totalCount: Int!
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -668,6 +668,8 @@ export type UpstreamOAuth2ProviderConnection = {
|
|||||||
nodes: Array<UpstreamOAuth2Provider>;
|
nodes: Array<UpstreamOAuth2Provider>;
|
||||||
/** Information to aid in pagination. */
|
/** Information to aid in pagination. */
|
||||||
pageInfo: PageInfo;
|
pageInfo: PageInfo;
|
||||||
|
/** Identifies the total count of items in the connection. */
|
||||||
|
totalCount: Scalars["Int"]["output"];
|
||||||
};
|
};
|
||||||
|
|
||||||
/** An edge in a connection. */
|
/** An edge in a connection. */
|
||||||
|
|||||||
@@ -1938,6 +1938,17 @@ export default {
|
|||||||
},
|
},
|
||||||
args: [],
|
args: [],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "totalCount",
|
||||||
|
type: {
|
||||||
|
kind: "NON_NULL",
|
||||||
|
ofType: {
|
||||||
|
kind: "SCALAR",
|
||||||
|
name: "Any",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: [],
|
||||||
|
},
|
||||||
],
|
],
|
||||||
interfaces: [],
|
interfaces: [],
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user