1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

storage: do less joins on authorization grants and refresh tokens

This commit is contained in:
Quentin Gliech
2023-01-05 16:49:19 +01:00
parent 603a26eabd
commit fb7c6f4dd1
7 changed files with 140 additions and 149 deletions

View File

@ -63,11 +63,11 @@ pub enum AuthorizationGrantStage {
#[default] #[default]
Pending, Pending,
Fulfilled { Fulfilled {
session: Session, session_id: Ulid,
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
}, },
Exchanged { Exchanged {
session: Session, session_id: Ulid,
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
exchanged_at: DateTime<Utc>, exchanged_at: DateTime<Utc>,
}, },
@ -85,12 +85,12 @@ impl AuthorizationGrantStage {
pub fn fulfill( pub fn fulfill(
self, self,
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
session: Session, session: &Session,
) -> Result<Self, InvalidTransitionError> { ) -> Result<Self, InvalidTransitionError> {
match self { match self {
Self::Pending => Ok(Self::Fulfilled { Self::Pending => Ok(Self::Fulfilled {
fulfilled_at, fulfilled_at,
session, session_id: session.id,
}), }),
_ => Err(InvalidTransitionError), _ => Err(InvalidTransitionError),
} }
@ -100,11 +100,11 @@ impl AuthorizationGrantStage {
match self { match self {
Self::Fulfilled { Self::Fulfilled {
fulfilled_at, fulfilled_at,
session, session_id,
} => Ok(Self::Exchanged { } => Ok(Self::Exchanged {
fulfilled_at, fulfilled_at,
exchanged_at, exchanged_at,
session, session_id,
}), }),
_ => Err(InvalidTransitionError), _ => Err(InvalidTransitionError),
} }

View File

@ -33,7 +33,7 @@ pub struct RefreshToken {
pub id: Ulid, pub id: Ulid,
pub refresh_token: String, pub refresh_token: String,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub access_token: Option<AccessToken>, pub access_token_id: Option<Ulid>,
} }
/// Type of token to generate or validate /// Type of token to generate or validate

View File

@ -109,12 +109,18 @@ pub(crate) enum RouteError {
#[error("failed to load browser session")] #[error("failed to load browser session")]
NoSuchBrowserSession, NoSuchBrowserSession,
#[error("failed to load oauth session")]
NoSuchOAuthSession,
} }
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
Self::Internal(_) | Self::InvalidSigningKey | Self::NoSuchBrowserSession => ( Self::Internal(_)
| Self::InvalidSigningKey
| Self::NoSuchBrowserSession
| Self::NoSuchOAuthSession => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)), Json(ClientError::from(ClientErrorCode::ServerError)),
), ),
@ -219,7 +225,7 @@ async fn authorization_code_grant(
let now = clock.now(); let now = clock.now();
let session = match authz_grant.stage { let session_id = match authz_grant.stage {
AuthorizationGrantStage::Cancelled { cancelled_at } => { AuthorizationGrantStage::Cancelled { cancelled_at } => {
debug!(%cancelled_at, "Authorization grant was cancelled"); debug!(%cancelled_at, "Authorization grant was cancelled");
return Err(RouteError::InvalidGrant); return Err(RouteError::InvalidGrant);
@ -227,13 +233,18 @@ async fn authorization_code_grant(
AuthorizationGrantStage::Exchanged { AuthorizationGrantStage::Exchanged {
exchanged_at, exchanged_at,
fulfilled_at, fulfilled_at,
session, session_id,
} => { } => {
debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged"); debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
// Ending the session if the token was already exchanged more than 20s ago // Ending the session if the token was already exchanged more than 20s ago
if now - exchanged_at > Duration::seconds(20) { if now - exchanged_at > Duration::seconds(20) {
debug!("Ending potentially compromised session"); debug!("Ending potentially compromised session");
let session = txn
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
txn.oauth2_session().finish(&clock, session).await?; txn.oauth2_session().finish(&clock, session).await?;
txn.commit().await?; txn.commit().await?;
} }
@ -245,7 +256,7 @@ async fn authorization_code_grant(
return Err(RouteError::InvalidGrant); return Err(RouteError::InvalidGrant);
} }
AuthorizationGrantStage::Fulfilled { AuthorizationGrantStage::Fulfilled {
ref session, session_id,
fulfilled_at, fulfilled_at,
} => { } => {
if now - fulfilled_at > Duration::minutes(10) { if now - fulfilled_at > Duration::minutes(10) {
@ -253,10 +264,16 @@ async fn authorization_code_grant(
return Err(RouteError::InvalidGrant); return Err(RouteError::InvalidGrant);
} }
session session_id
} }
}; };
let session = txn
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
// This should never happen, since we looked up in the database using the code // This should never happen, since we looked up in the database using the code
let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?; let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
@ -284,23 +301,16 @@ async fn authorization_code_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng); let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let access_token = add_access_token( let access_token =
&mut txn, add_access_token(&mut txn, &mut rng, &clock, &session, access_token_str, ttl).await?;
&mut rng,
&clock,
session,
access_token_str.clone(),
ttl,
)
.await?;
let _refresh_token = add_refresh_token( let refresh_token = add_refresh_token(
&mut txn, &mut txn,
&mut rng, &mut rng,
&clock, &clock,
session, &session,
access_token, &access_token,
refresh_token_str.clone(), refresh_token_str,
) )
.await?; .await?;
@ -328,7 +338,7 @@ async fn authorization_code_grant(
.signing_key_for_algorithm(&alg) .signing_key_for_algorithm(&alg)
.ok_or(RouteError::InvalidSigningKey)?; .ok_or(RouteError::InvalidSigningKey)?;
claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token_str)?)?; claims::AT_HASH.insert(&mut claims, hash_token(&alg, &access_token.access_token)?)?;
claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?; claims::C_HASH.insert(&mut claims, hash_token(&alg, &grant.code)?)?;
let signer = key.params().signing_key_for_alg(&alg)?; let signer = key.params().signing_key_for_alg(&alg)?;
@ -341,9 +351,9 @@ async fn authorization_code_grant(
None None
}; };
let mut params = AccessTokenResponse::new(access_token_str) let mut params = AccessTokenResponse::new(access_token.access_token)
.with_expires_in(ttl) .with_expires_in(ttl)
.with_refresh_token(refresh_token_str) .with_refresh_token(refresh_token.refresh_token)
.with_scope(session.scope.clone()); .with_scope(session.scope.clone());
if let Some(id_token) = id_token { if let Some(id_token) = id_token {
@ -392,15 +402,15 @@ async fn refresh_token_grant(
&mut rng, &mut rng,
&clock, &clock,
&session, &session,
new_access_token, &new_access_token,
refresh_token_str, refresh_token_str,
) )
.await?; .await?;
consume_refresh_token(&mut txn, &clock, &refresh_token).await?; consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
if let Some(access_token) = refresh_token.access_token { if let Some(access_token_id) = refresh_token.access_token_id {
revoke_access_token(&mut txn, &clock, access_token).await?; revoke_access_token(&mut txn, &clock, access_token_id).await?;
} }
let params = AccessTokenResponse::new(access_token_str) let params = AccessTokenResponse::new(access_token_str)

View File

@ -142,13 +142,13 @@ pub async fn lookup_active_access_token(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%access_token.id), fields(access_token.id = %access_token_id),
err, err,
)] )]
pub async fn revoke_access_token( pub async fn revoke_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
access_token: AccessToken, access_token_id: Ulid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let revoked_at = clock.now(); let revoked_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
@ -157,7 +157,7 @@ pub async fn revoke_access_token(
SET revoked_at = $2 SET revoked_at = $2
WHERE oauth2_access_token_id = $1 WHERE oauth2_access_token_id = $1
"#, "#,
Uuid::from(access_token.id), Uuid::from(access_token_id),
revoked_at, revoked_at,
) )
.execute(executor) .execute(executor)

View File

@ -149,7 +149,6 @@ struct GrantLookup {
oauth2_authorization_grant_requires_consent: bool, oauth2_authorization_grant_requires_consent: bool,
oauth2_client_id: Uuid, oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>, oauth2_session_id: Option<Uuid>,
user_session_id: Option<Uuid>,
} }
impl GrantLookup { impl GrantLookup {
@ -176,45 +175,22 @@ impl GrantLookup {
.row(id) .row(id)
})?; })?;
let session = match (self.oauth2_session_id, self.user_session_id) {
(Some(session_id), Some(user_session_id)) => {
let scope = scope.clone();
let session = Session {
id: session_id.into(),
client_id: client.id,
user_session_id: user_session_id.into(),
scope,
finished_at: None,
};
Some(session)
}
(None, None) => None,
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("oauth2_session_id")
.row(id)
.into(),
)
}
};
let stage = match ( let stage = match (
self.oauth2_authorization_grant_fulfilled_at, self.oauth2_authorization_grant_fulfilled_at,
self.oauth2_authorization_grant_exchanged_at, self.oauth2_authorization_grant_exchanged_at,
self.oauth2_authorization_grant_cancelled_at, self.oauth2_authorization_grant_cancelled_at,
session, self.oauth2_session_id,
) { ) {
(None, None, None, None) => AuthorizationGrantStage::Pending, (None, None, None, None) => AuthorizationGrantStage::Pending,
(Some(fulfilled_at), None, None, Some(session)) => AuthorizationGrantStage::Fulfilled { (Some(fulfilled_at), None, None, Some(session_id)) => {
session, AuthorizationGrantStage::Fulfilled {
fulfilled_at, session_id: session_id.into(),
}, fulfilled_at,
(Some(fulfilled_at), Some(exchanged_at), None, Some(session)) => { }
}
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
AuthorizationGrantStage::Exchanged { AuthorizationGrantStage::Exchanged {
session, session_id: session_id.into(),
fulfilled_at, fulfilled_at,
exchanged_at, exchanged_at,
} }
@ -343,32 +319,29 @@ pub async fn get_grant_by_id(
let res = sqlx::query_as!( let res = sqlx::query_as!(
GrantLookup, GrantLookup,
r#" r#"
SELECT og.oauth2_authorization_grant_id SELECT oauth2_authorization_grant_id
, og.created_at AS oauth2_authorization_grant_created_at , created_at AS oauth2_authorization_grant_created_at
, og.cancelled_at AS oauth2_authorization_grant_cancelled_at , cancelled_at AS oauth2_authorization_grant_cancelled_at
, og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at , fulfilled_at AS oauth2_authorization_grant_fulfilled_at
, og.exchanged_at AS oauth2_authorization_grant_exchanged_at , exchanged_at AS oauth2_authorization_grant_exchanged_at
, og.scope AS oauth2_authorization_grant_scope , scope AS oauth2_authorization_grant_scope
, og.state AS oauth2_authorization_grant_state , state AS oauth2_authorization_grant_state
, og.redirect_uri AS oauth2_authorization_grant_redirect_uri , redirect_uri AS oauth2_authorization_grant_redirect_uri
, og.response_mode AS oauth2_authorization_grant_response_mode , response_mode AS oauth2_authorization_grant_response_mode
, og.nonce AS oauth2_authorization_grant_nonce , nonce AS oauth2_authorization_grant_nonce
, og.max_age AS oauth2_authorization_grant_max_age , max_age AS oauth2_authorization_grant_max_age
, og.oauth2_client_id AS oauth2_client_id , oauth2_client_id AS oauth2_client_id
, og.authorization_code AS oauth2_authorization_grant_code , authorization_code AS oauth2_authorization_grant_code
, og.response_type_code AS oauth2_authorization_grant_response_type_code , response_type_code AS oauth2_authorization_grant_response_type_code
, og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token , response_type_id_token AS oauth2_authorization_grant_response_type_id_token
, og.code_challenge AS oauth2_authorization_grant_code_challenge , code_challenge AS oauth2_authorization_grant_code_challenge
, og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method , code_challenge_method AS oauth2_authorization_grant_code_challenge_method
, og.requires_consent AS oauth2_authorization_grant_requires_consent , requires_consent AS oauth2_authorization_grant_requires_consent
, os.oauth2_session_id AS "oauth2_session_id?" , oauth2_session_id AS "oauth2_session_id?"
, os.user_session_id AS "user_session_id?"
FROM FROM
oauth2_authorization_grants og oauth2_authorization_grants
LEFT JOIN oauth2_sessions os
USING (oauth2_session_id)
WHERE og.oauth2_authorization_grant_id = $1 WHERE oauth2_authorization_grant_id = $1
"#, "#,
Uuid::from(id), Uuid::from(id),
) )
@ -391,32 +364,29 @@ pub async fn lookup_grant_by_code(
let res = sqlx::query_as!( let res = sqlx::query_as!(
GrantLookup, GrantLookup,
r#" r#"
SELECT og.oauth2_authorization_grant_id SELECT oauth2_authorization_grant_id
, og.created_at AS oauth2_authorization_grant_created_at , created_at AS oauth2_authorization_grant_created_at
, og.cancelled_at AS oauth2_authorization_grant_cancelled_at , cancelled_at AS oauth2_authorization_grant_cancelled_at
, og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at , fulfilled_at AS oauth2_authorization_grant_fulfilled_at
, og.exchanged_at AS oauth2_authorization_grant_exchanged_at , exchanged_at AS oauth2_authorization_grant_exchanged_at
, og.scope AS oauth2_authorization_grant_scope , scope AS oauth2_authorization_grant_scope
, og.state AS oauth2_authorization_grant_state , state AS oauth2_authorization_grant_state
, og.redirect_uri AS oauth2_authorization_grant_redirect_uri , redirect_uri AS oauth2_authorization_grant_redirect_uri
, og.response_mode AS oauth2_authorization_grant_response_mode , response_mode AS oauth2_authorization_grant_response_mode
, og.nonce AS oauth2_authorization_grant_nonce , nonce AS oauth2_authorization_grant_nonce
, og.max_age AS oauth2_authorization_grant_max_age , max_age AS oauth2_authorization_grant_max_age
, og.oauth2_client_id AS oauth2_client_id , oauth2_client_id AS oauth2_client_id
, og.authorization_code AS oauth2_authorization_grant_code , authorization_code AS oauth2_authorization_grant_code
, og.response_type_code AS oauth2_authorization_grant_response_type_code , response_type_code AS oauth2_authorization_grant_response_type_code
, og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token , response_type_id_token AS oauth2_authorization_grant_response_type_id_token
, og.code_challenge AS oauth2_authorization_grant_code_challenge , code_challenge AS oauth2_authorization_grant_code_challenge
, og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method , code_challenge_method AS oauth2_authorization_grant_code_challenge_method
, og.requires_consent AS oauth2_authorization_grant_requires_consent , requires_consent AS oauth2_authorization_grant_requires_consent
, os.oauth2_session_id AS "oauth2_session_id?" , oauth2_session_id AS "oauth2_session_id?"
, os.user_session_id AS "user_session_id?"
FROM FROM
oauth2_authorization_grants og oauth2_authorization_grants
LEFT JOIN oauth2_sessions os
USING (oauth2_session_id)
WHERE og.authorization_code = $1 WHERE authorization_code = $1
"#, "#,
code, code,
) )
@ -466,7 +436,7 @@ pub async fn fulfill_grant(
grant.stage = grant grant.stage = grant
.stage .stage
.fulfill(fulfilled_at, session) .fulfill(fulfilled_at, &session)
.map_err(DatabaseError::to_invalid_operation)?; .map_err(DatabaseError::to_invalid_operation)?;
Ok(grant) Ok(grant)

View File

@ -36,7 +36,7 @@ pub async fn add_refresh_token(
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
session: &Session, session: &Session,
access_token: AccessToken, access_token: &AccessToken,
refresh_token: String, refresh_token: String,
) -> Result<RefreshToken, sqlx::Error> { ) -> Result<RefreshToken, sqlx::Error> {
let created_at = clock.now(); let created_at = clock.now();
@ -63,7 +63,7 @@ pub async fn add_refresh_token(
Ok(RefreshToken { Ok(RefreshToken {
id, id,
refresh_token, refresh_token,
access_token: Some(access_token), access_token_id: Some(access_token.id),
created_at, created_at,
}) })
} }
@ -73,9 +73,6 @@ struct OAuth2RefreshTokenLookup {
oauth2_refresh_token: String, oauth2_refresh_token: String,
oauth2_refresh_token_created_at: DateTime<Utc>, oauth2_refresh_token_created_at: DateTime<Utc>,
oauth2_access_token_id: Option<Uuid>, oauth2_access_token_id: Option<Uuid>,
oauth2_access_token: Option<String>,
oauth2_access_token_created_at: Option<DateTime<Utc>>,
oauth2_access_token_expires_at: Option<DateTime<Utc>>,
oauth2_session_id: Uuid, oauth2_session_id: Uuid,
oauth2_client_id: Uuid, oauth2_client_id: Uuid,
oauth2_session_scope: String, oauth2_session_scope: String,
@ -94,10 +91,7 @@ pub async fn lookup_active_refresh_token(
SELECT rt.oauth2_refresh_token_id SELECT rt.oauth2_refresh_token_id
, rt.refresh_token AS oauth2_refresh_token , rt.refresh_token AS oauth2_refresh_token
, rt.created_at AS oauth2_refresh_token_created_at , rt.created_at AS oauth2_refresh_token_created_at
, at.oauth2_access_token_id AS "oauth2_access_token_id?" , rt.oauth2_access_token_id AS "oauth2_access_token_id?"
, at.access_token AS "oauth2_access_token?"
, at.created_at AS "oauth2_access_token_created_at?"
, at.expires_at AS "oauth2_access_token_expires_at?"
, os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_session_id AS "oauth2_session_id!"
, os.oauth2_client_id AS "oauth2_client_id!" , os.oauth2_client_id AS "oauth2_client_id!"
, os.scope AS "oauth2_session_scope!" , os.scope AS "oauth2_session_scope!"
@ -105,8 +99,6 @@ pub async fn lookup_active_refresh_token(
FROM oauth2_refresh_tokens rt FROM oauth2_refresh_tokens rt
INNER JOIN oauth2_sessions os INNER JOIN oauth2_sessions os
USING (oauth2_session_id) USING (oauth2_session_id)
LEFT JOIN oauth2_access_tokens at
USING (oauth2_access_token_id)
WHERE rt.refresh_token = $1 WHERE rt.refresh_token = $1
AND rt.consumed_at IS NULL AND rt.consumed_at IS NULL
@ -118,31 +110,11 @@ pub async fn lookup_active_refresh_token(
.fetch_one(&mut *conn) .fetch_one(&mut *conn)
.await?; .await?;
let access_token = match (
res.oauth2_access_token_id,
res.oauth2_access_token,
res.oauth2_access_token_created_at,
res.oauth2_access_token_expires_at,
) {
(None, None, None, None) => None,
(Some(id), Some(access_token), Some(created_at), Some(expires_at)) => {
let id = Ulid::from(id);
Some(AccessToken {
id,
jti: id.to_string(),
access_token,
created_at,
expires_at,
})
}
_ => return Err(DatabaseInconsistencyError::on("oauth2_access_tokens").into()),
};
let refresh_token = RefreshToken { let refresh_token = RefreshToken {
id: res.oauth2_refresh_token_id.into(), id: res.oauth2_refresh_token_id.into(),
refresh_token: res.oauth2_refresh_token, refresh_token: res.oauth2_refresh_token,
created_at: res.oauth2_refresh_token_created_at, created_at: res.oauth2_refresh_token_created_at,
access_token, access_token_id: res.oauth2_access_token_id.map(Ulid::from),
}; };
let session_id = res.oauth2_session_id.into(); let session_id = res.oauth2_session_id.into();

View File

@ -23,13 +23,15 @@ use uuid::Uuid;
use crate::{ use crate::{
pagination::{process_page, Page, QueryBuilderExt}, pagination::{process_page, Page, QueryBuilderExt},
tracing::ExecuteExt, tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
}; };
#[async_trait] #[async_trait]
pub trait OAuth2SessionRepository { pub trait OAuth2SessionRepository {
type Error; type Error;
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
async fn create_from_grant( async fn create_from_grant(
&mut self, &mut self,
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
@ -66,6 +68,8 @@ struct OAuthSessionLookup {
user_session_id: Uuid, user_session_id: Uuid,
oauth2_client_id: Uuid, oauth2_client_id: Uuid,
scope: String, scope: String,
#[allow(dead_code)]
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>, finished_at: Option<DateTime<Utc>>,
} }
@ -95,6 +99,41 @@ impl TryFrom<OAuthSessionLookup> for Session {
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
type Error = DatabaseError; type Error = DatabaseError;
#[tracing::instrument(
name = "db.oauth2_session.lookup",
skip_all,
fields(
db.statement,
session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
let res = sqlx::query_as!(
OAuthSessionLookup,
r#"
SELECT oauth2_session_id
, user_session_id
, oauth2_client_id
, scope
, created_at
, finished_at
FROM oauth2_sessions
WHERE oauth2_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(session) = res else { return Ok(None) };
Ok(Some(session.try_into()?))
}
#[tracing::instrument( #[tracing::instrument(
name = "db.oauth2_session.create_from_grant", name = "db.oauth2_session.create_from_grant",
skip_all, skip_all,