1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: upstream oauth session repository + unit tests

This commit is contained in:
Quentin Gliech
2022-12-30 15:39:51 +01:00
parent 0faf08fce2
commit 870a37151f
9 changed files with 469 additions and 490 deletions

1
Cargo.lock generated
View File

@ -3104,6 +3104,7 @@ dependencies = [
"mas-jose", "mas-jose",
"oauth2-types", "oauth2-types",
"rand 0.8.5", "rand 0.8.5",
"rand_chacha 0.3.1",
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",

View File

@ -22,7 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
Repository,
};
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
@ -97,16 +100,17 @@ pub(crate) async fn get(
&mut rng, &mut rng,
)?; )?;
let session = mas_storage::upstream_oauth2::add_session( let session = txn
&mut txn, .upstream_oauth_session()
&mut rng, .add(
&clock, &mut rng,
&provider, &clock,
data.state.clone(), &provider,
data.code_challenge_verifier, data.state.clone(),
data.nonce, data.code_challenge_verifier,
) data.nonce,
.await?; )
.await?;
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar) let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
.add(session.id, provider.id, data.state, query.post_auth_action) .add(session.id, provider.id, data.state, query.post_auth_action)

View File

@ -26,7 +26,7 @@ use mas_oidc_client::requests::{
}; };
use mas_router::{Route, UrlBuilder}; use mas_router::{Route, UrlBuilder};
use mas_storage::{ use mas_storage::{
upstream_oauth2::{complete_session, lookup_session}, upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
Repository, UpstreamOAuthLinkRepository, Repository, UpstreamOAuthLinkRepository,
}; };
use oauth2_types::errors::ClientErrorCode; use oauth2_types::errors::ClientErrorCode;
@ -65,6 +65,9 @@ pub(crate) enum RouteError {
#[error("Session not found")] #[error("Session not found")]
SessionNotFound, SessionNotFound,
#[error("Provider not found")]
ProviderNotFound,
#[error("Provider mismatch")] #[error("Provider mismatch")]
ProviderMismatch, ProviderMismatch,
@ -105,6 +108,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(), Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(), e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
@ -127,16 +131,24 @@ pub(crate) async fn get(
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let provider = txn
.upstream_oauth_provider()
.lookup(provider_id)
.await?
.ok_or(RouteError::ProviderNotFound)?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let (session_id, _post_auth_action) = sessions_cookie let (session_id, _post_auth_action) = sessions_cookie
.find_session(provider_id, &params.state) .find_session(provider_id, &params.state)
.map_err(|_| RouteError::MissingCookie)?; .map_err(|_| RouteError::MissingCookie)?;
let (provider, session) = lookup_session(&mut txn, session_id) let session = txn
.upstream_oauth_session()
.lookup(session_id)
.await? .await?
.ok_or(RouteError::SessionNotFound)?; .ok_or(RouteError::SessionNotFound)?;
if provider.id != provider_id { if provider.id != session.provider_id {
// The provider in the session cookie should match the one from the URL // The provider in the session cookie should match the one from the URL
return Err(RouteError::ProviderMismatch); return Err(RouteError::ProviderMismatch);
} }
@ -245,7 +257,11 @@ pub(crate) async fn get(
.await? .await?
}; };
let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?; let session = txn
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, response.id_token)
.await?;
let cookie_jar = sessions_cookie let cookie_jar = sessions_cookie
.add_link_to_session(session.id, link.id)? .add_link_to_session(session.id, link.id)?
.save(cookie_jar, clock.now()); .save(cookie_jar, clock.now());

View File

@ -25,7 +25,7 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
upstream_oauth2::{consume_session, lookup_session_on_link}, upstream_oauth2::UpstreamOAuthSessionRepository,
user::{add_user, authenticate_session_with_upstream, lookup_user, start_session}, user::{add_user, authenticate_session_with_upstream, lookup_user, start_session},
Repository, UpstreamOAuthLinkRepository, Repository, UpstreamOAuthLinkRepository,
}; };
@ -109,12 +109,18 @@ pub(crate) async fn get(
.await? .await?
.ok_or(RouteError::LinkNotFound)?; .ok_or(RouteError::LinkNotFound)?;
// This checks that we're in a browser session which is allowed to consume this let upstream_session = txn
// link: the upstream auth session should have been started in this browser. .upstream_oauth_session()
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) .lookup(session_id)
.await? .await?
.ok_or(RouteError::SessionNotFound)?; .ok_or(RouteError::SessionNotFound)?;
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
if upstream_session.link_id != Some(link.id) {
return Err(RouteError::SessionNotFound);
}
if upstream_session.consumed() { if upstream_session.consumed() {
return Err(RouteError::SessionConsumed); return Err(RouteError::SessionConsumed);
} }
@ -127,7 +133,10 @@ pub(crate) async fn get(
(Some(mut session), Some(user_id)) if session.user.id == user_id => { (Some(mut session), Some(user_id)) if session.user.id == user_id => {
// Session already linked, and link matches the currently logged // Session already linked, and link matches the currently logged
// user. Mark the session as consumed and renew the authentication. // user. Mark the session as consumed and renew the authentication.
consume_session(&mut txn, &clock, upstream_session).await?; txn.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link) authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link)
.await?; .await?;
@ -212,12 +221,18 @@ pub(crate) async fn post(
.await? .await?
.ok_or(RouteError::LinkNotFound)?; .ok_or(RouteError::LinkNotFound)?;
// This checks that we're in a browser session which is allowed to consume this let upstream_session = txn
// link: the upstream auth session should have been started in this browser. .upstream_oauth_session()
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id) .lookup(session_id)
.await? .await?
.ok_or(RouteError::SessionNotFound)?; .ok_or(RouteError::SessionNotFound)?;
// This checks that we're in a browser session which is allowed to consume this
// link: the upstream auth session should have been started in this browser.
if upstream_session.link_id != Some(link.id) {
return Err(RouteError::SessionNotFound);
}
if upstream_session.consumed() { if upstream_session.consumed() {
return Err(RouteError::SessionConsumed); return Err(RouteError::SessionConsumed);
} }
@ -251,7 +266,10 @@ pub(crate) async fn post(
_ => return Err(RouteError::InvalidFormAction), _ => return Err(RouteError::InvalidFormAction),
}; };
consume_session(&mut txn, &clock, upstream_session).await?; txn.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?; authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?;
let cookie_jar = sessions_cookie let cookie_jar = sessions_cookie

View File

@ -14,8 +14,8 @@ serde_json = "1.0.91"
thiserror = "1.0.38" thiserror = "1.0.38"
tracing = "0.1.37" tracing = "0.1.37"
# Password hashing
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1"
url = { version = "2.3.1", features = ["serde"] } url = { version = "2.3.1", features = ["serde"] }
uuid = "1.2.2" uuid = "1.2.2"
ulid = { version = "1.0.0", features = ["uuid", "serde"] } ulid = { version = "1.0.0", features = ["uuid", "serde"] }

View File

@ -521,81 +521,6 @@
}, },
"query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n "
}, },
"2ca7b990c11e84db62fb7887a2bc3410ec1eee2f6a0ec124db36575111970ca9": {
"describe": {
"columns": [
{
"name": "upstream_oauth_authorization_session_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "upstream_oauth_provider_id",
"ordinal": 1,
"type_info": "Uuid"
},
{
"name": "upstream_oauth_link_id",
"ordinal": 2,
"type_info": "Uuid"
},
{
"name": "state",
"ordinal": 3,
"type_info": "Text"
},
{
"name": "code_challenge_verifier",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "nonce",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "id_token",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 7,
"type_info": "Timestamptz"
},
{
"name": "completed_at",
"ordinal": 8,
"type_info": "Timestamptz"
},
{
"name": "consumed_at",
"ordinal": 9,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
true,
false,
true,
false,
true,
false,
true,
true
],
"parameters": {
"Left": [
"Uuid",
"Uuid"
]
}
},
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n "
},
"2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": { "2e581d57db471b96091860cd0252361d16332deeffabab0dace405ead55324be": {
"describe": { "describe": {
"columns": [ "columns": [
@ -708,21 +633,6 @@
}, },
"query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n "
}, },
"2fb8f1aef96571a6f3f6260d7836de99ff24ba1947747e08b0e8d64097507442": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz",
"Text",
"Uuid"
]
}
},
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n "
},
"360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": { "360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -1388,7 +1298,24 @@
}, },
"query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n "
}, },
"65c7600f1af07cb6ea49d89ae6fbca5374a57c5a866c8aadd7b75ed1d2d1d0cd": { "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Text",
"Text",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n "
},
"67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c": {
"describe": { "describe": {
"columns": [ "columns": [
{ {
@ -1440,41 +1367,6 @@
"name": "consumed_at", "name": "consumed_at",
"ordinal": 9, "ordinal": 9,
"type_info": "Timestamptz" "type_info": "Timestamptz"
},
{
"name": "provider_issuer",
"ordinal": 10,
"type_info": "Text"
},
{
"name": "provider_scope",
"ordinal": 11,
"type_info": "Text"
},
{
"name": "provider_client_id",
"ordinal": 12,
"type_info": "Text"
},
{
"name": "provider_encrypted_client_secret",
"ordinal": 13,
"type_info": "Text"
},
{
"name": "provider_token_endpoint_auth_method",
"ordinal": 14,
"type_info": "Text"
},
{
"name": "provider_token_endpoint_signing_alg",
"ordinal": 15,
"type_info": "Text"
},
{
"name": "provider_created_at",
"ordinal": 16,
"type_info": "Timestamptz"
} }
], ],
"nullable": [ "nullable": [
@ -1487,14 +1379,7 @@
true, true,
false, false,
true, true,
true, true
false,
false,
false,
true,
false,
true,
false
], ],
"parameters": { "parameters": {
"Left": [ "Left": [
@ -1502,7 +1387,20 @@
] ]
} }
}, },
"query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.upstream_oauth_link_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.id_token,\n ua.created_at,\n ua.completed_at,\n ua.consumed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n " "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n "
},
"689ffbfc5137ec788e89062ad679bbe6b23a8861c09a7246dc1659c28f12bf8d": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Timestamptz",
"Uuid"
]
}
},
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n "
}, },
"6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": { "6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": {
"describe": { "describe": {
@ -2420,6 +2318,21 @@
}, },
"query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n " "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.user_email_id = $2\n "
}, },
"b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz",
"Text",
"Uuid"
]
}
},
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n "
},
"bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": { "bc768c63a7737818967bc28560de714bbbd262bdf3ab73d297263bb73dcd9f5e": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -2702,19 +2615,6 @@
}, },
"query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n "
}, },
"e30562e9637d3a723a91adca6336a8d083657ce6d7fe9551fcd6a9d672453d3c": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Timestamptz",
"Uuid"
]
}
},
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n "
},
"e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": { "e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": {
"describe": { "describe": {
"columns": [], "columns": [],
@ -2773,22 +2673,5 @@
} }
}, },
"query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n "
},
"fb71ac6539039313fd90b29ac943330e54c7b62b2778727726e2f60a554f9c5a": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Text",
"Text",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n "
} }
} }

View File

@ -14,7 +14,10 @@
use sqlx::{PgConnection, Postgres, Transaction}; use sqlx::{PgConnection, Postgres, Transaction};
use crate::upstream_oauth2::{PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository}; use crate::upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
PgUpstreamOAuthSessionRepository,
};
pub trait Repository { pub trait Repository {
type UpstreamOAuthLinkRepository<'c> type UpstreamOAuthLinkRepository<'c>
@ -25,13 +28,19 @@ pub trait Repository {
where where
Self: 'c; Self: 'c;
type UpstreamOAuthSessionRepository<'c>
where
Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>;
} }
impl Repository for PgConnection { impl Repository for PgConnection {
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self) PgUpstreamOAuthLinkRepository::new(self)
@ -40,11 +49,16 @@ impl Repository for PgConnection {
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self) PgUpstreamOAuthProviderRepository::new(self)
} }
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(self)
}
} }
impl<'t> Repository for Transaction<'t, Postgres> { impl<'t> Repository for Transaction<'t, Postgres> {
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self) PgUpstreamOAuthLinkRepository::new(self)
@ -53,4 +67,8 @@ impl<'t> Repository for Transaction<'t, Postgres> {
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
PgUpstreamOAuthProviderRepository::new(self) PgUpstreamOAuthProviderRepository::new(self)
} }
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
PgUpstreamOAuthSessionRepository::new(self)
}
} }

View File

@ -19,7 +19,111 @@ mod session;
pub use self::{ pub use self::{
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository}, provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
session::{ session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository},
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
},
}; };
#[cfg(test)]
mod tests {
use oauth2_types::scope::{Scope, OPENID};
use rand::SeedableRng;
use sqlx::PgPool;
use super::*;
use crate::{Clock, Repository};
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::default();
let mut conn = pool.acquire().await?;
// The provider list should be empty at the start
let all_providers = conn.upstream_oauth_provider().all().await?;
assert!(all_providers.is_empty());
// Let's add a provider
let provider = conn
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
"https://example.com/".to_owned(),
Scope::from_iter([OPENID]),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
"client-id".to_owned(),
None,
)
.await?;
// Look it up in the database
let provider = conn
.upstream_oauth_provider()
.lookup(provider.id)
.await?
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.client_id, "client-id");
// Start a session
let session = conn
.upstream_oauth_session()
.add(
&mut rng,
&clock,
&provider,
"some-state".to_owned(),
None,
"some-nonce".to_owned(),
)
.await?;
// Look it up in the database
let session = conn
.upstream_oauth_session()
.lookup(session.id)
.await?
.expect("session to be found in the database");
assert_eq!(session.provider_id, provider.id);
assert_eq!(session.link_id, None);
assert!(!session.completed());
assert!(!session.consumed());
// Create a link
let link = conn
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.await?;
// We can look it up by its ID
conn.upstream_oauth_link()
.lookup(link.id)
.await?
.expect("link to be found in database");
// or by its subject
let link = conn
.upstream_oauth_link()
.find_by_subject(&provider, "a-subject")
.await?
.expect("link to be found in database");
assert_eq!(link.subject, "a-subject");
assert_eq!(link.provider_id, provider.id);
let session = conn
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.await?;
assert!(session.completed());
assert!(!session.consumed());
assert_eq!(session.link_id, Some(link.id));
let session = conn
.upstream_oauth_session()
.consume(&clock, session)
.await?;
assert!(session.consumed());
Ok(())
}
}

View File

@ -12,261 +12,62 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
use rand::Rng; use rand::RngCore;
use sqlx::PgExecutor; use sqlx::PgConnection;
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; use crate::{Clock, DatabaseError, LookupResultExt};
struct SessionAndProviderLookup { #[async_trait]
upstream_oauth_authorization_session_id: Uuid, pub trait UpstreamOAuthSessionRepository: Send + Sync {
upstream_oauth_provider_id: Uuid, type Error;
upstream_oauth_link_id: Option<Uuid>,
state: String, /// Lookup a session by its ID
code_challenge_verifier: Option<String>, async fn lookup(
nonce: String, &mut self,
id_token: Option<String>, id: Ulid,
created_at: DateTime<Utc>, ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
completed_at: Option<DateTime<Utc>>,
consumed_at: Option<DateTime<Utc>>, /// Add a session to the database
provider_issuer: String, async fn add(
provider_scope: String, &mut self,
provider_client_id: String, rng: &mut (dyn RngCore + Send),
provider_encrypted_client_secret: Option<String>, clock: &Clock,
provider_token_endpoint_auth_method: String, upstream_oauth_provider: &UpstreamOAuthProvider,
provider_token_endpoint_signing_alg: Option<String>, state: String,
provider_created_at: DateTime<Utc>, code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
/// Mark a session as completed and associate the given link
async fn complete_with_link(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
/// Mark a session as consumed
async fn consume(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
} }
/// Lookup a session and its provider by its ID pub struct PgUpstreamOAuthSessionRepository<'c> {
#[tracing::instrument( conn: &'c mut PgConnection,
skip_all,
fields(upstream_oauth_authorization_session.id = %id),
err,
)]
pub async fn lookup_session(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession)>, DatabaseError> {
let res = sqlx::query_as!(
SessionAndProviderLookup,
r#"
SELECT
ua.upstream_oauth_authorization_session_id,
ua.upstream_oauth_provider_id,
ua.upstream_oauth_link_id,
ua.state,
ua.code_challenge_verifier,
ua.nonce,
ua.id_token,
ua.created_at,
ua.completed_at,
ua.consumed_at,
up.issuer AS "provider_issuer",
up.scope AS "provider_scope",
up.client_id AS "provider_client_id",
up.encrypted_client_secret AS "provider_encrypted_client_secret",
up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method",
up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg",
up.created_at AS "provider_created_at"
FROM upstream_oauth_authorization_sessions ua
INNER JOIN upstream_oauth_providers up
USING (upstream_oauth_provider_id)
WHERE upstream_oauth_authorization_session_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = res.upstream_oauth_provider_id.into();
let provider = UpstreamOAuthProvider {
id,
issuer: res.provider_issuer,
scope: res.provider_scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("scope")
.row(id)
.source(e)
})?,
client_id: res.provider_client_id,
encrypted_client_secret: res.provider_encrypted_client_secret,
token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err(
|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
},
)?,
token_endpoint_signing_alg: res
.provider_token_endpoint_signing_alg
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_signing_alg")
.row(id)
.source(e)
})?,
created_at: res.provider_created_at,
};
let session = UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: provider.id,
link_id: res.upstream_oauth_link_id.map(Ulid::from),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,
id_token: res.id_token,
created_at: res.created_at,
completed_at: res.completed_at,
consumed_at: res.consumed_at,
};
Ok(Some((provider, session)))
} }
/// Add a session to the database impl<'c> PgUpstreamOAuthSessionRepository<'c> {
#[tracing::instrument( pub fn new(conn: &'c mut PgConnection) -> Self {
skip_all, Self { conn }
fields( }
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),
err,
)]
pub async fn add_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record(
"upstream_oauth_authorization_session.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO upstream_oauth_authorization_sessions (
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state,
code_challenge_verifier.as_deref(),
nonce,
created_at,
)
.execute(executor)
.await?;
Ok(UpstreamOAuthAuthorizationSession {
id,
provider_id: upstream_oauth_provider.id,
link_id: None,
state,
code_challenge_verifier,
nonce,
id_token: None,
created_at,
completed_at: None,
consumed_at: None,
})
}
/// Mark a session as completed and associate the given link
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
%upstream_oauth_link.id,
),
err,
)]
pub async fn complete_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(executor)
.await?;
upstream_oauth_authorization_session.completed_at = Some(completed_at);
upstream_oauth_authorization_session.id_token = id_token;
Ok(upstream_oauth_authorization_session)
}
/// Mark a session as consumed
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
),
err,
)]
pub async fn consume_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET consumed_at = $1
WHERE upstream_oauth_authorization_session_id = $2
"#,
consumed_at,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(executor)
.await?;
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
Ok(upstream_oauth_authorization_session)
} }
struct SessionLookup { struct SessionLookup {
@ -282,57 +83,191 @@ struct SessionLookup {
consumed_at: Option<DateTime<Utc>>, consumed_at: Option<DateTime<Utc>>,
} }
/// Lookup a session, which belongs to a link, by its ID #[async_trait]
#[tracing::instrument( impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
skip_all, type Error = DatabaseError;
fields(
upstream_oauth_authorization_session.id = %id,
%upstream_oauth_link.id,
),
err,
)]
pub async fn lookup_session_on_link(
executor: impl PgExecutor<'_>,
upstream_oauth_link: &UpstreamOAuthLink,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, sqlx::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
id_token,
created_at,
completed_at,
consumed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
AND upstream_oauth_link_id = $2
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_link.id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) }; #[tracing::instrument(
skip_all,
fields(upstream_oauth_provider.id = %id),
err,
)]
async fn lookup(
&mut self,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
id_token,
created_at,
completed_at,
consumed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
Ok(Some(UpstreamOAuthAuthorizationSession { let Some(res) = res else { return Ok(None) };
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: res.upstream_oauth_provider_id.into(), let session = UpstreamOAuthAuthorizationSession {
link_id: res.upstream_oauth_link_id.map(Ulid::from), id: res.upstream_oauth_authorization_session_id.into(),
state: res.state, provider_id: res.upstream_oauth_provider_id.into(),
code_challenge_verifier: res.code_challenge_verifier, link_id: res.upstream_oauth_link_id.map(Ulid::from),
nonce: res.nonce, state: res.state,
id_token: res.id_token, code_challenge_verifier: res.code_challenge_verifier,
created_at: res.created_at, nonce: res.nonce,
completed_at: res.completed_at, id_token: res.id_token,
consumed_at: res.consumed_at, created_at: res.created_at,
})) completed_at: res.completed_at,
consumed_at: res.consumed_at,
};
Ok(Some(session))
}
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"upstream_oauth_authorization_session.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO upstream_oauth_authorization_sessions (
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state,
code_challenge_verifier.as_deref(),
nonce,
created_at,
)
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthAuthorizationSession {
id,
provider_id: upstream_oauth_provider.id,
link_id: None,
state,
code_challenge_verifier,
nonce,
id_token: None,
created_at,
completed_at: None,
consumed_at: None,
})
}
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
%upstream_oauth_link.id,
),
err,
)]
async fn complete_with_link(
&mut self,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(&mut *self.conn)
.await?;
upstream_oauth_authorization_session.completed_at = Some(completed_at);
upstream_oauth_authorization_session.id_token = id_token;
upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id);
Ok(upstream_oauth_authorization_session)
}
/// Mark a session as consumed
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET consumed_at = $1
WHERE upstream_oauth_authorization_session_id = $2
"#,
consumed_at,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(&mut *self.conn)
.await?;
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
Ok(upstream_oauth_authorization_session)
}
} }