1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

GraphQL API

This commit is contained in:
Quentin Gliech
2022-12-02 16:25:23 +01:00
parent 07636dd9e7
commit 2e7112ef13
14 changed files with 645 additions and 223 deletions

View File

@ -37,7 +37,7 @@ struct LinkLookup {
pub async fn lookup_link(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<(UpstreamOAuthLink, Ulid, Option<Ulid>), GenericLookupError> {
) -> Result<UpstreamOAuthLink, GenericLookupError> {
let res = sqlx::query_as!(
LinkLookup,
r#"
@ -56,15 +56,13 @@ pub async fn lookup_link(
.await
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
Ok((
UpstreamOAuthLink {
id: Ulid::from(res.upstream_oauth_link_id),
subject: res.subject,
created_at: res.created_at,
},
Ulid::from(res.upstream_oauth_provider_id),
res.user_id.map(Ulid::from),
))
Ok(UpstreamOAuthLink {
id: Ulid::from(res.upstream_oauth_link_id),
provider_id: Ulid::from(res.upstream_oauth_provider_id),
user_id: res.user_id.map(Ulid::from),
subject: res.subject,
created_at: res.created_at,
})
}
#[tracing::instrument(
@ -81,7 +79,7 @@ pub async fn lookup_link_by_subject(
executor: impl PgExecutor<'_>,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: &str,
) -> Result<(UpstreamOAuthLink, Option<Ulid>), GenericLookupError> {
) -> Result<UpstreamOAuthLink, GenericLookupError> {
let res = sqlx::query_as!(
LinkLookup,
r#"
@ -102,14 +100,13 @@ pub async fn lookup_link_by_subject(
.await
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
Ok((
UpstreamOAuthLink {
id: Ulid::from(res.upstream_oauth_link_id),
subject: res.subject,
created_at: res.created_at,
},
res.user_id.map(Ulid::from),
))
Ok(UpstreamOAuthLink {
id: Ulid::from(res.upstream_oauth_link_id),
provider_id: Ulid::from(res.upstream_oauth_provider_id),
user_id: res.user_id.map(Ulid::from),
subject: res.subject,
created_at: res.created_at,
})
}
#[tracing::instrument(
@ -154,6 +151,8 @@ pub async fn add_link(
Ok(UpstreamOAuthLink {
id,
provider_id: upstream_oauth_provider.id,
user_id: None,
subject,
created_at,
})

View File

@ -18,7 +18,7 @@ mod session;
pub use self::{
link::{add_link, associate_link_to_user, lookup_link, lookup_link_by_subject},
provider::{add_provider, lookup_provider, ProviderLookupError},
provider::{add_provider, get_paginated_providers, lookup_provider, ProviderLookupError},
session::{
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
SessionLookupError,

View File

@ -17,12 +17,16 @@ use mas_data_model::UpstreamOAuthProvider;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use rand::Rng;
use sqlx::PgExecutor;
use sqlx::{PgExecutor, QueryBuilder};
use thiserror::Error;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseInconsistencyError, LookupError};
use crate::{
pagination::{process_page, QueryBuilderExt},
Clock, DatabaseInconsistencyError, LookupError,
};
#[derive(Debug, Error)]
#[error("Failed to lookup upstream OAuth 2.0 provider")]
@ -37,6 +41,7 @@ impl LookupError for ProviderLookupError {
}
}
#[derive(sqlx::FromRow)]
struct ProviderLookup {
upstream_oauth_provider_id: Uuid,
issuer: String,
@ -48,6 +53,37 @@ struct ProviderLookup {
created_at: DateTime<Utc>,
}
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
type Error = DatabaseInconsistencyError;
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_provider_id.into();
let scope = value
.scope
.parse()
.map_err(|_| DatabaseInconsistencyError)?;
let token_endpoint_auth_method = value
.token_endpoint_auth_method
.parse()
.map_err(|_| DatabaseInconsistencyError)?;
let token_endpoint_signing_alg = value
.token_endpoint_signing_alg
.map(|x| x.parse())
.transpose()
.map_err(|_| DatabaseInconsistencyError)?;
Ok(UpstreamOAuthProvider {
id,
issuer: value.issuer,
scope,
client_id: value.client_id,
encrypted_client_secret: value.encrypted_client_secret,
token_endpoint_auth_method,
token_endpoint_signing_alg,
created_at: value.created_at,
})
}
}
#[tracing::instrument(
skip_all,
fields(upstream_oauth_provider.id = %id),
@ -77,23 +113,7 @@ pub async fn lookup_provider(
.fetch_one(executor)
.await?;
Ok(UpstreamOAuthProvider {
id: res.upstream_oauth_provider_id.into(),
issuer: res.issuer,
scope: res.scope.parse().map_err(|_| DatabaseInconsistencyError)?,
client_id: res.client_id,
encrypted_client_secret: res.encrypted_client_secret,
token_endpoint_auth_method: res
.token_endpoint_auth_method
.parse()
.map_err(|_| DatabaseInconsistencyError)?,
token_endpoint_signing_alg: res
.token_endpoint_signing_alg
.map(|x| x.parse())
.transpose()
.map_err(|_| DatabaseInconsistencyError)?,
created_at: res.created_at,
})
Ok(res.try_into()?)
}
#[tracing::instrument(
@ -157,3 +177,45 @@ pub async fn add_provider(
created_at,
})
}
#[tracing::instrument(skip_all, err(Display))]
pub async fn get_paginated_providers(
executor: impl PgExecutor<'_>,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<UpstreamOAuthProvider>), anyhow::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_provider_id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at
FROM upstream_oauth_providers
WHERE 1 = 1
"#,
);
query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?;
let span = info_span!(
"Fetch paginated upstream OAuth 2.0 providers",
db.statement = query.sql()
);
let page: Vec<ProviderLookup> = query
.build_query_as()
.fetch_all(executor)
.instrument(span)
.await?;
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
let page: Result<Vec<_>, _> = page.into_iter().map(TryInto::try_into).collect();
Ok((has_previous_page, has_next_page, page?))
}

View File

@ -38,6 +38,7 @@ impl LookupError for SessionLookupError {
struct SessionAndProviderLookup {
upstream_oauth_authorization_session_id: Uuid,
upstream_oauth_provider_id: Uuid,
upstream_oauth_link_id: Option<Uuid>,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
@ -70,6 +71,7 @@ pub async fn lookup_session(
SELECT
ua.upstream_oauth_authorization_session_id,
ua.upstream_oauth_provider_id,
ua.upstream_oauth_link_id,
ua.state,
ua.code_challenge_verifier,
ua.nonce,
@ -120,6 +122,8 @@ pub async fn lookup_session(
let session = UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: provider.id,
link_id: res.upstream_oauth_link_id.map(Ulid::from),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,
@ -185,6 +189,8 @@ pub async fn add_session(
Ok(UpstreamOAuthAuthorizationSession {
id,
provider_id: upstream_oauth_provider.id,
link_id: None,
state,
code_challenge_verifier,
nonce,
@ -267,6 +273,8 @@ pub async fn consume_session(
struct SessionLookup {
upstream_oauth_authorization_session_id: Uuid,
upstream_oauth_provider_id: Uuid,
upstream_oauth_link_id: Option<Uuid>,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
@ -295,6 +303,8 @@ pub async fn lookup_session_on_link(
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
@ -317,6 +327,8 @@ pub async fn lookup_session_on_link(
Ok(UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: res.upstream_oauth_provider_id.into(),
link_id: res.upstream_oauth_link_id.map(Ulid::from),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,