You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-28 11:02:02 +03:00
Lookup and save upstream links
This commit is contained in:
120
crates/storage/src/upstream_oauth2/link.rs
Normal file
120
crates/storage/src/upstream_oauth2/link.rs
Normal file
@ -0,0 +1,120 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||
use rand::Rng;
|
||||
use sqlx::PgExecutor;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, GenericLookupError};
|
||||
|
||||
struct LinkLookup {
|
||||
upstream_oauth_link_id: Uuid,
|
||||
user_id: Option<Uuid>,
|
||||
subject: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
upstream_oauth_link.subject = subject,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_link_by_subject(
|
||||
executor: impl PgExecutor<'_>,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: &str,
|
||||
) -> Result<(UpstreamOAuthLink, Option<Ulid>), GenericLookupError> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
WHERE upstream_oauth_provider_id = $1
|
||||
AND subject = $2
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
subject,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
|
||||
|
||||
Ok((
|
||||
UpstreamOAuthLink {
|
||||
id: Ulid::from(res.upstream_oauth_link_id),
|
||||
subject: res.subject,
|
||||
created_at: res.created_at,
|
||||
},
|
||||
res.user_id.map(Ulid::from),
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
upstream_oauth_link.id,
|
||||
upstream_oauth_link.subject = subject,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn add_link(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: String,
|
||||
) -> Result<UpstreamOAuthLink, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO upstream_oauth_links (
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
) VALUES ($1, $2, NULL, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
&subject,
|
||||
created_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
Ok(UpstreamOAuthLink {
|
||||
id,
|
||||
subject,
|
||||
created_at,
|
||||
})
|
||||
}
|
@ -12,10 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod link;
|
||||
mod provider;
|
||||
mod session;
|
||||
|
||||
pub use self::{
|
||||
link::{add_link, lookup_link_by_subject},
|
||||
provider::{add_provider, lookup_provider, ProviderLookupError},
|
||||
session::{add_session, lookup_session, SessionLookupError},
|
||||
session::{add_session, complete_session, lookup_session, SessionLookupError},
|
||||
};
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthProvider};
|
||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||
use rand::Rng;
|
||||
use sqlx::PgExecutor;
|
||||
use thiserror::Error;
|
||||
@ -128,9 +128,9 @@ pub async fn lookup_session(
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
upstream_oauth_provider.id = %provider.id,
|
||||
upstream_oauth_provider.issuer = %provider.issuer,
|
||||
upstream_oauth_provider.client_id = %provider.client_id,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
upstream_oauth_authorization_session.id,
|
||||
),
|
||||
err,
|
||||
@ -139,7 +139,7 @@ pub async fn add_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
provider: &UpstreamOAuthProvider,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
state: String,
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
@ -164,7 +164,7 @@ pub async fn add_session(
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, NULL)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(provider.id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
&state,
|
||||
code_challenge_verifier.as_deref(),
|
||||
nonce,
|
||||
@ -182,3 +182,35 @@ pub async fn add_session(
|
||||
completed_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%upstream_oauth_authorization_session.id,
|
||||
%upstream_oauth_link.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn complete_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
|
||||
let completed_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_authorization_sessions
|
||||
SET upstream_oauth_link_id = $1,
|
||||
completed_at = $2
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
completed_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
upstream_oauth_authorization_session.completed_at = Some(completed_at);
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
|
Reference in New Issue
Block a user