1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

GraphQL: query upstream links from users

This commit is contained in:
Quentin Gliech
2022-12-05 19:09:45 +01:00
parent 23fd833d45
commit 1655080b8f
5 changed files with 167 additions and 17 deletions

View File

@ -84,6 +84,11 @@ impl UpstreamOAuth2Link {
self.link.created_at self.link.created_at
} }
/// Subject used for linking
pub async fn subject(&self) -> &str {
&self.link.subject
}
/// The provider for which this link is. /// The provider for which this link is.
pub async fn provider( pub async fn provider(
&self, &self,

View File

@ -22,6 +22,7 @@ use sqlx::PgPool;
use super::{ use super::{
compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session, compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session,
UpstreamOAuth2Link,
}; };
#[derive(Description)] #[derive(Description)]
@ -252,6 +253,58 @@ impl User {
) )
.await .await
} }
/// Get the list of upstream OAuth 2.0 links
async fn upstream_oauth2_links(
&self,
ctx: &Context<'_>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let mut conn = database.acquire().await?;
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Link)
})
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Link)
})
.transpose()?;
let (has_previous_page, has_next_page, edges) =
mas_storage::upstream_oauth2::get_paginated_user_links(
&mut conn, &self.0, before_id, after_id, first, last,
)
.await?;
let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|s| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)),
UpstreamOAuth2Link::new(s),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
} }
/// A user email address /// A user email address

View File

@ -15,12 +15,17 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
use rand::Rng; use rand::Rng;
use sqlx::PgExecutor; use sqlx::{PgExecutor, QueryBuilder};
use tracing::{info_span, Instrument};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, GenericLookupError, PostgresqlBackend}; use crate::{
pagination::{process_page, QueryBuilderExt},
Clock, GenericLookupError, PostgresqlBackend,
};
#[derive(sqlx::FromRow)]
struct LinkLookup { struct LinkLookup {
upstream_oauth_link_id: Uuid, upstream_oauth_link_id: Uuid,
upstream_oauth_provider_id: Uuid, upstream_oauth_provider_id: Uuid,
@ -29,6 +34,18 @@ struct LinkLookup {
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
} }
impl From<LinkLookup> for UpstreamOAuthLink {
fn from(value: LinkLookup) -> Self {
UpstreamOAuthLink {
id: Ulid::from(value.upstream_oauth_link_id),
provider_id: Ulid::from(value.upstream_oauth_provider_id),
user_id: value.user_id.map(Ulid::from),
subject: value.subject,
created_at: value.created_at,
}
}
}
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(upstream_oauth_link.id = %id), fields(upstream_oauth_link.id = %id),
@ -56,13 +73,7 @@ pub async fn lookup_link(
.await .await
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
Ok(UpstreamOAuthLink { Ok(res.into())
id: Ulid::from(res.upstream_oauth_link_id),
provider_id: Ulid::from(res.upstream_oauth_provider_id),
user_id: res.user_id.map(Ulid::from),
subject: res.subject,
created_at: res.created_at,
})
} }
#[tracing::instrument( #[tracing::instrument(
@ -100,13 +111,7 @@ pub async fn lookup_link_by_subject(
.await .await
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
Ok(UpstreamOAuthLink { Ok(res.into())
id: Ulid::from(res.upstream_oauth_link_id),
provider_id: Ulid::from(res.upstream_oauth_provider_id),
user_id: res.user_id.map(Ulid::from),
subject: res.subject,
created_at: res.created_at,
})
} }
#[tracing::instrument( #[tracing::instrument(
@ -187,3 +192,45 @@ pub async fn associate_link_to_user(
Ok(()) Ok(())
} }
#[tracing::instrument(skip_all, err(Display))]
pub async fn get_paginated_user_links(
executor: impl PgExecutor<'_>,
user: &User<PostgresqlBackend>,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<UpstreamOAuthLink>), anyhow::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.data))
.generate_pagination("upstream_oauth_link_id", before, after, first, last)?;
let span = info_span!(
"Fetch paginated upstream OAuth 2.0 user links",
db.statement = query.sql()
);
let page: Vec<LinkLookup> = query
.build_query_as()
.fetch_all(executor)
.instrument(span)
.await?;
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
let page: Vec<_> = page.into_iter().map(Into::into).collect();
Ok((has_previous_page, has_next_page, page))
}

View File

@ -17,7 +17,10 @@ mod provider;
mod session; mod session;
pub use self::{ pub use self::{
link::{add_link, associate_link_to_user, lookup_link, lookup_link_by_subject}, link::{
add_link, associate_link_to_user, get_paginated_user_links, lookup_link,
lookup_link_by_subject,
},
provider::{ provider::{
add_provider, get_paginated_providers, get_providers, lookup_provider, ProviderLookupError, add_provider, get_paginated_providers, get_providers, lookup_provider, ProviderLookupError,
}, },

View File

@ -342,6 +342,10 @@ type UpstreamOAuth2Link implements Node {
""" """
createdAt: DateTime! createdAt: DateTime!
""" """
Subject used for linking
"""
subject: String!
"""
The provider for which this link is. The provider for which this link is.
""" """
provider: UpstreamOAuth2Provider! provider: UpstreamOAuth2Provider!
@ -351,6 +355,35 @@ type UpstreamOAuth2Link implements Node {
user: User user: User
} }
type UpstreamOAuth2LinkConnection {
"""
Information to aid in pagination.
"""
pageInfo: PageInfo!
"""
A list of edges.
"""
edges: [UpstreamOAuth2LinkEdge!]!
"""
A list of nodes.
"""
nodes: [UpstreamOAuth2Link!]!
}
"""
An edge in a connection.
"""
type UpstreamOAuth2LinkEdge {
"""
A cursor for use in pagination
"""
cursor: String!
"""
The item at the end of the edge
"""
node: UpstreamOAuth2Link!
}
type UpstreamOAuth2Provider implements Node { type UpstreamOAuth2Provider implements Node {
""" """
ID of the object. ID of the object.
@ -456,6 +489,15 @@ type User implements Node {
first: Int first: Int
last: Int last: Int
): Oauth2SessionConnection! ): Oauth2SessionConnection!
"""
Get the list of upstream OAuth 2.0 links
"""
upstreamOauth2Links(
after: String
before: String
first: Int
last: Int
): UpstreamOAuth2LinkConnection!
} }
""" """