You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
storage: upstream oauth session repository + unit tests
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -3104,6 +3104,7 @@ dependencies = [
|
|||||||
"mas-jose",
|
"mas-jose",
|
||||||
"oauth2-types",
|
"oauth2-types",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
|
"rand_chacha 0.3.1",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
|
@ -22,7 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
|
|||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository};
|
use mas_storage::{
|
||||||
|
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
|
||||||
|
Repository,
|
||||||
|
};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
@ -97,16 +100,17 @@ pub(crate) async fn get(
|
|||||||
&mut rng,
|
&mut rng,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let session = mas_storage::upstream_oauth2::add_session(
|
let session = txn
|
||||||
&mut txn,
|
.upstream_oauth_session()
|
||||||
&mut rng,
|
.add(
|
||||||
&clock,
|
&mut rng,
|
||||||
&provider,
|
&clock,
|
||||||
data.state.clone(),
|
&provider,
|
||||||
data.code_challenge_verifier,
|
data.state.clone(),
|
||||||
data.nonce,
|
data.code_challenge_verifier,
|
||||||
)
|
data.nonce,
|
||||||
.await?;
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
||||||
.add(session.id, provider.id, data.state, query.post_auth_action)
|
.add(session.id, provider.id, data.state, query.post_auth_action)
|
||||||
|
@ -26,7 +26,7 @@ use mas_oidc_client::requests::{
|
|||||||
};
|
};
|
||||||
use mas_router::{Route, UrlBuilder};
|
use mas_router::{Route, UrlBuilder};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
upstream_oauth2::{complete_session, lookup_session},
|
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
|
||||||
Repository, UpstreamOAuthLinkRepository,
|
Repository, UpstreamOAuthLinkRepository,
|
||||||
};
|
};
|
||||||
use oauth2_types::errors::ClientErrorCode;
|
use oauth2_types::errors::ClientErrorCode;
|
||||||
@ -65,6 +65,9 @@ pub(crate) enum RouteError {
|
|||||||
#[error("Session not found")]
|
#[error("Session not found")]
|
||||||
SessionNotFound,
|
SessionNotFound,
|
||||||
|
|
||||||
|
#[error("Provider not found")]
|
||||||
|
ProviderNotFound,
|
||||||
|
|
||||||
#[error("Provider mismatch")]
|
#[error("Provider mismatch")]
|
||||||
ProviderMismatch,
|
ProviderMismatch,
|
||||||
|
|
||||||
@ -105,6 +108,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
|||||||
impl IntoResponse for RouteError {
|
impl IntoResponse for RouteError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
match self {
|
match self {
|
||||||
|
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||||
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
||||||
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||||
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||||
@ -127,16 +131,24 @@ pub(crate) async fn get(
|
|||||||
|
|
||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
|
let provider = txn
|
||||||
|
.upstream_oauth_provider()
|
||||||
|
.lookup(provider_id)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::ProviderNotFound)?;
|
||||||
|
|
||||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||||
let (session_id, _post_auth_action) = sessions_cookie
|
let (session_id, _post_auth_action) = sessions_cookie
|
||||||
.find_session(provider_id, ¶ms.state)
|
.find_session(provider_id, ¶ms.state)
|
||||||
.map_err(|_| RouteError::MissingCookie)?;
|
.map_err(|_| RouteError::MissingCookie)?;
|
||||||
|
|
||||||
let (provider, session) = lookup_session(&mut txn, session_id)
|
let session = txn
|
||||||
|
.upstream_oauth_session()
|
||||||
|
.lookup(session_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::SessionNotFound)?;
|
.ok_or(RouteError::SessionNotFound)?;
|
||||||
|
|
||||||
if provider.id != provider_id {
|
if provider.id != session.provider_id {
|
||||||
// The provider in the session cookie should match the one from the URL
|
// The provider in the session cookie should match the one from the URL
|
||||||
return Err(RouteError::ProviderMismatch);
|
return Err(RouteError::ProviderMismatch);
|
||||||
}
|
}
|
||||||
@ -245,7 +257,11 @@ pub(crate) async fn get(
|
|||||||
.await?
|
.await?
|
||||||
};
|
};
|
||||||
|
|
||||||
let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?;
|
let session = txn
|
||||||
|
.upstream_oauth_session()
|
||||||
|
.complete_with_link(&clock, session, &link, response.id_token)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let cookie_jar = sessions_cookie
|
let cookie_jar = sessions_cookie
|
||||||
.add_link_to_session(session.id, link.id)?
|
.add_link_to_session(session.id, link.id)?
|
||||||
.save(cookie_jar, clock.now());
|
.save(cookie_jar, clock.now());
|
||||||
|
@ -25,7 +25,7 @@ use mas_axum_utils::{
|
|||||||
};
|
};
|
||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
upstream_oauth2::{consume_session, lookup_session_on_link},
|
upstream_oauth2::UpstreamOAuthSessionRepository,
|
||||||
user::{add_user, authenticate_session_with_upstream, lookup_user, start_session},
|
user::{add_user, authenticate_session_with_upstream, lookup_user, start_session},
|
||||||
Repository, UpstreamOAuthLinkRepository,
|
Repository, UpstreamOAuthLinkRepository,
|
||||||
};
|
};
|
||||||
@ -109,12 +109,18 @@ pub(crate) async fn get(
|
|||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::LinkNotFound)?;
|
.ok_or(RouteError::LinkNotFound)?;
|
||||||
|
|
||||||
// This checks that we're in a browser session which is allowed to consume this
|
let upstream_session = txn
|
||||||
// link: the upstream auth session should have been started in this browser.
|
.upstream_oauth_session()
|
||||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
.lookup(session_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::SessionNotFound)?;
|
.ok_or(RouteError::SessionNotFound)?;
|
||||||
|
|
||||||
|
// This checks that we're in a browser session which is allowed to consume this
|
||||||
|
// link: the upstream auth session should have been started in this browser.
|
||||||
|
if upstream_session.link_id != Some(link.id) {
|
||||||
|
return Err(RouteError::SessionNotFound);
|
||||||
|
}
|
||||||
|
|
||||||
if upstream_session.consumed() {
|
if upstream_session.consumed() {
|
||||||
return Err(RouteError::SessionConsumed);
|
return Err(RouteError::SessionConsumed);
|
||||||
}
|
}
|
||||||
@ -127,7 +133,10 @@ pub(crate) async fn get(
|
|||||||
(Some(mut session), Some(user_id)) if session.user.id == user_id => {
|
(Some(mut session), Some(user_id)) if session.user.id == user_id => {
|
||||||
// Session already linked, and link matches the currently logged
|
// Session already linked, and link matches the currently logged
|
||||||
// user. Mark the session as consumed and renew the authentication.
|
// user. Mark the session as consumed and renew the authentication.
|
||||||
consume_session(&mut txn, &clock, upstream_session).await?;
|
txn.upstream_oauth_session()
|
||||||
|
.consume(&clock, upstream_session)
|
||||||
|
.await?;
|
||||||
|
|
||||||
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link)
|
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@ -212,12 +221,18 @@ pub(crate) async fn post(
|
|||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::LinkNotFound)?;
|
.ok_or(RouteError::LinkNotFound)?;
|
||||||
|
|
||||||
// This checks that we're in a browser session which is allowed to consume this
|
let upstream_session = txn
|
||||||
// link: the upstream auth session should have been started in this browser.
|
.upstream_oauth_session()
|
||||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
.lookup(session_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::SessionNotFound)?;
|
.ok_or(RouteError::SessionNotFound)?;
|
||||||
|
|
||||||
|
// This checks that we're in a browser session which is allowed to consume this
|
||||||
|
// link: the upstream auth session should have been started in this browser.
|
||||||
|
if upstream_session.link_id != Some(link.id) {
|
||||||
|
return Err(RouteError::SessionNotFound);
|
||||||
|
}
|
||||||
|
|
||||||
if upstream_session.consumed() {
|
if upstream_session.consumed() {
|
||||||
return Err(RouteError::SessionConsumed);
|
return Err(RouteError::SessionConsumed);
|
||||||
}
|
}
|
||||||
@ -251,7 +266,10 @@ pub(crate) async fn post(
|
|||||||
_ => return Err(RouteError::InvalidFormAction),
|
_ => return Err(RouteError::InvalidFormAction),
|
||||||
};
|
};
|
||||||
|
|
||||||
consume_session(&mut txn, &clock, upstream_session).await?;
|
txn.upstream_oauth_session()
|
||||||
|
.consume(&clock, upstream_session)
|
||||||
|
.await?;
|
||||||
|
|
||||||
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?;
|
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?;
|
||||||
|
|
||||||
let cookie_jar = sessions_cookie
|
let cookie_jar = sessions_cookie
|
||||||
|
@ -14,8 +14,8 @@ serde_json = "1.0.91"
|
|||||||
thiserror = "1.0.38"
|
thiserror = "1.0.38"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
|
||||||
# Password hashing
|
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
rand_chacha = "0.3.1"
|
||||||
url = { version = "2.3.1", features = ["serde"] }
|
url = { version = "2.3.1", features = ["serde"] }
|
||||||
uuid = "1.2.2"
|
uuid = "1.2.2"
|
||||||
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
|
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
|
||||||
|
@ -521,81 +521,6 @@
|
|||||||
},
|
},
|
||||||
"query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n "
|
"query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n "
|
||||||
},
|
},
|
||||||
"2ca7b990c11e84db62fb7887a2bc3410ec1eee2f6a0ec124db36575111970ca9": {
|
|
||||||
"describe": {
|
|
||||||
"columns": [
|
|
||||||
{
|
|
||||||
"name": "upstream_oauth_authorization_session_id",
|
|
||||||
"ordinal": 0,
|
|
||||||
"type_info": "Uuid"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "upstream_oauth_provider_id",
|
|
||||||
"ordinal": 1,
|
|
||||||
"type_info": "Uuid"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "upstream_oauth_link_id",
|
|
||||||
"ordinal": 2,
|
|
||||||
"type_info": "Uuid"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "state",
|
|
||||||
"ordinal": 3,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "code_challenge_verifier",
|
|
||||||
"ordinal": 4,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "nonce",
|
|
||||||
"ordinal": 5,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "id_token",
|
|
||||||
"ordinal": 6,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "created_at",
|
|
||||||
"ordinal": 7,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "completed_at",
|
|
||||||
"ordinal": 8,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "consumed_at",
|
|
||||||
"ordinal": 9,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nullable": [
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
true
|
|
||||||
],
|
|
||||||
"parameters": {
|
|
||||||
"Left": [
|
|
||||||
"Uuid",
|
|
||||||
"Uuid"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n "
|
|
||||||
},
|
|
||||||
"2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": {
|
"2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
@ -708,21 +633,6 @@
|
|||||||
},
|
},
|
||||||
"query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n "
|
"query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n "
|
||||||
},
|
},
|
||||||
"2fb8f1aef96571a6f3f6260d7836de99ff24ba1947747e08b0e8d64097507442": {
|
|
||||||
"describe": {
|
|
||||||
"columns": [],
|
|
||||||
"nullable": [],
|
|
||||||
"parameters": {
|
|
||||||
"Left": [
|
|
||||||
"Uuid",
|
|
||||||
"Timestamptz",
|
|
||||||
"Text",
|
|
||||||
"Uuid"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n "
|
|
||||||
},
|
|
||||||
"360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": {
|
"360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [],
|
"columns": [],
|
||||||
@ -1388,7 +1298,24 @@
|
|||||||
},
|
},
|
||||||
"query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n "
|
"query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n "
|
||||||
},
|
},
|
||||||
"65c7600f1af07cb6ea49d89ae6fbca5374a57c5a866c8aadd7b75ed1d2d1d0cd": {
|
"64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"nullable": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Uuid",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Timestamptz"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n "
|
||||||
|
},
|
||||||
|
"67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
{
|
{
|
||||||
@ -1440,41 +1367,6 @@
|
|||||||
"name": "consumed_at",
|
"name": "consumed_at",
|
||||||
"ordinal": 9,
|
"ordinal": 9,
|
||||||
"type_info": "Timestamptz"
|
"type_info": "Timestamptz"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "provider_issuer",
|
|
||||||
"ordinal": 10,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "provider_scope",
|
|
||||||
"ordinal": 11,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "provider_client_id",
|
|
||||||
"ordinal": 12,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "provider_encrypted_client_secret",
|
|
||||||
"ordinal": 13,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "provider_token_endpoint_auth_method",
|
|
||||||
"ordinal": 14,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "provider_token_endpoint_signing_alg",
|
|
||||||
"ordinal": 15,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "provider_created_at",
|
|
||||||
"ordinal": 16,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"nullable": [
|
"nullable": [
|
||||||
@ -1487,14 +1379,7 @@
|
|||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
true,
|
true,
|
||||||
true,
|
true
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false
|
|
||||||
],
|
],
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"Left": [
|
"Left": [
|
||||||
@ -1502,7 +1387,20 @@
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.upstream_oauth_link_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.id_token,\n ua.created_at,\n ua.completed_at,\n ua.consumed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n "
|
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n "
|
||||||
|
},
|
||||||
|
"689ffbfc5137ec788e89062ad679bbe6b23a8861c09a7246dc1659c28f12bf8d": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"nullable": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Timestamptz",
|
||||||
|
"Uuid"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n "
|
||||||
},
|
},
|
||||||
"6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": {
|
"6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": {
|
||||||
"describe": {
|
"describe": {
|
||||||
@ -2420,6 +2318,21 @@
|
|||||||
},
|
},
|
||||||
"query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n "
|
"query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n "
|
||||||
},
|
},
|
||||||
|
"b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"nullable": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Timestamptz",
|
||||||
|
"Text",
|
||||||
|
"Uuid"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n "
|
||||||
|
},
|
||||||
"bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": {
|
"bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [],
|
"columns": [],
|
||||||
@ -2702,19 +2615,6 @@
|
|||||||
},
|
},
|
||||||
"query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n "
|
"query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n "
|
||||||
},
|
},
|
||||||
"e30562e9637d3a723a91adca6336a8d083657ce6d7fe9551fcd6a9d672453d3c": {
|
|
||||||
"describe": {
|
|
||||||
"columns": [],
|
|
||||||
"nullable": [],
|
|
||||||
"parameters": {
|
|
||||||
"Left": [
|
|
||||||
"Timestamptz",
|
|
||||||
"Uuid"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n "
|
|
||||||
},
|
|
||||||
"e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": {
|
"e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [],
|
"columns": [],
|
||||||
@ -2773,22 +2673,5 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n "
|
"query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n "
|
||||||
},
|
|
||||||
"fb71ac6539039313fd90b29ac943330e54c7b62b2778727726e2f60a554f9c5a": {
|
|
||||||
"describe": {
|
|
||||||
"columns": [],
|
|
||||||
"nullable": [],
|
|
||||||
"parameters": {
|
|
||||||
"Left": [
|
|
||||||
"Uuid",
|
|
||||||
"Uuid",
|
|
||||||
"Text",
|
|
||||||
"Text",
|
|
||||||
"Text",
|
|
||||||
"Timestamptz"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n "
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -14,7 +14,10 @@
|
|||||||
|
|
||||||
use sqlx::{PgConnection, Postgres, Transaction};
|
use sqlx::{PgConnection, Postgres, Transaction};
|
||||||
|
|
||||||
use crate::upstream_oauth2::{PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository};
|
use crate::upstream_oauth2::{
|
||||||
|
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
|
||||||
|
PgUpstreamOAuthSessionRepository,
|
||||||
|
};
|
||||||
|
|
||||||
pub trait Repository {
|
pub trait Repository {
|
||||||
type UpstreamOAuthLinkRepository<'c>
|
type UpstreamOAuthLinkRepository<'c>
|
||||||
@ -25,13 +28,19 @@ pub trait Repository {
|
|||||||
where
|
where
|
||||||
Self: 'c;
|
Self: 'c;
|
||||||
|
|
||||||
|
type UpstreamOAuthSessionRepository<'c>
|
||||||
|
where
|
||||||
|
Self: 'c;
|
||||||
|
|
||||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
|
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
|
||||||
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
|
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
|
||||||
|
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Repository for PgConnection {
|
impl Repository for PgConnection {
|
||||||
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
|
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
|
||||||
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
|
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
|
||||||
|
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
|
||||||
|
|
||||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||||
PgUpstreamOAuthLinkRepository::new(self)
|
PgUpstreamOAuthLinkRepository::new(self)
|
||||||
@ -40,11 +49,16 @@ impl Repository for PgConnection {
|
|||||||
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
|
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
|
||||||
PgUpstreamOAuthProviderRepository::new(self)
|
PgUpstreamOAuthProviderRepository::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
|
||||||
|
PgUpstreamOAuthSessionRepository::new(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'t> Repository for Transaction<'t, Postgres> {
|
impl<'t> Repository for Transaction<'t, Postgres> {
|
||||||
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
|
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
|
||||||
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
|
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
|
||||||
|
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
|
||||||
|
|
||||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||||
PgUpstreamOAuthLinkRepository::new(self)
|
PgUpstreamOAuthLinkRepository::new(self)
|
||||||
@ -53,4 +67,8 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
|||||||
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
|
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
|
||||||
PgUpstreamOAuthProviderRepository::new(self)
|
PgUpstreamOAuthProviderRepository::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
|
||||||
|
PgUpstreamOAuthSessionRepository::new(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,111 @@ mod session;
|
|||||||
pub use self::{
|
pub use self::{
|
||||||
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
|
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
|
||||||
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
|
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
|
||||||
session::{
|
session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository},
|
||||||
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use oauth2_types::scope::{Scope, OPENID};
|
||||||
|
use rand::SeedableRng;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::{Clock, Repository};
|
||||||
|
|
||||||
|
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||||
|
async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
||||||
|
let clock = Clock::default();
|
||||||
|
let mut conn = pool.acquire().await?;
|
||||||
|
|
||||||
|
// The provider list should be empty at the start
|
||||||
|
let all_providers = conn.upstream_oauth_provider().all().await?;
|
||||||
|
assert!(all_providers.is_empty());
|
||||||
|
|
||||||
|
// Let's add a provider
|
||||||
|
let provider = conn
|
||||||
|
.upstream_oauth_provider()
|
||||||
|
.add(
|
||||||
|
&mut rng,
|
||||||
|
&clock,
|
||||||
|
"https://example.com/".to_owned(),
|
||||||
|
Scope::from_iter([OPENID]),
|
||||||
|
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||||
|
None,
|
||||||
|
"client-id".to_owned(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Look it up in the database
|
||||||
|
let provider = conn
|
||||||
|
.upstream_oauth_provider()
|
||||||
|
.lookup(provider.id)
|
||||||
|
.await?
|
||||||
|
.expect("provider to be found in the database");
|
||||||
|
assert_eq!(provider.issuer, "https://example.com/");
|
||||||
|
assert_eq!(provider.client_id, "client-id");
|
||||||
|
|
||||||
|
// Start a session
|
||||||
|
let session = conn
|
||||||
|
.upstream_oauth_session()
|
||||||
|
.add(
|
||||||
|
&mut rng,
|
||||||
|
&clock,
|
||||||
|
&provider,
|
||||||
|
"some-state".to_owned(),
|
||||||
|
None,
|
||||||
|
"some-nonce".to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Look it up in the database
|
||||||
|
let session = conn
|
||||||
|
.upstream_oauth_session()
|
||||||
|
.lookup(session.id)
|
||||||
|
.await?
|
||||||
|
.expect("session to be found in the database");
|
||||||
|
assert_eq!(session.provider_id, provider.id);
|
||||||
|
assert_eq!(session.link_id, None);
|
||||||
|
assert!(!session.completed());
|
||||||
|
assert!(!session.consumed());
|
||||||
|
|
||||||
|
// Create a link
|
||||||
|
let link = conn
|
||||||
|
.upstream_oauth_link()
|
||||||
|
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// We can look it up by its ID
|
||||||
|
conn.upstream_oauth_link()
|
||||||
|
.lookup(link.id)
|
||||||
|
.await?
|
||||||
|
.expect("link to be found in database");
|
||||||
|
|
||||||
|
// or by its subject
|
||||||
|
let link = conn
|
||||||
|
.upstream_oauth_link()
|
||||||
|
.find_by_subject(&provider, "a-subject")
|
||||||
|
.await?
|
||||||
|
.expect("link to be found in database");
|
||||||
|
assert_eq!(link.subject, "a-subject");
|
||||||
|
assert_eq!(link.provider_id, provider.id);
|
||||||
|
|
||||||
|
let session = conn
|
||||||
|
.upstream_oauth_session()
|
||||||
|
.complete_with_link(&clock, session, &link, None)
|
||||||
|
.await?;
|
||||||
|
assert!(session.completed());
|
||||||
|
assert!(!session.consumed());
|
||||||
|
assert_eq!(session.link_id, Some(link.id));
|
||||||
|
|
||||||
|
let session = conn
|
||||||
|
.upstream_oauth_session()
|
||||||
|
.consume(&clock, session)
|
||||||
|
.await?;
|
||||||
|
assert!(session.consumed());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -12,261 +12,62 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||||
use rand::Rng;
|
use rand::RngCore;
|
||||||
use sqlx::PgExecutor;
|
use sqlx::PgConnection;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
|
use crate::{Clock, DatabaseError, LookupResultExt};
|
||||||
|
|
||||||
struct SessionAndProviderLookup {
|
#[async_trait]
|
||||||
upstream_oauth_authorization_session_id: Uuid,
|
pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||||
upstream_oauth_provider_id: Uuid,
|
type Error;
|
||||||
upstream_oauth_link_id: Option<Uuid>,
|
|
||||||
state: String,
|
/// Lookup a session by its ID
|
||||||
code_challenge_verifier: Option<String>,
|
async fn lookup(
|
||||||
nonce: String,
|
&mut self,
|
||||||
id_token: Option<String>,
|
id: Ulid,
|
||||||
created_at: DateTime<Utc>,
|
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
|
||||||
completed_at: Option<DateTime<Utc>>,
|
|
||||||
consumed_at: Option<DateTime<Utc>>,
|
/// Add a session to the database
|
||||||
provider_issuer: String,
|
async fn add(
|
||||||
provider_scope: String,
|
&mut self,
|
||||||
provider_client_id: String,
|
rng: &mut (dyn RngCore + Send),
|
||||||
provider_encrypted_client_secret: Option<String>,
|
clock: &Clock,
|
||||||
provider_token_endpoint_auth_method: String,
|
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||||
provider_token_endpoint_signing_alg: Option<String>,
|
state: String,
|
||||||
provider_created_at: DateTime<Utc>,
|
code_challenge_verifier: Option<String>,
|
||||||
|
nonce: String,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||||
|
|
||||||
|
/// Mark a session as completed and associate the given link
|
||||||
|
async fn complete_with_link(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||||
|
upstream_oauth_link: &UpstreamOAuthLink,
|
||||||
|
id_token: Option<String>,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||||
|
|
||||||
|
/// Mark a session as consumed
|
||||||
|
async fn consume(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lookup a session and its provider by its ID
|
pub struct PgUpstreamOAuthSessionRepository<'c> {
|
||||||
#[tracing::instrument(
|
conn: &'c mut PgConnection,
|
||||||
skip_all,
|
|
||||||
fields(upstream_oauth_authorization_session.id = %id),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn lookup_session(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
id: Ulid,
|
|
||||||
) -> Result<Option<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession)>, DatabaseError> {
|
|
||||||
let res = sqlx::query_as!(
|
|
||||||
SessionAndProviderLookup,
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
ua.upstream_oauth_authorization_session_id,
|
|
||||||
ua.upstream_oauth_provider_id,
|
|
||||||
ua.upstream_oauth_link_id,
|
|
||||||
ua.state,
|
|
||||||
ua.code_challenge_verifier,
|
|
||||||
ua.nonce,
|
|
||||||
ua.id_token,
|
|
||||||
ua.created_at,
|
|
||||||
ua.completed_at,
|
|
||||||
ua.consumed_at,
|
|
||||||
up.issuer AS "provider_issuer",
|
|
||||||
up.scope AS "provider_scope",
|
|
||||||
up.client_id AS "provider_client_id",
|
|
||||||
up.encrypted_client_secret AS "provider_encrypted_client_secret",
|
|
||||||
up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method",
|
|
||||||
up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg",
|
|
||||||
up.created_at AS "provider_created_at"
|
|
||||||
FROM upstream_oauth_authorization_sessions ua
|
|
||||||
INNER JOIN upstream_oauth_providers up
|
|
||||||
USING (upstream_oauth_provider_id)
|
|
||||||
WHERE upstream_oauth_authorization_session_id = $1
|
|
||||||
"#,
|
|
||||||
Uuid::from(id),
|
|
||||||
)
|
|
||||||
.fetch_one(executor)
|
|
||||||
.await
|
|
||||||
.to_option()?;
|
|
||||||
|
|
||||||
let Some(res) = res else { return Ok(None) };
|
|
||||||
|
|
||||||
let id = res.upstream_oauth_provider_id.into();
|
|
||||||
let provider = UpstreamOAuthProvider {
|
|
||||||
id,
|
|
||||||
issuer: res.provider_issuer,
|
|
||||||
scope: res.provider_scope.parse().map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
|
||||||
.column("scope")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
})?,
|
|
||||||
client_id: res.provider_client_id,
|
|
||||||
encrypted_client_secret: res.provider_encrypted_client_secret,
|
|
||||||
token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err(
|
|
||||||
|e| {
|
|
||||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
|
||||||
.column("token_endpoint_auth_method")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
},
|
|
||||||
)?,
|
|
||||||
token_endpoint_signing_alg: res
|
|
||||||
.provider_token_endpoint_signing_alg
|
|
||||||
.map(|x| x.parse())
|
|
||||||
.transpose()
|
|
||||||
.map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
|
||||||
.column("token_endpoint_signing_alg")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
})?,
|
|
||||||
created_at: res.provider_created_at,
|
|
||||||
};
|
|
||||||
|
|
||||||
let session = UpstreamOAuthAuthorizationSession {
|
|
||||||
id: res.upstream_oauth_authorization_session_id.into(),
|
|
||||||
provider_id: provider.id,
|
|
||||||
link_id: res.upstream_oauth_link_id.map(Ulid::from),
|
|
||||||
state: res.state,
|
|
||||||
code_challenge_verifier: res.code_challenge_verifier,
|
|
||||||
nonce: res.nonce,
|
|
||||||
id_token: res.id_token,
|
|
||||||
created_at: res.created_at,
|
|
||||||
completed_at: res.completed_at,
|
|
||||||
consumed_at: res.consumed_at,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Some((provider, session)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a session to the database
|
impl<'c> PgUpstreamOAuthSessionRepository<'c> {
|
||||||
#[tracing::instrument(
|
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||||
skip_all,
|
Self { conn }
|
||||||
fields(
|
}
|
||||||
%upstream_oauth_provider.id,
|
|
||||||
%upstream_oauth_provider.issuer,
|
|
||||||
%upstream_oauth_provider.client_id,
|
|
||||||
upstream_oauth_authorization_session.id,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn add_session(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
mut rng: impl Rng + Send,
|
|
||||||
clock: &Clock,
|
|
||||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
|
||||||
state: String,
|
|
||||||
code_challenge_verifier: Option<String>,
|
|
||||||
nonce: String,
|
|
||||||
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
|
|
||||||
let created_at = clock.now();
|
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
|
||||||
tracing::Span::current().record(
|
|
||||||
"upstream_oauth_authorization_session.id",
|
|
||||||
tracing::field::display(id),
|
|
||||||
);
|
|
||||||
|
|
||||||
sqlx::query!(
|
|
||||||
r#"
|
|
||||||
INSERT INTO upstream_oauth_authorization_sessions (
|
|
||||||
upstream_oauth_authorization_session_id,
|
|
||||||
upstream_oauth_provider_id,
|
|
||||||
state,
|
|
||||||
code_challenge_verifier,
|
|
||||||
nonce,
|
|
||||||
created_at,
|
|
||||||
completed_at,
|
|
||||||
consumed_at,
|
|
||||||
id_token
|
|
||||||
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
|
|
||||||
"#,
|
|
||||||
Uuid::from(id),
|
|
||||||
Uuid::from(upstream_oauth_provider.id),
|
|
||||||
&state,
|
|
||||||
code_challenge_verifier.as_deref(),
|
|
||||||
nonce,
|
|
||||||
created_at,
|
|
||||||
)
|
|
||||||
.execute(executor)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(UpstreamOAuthAuthorizationSession {
|
|
||||||
id,
|
|
||||||
provider_id: upstream_oauth_provider.id,
|
|
||||||
link_id: None,
|
|
||||||
state,
|
|
||||||
code_challenge_verifier,
|
|
||||||
nonce,
|
|
||||||
id_token: None,
|
|
||||||
created_at,
|
|
||||||
completed_at: None,
|
|
||||||
consumed_at: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Mark a session as completed and associate the given link
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
%upstream_oauth_authorization_session.id,
|
|
||||||
%upstream_oauth_link.id,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn complete_session(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
clock: &Clock,
|
|
||||||
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
|
||||||
upstream_oauth_link: &UpstreamOAuthLink,
|
|
||||||
id_token: Option<String>,
|
|
||||||
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
|
|
||||||
let completed_at = clock.now();
|
|
||||||
sqlx::query!(
|
|
||||||
r#"
|
|
||||||
UPDATE upstream_oauth_authorization_sessions
|
|
||||||
SET upstream_oauth_link_id = $1,
|
|
||||||
completed_at = $2,
|
|
||||||
id_token = $3
|
|
||||||
WHERE upstream_oauth_authorization_session_id = $4
|
|
||||||
"#,
|
|
||||||
Uuid::from(upstream_oauth_link.id),
|
|
||||||
completed_at,
|
|
||||||
id_token,
|
|
||||||
Uuid::from(upstream_oauth_authorization_session.id),
|
|
||||||
)
|
|
||||||
.execute(executor)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
upstream_oauth_authorization_session.completed_at = Some(completed_at);
|
|
||||||
upstream_oauth_authorization_session.id_token = id_token;
|
|
||||||
|
|
||||||
Ok(upstream_oauth_authorization_session)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Mark a session as consumed
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
%upstream_oauth_authorization_session.id,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn consume_session(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
clock: &Clock,
|
|
||||||
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
|
||||||
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
|
|
||||||
let consumed_at = clock.now();
|
|
||||||
sqlx::query!(
|
|
||||||
r#"
|
|
||||||
UPDATE upstream_oauth_authorization_sessions
|
|
||||||
SET consumed_at = $1
|
|
||||||
WHERE upstream_oauth_authorization_session_id = $2
|
|
||||||
"#,
|
|
||||||
consumed_at,
|
|
||||||
Uuid::from(upstream_oauth_authorization_session.id),
|
|
||||||
)
|
|
||||||
.execute(executor)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
|
|
||||||
|
|
||||||
Ok(upstream_oauth_authorization_session)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SessionLookup {
|
struct SessionLookup {
|
||||||
@ -282,57 +83,191 @@ struct SessionLookup {
|
|||||||
consumed_at: Option<DateTime<Utc>>,
|
consumed_at: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lookup a session, which belongs to a link, by its ID
|
#[async_trait]
|
||||||
#[tracing::instrument(
|
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
|
||||||
skip_all,
|
type Error = DatabaseError;
|
||||||
fields(
|
|
||||||
upstream_oauth_authorization_session.id = %id,
|
|
||||||
%upstream_oauth_link.id,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn lookup_session_on_link(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
upstream_oauth_link: &UpstreamOAuthLink,
|
|
||||||
id: Ulid,
|
|
||||||
) -> Result<Option<UpstreamOAuthAuthorizationSession>, sqlx::Error> {
|
|
||||||
let res = sqlx::query_as!(
|
|
||||||
SessionLookup,
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
upstream_oauth_authorization_session_id,
|
|
||||||
upstream_oauth_provider_id,
|
|
||||||
upstream_oauth_link_id,
|
|
||||||
state,
|
|
||||||
code_challenge_verifier,
|
|
||||||
nonce,
|
|
||||||
id_token,
|
|
||||||
created_at,
|
|
||||||
completed_at,
|
|
||||||
consumed_at
|
|
||||||
FROM upstream_oauth_authorization_sessions
|
|
||||||
WHERE upstream_oauth_authorization_session_id = $1
|
|
||||||
AND upstream_oauth_link_id = $2
|
|
||||||
"#,
|
|
||||||
Uuid::from(id),
|
|
||||||
Uuid::from(upstream_oauth_link.id),
|
|
||||||
)
|
|
||||||
.fetch_one(executor)
|
|
||||||
.await
|
|
||||||
.to_option()?;
|
|
||||||
|
|
||||||
let Some(res) = res else { return Ok(None) };
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(upstream_oauth_provider.id = %id),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn lookup(
|
||||||
|
&mut self,
|
||||||
|
id: Ulid,
|
||||||
|
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
|
||||||
|
let res = sqlx::query_as!(
|
||||||
|
SessionLookup,
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
upstream_oauth_authorization_session_id,
|
||||||
|
upstream_oauth_provider_id,
|
||||||
|
upstream_oauth_link_id,
|
||||||
|
state,
|
||||||
|
code_challenge_verifier,
|
||||||
|
nonce,
|
||||||
|
id_token,
|
||||||
|
created_at,
|
||||||
|
completed_at,
|
||||||
|
consumed_at
|
||||||
|
FROM upstream_oauth_authorization_sessions
|
||||||
|
WHERE upstream_oauth_authorization_session_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
Ok(Some(UpstreamOAuthAuthorizationSession {
|
let Some(res) = res else { return Ok(None) };
|
||||||
id: res.upstream_oauth_authorization_session_id.into(),
|
|
||||||
provider_id: res.upstream_oauth_provider_id.into(),
|
let session = UpstreamOAuthAuthorizationSession {
|
||||||
link_id: res.upstream_oauth_link_id.map(Ulid::from),
|
id: res.upstream_oauth_authorization_session_id.into(),
|
||||||
state: res.state,
|
provider_id: res.upstream_oauth_provider_id.into(),
|
||||||
code_challenge_verifier: res.code_challenge_verifier,
|
link_id: res.upstream_oauth_link_id.map(Ulid::from),
|
||||||
nonce: res.nonce,
|
state: res.state,
|
||||||
id_token: res.id_token,
|
code_challenge_verifier: res.code_challenge_verifier,
|
||||||
created_at: res.created_at,
|
nonce: res.nonce,
|
||||||
completed_at: res.completed_at,
|
id_token: res.id_token,
|
||||||
consumed_at: res.consumed_at,
|
created_at: res.created_at,
|
||||||
}))
|
completed_at: res.completed_at,
|
||||||
|
consumed_at: res.consumed_at,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(session))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
%upstream_oauth_provider.id,
|
||||||
|
%upstream_oauth_provider.issuer,
|
||||||
|
%upstream_oauth_provider.client_id,
|
||||||
|
upstream_oauth_authorization_session.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn add(
|
||||||
|
&mut self,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
|
clock: &Clock,
|
||||||
|
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||||
|
state: String,
|
||||||
|
code_challenge_verifier: Option<String>,
|
||||||
|
nonce: String,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||||
|
let created_at = clock.now();
|
||||||
|
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||||
|
tracing::Span::current().record(
|
||||||
|
"upstream_oauth_authorization_session.id",
|
||||||
|
tracing::field::display(id),
|
||||||
|
);
|
||||||
|
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
INSERT INTO upstream_oauth_authorization_sessions (
|
||||||
|
upstream_oauth_authorization_session_id,
|
||||||
|
upstream_oauth_provider_id,
|
||||||
|
state,
|
||||||
|
code_challenge_verifier,
|
||||||
|
nonce,
|
||||||
|
created_at,
|
||||||
|
completed_at,
|
||||||
|
consumed_at,
|
||||||
|
id_token
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
Uuid::from(upstream_oauth_provider.id),
|
||||||
|
&state,
|
||||||
|
code_challenge_verifier.as_deref(),
|
||||||
|
nonce,
|
||||||
|
created_at,
|
||||||
|
)
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(UpstreamOAuthAuthorizationSession {
|
||||||
|
id,
|
||||||
|
provider_id: upstream_oauth_provider.id,
|
||||||
|
link_id: None,
|
||||||
|
state,
|
||||||
|
code_challenge_verifier,
|
||||||
|
nonce,
|
||||||
|
id_token: None,
|
||||||
|
created_at,
|
||||||
|
completed_at: None,
|
||||||
|
consumed_at: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
%upstream_oauth_authorization_session.id,
|
||||||
|
%upstream_oauth_link.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn complete_with_link(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||||
|
upstream_oauth_link: &UpstreamOAuthLink,
|
||||||
|
id_token: Option<String>,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||||
|
let completed_at = clock.now();
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
UPDATE upstream_oauth_authorization_sessions
|
||||||
|
SET upstream_oauth_link_id = $1,
|
||||||
|
completed_at = $2,
|
||||||
|
id_token = $3
|
||||||
|
WHERE upstream_oauth_authorization_session_id = $4
|
||||||
|
"#,
|
||||||
|
Uuid::from(upstream_oauth_link.id),
|
||||||
|
completed_at,
|
||||||
|
id_token,
|
||||||
|
Uuid::from(upstream_oauth_authorization_session.id),
|
||||||
|
)
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
upstream_oauth_authorization_session.completed_at = Some(completed_at);
|
||||||
|
upstream_oauth_authorization_session.id_token = id_token;
|
||||||
|
upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id);
|
||||||
|
|
||||||
|
Ok(upstream_oauth_authorization_session)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark a session as consumed
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
%upstream_oauth_authorization_session.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn consume(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||||
|
let consumed_at = clock.now();
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
UPDATE upstream_oauth_authorization_sessions
|
||||||
|
SET consumed_at = $1
|
||||||
|
WHERE upstream_oauth_authorization_session_id = $2
|
||||||
|
"#,
|
||||||
|
consumed_at,
|
||||||
|
Uuid::from(upstream_oauth_authorization_session.id),
|
||||||
|
)
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
|
||||||
|
|
||||||
|
Ok(upstream_oauth_authorization_session)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user