1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

WIP: handle account linking

This commit is contained in:
Quentin Gliech
2022-11-23 17:26:59 +01:00
parent cde9187adc
commit 22a337cd45
18 changed files with 848 additions and 50 deletions

View File

@ -23,11 +23,50 @@ use crate::{Clock, GenericLookupError};
struct LinkLookup {
upstream_oauth_link_id: Uuid,
upstream_oauth_provider_id: Uuid,
user_id: Option<Uuid>,
subject: String,
created_at: DateTime<Utc>,
}
#[tracing::instrument(
skip_all,
fields(upstream_oauth_link.id = %id),
err,
)]
pub async fn lookup_link(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<(UpstreamOAuthLink, Ulid, Option<Ulid>), GenericLookupError> {
let res = sqlx::query_as!(
LinkLookup,
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_link_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
Ok((
UpstreamOAuthLink {
id: Ulid::from(res.upstream_oauth_link_id),
subject: res.subject,
created_at: res.created_at,
},
Ulid::from(res.upstream_oauth_provider_id),
res.user_id.map(Ulid::from),
))
}
#[tracing::instrument(
skip_all,
fields(
@ -48,6 +87,7 @@ pub async fn lookup_link_by_subject(
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at

View File

@ -17,7 +17,9 @@ mod provider;
mod session;
pub use self::{
link::{add_link, lookup_link_by_subject},
link::{add_link, lookup_link, lookup_link_by_subject},
provider::{add_provider, lookup_provider, ProviderLookupError},
session::{add_session, complete_session, lookup_session, SessionLookupError},
session::{
add_session, complete_session, lookup_session, lookup_session_on_link, SessionLookupError,
},
};

View File

@ -20,7 +20,7 @@ use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseInconsistencyError, LookupError};
use crate::{Clock, DatabaseInconsistencyError, GenericLookupError, LookupError};
#[derive(Debug, Error)]
#[error("Failed to lookup upstream OAuth 2.0 authorization session")]
@ -35,7 +35,7 @@ impl LookupError for SessionLookupError {
}
}
struct SessionLookup {
struct SessionAndProviderLookup {
upstream_oauth_authorization_session_id: Uuid,
upstream_oauth_provider_id: Uuid,
state: String,
@ -52,6 +52,7 @@ struct SessionLookup {
provider_created_at: DateTime<Utc>,
}
/// Lookup a session and its provider by its ID
#[tracing::instrument(
skip_all,
fields(upstream_oauth_authorization_session.id = %id),
@ -62,7 +63,7 @@ pub async fn lookup_session(
id: Ulid,
) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> {
let res = sqlx::query_as!(
SessionLookup,
SessionAndProviderLookup,
r#"
SELECT
ua.upstream_oauth_authorization_session_id,
@ -125,6 +126,7 @@ pub async fn lookup_session(
Ok((provider, session))
}
/// Add a session to the database
#[tracing::instrument(
skip_all,
fields(
@ -183,6 +185,7 @@ pub async fn add_session(
})
}
/// Mark a session as completed and associate the given link
#[tracing::instrument(
skip_all,
fields(
@ -214,3 +217,59 @@ pub async fn complete_session(
Ok(upstream_oauth_authorization_session)
}
struct SessionLookup {
upstream_oauth_authorization_session_id: Uuid,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
}
/// Lookup a session, which belongs to a link, by its ID
#[tracing::instrument(
skip_all,
fields(
upstream_oauth_authorization_session.id = %id,
%upstream_oauth_link.id,
),
err,
)]
pub async fn lookup_session_on_link(
executor: impl PgExecutor<'_>,
upstream_oauth_link: &UpstreamOAuthLink,
id: Ulid,
) -> Result<UpstreamOAuthAuthorizationSession, GenericLookupError> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
AND upstream_oauth_link_id = $2
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_link.id),
)
.fetch_one(executor)
.await
.map_err(GenericLookupError::what(
"Upstream OAuth 2.0 session on link",
))?;
Ok(UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,
created_at: res.created_at,
completed_at: res.completed_at,
})
}