You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
WIP: handle account linking
This commit is contained in:
@ -23,11 +23,50 @@ use crate::{Clock, GenericLookupError};
|
||||
|
||||
struct LinkLookup {
|
||||
upstream_oauth_link_id: Uuid,
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
user_id: Option<Uuid>,
|
||||
subject: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(upstream_oauth_link.id = %id),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_link(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<(UpstreamOAuthLink, Ulid, Option<Ulid>), GenericLookupError> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
WHERE upstream_oauth_link_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
|
||||
|
||||
Ok((
|
||||
UpstreamOAuthLink {
|
||||
id: Ulid::from(res.upstream_oauth_link_id),
|
||||
subject: res.subject,
|
||||
created_at: res.created_at,
|
||||
},
|
||||
Ulid::from(res.upstream_oauth_provider_id),
|
||||
res.user_id.map(Ulid::from),
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
@ -48,6 +87,7 @@ pub async fn lookup_link_by_subject(
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
|
@ -17,7 +17,9 @@ mod provider;
|
||||
mod session;
|
||||
|
||||
pub use self::{
|
||||
link::{add_link, lookup_link_by_subject},
|
||||
link::{add_link, lookup_link, lookup_link_by_subject},
|
||||
provider::{add_provider, lookup_provider, ProviderLookupError},
|
||||
session::{add_session, complete_session, lookup_session, SessionLookupError},
|
||||
session::{
|
||||
add_session, complete_session, lookup_session, lookup_session_on_link, SessionLookupError,
|
||||
},
|
||||
};
|
||||
|
@ -20,7 +20,7 @@ use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, DatabaseInconsistencyError, LookupError};
|
||||
use crate::{Clock, DatabaseInconsistencyError, GenericLookupError, LookupError};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Failed to lookup upstream OAuth 2.0 authorization session")]
|
||||
@ -35,7 +35,7 @@ impl LookupError for SessionLookupError {
|
||||
}
|
||||
}
|
||||
|
||||
struct SessionLookup {
|
||||
struct SessionAndProviderLookup {
|
||||
upstream_oauth_authorization_session_id: Uuid,
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
state: String,
|
||||
@ -52,6 +52,7 @@ struct SessionLookup {
|
||||
provider_created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Lookup a session and its provider by its ID
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(upstream_oauth_authorization_session.id = %id),
|
||||
@ -62,7 +63,7 @@ pub async fn lookup_session(
|
||||
id: Ulid,
|
||||
) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionLookup,
|
||||
SessionAndProviderLookup,
|
||||
r#"
|
||||
SELECT
|
||||
ua.upstream_oauth_authorization_session_id,
|
||||
@ -125,6 +126,7 @@ pub async fn lookup_session(
|
||||
Ok((provider, session))
|
||||
}
|
||||
|
||||
/// Add a session to the database
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
@ -183,6 +185,7 @@ pub async fn add_session(
|
||||
})
|
||||
}
|
||||
|
||||
/// Mark a session as completed and associate the given link
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
@ -214,3 +217,59 @@ pub async fn complete_session(
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
|
||||
struct SessionLookup {
|
||||
upstream_oauth_authorization_session_id: Uuid,
|
||||
state: String,
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Lookup a session, which belongs to a link, by its ID
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
upstream_oauth_authorization_session.id = %id,
|
||||
%upstream_oauth_link.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_session_on_link(
|
||||
executor: impl PgExecutor<'_>,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id: Ulid,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, GenericLookupError> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_authorization_session_id,
|
||||
state,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
created_at,
|
||||
completed_at
|
||||
FROM upstream_oauth_authorization_sessions
|
||||
WHERE upstream_oauth_authorization_session_id = $1
|
||||
AND upstream_oauth_link_id = $2
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(GenericLookupError::what(
|
||||
"Upstream OAuth 2.0 session on link",
|
||||
))?;
|
||||
|
||||
Ok(UpstreamOAuthAuthorizationSession {
|
||||
id: res.upstream_oauth_authorization_session_id.into(),
|
||||
state: res.state,
|
||||
code_challenge_verifier: res.code_challenge_verifier,
|
||||
nonce: res.nonce,
|
||||
created_at: res.created_at,
|
||||
completed_at: res.completed_at,
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user