You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
data-model: have more structs use a state machine
This commit is contained in:
@ -14,8 +14,8 @@
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,
|
||||
Device, User,
|
||||
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin,
|
||||
CompatSsoLoginState, Device, User,
|
||||
};
|
||||
use rand::Rng;
|
||||
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
|
||||
@ -93,12 +93,17 @@ pub async fn lookup_active_compat_access_token(
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match res.compat_session_finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: res.user_id.into(),
|
||||
device,
|
||||
created_at: res.compat_session_created_at,
|
||||
finished_at: res.compat_session_finished_at,
|
||||
};
|
||||
|
||||
Ok(Some((token, session)))
|
||||
@ -181,12 +186,17 @@ pub async fn lookup_active_compat_refresh_token(
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match res.compat_session_finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: res.user_id.into(),
|
||||
device,
|
||||
created_at: res.compat_session_created_at,
|
||||
finished_at: res.compat_session_finished_at,
|
||||
};
|
||||
|
||||
Ok(Some((refresh_token, access_token, session)))
|
||||
@ -468,12 +478,18 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
Some(CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: user_id.into(),
|
||||
device,
|
||||
created_at,
|
||||
finished_at,
|
||||
})
|
||||
}
|
||||
(None, None, None, None, None) => None,
|
||||
@ -686,10 +702,10 @@ pub async fn start_compat_session(
|
||||
|
||||
Ok(CompatSession {
|
||||
id,
|
||||
state: CompatSessionState::default(),
|
||||
user_id: user.id,
|
||||
device,
|
||||
created_at,
|
||||
finished_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
@ -709,7 +725,7 @@ pub async fn fullfill_compat_sso_login(
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
mut compat_sso_login: CompatSsoLogin,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
device: Device,
|
||||
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||
if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) {
|
||||
@ -719,8 +735,12 @@ pub async fn fullfill_compat_sso_login(
|
||||
let mut txn = conn.begin().await?;
|
||||
|
||||
let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?;
|
||||
let session_id = session.id;
|
||||
|
||||
let fulfilled_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.fulfill(fulfilled_at, session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
@ -731,20 +751,13 @@ pub async fn fullfill_compat_sso_login(
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
Uuid::from(session.id),
|
||||
Uuid::from(session_id),
|
||||
fulfilled_at,
|
||||
)
|
||||
.execute(&mut txn)
|
||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||
.await?;
|
||||
|
||||
let state = CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session,
|
||||
};
|
||||
|
||||
compat_sso_login.state = state;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
@ -761,13 +774,13 @@ pub async fn fullfill_compat_sso_login(
|
||||
pub async fn mark_compat_sso_login_as_exchanged(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut compat_sso_login: CompatSsoLogin,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||
let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else {
|
||||
return Err(DatabaseError::invalid_operation());
|
||||
};
|
||||
|
||||
let exchanged_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.exchange(exchanged_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
@ -783,11 +796,5 @@ pub async fn mark_compat_sso_login_as_exchanged(
|
||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||
.await?;
|
||||
|
||||
let state = CompatSsoLoginState::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session,
|
||||
};
|
||||
compat_sso_login.state = state;
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{AccessToken, Session};
|
||||
use mas_data_model::{AccessToken, Session, SessionState};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
@ -76,6 +76,7 @@ pub struct OAuth2AccessTokenLookup {
|
||||
oauth2_access_token: String,
|
||||
oauth2_access_token_created_at: DateTime<Utc>,
|
||||
oauth2_access_token_expires_at: DateTime<Utc>,
|
||||
oauth2_session_created_at: DateTime<Utc>,
|
||||
oauth2_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
scope: String,
|
||||
@ -94,6 +95,7 @@ pub async fn lookup_active_access_token(
|
||||
, at.access_token AS "oauth2_access_token"
|
||||
, at.created_at AS "oauth2_access_token_created_at"
|
||||
, at.expires_at AS "oauth2_access_token_expires_at"
|
||||
, os.created_at AS "oauth2_session_created_at"
|
||||
, os.oauth2_session_id AS "oauth2_session_id!"
|
||||
, os.oauth2_client_id AS "oauth2_client_id!"
|
||||
, os.scope AS "scope!"
|
||||
@ -131,10 +133,11 @@ pub async fn lookup_active_access_token(
|
||||
|
||||
let session = Session {
|
||||
id: session_id,
|
||||
state: SessionState::Valid,
|
||||
created_at: res.oauth2_session_created_at,
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
finished_at: None,
|
||||
};
|
||||
|
||||
Ok(Some((access_token, session)))
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AccessToken, RefreshToken, Session};
|
||||
use mas_data_model::{AccessToken, RefreshToken, Session, SessionState};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
@ -73,6 +73,7 @@ struct OAuth2RefreshTokenLookup {
|
||||
oauth2_refresh_token: String,
|
||||
oauth2_refresh_token_created_at: DateTime<Utc>,
|
||||
oauth2_access_token_id: Option<Uuid>,
|
||||
oauth2_session_created_at: DateTime<Utc>,
|
||||
oauth2_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
oauth2_session_scope: String,
|
||||
@ -92,6 +93,7 @@ pub async fn lookup_active_refresh_token(
|
||||
, rt.refresh_token AS oauth2_refresh_token
|
||||
, rt.created_at AS oauth2_refresh_token_created_at
|
||||
, rt.oauth2_access_token_id AS "oauth2_access_token_id?"
|
||||
, os.created_at AS "oauth2_session_created_at"
|
||||
, os.oauth2_session_id AS "oauth2_session_id!"
|
||||
, os.oauth2_client_id AS "oauth2_client_id!"
|
||||
, os.scope AS "oauth2_session_scope!"
|
||||
@ -127,10 +129,11 @@ pub async fn lookup_active_refresh_token(
|
||||
|
||||
let session = Session {
|
||||
id: session_id,
|
||||
state: SessionState::Valid,
|
||||
created_at: res.oauth2_session_created_at,
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
finished_at: None,
|
||||
};
|
||||
|
||||
Ok(Some((refresh_token, session)))
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
@ -85,12 +85,18 @@ impl TryFrom<OAuthSessionLookup> for Session {
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match value.finished_at {
|
||||
None => SessionState::Valid,
|
||||
Some(finished_at) => SessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
state,
|
||||
created_at: value.created_at,
|
||||
client_id: value.oauth2_client_id.into(),
|
||||
user_session_id: value.user_session_id.into(),
|
||||
scope,
|
||||
finished_at: value.finished_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -182,10 +188,11 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
state: SessionState::Valid,
|
||||
created_at,
|
||||
user_session_id: user_session.id,
|
||||
client_id: grant.client_id,
|
||||
scope: grant.scope.clone(),
|
||||
finished_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -85,9 +85,10 @@ mod tests {
|
||||
.await?
|
||||
.expect("session to be found in the database");
|
||||
assert_eq!(session.provider_id, provider.id);
|
||||
assert_eq!(session.link_id, None);
|
||||
assert!(!session.completed());
|
||||
assert!(!session.consumed());
|
||||
assert_eq!(session.link_id(), None);
|
||||
assert!(session.is_pending());
|
||||
assert!(!session.is_completed());
|
||||
assert!(!session.is_consumed());
|
||||
|
||||
// Create a link
|
||||
let link = conn
|
||||
@ -114,15 +115,15 @@ mod tests {
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, None)
|
||||
.await?;
|
||||
assert!(session.completed());
|
||||
assert!(!session.consumed());
|
||||
assert_eq!(session.link_id, Some(link.id));
|
||||
assert!(session.is_completed());
|
||||
assert!(!session.is_consumed());
|
||||
assert_eq!(session.link_id(), Some(link.id));
|
||||
|
||||
let session = conn
|
||||
.upstream_oauth_session()
|
||||
.consume(&clock, session)
|
||||
.await?;
|
||||
assert!(session.consumed());
|
||||
assert!(session.is_consumed());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -14,13 +14,18 @@
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
|
||||
UpstreamOAuthProvider,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
@ -83,6 +88,52 @@ struct SessionLookup {
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.upstream_oauth_authorization_session_id.into();
|
||||
let state = match (
|
||||
value.upstream_oauth_link_id,
|
||||
value.id_token,
|
||||
value.completed_at,
|
||||
value.consumed_at,
|
||||
) {
|
||||
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
|
||||
(Some(link_id), id_token, Some(completed_at), None) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Completed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
}
|
||||
}
|
||||
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Consumed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
consumed_at,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
provider_id: value.upstream_oauth_provider_id.into(),
|
||||
state_str: value.state,
|
||||
nonce: value.nonce,
|
||||
code_challenge_verifier: value.code_challenge_verifier,
|
||||
created_at: value.created_at,
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
@ -126,20 +177,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let session = UpstreamOAuthAuthorizationSession {
|
||||
id: res.upstream_oauth_authorization_session_id.into(),
|
||||
provider_id: res.upstream_oauth_provider_id.into(),
|
||||
link_id: res.upstream_oauth_link_id.map(Ulid::from),
|
||||
state: res.state,
|
||||
code_challenge_verifier: res.code_challenge_verifier,
|
||||
nonce: res.nonce,
|
||||
id_token: res.id_token,
|
||||
created_at: res.created_at,
|
||||
completed_at: res.completed_at,
|
||||
consumed_at: res.consumed_at,
|
||||
};
|
||||
|
||||
Ok(Some(session))
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -159,7 +197,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
state: String,
|
||||
state_str: String,
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
@ -186,7 +224,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
&state,
|
||||
&state_str,
|
||||
code_challenge_verifier.as_deref(),
|
||||
nonce,
|
||||
created_at,
|
||||
@ -197,15 +235,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
|
||||
Ok(UpstreamOAuthAuthorizationSession {
|
||||
id,
|
||||
state: UpstreamOAuthAuthorizationSessionState::default(),
|
||||
provider_id: upstream_oauth_provider.id,
|
||||
link_id: None,
|
||||
state,
|
||||
state_str,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
id_token: None,
|
||||
created_at,
|
||||
completed_at: None,
|
||||
consumed_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
@ -222,11 +257,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
async fn complete_with_link(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let completed_at = clock.now();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_authorization_sessions
|
||||
@ -244,9 +280,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
upstream_oauth_authorization_session.completed_at = Some(completed_at);
|
||||
upstream_oauth_authorization_session.id_token = id_token;
|
||||
upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id);
|
||||
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
|
||||
.complete(completed_at, upstream_oauth_link, id_token)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
@ -264,7 +300,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
sqlx::query!(
|
||||
@ -280,7 +316,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
|
||||
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
|
Reference in New Issue
Block a user