You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
storage: upstream oauth session repository + unit tests
This commit is contained in:
@@ -22,7 +22,10 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
|
||||
Repository,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
@@ -97,16 +100,17 @@ pub(crate) async fn get(
|
||||
&mut rng,
|
||||
)?;
|
||||
|
||||
let session = mas_storage::upstream_oauth2::add_session(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
data.state.clone(),
|
||||
data.code_challenge_verifier,
|
||||
data.nonce,
|
||||
)
|
||||
.await?;
|
||||
let session = txn
|
||||
.upstream_oauth_session()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
data.state.clone(),
|
||||
data.code_challenge_verifier,
|
||||
data.nonce,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
||||
.add(session.id, provider.id, data.state, query.post_auth_action)
|
||||
|
||||
@@ -26,7 +26,7 @@ use mas_oidc_client::requests::{
|
||||
};
|
||||
use mas_router::{Route, UrlBuilder};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{complete_session, lookup_session},
|
||||
upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
|
||||
Repository, UpstreamOAuthLinkRepository,
|
||||
};
|
||||
use oauth2_types::errors::ClientErrorCode;
|
||||
@@ -65,6 +65,9 @@ pub(crate) enum RouteError {
|
||||
#[error("Session not found")]
|
||||
SessionNotFound,
|
||||
|
||||
#[error("Provider not found")]
|
||||
ProviderNotFound,
|
||||
|
||||
#[error("Provider mismatch")]
|
||||
ProviderMismatch,
|
||||
|
||||
@@ -105,6 +108,7 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
||||
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||
@@ -127,16 +131,24 @@ pub(crate) async fn get(
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let provider = txn
|
||||
.upstream_oauth_provider()
|
||||
.lookup(provider_id)
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let (session_id, _post_auth_action) = sessions_cookie
|
||||
.find_session(provider_id, ¶ms.state)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let (provider, session) = lookup_session(&mut txn, session_id)
|
||||
let session = txn
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
if provider.id != provider_id {
|
||||
if provider.id != session.provider_id {
|
||||
// The provider in the session cookie should match the one from the URL
|
||||
return Err(RouteError::ProviderMismatch);
|
||||
}
|
||||
@@ -245,7 +257,11 @@ pub(crate) async fn get(
|
||||
.await?
|
||||
};
|
||||
|
||||
let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?;
|
||||
let session = txn
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, response.id_token)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.add_link_to_session(session.id, link.id)?
|
||||
.save(cookie_jar, clock.now());
|
||||
|
||||
@@ -25,7 +25,7 @@ use mas_axum_utils::{
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{consume_session, lookup_session_on_link},
|
||||
upstream_oauth2::UpstreamOAuthSessionRepository,
|
||||
user::{add_user, authenticate_session_with_upstream, lookup_user, start_session},
|
||||
Repository, UpstreamOAuthLinkRepository,
|
||||
};
|
||||
@@ -109,12 +109,18 @@ pub(crate) async fn get(
|
||||
.await?
|
||||
.ok_or(RouteError::LinkNotFound)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
||||
let upstream_session = txn
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
if upstream_session.link_id != Some(link.id) {
|
||||
return Err(RouteError::SessionNotFound);
|
||||
}
|
||||
|
||||
if upstream_session.consumed() {
|
||||
return Err(RouteError::SessionConsumed);
|
||||
}
|
||||
@@ -127,7 +133,10 @@ pub(crate) async fn get(
|
||||
(Some(mut session), Some(user_id)) if session.user.id == user_id => {
|
||||
// Session already linked, and link matches the currently logged
|
||||
// user. Mark the session as consumed and renew the authentication.
|
||||
consume_session(&mut txn, &clock, upstream_session).await?;
|
||||
txn.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link)
|
||||
.await?;
|
||||
|
||||
@@ -212,12 +221,18 @@ pub(crate) async fn post(
|
||||
.await?
|
||||
.ok_or(RouteError::LinkNotFound)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
||||
let upstream_session = txn
|
||||
.upstream_oauth_session()
|
||||
.lookup(session_id)
|
||||
.await?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
if upstream_session.link_id != Some(link.id) {
|
||||
return Err(RouteError::SessionNotFound);
|
||||
}
|
||||
|
||||
if upstream_session.consumed() {
|
||||
return Err(RouteError::SessionConsumed);
|
||||
}
|
||||
@@ -251,7 +266,10 @@ pub(crate) async fn post(
|
||||
_ => return Err(RouteError::InvalidFormAction),
|
||||
};
|
||||
|
||||
consume_session(&mut txn, &clock, upstream_session).await?;
|
||||
txn.upstream_oauth_session()
|
||||
.consume(&clock, upstream_session)
|
||||
.await?;
|
||||
|
||||
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?;
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
|
||||
Reference in New Issue
Block a user