1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

data-model: have more structs use a state machine

This commit is contained in:
Quentin Gliech
2023-01-09 18:02:32 +01:00
parent 39cd9a2578
commit 35787aa072
21 changed files with 1148 additions and 621 deletions

View File

@ -85,9 +85,10 @@ mod tests {
.await?
.expect("session to be found in the database");
assert_eq!(session.provider_id, provider.id);
assert_eq!(session.link_id, None);
assert!(!session.completed());
assert!(!session.consumed());
assert_eq!(session.link_id(), None);
assert!(session.is_pending());
assert!(!session.is_completed());
assert!(!session.is_consumed());
// Create a link
let link = conn
@ -114,15 +115,15 @@ mod tests {
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.await?;
assert!(session.completed());
assert!(!session.consumed());
assert_eq!(session.link_id, Some(link.id));
assert!(session.is_completed());
assert!(!session.is_consumed());
assert_eq!(session.link_id(), Some(link.id));
let session = conn
.upstream_oauth_session()
.consume(&clock, session)
.await?;
assert!(session.consumed());
assert!(session.is_consumed());
Ok(())
}

View File

@ -14,13 +14,18 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
use mas_data_model::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
UpstreamOAuthProvider,
};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[async_trait]
pub trait UpstreamOAuthSessionRepository: Send + Sync {
@ -83,6 +88,52 @@ struct SessionLookup {
consumed_at: Option<DateTime<Utc>>,
}
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_authorization_session_id.into();
let state = match (
value.upstream_oauth_link_id,
value.id_token,
value.completed_at,
value.consumed_at,
) {
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(Some(link_id), id_token, Some(completed_at), None) => {
UpstreamOAuthAuthorizationSessionState::Completed {
completed_at,
link_id: link_id.into(),
id_token,
}
}
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
UpstreamOAuthAuthorizationSessionState::Consumed {
completed_at,
link_id: link_id.into(),
id_token,
consumed_at,
}
}
_ => {
return Err(
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
)
}
};
Ok(Self {
id,
provider_id: value.upstream_oauth_provider_id.into(),
state_str: value.state,
nonce: value.nonce,
code_challenge_verifier: value.code_challenge_verifier,
created_at: value.created_at,
state,
})
}
}
#[async_trait]
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
type Error = DatabaseError;
@ -126,20 +177,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
let Some(res) = res else { return Ok(None) };
let session = UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: res.upstream_oauth_provider_id.into(),
link_id: res.upstream_oauth_link_id.map(Ulid::from),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,
id_token: res.id_token,
created_at: res.created_at,
completed_at: res.completed_at,
consumed_at: res.consumed_at,
};
Ok(Some(session))
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
@ -159,7 +197,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state: String,
state_str: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
@ -186,7 +224,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state,
&state_str,
code_challenge_verifier.as_deref(),
nonce,
created_at,
@ -197,15 +235,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
Ok(UpstreamOAuthAuthorizationSession {
id,
state: UpstreamOAuthAuthorizationSessionState::default(),
provider_id: upstream_oauth_provider.id,
link_id: None,
state,
state_str,
code_challenge_verifier,
nonce,
id_token: None,
created_at,
completed_at: None,
consumed_at: None,
})
}
@ -222,11 +257,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
async fn complete_with_link(
&mut self,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
@ -244,9 +280,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
.execute(&mut *self.conn)
.await?;
upstream_oauth_authorization_session.completed_at = Some(completed_at);
upstream_oauth_authorization_session.id_token = id_token;
upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id);
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.complete(completed_at, upstream_oauth_link, id_token)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}
@ -264,7 +300,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
async fn consume(
&mut self,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let consumed_at = clock.now();
sqlx::query!(
@ -280,7 +316,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
.execute(&mut *self.conn)
.await?;
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)
}