You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Split the storage trait from the implementation
This commit is contained in:
@ -13,14 +13,12 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use chrono::Duration;
|
||||
use mas_data_model::{CompatAccessToken, CompatSession};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatAccessTokenRepository: Send + Sync {
|
||||
@ -52,195 +50,3 @@ pub trait CompatAccessTokenRepository: Send + Sync {
|
||||
compat_access_token: CompatAccessToken,
|
||||
) -> Result<CompatAccessToken, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatAccessTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatAccessTokenRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatAccessTokenLookup {
|
||||
compat_access_token_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatAccessTokenLookup> for CompatAccessToken {
|
||||
fn from(value: CompatAccessTokenLookup) -> Self {
|
||||
Self {
|
||||
id: value.compat_access_token_id.into(),
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
access_token: &str,
|
||||
) -> Result<Option<CompatAccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
access_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_access_token.id,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
compat_session: &CompatSession,
|
||||
token: String,
|
||||
expires_after: Option<Duration>,
|
||||
) -> Result<CompatAccessToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
||||
|
||||
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_access_tokens
|
||||
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(compat_session.id),
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatAccessToken {
|
||||
id,
|
||||
session_id: compat_session.id,
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.expire",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_access_token.id,
|
||||
compat_session.id = %compat_access_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn expire(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut compat_access_token: CompatAccessToken,
|
||||
) -> Result<CompatAccessToken, Self::Error> {
|
||||
let expires_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_access_tokens
|
||||
SET expires_at = $2
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_access_token.id),
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
compat_access_token.expires_at = Some(expires_at);
|
||||
Ok(compat_access_token)
|
||||
}
|
||||
}
|
||||
|
@ -18,301 +18,6 @@ mod session;
|
||||
mod sso_login;
|
||||
|
||||
pub use self::{
|
||||
access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository},
|
||||
refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository},
|
||||
session::{CompatSessionRepository, PgCompatSessionRepository},
|
||||
sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository},
|
||||
access_token::CompatAccessTokenRepository, refresh_token::CompatRefreshTokenRepository,
|
||||
session::CompatSessionRepository, sso_login::CompatSsoLoginRepository,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::Duration;
|
||||
use mas_data_model::Device;
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::*;
|
||||
use crate::{user::UserRepository, Clock, PgRepository, Repository};
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_session_repository(pool: PgPool) {
|
||||
const FIRST_TOKEN: &str = "first_access_token";
|
||||
const SECOND_TOKEN: &str = "second_access_token";
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// Create a user
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Start a compat session for that user
|
||||
let device = Device::generate(&mut rng);
|
||||
let device_str = device.as_str().to_owned();
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(session.user_id, user.id);
|
||||
assert_eq!(session.device.as_str(), device_str);
|
||||
assert!(session.is_valid());
|
||||
assert!(!session.is_finished());
|
||||
|
||||
// Lookup the session and check it didn't change
|
||||
let session_lookup = repo
|
||||
.compat_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat session not found");
|
||||
assert_eq!(session_lookup.id, session.id);
|
||||
assert_eq!(session_lookup.user_id, user.id);
|
||||
assert_eq!(session_lookup.device.as_str(), device_str);
|
||||
assert!(session_lookup.is_valid());
|
||||
assert!(!session_lookup.is_finished());
|
||||
|
||||
// Finish the session
|
||||
let session = repo.compat_session().finish(&clock, session).await.unwrap();
|
||||
assert!(!session.is_valid());
|
||||
assert!(session.is_finished());
|
||||
|
||||
// Reload the session and check again
|
||||
let session_lookup = repo
|
||||
.compat_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat session not found");
|
||||
assert!(!session_lookup.is_valid());
|
||||
assert!(session_lookup.is_finished());
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_access_token_repository(pool: PgPool) {
|
||||
const FIRST_TOKEN: &str = "first_access_token";
|
||||
const SECOND_TOKEN: &str = "second_access_token";
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// Create a user
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Start a compat session for that user
|
||||
let device = Device::generate(&mut rng);
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Add an access token to that session
|
||||
let token = repo
|
||||
.compat_access_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
FIRST_TOKEN.to_owned(),
|
||||
Some(Duration::minutes(1)),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(token.session_id, session.id);
|
||||
assert_eq!(token.token, FIRST_TOKEN);
|
||||
|
||||
// Commit the txn and grab a new transaction, to test a conflict
|
||||
repo.save().await.unwrap();
|
||||
|
||||
{
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
// Adding the same token a second time should conflict
|
||||
assert!(repo
|
||||
.compat_access_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
FIRST_TOKEN.to_owned(),
|
||||
Some(Duration::minutes(1)),
|
||||
)
|
||||
.await
|
||||
.is_err());
|
||||
repo.cancel().await.unwrap();
|
||||
}
|
||||
|
||||
// Grab a new repo
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// Looking up via ID works
|
||||
let token_lookup = repo
|
||||
.compat_access_token()
|
||||
.lookup(token.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat access token not found");
|
||||
assert_eq!(token.id, token_lookup.id);
|
||||
assert_eq!(token_lookup.session_id, session.id);
|
||||
|
||||
// Looking up via the token value works
|
||||
let token_lookup = repo
|
||||
.compat_access_token()
|
||||
.find_by_token(FIRST_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat access token not found");
|
||||
assert_eq!(token.id, token_lookup.id);
|
||||
assert_eq!(token_lookup.session_id, session.id);
|
||||
|
||||
// Token is currently valid
|
||||
assert!(token.is_valid(clock.now()));
|
||||
|
||||
clock.advance(Duration::minutes(1));
|
||||
// Token should have expired
|
||||
assert!(!token.is_valid(clock.now()));
|
||||
|
||||
// Add a second access token, this time without expiration
|
||||
let token = repo
|
||||
.compat_access_token()
|
||||
.add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(token.session_id, session.id);
|
||||
assert_eq!(token.token, SECOND_TOKEN);
|
||||
|
||||
// Token is currently valid
|
||||
assert!(token.is_valid(clock.now()));
|
||||
|
||||
// Make it expire
|
||||
repo.compat_access_token()
|
||||
.expire(&clock, token)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Reload it
|
||||
let token = repo
|
||||
.compat_access_token()
|
||||
.find_by_token(SECOND_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("compat access token not found");
|
||||
|
||||
// Token is not valid anymore
|
||||
assert!(!token.is_valid(clock.now()));
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_refresh_token_repository(pool: PgPool) {
|
||||
const ACCESS_TOKEN: &str = "access_token";
|
||||
const REFRESH_TOKEN: &str = "refresh_token";
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// Create a user
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Start a compat session for that user
|
||||
let device = Device::generate(&mut rng);
|
||||
let session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Add an access token to that session
|
||||
let access_token = repo
|
||||
.compat_access_token()
|
||||
.add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&access_token,
|
||||
REFRESH_TOKEN.to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(refresh_token.session_id, session.id);
|
||||
assert_eq!(refresh_token.access_token_id, access_token.id);
|
||||
assert_eq!(refresh_token.token, REFRESH_TOKEN);
|
||||
assert!(refresh_token.is_valid());
|
||||
assert!(!refresh_token.is_consumed());
|
||||
|
||||
// Look it up by ID and check everything matches
|
||||
let refresh_token_lookup = repo
|
||||
.compat_refresh_token()
|
||||
.lookup(refresh_token.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert_eq!(refresh_token_lookup.id, refresh_token.id);
|
||||
assert_eq!(refresh_token_lookup.session_id, session.id);
|
||||
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
|
||||
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
|
||||
assert!(refresh_token_lookup.is_valid());
|
||||
assert!(!refresh_token_lookup.is_consumed());
|
||||
|
||||
// Look it up by token and check everything matches
|
||||
let refresh_token_lookup = repo
|
||||
.compat_refresh_token()
|
||||
.find_by_token(REFRESH_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert_eq!(refresh_token_lookup.id, refresh_token.id);
|
||||
assert_eq!(refresh_token_lookup.session_id, session.id);
|
||||
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
|
||||
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
|
||||
assert!(refresh_token_lookup.is_valid());
|
||||
assert!(!refresh_token_lookup.is_consumed());
|
||||
|
||||
// Consume it
|
||||
let refresh_token = repo
|
||||
.compat_refresh_token()
|
||||
.consume(&clock, refresh_token)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!refresh_token.is_valid());
|
||||
assert!(refresh_token.is_consumed());
|
||||
|
||||
// Reload it and check again
|
||||
let refresh_token_lookup = repo
|
||||
.compat_refresh_token()
|
||||
.find_by_token(REFRESH_TOKEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("refresh token not found");
|
||||
assert!(!refresh_token_lookup.is_valid());
|
||||
assert!(refresh_token_lookup.is_consumed());
|
||||
|
||||
// Consuming it again should not work
|
||||
assert!(repo
|
||||
.compat_refresh_token()
|
||||
.consume(&clock, refresh_token)
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -13,16 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
||||
};
|
||||
use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatRefreshTokenRepository: Send + Sync {
|
||||
@ -54,207 +49,3 @@ pub trait CompatRefreshTokenRepository: Send + Sync {
|
||||
compat_refresh_token: CompatRefreshToken,
|
||||
) -> Result<CompatRefreshToken, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatRefreshTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatRefreshTokenRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatRefreshTokenLookup {
|
||||
compat_refresh_token_id: Uuid,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
compat_access_token_id: Uuid,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
|
||||
fn from(value: CompatRefreshTokenLookup) -> Self {
|
||||
let state = match value.consumed_at {
|
||||
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
|
||||
None => CompatRefreshTokenState::Valid,
|
||||
};
|
||||
|
||||
Self {
|
||||
id: value.compat_refresh_token_id.into(),
|
||||
state,
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.refresh_token,
|
||||
created_at: value.created_at,
|
||||
access_token_id: value.compat_access_token_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_refresh_token.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE compat_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
refresh_token: &str,
|
||||
) -> Result<Option<CompatRefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
refresh_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_refresh_token.id,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
compat_session: &CompatSession,
|
||||
compat_access_token: &CompatAccessToken,
|
||||
token: String,
|
||||
) -> Result<CompatRefreshToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_refresh_tokens
|
||||
(compat_refresh_token_id, compat_session_id,
|
||||
compat_access_token_id, refresh_token, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(compat_session.id),
|
||||
Uuid::from(compat_access_token.id),
|
||||
token,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatRefreshToken {
|
||||
id,
|
||||
state: CompatRefreshTokenState::default(),
|
||||
session_id: compat_session.id,
|
||||
access_token_id: compat_access_token.id,
|
||||
token,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.consume",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_refresh_token.id,
|
||||
compat_session.id = %compat_refresh_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_refresh_token: CompatRefreshToken,
|
||||
) -> Result<CompatRefreshToken, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_refresh_tokens
|
||||
SET consumed_at = $2
|
||||
WHERE compat_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_refresh_token.id),
|
||||
consumed_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_refresh_token = compat_refresh_token
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_refresh_token)
|
||||
}
|
||||
}
|
||||
|
@ -13,16 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
|
||||
use mas_data_model::{CompatSession, Device, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatSessionRepository: Send + Sync {
|
||||
@ -47,174 +42,3 @@ pub trait CompatSessionRepository: Send + Sync {
|
||||
compat_session: CompatSession,
|
||||
) -> Result<CompatSession, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatSessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatSessionRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatSessionLookup {
|
||||
compat_session_id: Uuid,
|
||||
device_id: String,
|
||||
user_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSessionLookup> for CompatSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.compat_session_id.into();
|
||||
let device = Device::try_from(value.device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match value.finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: value.user_id.into(),
|
||||
device,
|
||||
created_at: value.created_at,
|
||||
};
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSessionLookup,
|
||||
r#"
|
||||
SELECT compat_session_id
|
||||
, device_id
|
||||
, user_id
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM compat_sessions
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id,
|
||||
%user.id,
|
||||
%user.username,
|
||||
compat_session.device.id = device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
device: Device,
|
||||
) -> Result<CompatSession, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user.id),
|
||||
device.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatSession {
|
||||
id,
|
||||
state: CompatSessionState::default(),
|
||||
user_id: user.id,
|
||||
device,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.finish",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
compat_session.device.id = compat_session.device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn finish(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_session: CompatSession,
|
||||
) -> Result<CompatSession, Self::Error> {
|
||||
let finished_at = clock.now();
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sessions cs
|
||||
SET finished_at = $2
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_session.id),
|
||||
finished_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_session = compat_session
|
||||
.finish(finished_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_session)
|
||||
}
|
||||
}
|
||||
|
@ -13,19 +13,12 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
|
||||
use mas_data_model::{CompatSession, CompatSsoLogin, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatSsoLoginRepository: Send + Sync {
|
||||
@ -71,317 +64,3 @@ pub trait CompatSsoLoginRepository: Send + Sync {
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<CompatSsoLogin>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatSsoLoginRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatSsoLoginRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct CompatSsoLoginLookup {
|
||||
compat_sso_login_id: Uuid,
|
||||
login_token: String,
|
||||
redirect_uri: String,
|
||||
created_at: DateTime<Utc>,
|
||||
fulfilled_at: Option<DateTime<Utc>>,
|
||||
exchanged_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
|
||||
let id = res.compat_sso_login_id.into();
|
||||
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sso_logins")
|
||||
.column("redirect_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
|
||||
(None, None, None) => CompatSsoLoginState::Pending,
|
||||
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session_id: session_id.into(),
|
||||
},
|
||||
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
|
||||
CompatSsoLoginState::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session_id: session_id.into(),
|
||||
}
|
||||
}
|
||||
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
|
||||
};
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token: res.login_token,
|
||||
redirect_uri,
|
||||
created_at: res.created_at,
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_sso_login.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT compat_sso_login_id
|
||||
, login_token
|
||||
, redirect_uri
|
||||
, created_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_sso_logins
|
||||
WHERE compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
login_token: &str,
|
||||
) -> Result<Option<CompatSsoLogin>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT compat_sso_login_id
|
||||
, login_token
|
||||
, redirect_uri
|
||||
, created_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_sso_logins
|
||||
WHERE login_token = $1
|
||||
"#,
|
||||
login_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_sso_login.id,
|
||||
compat_sso_login.redirect_uri = %redirect_uri,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
login_token: String,
|
||||
redirect_uri: Url,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sso_logins
|
||||
(compat_sso_login_id, login_token, redirect_uri, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&login_token,
|
||||
redirect_uri.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token,
|
||||
redirect_uri,
|
||||
created_at,
|
||||
state: CompatSsoLoginState::default(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.fulfill",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_sso_login.id,
|
||||
%compat_session.id,
|
||||
compat_session.device.id = compat_session.device.as_str(),
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn fulfill(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
compat_session: &CompatSession,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let fulfilled_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.fulfill(fulfilled_at, compat_session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
compat_session_id = $2,
|
||||
fulfilled_at = $3
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
Uuid::from(compat_session.id),
|
||||
fulfilled_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.exchange",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_sso_login.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn exchange(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let exchanged_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.exchange(exchanged_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
exchanged_at = $2
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
exchanged_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<CompatSsoLogin>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT cl.compat_sso_login_id
|
||||
, cl.login_token
|
||||
, cl.redirect_uri
|
||||
, cl.created_at
|
||||
, cl.fulfilled_at
|
||||
, cl.exchanged_at
|
||||
, cl.compat_session_id
|
||||
|
||||
FROM compat_sso_logins cl
|
||||
INNER JOIN compat_sessions ON compat_session_id
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("cl.compat_sso_login_id", pagination);
|
||||
|
||||
let edges: Vec<CompatSsoLoginLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination
|
||||
.process(edges)
|
||||
.try_map(CompatSsoLogin::try_from)?;
|
||||
Ok(page)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -29,150 +29,19 @@
|
||||
)]
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use pagination::InvalidPagination;
|
||||
use sqlx::{migrate::Migrator, postgres::PgQueryResult};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
trait LookupResultExt {
|
||||
type Output;
|
||||
|
||||
/// Transform a [`Result`] from a sqlx query to transform "not found" errors
|
||||
/// into [`None`]
|
||||
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error>;
|
||||
}
|
||||
|
||||
impl<T> LookupResultExt for Result<T, sqlx::Error> {
|
||||
type Output = T;
|
||||
|
||||
fn to_option(self) -> Result<Option<Self::Output>, sqlx::Error> {
|
||||
match self {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(sqlx::Error::RowNotFound) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic error when interacting with the database
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub enum DatabaseError {
|
||||
/// An error which came from the database itself
|
||||
Driver(#[from] sqlx::Error),
|
||||
|
||||
/// An error which occured while converting the data from the database
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
|
||||
/// An error which occured while generating the paginated query
|
||||
Pagination(#[from] InvalidPagination),
|
||||
|
||||
/// An error which happened because the requested database operation is
|
||||
/// invalid
|
||||
#[error("Invalid database operation")]
|
||||
InvalidOperation {
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
},
|
||||
|
||||
/// An error which happens when an operation affects not enough or too many
|
||||
/// rows
|
||||
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
|
||||
RowsAffected { expected: u64, actual: u64 },
|
||||
}
|
||||
|
||||
impl DatabaseError {
|
||||
pub(crate) fn ensure_affected_rows(
|
||||
result: &PgQueryResult,
|
||||
expected: u64,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let actual = result.rows_affected();
|
||||
if actual == expected {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(DatabaseError::RowsAffected { expected, actual })
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
|
||||
Self::InvalidOperation {
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const fn invalid_operation() -> Self {
|
||||
Self::InvalidOperation { source: None }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub struct DatabaseInconsistencyError {
|
||||
table: &'static str,
|
||||
column: Option<&'static str>,
|
||||
row: Option<Ulid>,
|
||||
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DatabaseInconsistencyError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Database inconsistency on table {}", self.table)?;
|
||||
if let Some(column) = self.column {
|
||||
write!(f, " column {column}")?;
|
||||
}
|
||||
if let Some(row) = self.row {
|
||||
write!(f, " row {row}")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseInconsistencyError {
|
||||
#[must_use]
|
||||
pub(crate) const fn on(table: &'static str) -> Self {
|
||||
Self {
|
||||
table,
|
||||
column: None,
|
||||
row: None,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) const fn column(mut self, column: &'static str) -> Self {
|
||||
self.column = Some(column);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) const fn row(mut self, row: Ulid) -> Self {
|
||||
self.row = Some(row);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
|
||||
mut self,
|
||||
source: E,
|
||||
) -> Self {
|
||||
self.source = Some(Box::new(source));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Clock {
|
||||
_private: (),
|
||||
|
||||
#[cfg(test)]
|
||||
// #[cfg(test)]
|
||||
mock: Option<std::sync::Arc<std::sync::atomic::AtomicI64>>,
|
||||
}
|
||||
|
||||
impl Clock {
|
||||
#[must_use]
|
||||
pub fn now(&self) -> DateTime<Utc> {
|
||||
#[cfg(test)]
|
||||
// #[cfg(test)]
|
||||
if let Some(timestamp) = &self.mock {
|
||||
let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed);
|
||||
return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap();
|
||||
@ -183,13 +52,14 @@ impl Clock {
|
||||
Utc::now()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// #[cfg(test)]
|
||||
#[must_use]
|
||||
pub fn mock() -> Self {
|
||||
use std::sync::{atomic::AtomicI64, Arc};
|
||||
|
||||
use chrono::TimeZone;
|
||||
|
||||
let datetime = Utc.with_ymd_and_hms(2022, 01, 16, 14, 40, 0).unwrap();
|
||||
let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap();
|
||||
let timestamp = datetime.timestamp();
|
||||
|
||||
Self {
|
||||
@ -198,7 +68,7 @@ impl Clock {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// #[cfg(test)]
|
||||
pub fn advance(&self, duration: chrono::Duration) {
|
||||
let timestamp = self
|
||||
.mock
|
||||
@ -247,16 +117,12 @@ mod tests {
|
||||
|
||||
pub mod compat;
|
||||
pub mod oauth2;
|
||||
pub(crate) mod pagination;
|
||||
pub mod pagination;
|
||||
pub(crate) mod repository;
|
||||
pub(crate) mod tracing;
|
||||
pub mod upstream_oauth2;
|
||||
pub mod user;
|
||||
|
||||
pub use self::{
|
||||
pagination::Pagination,
|
||||
repository::{PgRepository, Repository},
|
||||
pagination::{Page, Pagination},
|
||||
repository::Repository,
|
||||
};
|
||||
|
||||
/// Embedded migrations, allowing them to run on startup
|
||||
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
||||
|
@ -13,14 +13,12 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{AccessToken, AccessTokenState, Session};
|
||||
use chrono::Duration;
|
||||
use mas_data_model::{AccessToken, Session};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuth2AccessTokenRepository: Send + Sync {
|
||||
@ -55,202 +53,3 @@ pub trait OAuth2AccessTokenRepository: Send + Sync {
|
||||
/// Cleanup expired access tokens
|
||||
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgOAuth2AccessTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2AccessTokenRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct OAuth2AccessTokenLookup {
|
||||
oauth2_access_token_id: Uuid,
|
||||
oauth2_session_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: DateTime<Utc>,
|
||||
revoked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl From<OAuth2AccessTokenLookup> for AccessToken {
|
||||
fn from(value: OAuth2AccessTokenLookup) -> Self {
|
||||
let state = match value.revoked_at {
|
||||
None => AccessTokenState::Valid,
|
||||
Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
|
||||
};
|
||||
|
||||
Self {
|
||||
id: value.oauth2_access_token_id.into(),
|
||||
state,
|
||||
session_id: value.oauth2_session_id.into(),
|
||||
access_token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_access_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
access_token: &str,
|
||||
) -> Result<Option<AccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, revoked_at
|
||||
, oauth2_session_id
|
||||
|
||||
FROM oauth2_access_tokens
|
||||
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
access_token,
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_access_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%session.id,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
access_token.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
session: &Session,
|
||||
access_token: String,
|
||||
expires_after: Duration,
|
||||
) -> Result<AccessToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let expires_at = created_at + expires_after;
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
|
||||
tracing::Span::current().record("access_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_access_tokens
|
||||
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(session.id),
|
||||
&access_token,
|
||||
created_at,
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(AccessToken {
|
||||
id,
|
||||
state: AccessTokenState::default(),
|
||||
access_token,
|
||||
session_id: session.id,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
async fn revoke(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
access_token: AccessToken,
|
||||
) -> Result<AccessToken, Self::Error> {
|
||||
let revoked_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_access_tokens
|
||||
SET revoked_at = $2
|
||||
WHERE oauth2_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(access_token.id),
|
||||
revoked_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
access_token
|
||||
.revoke(revoked_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error> {
|
||||
// Cleanup token which expired more than 15 minutes ago
|
||||
let threshold = clock.now() - Duration::minutes(15);
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth2_access_tokens
|
||||
WHERE expires_at < $1
|
||||
"#,
|
||||
threshold,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -15,21 +15,13 @@
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
|
||||
};
|
||||
use mas_iana::oauth::PkceCodeChallengeMethod;
|
||||
use mas_data_model::{AuthorizationCode, AuthorizationGrant, Client, Session};
|
||||
use oauth2_types::{requests::ResponseMode, scope::Scope};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
|
||||
@ -75,482 +67,3 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
|
||||
authorization_grant: AuthorizationGrant,
|
||||
) -> Result<AuthorizationGrant, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
struct GrantLookup {
|
||||
oauth2_authorization_grant_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
cancelled_at: Option<DateTime<Utc>>,
|
||||
fulfilled_at: Option<DateTime<Utc>>,
|
||||
exchanged_at: Option<DateTime<Utc>>,
|
||||
scope: String,
|
||||
state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
redirect_uri: String,
|
||||
response_mode: String,
|
||||
max_age: Option<i32>,
|
||||
response_type_code: bool,
|
||||
response_type_id_token: bool,
|
||||
authorization_code: Option<String>,
|
||||
code_challenge: Option<String>,
|
||||
code_challenge_method: Option<String>,
|
||||
requires_consent: bool,
|
||||
oauth2_client_id: Uuid,
|
||||
oauth2_session_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<GrantLookup> for AuthorizationGrant {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.oauth2_authorization_grant_id.into();
|
||||
let scope: Scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let stage = match (
|
||||
value.fulfilled_at,
|
||||
value.exchanged_at,
|
||||
value.cancelled_at,
|
||||
value.oauth2_session_id,
|
||||
) {
|
||||
(None, None, None, None) => AuthorizationGrantStage::Pending,
|
||||
(Some(fulfilled_at), None, None, Some(session_id)) => {
|
||||
AuthorizationGrantStage::Fulfilled {
|
||||
session_id: session_id.into(),
|
||||
fulfilled_at,
|
||||
}
|
||||
}
|
||||
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
|
||||
AuthorizationGrantStage::Exchanged {
|
||||
session_id: session_id.into(),
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
}
|
||||
}
|
||||
(None, None, Some(cancelled_at), None) => {
|
||||
AuthorizationGrantStage::Cancelled { cancelled_at }
|
||||
}
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("stage")
|
||||
.row(id),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let pkce = match (value.code_challenge, value.code_challenge_method) {
|
||||
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
|
||||
Some(Pkce {
|
||||
challenge_method: PkceCodeChallengeMethod::Plain,
|
||||
challenge,
|
||||
})
|
||||
}
|
||||
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
|
||||
challenge_method: PkceCodeChallengeMethod::S256,
|
||||
challenge,
|
||||
}),
|
||||
(None, None) => None,
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("code_challenge_method")
|
||||
.row(id),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let code: Option<AuthorizationCode> =
|
||||
match (value.response_type_code, value.authorization_code, pkce) {
|
||||
(false, None, None) => None,
|
||||
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("authorization_code")
|
||||
.row(id),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let redirect_uri = value.redirect_uri.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("redirect_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let response_mode = value.response_mode.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("response_mode")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let max_age = value
|
||||
.max_age
|
||||
.map(u32::try_from)
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("max_age")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?
|
||||
.map(NonZeroU32::try_from)
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
.column("max_age")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(AuthorizationGrant {
|
||||
id,
|
||||
stage,
|
||||
client_id: value.oauth2_client_id.into(),
|
||||
code,
|
||||
scope,
|
||||
state: value.state,
|
||||
nonce: value.nonce,
|
||||
max_age,
|
||||
response_mode,
|
||||
redirect_uri,
|
||||
created_at: value.created_at,
|
||||
response_type_id_token: value.response_type_id_token,
|
||||
requires_consent: value.requires_consent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
grant.id,
|
||||
grant.scope = %scope,
|
||||
%client.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
client: &Client,
|
||||
redirect_uri: Url,
|
||||
scope: Scope,
|
||||
code: Option<AuthorizationCode>,
|
||||
state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
max_age: Option<NonZeroU32>,
|
||||
response_mode: ResponseMode,
|
||||
response_type_id_token: bool,
|
||||
requires_consent: bool,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
let code_challenge = code
|
||||
.as_ref()
|
||||
.and_then(|c| c.pkce.as_ref())
|
||||
.map(|p| &p.challenge);
|
||||
let code_challenge_method = code
|
||||
.as_ref()
|
||||
.and_then(|c| c.pkce.as_ref())
|
||||
.map(|p| p.challenge_method.to_string());
|
||||
// TODO: this conversion is a bit ugly
|
||||
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
|
||||
let code_str = code.as_ref().map(|c| &c.code);
|
||||
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("grant.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_authorization_grants (
|
||||
oauth2_authorization_grant_id,
|
||||
oauth2_client_id,
|
||||
redirect_uri,
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
max_age,
|
||||
response_mode,
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
response_type_code,
|
||||
response_type_id_token,
|
||||
authorization_code,
|
||||
requires_consent,
|
||||
created_at
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(client.id),
|
||||
redirect_uri.to_string(),
|
||||
scope.to_string(),
|
||||
state,
|
||||
nonce,
|
||||
max_age_i32,
|
||||
response_mode.to_string(),
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
code.is_some(),
|
||||
response_type_id_token,
|
||||
code_str,
|
||||
requires_consent,
|
||||
created_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(AuthorizationGrant {
|
||||
id,
|
||||
stage: AuthorizationGrantStage::Pending,
|
||||
code,
|
||||
redirect_uri,
|
||||
client_id: client.id,
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
max_age,
|
||||
response_mode,
|
||||
created_at,
|
||||
response_type_id_token,
|
||||
requires_consent,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
grant.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
GrantLookup,
|
||||
r#"
|
||||
SELECT oauth2_authorization_grant_id
|
||||
, created_at
|
||||
, cancelled_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, scope
|
||||
, state
|
||||
, redirect_uri
|
||||
, response_mode
|
||||
, nonce
|
||||
, max_age
|
||||
, oauth2_client_id
|
||||
, authorization_code
|
||||
, response_type_code
|
||||
, response_type_id_token
|
||||
, code_challenge
|
||||
, code_challenge_method
|
||||
, requires_consent
|
||||
, oauth2_session_id
|
||||
FROM
|
||||
oauth2_authorization_grants
|
||||
|
||||
WHERE oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.find_by_code",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_code(
|
||||
&mut self,
|
||||
code: &str,
|
||||
) -> Result<Option<AuthorizationGrant>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
GrantLookup,
|
||||
r#"
|
||||
SELECT oauth2_authorization_grant_id
|
||||
, created_at
|
||||
, cancelled_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, scope
|
||||
, state
|
||||
, redirect_uri
|
||||
, response_mode
|
||||
, nonce
|
||||
, max_age
|
||||
, oauth2_client_id
|
||||
, authorization_code
|
||||
, response_type_code
|
||||
, response_type_id_token
|
||||
, code_challenge
|
||||
, code_challenge_method
|
||||
, requires_consent
|
||||
, oauth2_session_id
|
||||
FROM
|
||||
oauth2_authorization_grants
|
||||
|
||||
WHERE authorization_code = $1
|
||||
"#,
|
||||
code,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.fulfill",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
%session.id,
|
||||
user_session.id = %session.user_session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn fulfill(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
session: &Session,
|
||||
grant: AuthorizationGrant,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
let fulfilled_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_authorization_grants
|
||||
SET fulfilled_at = $2
|
||||
, oauth2_session_id = $3
|
||||
WHERE oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(grant.id),
|
||||
fulfilled_at,
|
||||
Uuid::from(session.id),
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
// XXX: check affected rows & new methods
|
||||
let grant = grant
|
||||
.fulfill(fulfilled_at, session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.exchange",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn exchange(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
grant: AuthorizationGrant,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
let exchanged_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_authorization_grants
|
||||
SET exchanged_at = $2
|
||||
WHERE oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(grant.id),
|
||||
exchanged_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let grant = grant
|
||||
.exchange(exchanged_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_authorization_grant.give_consent",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn give_consent(
|
||||
&mut self,
|
||||
mut grant: AuthorizationGrant,
|
||||
) -> Result<AuthorizationGrant, Self::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_authorization_grants AS og
|
||||
SET
|
||||
requires_consent = 'f'
|
||||
WHERE
|
||||
og.oauth2_authorization_grant_id = $1
|
||||
"#,
|
||||
Uuid::from(grant.id),
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
grant.requires_consent = false;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -12,33 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
str::FromStr,
|
||||
string::ToString,
|
||||
};
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::{Client, JwksOrJwksUri, User};
|
||||
use mas_iana::{
|
||||
jose::JsonWebSignatureAlg,
|
||||
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
|
||||
};
|
||||
use mas_data_model::{Client, User};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_jose::jwk::PublicJsonWebKeySet;
|
||||
use oauth2_types::{
|
||||
requests::GrantType,
|
||||
scope::{Scope, ScopeToken},
|
||||
};
|
||||
use oauth2_types::{requests::GrantType, scope::Scope};
|
||||
use rand::{Rng, RngCore};
|
||||
use sqlx::PgConnection;
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuth2ClientRepository: Send + Sync {
|
||||
@ -107,708 +92,3 @@ pub trait OAuth2ClientRepository: Send + Sync {
|
||||
scope: &Scope,
|
||||
) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgOAuth2ClientRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2ClientRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: response_types & contacts
|
||||
#[derive(Debug)]
|
||||
struct OAuth2ClientLookup {
|
||||
oauth2_client_id: Uuid,
|
||||
encrypted_client_secret: Option<String>,
|
||||
redirect_uris: Vec<String>,
|
||||
// response_types: Vec<String>,
|
||||
grant_type_authorization_code: bool,
|
||||
grant_type_refresh_token: bool,
|
||||
// contacts: Vec<String>,
|
||||
client_name: Option<String>,
|
||||
logo_uri: Option<String>,
|
||||
client_uri: Option<String>,
|
||||
policy_uri: Option<String>,
|
||||
tos_uri: Option<String>,
|
||||
jwks_uri: Option<String>,
|
||||
jwks: Option<serde_json::Value>,
|
||||
id_token_signed_response_alg: Option<String>,
|
||||
userinfo_signed_response_alg: Option<String>,
|
||||
token_endpoint_auth_method: Option<String>,
|
||||
token_endpoint_auth_signing_alg: Option<String>,
|
||||
initiate_login_uri: Option<String>,
|
||||
}
|
||||
|
||||
impl TryInto<Client> for OAuth2ClientLookup {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
|
||||
fn try_into(self) -> Result<Client, Self::Error> {
|
||||
let id = Ulid::from(self.oauth2_client_id);
|
||||
|
||||
let redirect_uris: Result<Vec<Url>, _> =
|
||||
self.redirect_uris.iter().map(|s| s.parse()).collect();
|
||||
let redirect_uris = redirect_uris.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("redirect_uris")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let response_types = vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
];
|
||||
/* XXX
|
||||
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
|
||||
self.response_types.iter().map(|s| s.parse()).collect();
|
||||
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "response_types",
|
||||
source,
|
||||
})?;
|
||||
*/
|
||||
|
||||
let mut grant_types = Vec::new();
|
||||
if self.grant_type_authorization_code {
|
||||
grant_types.push(GrantType::AuthorizationCode);
|
||||
}
|
||||
if self.grant_type_refresh_token {
|
||||
grant_types.push(GrantType::RefreshToken);
|
||||
}
|
||||
|
||||
let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("logo_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let client_uri = self
|
||||
.client_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("client_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let policy_uri = self
|
||||
.policy_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("policy_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("tos_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let id_token_signed_response_alg = self
|
||||
.id_token_signed_response_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("id_token_signed_response_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let userinfo_signed_response_alg = self
|
||||
.userinfo_signed_response_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("userinfo_signed_response_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_method = self
|
||||
.token_endpoint_auth_method
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("token_endpoint_auth_method")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_signing_alg = self
|
||||
.token_endpoint_auth_signing_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("token_endpoint_auth_signing_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let initiate_login_uri = self
|
||||
.initiate_login_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("initiate_login_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let jwks = match (self.jwks, self.jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => {
|
||||
let jwks = serde_json::from_value(jwks).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("jwks")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
Some(JwksOrJwksUri::Jwks(jwks))
|
||||
}
|
||||
(None, Some(jwks_uri)) => {
|
||||
let jwks_uri = jwks_uri.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("jwks_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Some(JwksOrJwksUri::JwksUri(jwks_uri))
|
||||
}
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError::on("oauth2_clients")
|
||||
.column("jwks(_uri)")
|
||||
.row(id))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id,
|
||||
client_id: id.to_string(),
|
||||
encrypted_client_secret: self.encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types,
|
||||
grant_types,
|
||||
// contacts: self.contacts,
|
||||
contacts: vec![],
|
||||
client_name: self.client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
userinfo_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
oauth2_client.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!"
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
WHERE oauth2_client_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.load_batch",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn load_batch(
|
||||
&mut self,
|
||||
ids: BTreeSet<Ulid>,
|
||||
) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
|
||||
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!"
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
WHERE oauth2_client_id = ANY($1::uuid[])
|
||||
"#,
|
||||
&ids,
|
||||
)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
res.into_iter()
|
||||
.map(|r| {
|
||||
r.try_into()
|
||||
.map(|c: Client| (c.id, c))
|
||||
.map_err(DatabaseError::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
client.id,
|
||||
client.name = client_name
|
||||
),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
mut rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
redirect_uris: Vec<Url>,
|
||||
encrypted_client_secret: Option<String>,
|
||||
grant_types: Vec<GrantType>,
|
||||
contacts: Vec<String>,
|
||||
client_name: Option<String>,
|
||||
logo_uri: Option<Url>,
|
||||
client_uri: Option<Url>,
|
||||
policy_uri: Option<Url>,
|
||||
tos_uri: Option<Url>,
|
||||
jwks_uri: Option<Url>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
|
||||
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
initiate_login_uri: Option<Url>,
|
||||
) -> Result<Client, Self::Error> {
|
||||
let now = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(now.into(), rng);
|
||||
tracing::Span::current().record("client.id", tracing::field::display(id));
|
||||
|
||||
let jwks_json = jwks
|
||||
.as_ref()
|
||||
.map(serde_json::to_value)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
( oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
encrypted_client_secret,
|
||||
grant_types.contains(&GrantType::AuthorizationCode),
|
||||
grant_types.contains(&GrantType::RefreshToken),
|
||||
client_name,
|
||||
logo_uri.as_ref().map(Url::as_str),
|
||||
client_uri.as_ref().map(Url::as_str),
|
||||
policy_uri.as_ref().map(Url::as_str),
|
||||
tos_uri.as_ref().map(Url::as_str),
|
||||
jwks_uri.as_ref().map(Url::as_str),
|
||||
jwks_json,
|
||||
id_token_signed_response_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
userinfo_signed_response_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
token_endpoint_auth_method.as_ref().map(ToString::to_string),
|
||||
token_endpoint_auth_signing_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
initiate_login_uri.as_ref().map(Url::as_str),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
{
|
||||
let span = info_span!(
|
||||
"db.oauth2_client.add.redirect_uris",
|
||||
db.statement = tracing::field::Empty,
|
||||
client.id = %id,
|
||||
);
|
||||
|
||||
let (uri_ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
uri.as_str().to_owned(),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&uri_ids,
|
||||
Uuid::from(id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let jwks = match (jwks, jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
|
||||
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
|
||||
_ => return Err(DatabaseError::invalid_operation()),
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id,
|
||||
client_id: id.to_string(),
|
||||
encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types: vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
],
|
||||
grant_types,
|
||||
contacts,
|
||||
client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
userinfo_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.add_from_config",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
client.id = %client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add_from_config(
|
||||
&mut self,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
client_id: Ulid,
|
||||
client_auth_method: OAuthClientAuthenticationMethod,
|
||||
encrypted_client_secret: Option<String>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
jwks_uri: Option<Url>,
|
||||
redirect_uris: Vec<Url>,
|
||||
) -> Result<Client, Self::Error> {
|
||||
let jwks_json = jwks
|
||||
.as_ref()
|
||||
.map(serde_json::to_value)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let client_auth_method = client_auth_method.to_string();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
( oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, token_endpoint_auth_method
|
||||
, jwks
|
||||
, jwks_uri
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (oauth2_client_id)
|
||||
DO
|
||||
UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
|
||||
, grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
|
||||
, grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
|
||||
, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
|
||||
, jwks = EXCLUDED.jwks
|
||||
, jwks_uri = EXCLUDED.jwks_uri
|
||||
"#,
|
||||
Uuid::from(client_id),
|
||||
encrypted_client_secret,
|
||||
true,
|
||||
true,
|
||||
client_auth_method,
|
||||
jwks_json,
|
||||
jwks_uri.as_ref().map(Url::as_str),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
{
|
||||
let span = info_span!(
|
||||
"db.oauth2_client.add_from_config.redirect_uris",
|
||||
client.id = %client_id,
|
||||
db.statement = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let now = clock.now();
|
||||
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
uri.as_str().to_owned(),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&ids,
|
||||
Uuid::from(client_id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let jwks = match (jwks, jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
|
||||
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
|
||||
_ => return Err(DatabaseError::invalid_operation()),
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id: client_id,
|
||||
client_id: client_id.to_string(),
|
||||
encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types: vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
],
|
||||
grant_types: Vec::new(),
|
||||
contacts: Vec::new(),
|
||||
client_name: None,
|
||||
logo_uri: None,
|
||||
client_uri: None,
|
||||
policy_uri: None,
|
||||
tos_uri: None,
|
||||
jwks,
|
||||
id_token_signed_response_alg: None,
|
||||
userinfo_signed_response_alg: None,
|
||||
token_endpoint_auth_method: None,
|
||||
token_endpoint_auth_signing_alg: None,
|
||||
initiate_login_uri: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.get_consent_for_user",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%client.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn get_consent_for_user(
|
||||
&mut self,
|
||||
client: &Client,
|
||||
user: &User,
|
||||
) -> Result<Scope, Self::Error> {
|
||||
let scope_tokens: Vec<String> = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT scope_token
|
||||
FROM oauth2_consents
|
||||
WHERE user_id = $1 AND oauth2_client_id = $2
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
Uuid::from(client.id),
|
||||
)
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let scope: Result<Scope, _> = scope_tokens
|
||||
.into_iter()
|
||||
.map(|s| ScopeToken::from_str(&s))
|
||||
.collect();
|
||||
|
||||
let scope = scope.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_consents")
|
||||
.column("scope_token")
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(scope)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%client.id,
|
||||
%scope,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn give_consent_for_user(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
client: &Client,
|
||||
user: &User,
|
||||
scope: &Scope,
|
||||
) -> Result<(), Self::Error> {
|
||||
let now = clock.now();
|
||||
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
|
||||
.iter()
|
||||
.map(|token| {
|
||||
(
|
||||
token.to_string(),
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_consents
|
||||
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
|
||||
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
|
||||
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
|
||||
"#,
|
||||
&ids,
|
||||
Uuid::from(user.id),
|
||||
Uuid::from(client.id),
|
||||
&tokens,
|
||||
now,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -19,11 +19,7 @@ mod refresh_token;
|
||||
mod session;
|
||||
|
||||
pub use self::{
|
||||
access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository},
|
||||
authorization_grant::{
|
||||
OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository,
|
||||
},
|
||||
client::{OAuth2ClientRepository, PgOAuth2ClientRepository},
|
||||
refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository},
|
||||
session::{OAuth2SessionRepository, PgOAuth2SessionRepository},
|
||||
access_token::OAuth2AccessTokenRepository,
|
||||
authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository,
|
||||
refresh_token::OAuth2RefreshTokenRepository, session::OAuth2SessionRepository,
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,14 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
|
||||
use mas_data_model::{AccessToken, RefreshToken, Session};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuth2RefreshTokenRepository: Send + Sync {
|
||||
@ -52,203 +49,3 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
|
||||
refresh_token: RefreshToken,
|
||||
) -> Result<RefreshToken, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgOAuth2RefreshTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct OAuth2RefreshTokenLookup {
|
||||
oauth2_refresh_token_id: Uuid,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
oauth2_access_token_id: Option<Uuid>,
|
||||
oauth2_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
|
||||
fn from(value: OAuth2RefreshTokenLookup) -> Self {
|
||||
let state = match value.consumed_at {
|
||||
None => RefreshTokenState::Valid,
|
||||
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
|
||||
};
|
||||
|
||||
RefreshToken {
|
||||
id: value.oauth2_refresh_token_id.into(),
|
||||
state,
|
||||
session_id: value.oauth2_session_id.into(),
|
||||
refresh_token: value.refresh_token,
|
||||
created_at: value.created_at,
|
||||
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
refresh_token.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE oauth2_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
refresh_token: &str,
|
||||
) -> Result<Option<RefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
SELECT oauth2_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, oauth2_access_token_id
|
||||
, oauth2_session_id
|
||||
FROM oauth2_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
refresh_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%session.id,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
refresh_token.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
session: &Session,
|
||||
access_token: &AccessToken,
|
||||
refresh_token: String,
|
||||
) -> Result<RefreshToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_refresh_tokens
|
||||
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
|
||||
refresh_token, created_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(session.id),
|
||||
Uuid::from(access_token.id),
|
||||
refresh_token,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(RefreshToken {
|
||||
id,
|
||||
state: RefreshTokenState::default(),
|
||||
session_id: session.id,
|
||||
refresh_token,
|
||||
access_token_id: Some(access_token.id),
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_refresh_token.consume",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%refresh_token.id,
|
||||
session.id = %refresh_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
refresh_token: RefreshToken,
|
||||
) -> Result<RefreshToken, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_refresh_tokens
|
||||
SET consumed_at = $2
|
||||
WHERE oauth2_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(refresh_token.id),
|
||||
consumed_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
refresh_token
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,18 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuth2SessionRepository: Send + Sync {
|
||||
@ -48,224 +41,3 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<Session>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgOAuth2SessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2SessionRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct OAuthSessionLookup {
|
||||
oauth2_session_id: Uuid,
|
||||
user_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
scope: String,
|
||||
#[allow(dead_code)]
|
||||
created_at: DateTime<Utc>,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<OAuthSessionLookup> for Session {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = Ulid::from(value.oauth2_session_id);
|
||||
let scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match value.finished_at {
|
||||
None => SessionState::Valid,
|
||||
Some(finished_at) => SessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
state,
|
||||
created_at: value.created_at,
|
||||
client_id: value.oauth2_client_id.into(),
|
||||
user_session_id: value.user_session_id.into(),
|
||||
scope,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuthSessionLookup,
|
||||
r#"
|
||||
SELECT oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM oauth2_sessions
|
||||
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(session) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(session.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.create_from_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_session.id,
|
||||
user.id = %user_session.user.id,
|
||||
%grant.id,
|
||||
client.id = %grant.client_id,
|
||||
session.id,
|
||||
session.scope = %grant.scope,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn create_from_grant(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
grant: &AuthorizationGrant,
|
||||
user_session: &BrowserSession,
|
||||
) -> Result<Session, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
( oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, created_at
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_session.id),
|
||||
Uuid::from(grant.client_id),
|
||||
grant.scope.to_string(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
state: SessionState::Valid,
|
||||
created_at,
|
||||
user_session_id: user_session.id,
|
||||
client_id: grant.client_id,
|
||||
scope: grant.scope.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.finish",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%session.id,
|
||||
%session.scope,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn finish(&mut self, clock: &Clock, session: Session) -> Result<Session, Self::Error> {
|
||||
let finished_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_sessions
|
||||
SET finished_at = $2
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
Uuid::from(session.id),
|
||||
finished_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
session
|
||||
.finish(finished_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<Session>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM oauth2_sessions os
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE us.user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("oauth2_session_id", pagination);
|
||||
|
||||
let edges: Vec<OAuthSessionLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).try_map(Session::try_from)?;
|
||||
Ok(page)
|
||||
}
|
||||
}
|
||||
|
@ -14,10 +14,8 @@
|
||||
|
||||
//! Utilities to manage paginated queries.
|
||||
|
||||
use sqlx::{Database, QueryBuilder};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// An error returned when invalid pagination parameters are provided
|
||||
#[derive(Debug, Error)]
|
||||
@ -26,14 +24,14 @@ pub struct InvalidPagination;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Pagination {
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
count: usize,
|
||||
direction: PaginationDirection,
|
||||
pub before: Option<Ulid>,
|
||||
pub after: Option<Ulid>,
|
||||
pub count: usize,
|
||||
pub direction: PaginationDirection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum PaginationDirection {
|
||||
pub enum PaginationDirection {
|
||||
Forward,
|
||||
Backward,
|
||||
}
|
||||
@ -101,60 +99,8 @@ impl Pagination {
|
||||
self
|
||||
}
|
||||
|
||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||
/// connections
|
||||
fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str)
|
||||
where
|
||||
DB: Database,
|
||||
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
{
|
||||
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
||||
// 1. Start from the greedy query: SELECT * FROM table
|
||||
|
||||
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
|
||||
// clause
|
||||
if let Some(after) = self.after {
|
||||
query
|
||||
.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" > ")
|
||||
.push_bind(Uuid::from(after));
|
||||
}
|
||||
|
||||
// 3. If the before argument is provided, add `id < parsed_cursor` to the
|
||||
// `WHERE` clause
|
||||
if let Some(before) = self.before {
|
||||
query
|
||||
.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" < ")
|
||||
.push_bind(Uuid::from(before));
|
||||
}
|
||||
|
||||
match self.direction {
|
||||
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
|
||||
// query
|
||||
PaginationDirection::Forward => {
|
||||
query
|
||||
.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" ASC LIMIT ")
|
||||
.push_bind((self.count + 1) as i64);
|
||||
}
|
||||
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
|
||||
// query
|
||||
PaginationDirection::Backward => {
|
||||
query
|
||||
.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" DESC LIMIT ")
|
||||
.push_bind((self.count + 1) as i64);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Process a page returned by a paginated query
|
||||
#[must_use]
|
||||
pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> {
|
||||
let is_full = edges.len() == (self.count + 1);
|
||||
if is_full {
|
||||
@ -198,7 +144,6 @@ impl<T> Page<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn try_map<F, E, T2>(self, f: F) -> Result<Page<T2>, E>
|
||||
where
|
||||
F: FnMut(T) -> Result<T2, E>,
|
||||
@ -211,23 +156,3 @@ impl<T> Page<T> {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
|
||||
/// to a query
|
||||
pub trait QueryBuilderExt {
|
||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||
/// connections
|
||||
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self;
|
||||
}
|
||||
|
||||
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
||||
where
|
||||
DB: Database,
|
||||
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
{
|
||||
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
|
||||
pagination.generate_pagination(self, id_field);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -12,31 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
|
||||
use crate::{
|
||||
compat::{
|
||||
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
||||
CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository,
|
||||
PgCompatSessionRepository, PgCompatSsoLoginRepository,
|
||||
CompatSsoLoginRepository,
|
||||
},
|
||||
oauth2::{
|
||||
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
|
||||
OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository,
|
||||
PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository,
|
||||
PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
|
||||
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
|
||||
},
|
||||
upstream_oauth2::{
|
||||
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
|
||||
PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository,
|
||||
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository,
|
||||
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
|
||||
UpstreamOAuthSessionRepository,
|
||||
},
|
||||
user::{
|
||||
BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository,
|
||||
PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository,
|
||||
UserRepository,
|
||||
},
|
||||
DatabaseError,
|
||||
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
};
|
||||
|
||||
pub trait Repository: Send {
|
||||
@ -126,109 +115,3 @@ pub trait Repository: Send {
|
||||
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
|
||||
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>;
|
||||
}
|
||||
|
||||
pub struct PgRepository {
|
||||
txn: Transaction<'static, Postgres>,
|
||||
}
|
||||
|
||||
impl PgRepository {
|
||||
pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
|
||||
let txn = pool.begin().await?;
|
||||
Ok(PgRepository { txn })
|
||||
}
|
||||
|
||||
pub async fn save(self) -> Result<(), DatabaseError> {
|
||||
self.txn.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn cancel(self) -> Result<(), DatabaseError> {
|
||||
self.txn.rollback().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Repository for PgRepository {
|
||||
type Error = DatabaseError;
|
||||
|
||||
type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c;
|
||||
type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c;
|
||||
type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c;
|
||||
type UserRepository<'c> = PgUserRepository<'c> where Self: 'c;
|
||||
type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c;
|
||||
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
|
||||
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
||||
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
|
||||
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
|
||||
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
|
||||
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
|
||||
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
|
||||
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||
PgUpstreamOAuthLinkRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> {
|
||||
PgUpstreamOAuthProviderRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> {
|
||||
PgUpstreamOAuthSessionRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn user(&mut self) -> Self::UserRepository<'_> {
|
||||
PgUserRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn user_email(&mut self) -> Self::UserEmailRepository<'_> {
|
||||
PgUserEmailRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn user_password(&mut self) -> Self::UserPasswordRepository<'_> {
|
||||
PgUserPasswordRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
|
||||
PgBrowserSessionRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
|
||||
PgOAuth2ClientRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
|
||||
PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
||||
PgOAuth2SessionRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
|
||||
PgOAuth2AccessTokenRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
|
||||
PgOAuth2RefreshTokenRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
|
||||
PgCompatSessionRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
|
||||
PgCompatSsoLoginRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
|
||||
PgCompatAccessTokenRepository::new(&mut self.txn)
|
||||
}
|
||||
|
||||
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
|
||||
PgCompatRefreshTokenRepository::new(&mut self.txn)
|
||||
}
|
||||
}
|
||||
|
@ -1,36 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tracing::Span;
|
||||
|
||||
pub trait ExecuteExt<'q, DB>: Sized {
|
||||
/// Records the statement as `db.statement` in the current span
|
||||
fn traced(self) -> Self {
|
||||
self.record(&Span::current())
|
||||
}
|
||||
|
||||
/// Records the statement as `db.statement` in the given span
|
||||
fn record(self, span: &Span) -> Self;
|
||||
}
|
||||
|
||||
impl<'q, DB, T> ExecuteExt<'q, DB> for T
|
||||
where
|
||||
T: sqlx::Execute<'q, DB>,
|
||||
DB: sqlx::Database,
|
||||
{
|
||||
fn record(self, span: &Span) -> Self {
|
||||
span.record("db.statement", self.sql());
|
||||
self
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,18 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UpstreamOAuthLinkRepository: Send + Sync {
|
||||
@ -63,241 +56,3 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUpstreamOAuthLinkRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUpstreamOAuthLinkRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct LinkLookup {
|
||||
upstream_oauth_link_id: Uuid,
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
user_id: Option<Uuid>,
|
||||
subject: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<LinkLookup> for UpstreamOAuthLink {
|
||||
fn from(value: LinkLookup) -> Self {
|
||||
UpstreamOAuthLink {
|
||||
id: Ulid::from(value.upstream_oauth_link_id),
|
||||
provider_id: Ulid::from(value.upstream_oauth_provider_id),
|
||||
user_id: value.user_id.map(Ulid::from),
|
||||
subject: value.subject,
|
||||
created_at: value.created_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_link.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
WHERE upstream_oauth_link_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?
|
||||
.map(Into::into);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.find_by_subject",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_link.subject = subject,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_subject(
|
||||
&mut self,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: &str,
|
||||
) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
WHERE upstream_oauth_provider_id = $1
|
||||
AND subject = $2
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
subject,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?
|
||||
.map(Into::into);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_link.id,
|
||||
upstream_oauth_link.subject = subject,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: String,
|
||||
) -> Result<UpstreamOAuthLink, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO upstream_oauth_links (
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
) VALUES ($1, $2, NULL, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
&subject,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(UpstreamOAuthLink {
|
||||
id,
|
||||
provider_id: upstream_oauth_provider.id,
|
||||
user_id: None,
|
||||
subject,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.associate_to_user",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_link.id,
|
||||
%upstream_oauth_link.subject,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn associate_to_user(
|
||||
&mut self,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
user: &User,
|
||||
) -> Result<(), Self::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_links
|
||||
SET user_id = $1
|
||||
WHERE upstream_oauth_link_id = $2
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_link.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_link_id,
|
||||
upstream_oauth_provider_id,
|
||||
user_id,
|
||||
subject,
|
||||
created_at
|
||||
FROM upstream_oauth_links
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("upstream_oauth_link_id", pagination);
|
||||
|
||||
let edges: Vec<LinkLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
|
||||
Ok(page)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -17,249 +17,6 @@ mod provider;
|
||||
mod session;
|
||||
|
||||
pub use self::{
|
||||
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
|
||||
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
|
||||
session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository},
|
||||
link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository,
|
||||
session::UpstreamOAuthSessionRepository,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::Duration;
|
||||
use oauth2_types::scope::{Scope, OPENID};
|
||||
use rand::SeedableRng;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::*;
|
||||
use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository};
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_repository(pool: PgPool) {
|
||||
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// The provider list should be empty at the start
|
||||
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
|
||||
assert!(all_providers.is_empty());
|
||||
|
||||
// Let's add a provider
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
"https://example.com/".to_owned(),
|
||||
Scope::from_iter([OPENID]),
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
"client-id".to_owned(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Look it up in the database
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.lookup(provider.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("provider to be found in the database");
|
||||
assert_eq!(provider.issuer, "https://example.com/");
|
||||
assert_eq!(provider.client_id, "client-id");
|
||||
|
||||
// Start a session
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
"some-state".to_owned(),
|
||||
None,
|
||||
"some-nonce".to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Look it up in the database
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("session to be found in the database");
|
||||
assert_eq!(session.provider_id, provider.id);
|
||||
assert_eq!(session.link_id(), None);
|
||||
assert!(session.is_pending());
|
||||
assert!(!session.is_completed());
|
||||
assert!(!session.is_consumed());
|
||||
|
||||
// Create a link
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We can look it up by its ID
|
||||
repo.upstream_oauth_link()
|
||||
.lookup(link.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("link to be found in database");
|
||||
|
||||
// or by its subject
|
||||
let link = repo
|
||||
.upstream_oauth_link()
|
||||
.find_by_subject(&provider, "a-subject")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("link to be found in database");
|
||||
assert_eq!(link.subject, "a-subject");
|
||||
assert_eq!(link.provider_id, provider.id);
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Reload the session
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("session to be found in the database");
|
||||
assert!(session.is_completed());
|
||||
assert!(!session.is_consumed());
|
||||
assert_eq!(session.link_id(), Some(link.id));
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.consume(&clock, session)
|
||||
.await
|
||||
.unwrap();
|
||||
// Reload the session
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("session to be found in the database");
|
||||
assert!(session.is_consumed());
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
repo.upstream_oauth_link()
|
||||
.associate_to_user(&link, &user)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let links = repo
|
||||
.upstream_oauth_link()
|
||||
.list_paginated(&user, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!links.has_previous_page);
|
||||
assert!(!links.has_next_page);
|
||||
assert_eq!(links.edges.len(), 1);
|
||||
assert_eq!(links.edges[0].id, link.id);
|
||||
assert_eq!(links.edges[0].user_id, Some(user.id));
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_provider_repository_pagination(pool: PgPool) {
|
||||
const ISSUER: &str = "https://example.com/";
|
||||
let scope = Scope::from_iter([OPENID]);
|
||||
|
||||
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
let mut ids = Vec::with_capacity(20);
|
||||
// Create 20 providers
|
||||
for idx in 0..20 {
|
||||
let client_id = format!("client-{idx}");
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
ISSUER.to_owned(),
|
||||
scope.clone(),
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
client_id,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
ids.push(provider.id);
|
||||
clock.advance(Duration::seconds(10));
|
||||
}
|
||||
|
||||
// Lookup the first 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the first 10 items
|
||||
assert!(page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[..10]);
|
||||
|
||||
// Lookup the next 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::first(10).after(ids[9]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the next 10 items
|
||||
assert!(!page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[10..]);
|
||||
|
||||
// Lookup the last 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::last(10))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the last 10 items
|
||||
assert!(page.has_previous_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[10..]);
|
||||
|
||||
// Lookup the previous 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::last(10).before(ids[10]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the previous 10 items
|
||||
assert!(!page.has_previous_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[..10]);
|
||||
|
||||
// Lookup 10 items between two IDs
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list_paginated(Pagination::first(10).after(ids[5]).before(ids[8]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// It returned the items in between
|
||||
assert!(!page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[6..8]);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,20 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::UpstreamOAuthProvider;
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
@ -58,247 +51,3 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
/// Get all upstream OAuth providers
|
||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUpstreamOAuthProviderRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUpstreamOAuthProviderRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct ProviderLookup {
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
issuer: String,
|
||||
scope: String,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
token_endpoint_signing_alg: Option<String>,
|
||||
token_endpoint_auth_method: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.upstream_oauth_provider_id.into();
|
||||
let scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_auth_method")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
let token_endpoint_signing_alg = value
|
||||
.token_endpoint_signing_alg
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_signing_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
issuer: value.issuer,
|
||||
scope,
|
||||
client_id: value.client_id,
|
||||
encrypted_client_secret: value.encrypted_client_secret,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
created_at: value.created_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
ProviderLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at
|
||||
FROM upstream_oauth_providers
|
||||
WHERE upstream_oauth_provider_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let res = res
|
||||
.map(UpstreamOAuthProvider::try_from)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::from)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id,
|
||||
upstream_oauth_provider.issuer = %issuer,
|
||||
upstream_oauth_provider.client_id = %client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO upstream_oauth_providers (
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&issuer,
|
||||
scope.to_string(),
|
||||
token_endpoint_auth_method.to_string(),
|
||||
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
|
||||
&client_id,
|
||||
encrypted_client_secret.as_deref(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at
|
||||
FROM upstream_oauth_providers
|
||||
WHERE 1 = 1
|
||||
"#,
|
||||
);
|
||||
|
||||
query.generate_pagination("upstream_oauth_provider_id", pagination);
|
||||
|
||||
let edges: Vec<ProviderLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).try_map(TryInto::try_into)?;
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.all",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
ProviderLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_provider_id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at
|
||||
FROM upstream_oauth_providers
|
||||
"#,
|
||||
)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
|
||||
Ok(res?)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,19 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
|
||||
UpstreamOAuthProvider,
|
||||
};
|
||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
@ -64,262 +56,3 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUpstreamOAuthSessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUpstreamOAuthSessionRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct SessionLookup {
|
||||
upstream_oauth_authorization_session_id: Uuid,
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
upstream_oauth_link_id: Option<Uuid>,
|
||||
state: String,
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
id_token: Option<String>,
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.upstream_oauth_authorization_session_id.into();
|
||||
let state = match (
|
||||
value.upstream_oauth_link_id,
|
||||
value.id_token,
|
||||
value.completed_at,
|
||||
value.consumed_at,
|
||||
) {
|
||||
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
|
||||
(Some(link_id), id_token, Some(completed_at), None) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Completed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
}
|
||||
}
|
||||
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Consumed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
consumed_at,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
provider_id: value.upstream_oauth_provider_id.into(),
|
||||
state_str: value.state,
|
||||
nonce: value.nonce,
|
||||
code_challenge_verifier: value.code_challenge_verifier,
|
||||
created_at: value.created_at,
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(
|
||||
&mut self,
|
||||
id: Ulid,
|
||||
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionLookup,
|
||||
r#"
|
||||
SELECT
|
||||
upstream_oauth_authorization_session_id,
|
||||
upstream_oauth_provider_id,
|
||||
upstream_oauth_link_id,
|
||||
state,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
id_token,
|
||||
created_at,
|
||||
completed_at,
|
||||
consumed_at
|
||||
FROM upstream_oauth_authorization_sessions
|
||||
WHERE upstream_oauth_authorization_session_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_provider.id,
|
||||
%upstream_oauth_provider.issuer,
|
||||
%upstream_oauth_provider.client_id,
|
||||
upstream_oauth_authorization_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
state_str: String,
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record(
|
||||
"upstream_oauth_authorization_session.id",
|
||||
tracing::field::display(id),
|
||||
);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO upstream_oauth_authorization_sessions (
|
||||
upstream_oauth_authorization_session_id,
|
||||
upstream_oauth_provider_id,
|
||||
state,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
created_at,
|
||||
completed_at,
|
||||
consumed_at,
|
||||
id_token
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
&state_str,
|
||||
code_challenge_verifier.as_deref(),
|
||||
nonce,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(UpstreamOAuthAuthorizationSession {
|
||||
id,
|
||||
state: UpstreamOAuthAuthorizationSessionState::default(),
|
||||
provider_id: upstream_oauth_provider.id,
|
||||
state_str,
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.complete_with_link",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_authorization_session.id,
|
||||
%upstream_oauth_link.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn complete_with_link(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let completed_at = clock.now();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_authorization_sessions
|
||||
SET upstream_oauth_link_id = $1,
|
||||
completed_at = $2,
|
||||
id_token = $3
|
||||
WHERE upstream_oauth_authorization_session_id = $4
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
completed_at,
|
||||
id_token,
|
||||
Uuid::from(upstream_oauth_authorization_session.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
|
||||
.complete(completed_at, upstream_oauth_link, id_token)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
|
||||
/// Mark a session as consumed
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_authorization_session.consume",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_authorization_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_authorization_sessions
|
||||
SET consumed_at = $1
|
||||
WHERE upstream_oauth_authorization_session_id = $2
|
||||
"#,
|
||||
consumed_at,
|
||||
Uuid::from(upstream_oauth_authorization_session.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,19 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState};
|
||||
use mas_data_model::{User, UserEmail, UserEmailVerification};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UserEmailRepository: Send + Sync {
|
||||
@ -82,529 +74,3 @@ pub trait UserEmailRepository: Send + Sync {
|
||||
verification: UserEmailVerification,
|
||||
) -> Result<UserEmailVerification, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUserEmailRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUserEmailRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||
struct UserEmailLookup {
|
||||
user_email_id: Uuid,
|
||||
user_id: Uuid,
|
||||
email: String,
|
||||
created_at: DateTime<Utc>,
|
||||
confirmed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl From<UserEmailLookup> for UserEmail {
|
||||
fn from(e: UserEmailLookup) -> UserEmail {
|
||||
UserEmail {
|
||||
id: e.user_email_id.into(),
|
||||
user_id: e.user_id.into(),
|
||||
email: e.email,
|
||||
created_at: e.created_at,
|
||||
confirmed_at: e.confirmed_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct UserEmailConfirmationCodeLookup {
|
||||
user_email_confirmation_code_id: Uuid,
|
||||
user_email_id: Uuid,
|
||||
code: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl UserEmailConfirmationCodeLookup {
|
||||
fn into_verification(self, clock: &Clock) -> UserEmailVerification {
|
||||
let now = clock.now();
|
||||
let state = if let Some(when) = self.consumed_at {
|
||||
UserEmailVerificationState::AlreadyUsed { when }
|
||||
} else if self.expires_at < now {
|
||||
UserEmailVerificationState::Expired {
|
||||
when: self.expires_at,
|
||||
}
|
||||
} else {
|
||||
UserEmailVerificationState::Valid
|
||||
};
|
||||
|
||||
UserEmailVerification {
|
||||
id: self.user_email_confirmation_code_id.into(),
|
||||
user_email_id: self.user_email_id.into(),
|
||||
code: self.code,
|
||||
state,
|
||||
created_at: self.created_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
user_email.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailLookup,
|
||||
r#"
|
||||
SELECT user_email_id
|
||||
, user_id
|
||||
, email
|
||||
, created_at
|
||||
, confirmed_at
|
||||
FROM user_emails
|
||||
|
||||
WHERE user_email_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(user_email) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(user_email.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.find",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
user_email.email = email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailLookup,
|
||||
r#"
|
||||
SELECT user_email_id
|
||||
, user_id
|
||||
, email
|
||||
, created_at
|
||||
, confirmed_at
|
||||
FROM user_emails
|
||||
|
||||
WHERE user_id = $1 AND email = $2
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
email,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(user_email) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(user_email.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.get_primary",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn get_primary(&mut self, user: &User) -> Result<Option<UserEmail>, Self::Error> {
|
||||
let Some(id) = user.primary_user_email_id else { return Ok(None) };
|
||||
|
||||
let user_email = self.lookup(id).await?.ok_or_else(|| {
|
||||
DatabaseInconsistencyError::on("users")
|
||||
.column("primary_user_email_id")
|
||||
.row(user.id)
|
||||
})?;
|
||||
|
||||
Ok(Some(user_email))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.all",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailLookup,
|
||||
r#"
|
||||
SELECT user_email_id
|
||||
, user_id
|
||||
, email
|
||||
, created_at
|
||||
, confirmed_at
|
||||
FROM user_emails
|
||||
|
||||
WHERE user_id = $1
|
||||
|
||||
ORDER BY email ASC
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(res.into_iter().map(Into::into).collect())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UserEmail>, DatabaseError> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT user_email_id
|
||||
, user_id
|
||||
, email
|
||||
, created_at
|
||||
, confirmed_at
|
||||
FROM user_emails
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("ue.user_email_id", pagination);
|
||||
|
||||
let edges: Vec<UserEmailLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).map(UserEmail::from);
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.count",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn count(&mut self, user: &User) -> Result<usize, Self::Error> {
|
||||
let res = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM user_emails
|
||||
WHERE user_id = $1
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let res = res.unwrap_or_default();
|
||||
|
||||
Ok(res
|
||||
.try_into()
|
||||
.map_err(DatabaseError::to_invalid_operation)?)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
user_email.id,
|
||||
user_email.email = email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
email: String,
|
||||
) -> Result<UserEmail, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("user_email.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_emails (user_email_id, user_id, email, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user.id),
|
||||
&email,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(UserEmail {
|
||||
id,
|
||||
user_id: user.id,
|
||||
email,
|
||||
created_at,
|
||||
confirmed_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.remove",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
user.id = %user_email.user_id,
|
||||
%user_email.id,
|
||||
%user_email.email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
|
||||
let span = info_span!(
|
||||
"db.user_email.remove.codes",
|
||||
db.statement = tracing::field::Empty
|
||||
);
|
||||
sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM user_email_confirmation_codes
|
||||
WHERE user_email_id = $1
|
||||
"#,
|
||||
Uuid::from(user_email.id),
|
||||
)
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM user_emails
|
||||
WHERE user_email_id = $1
|
||||
"#,
|
||||
Uuid::from(user_email.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn mark_as_verified(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut user_email: UserEmail,
|
||||
) -> Result<UserEmail, Self::Error> {
|
||||
let confirmed_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE user_emails
|
||||
SET confirmed_at = $2
|
||||
WHERE user_email_id = $1
|
||||
"#,
|
||||
Uuid::from(user_email.id),
|
||||
confirmed_at,
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
user_email.confirmed_at = Some(confirmed_at);
|
||||
Ok(user_email)
|
||||
}
|
||||
|
||||
async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE users
|
||||
SET primary_user_email_id = user_emails.user_email_id
|
||||
FROM user_emails
|
||||
WHERE user_emails.user_email_id = $1
|
||||
AND users.user_id = user_emails.user_id
|
||||
"#,
|
||||
Uuid::from(user_email.id),
|
||||
)
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.add_verification_code",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_email.id,
|
||||
%user_email.email,
|
||||
user_email_verification.id,
|
||||
user_email_verification.code = code,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add_verification_code(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
user_email: &UserEmail,
|
||||
max_age: chrono::Duration,
|
||||
code: String,
|
||||
) -> Result<UserEmailVerification, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
|
||||
let expires_at = created_at + max_age;
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_email_confirmation_codes
|
||||
(user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_email.id),
|
||||
code,
|
||||
created_at,
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let verification = UserEmailVerification {
|
||||
id,
|
||||
user_email_id: user_email.id,
|
||||
code,
|
||||
created_at,
|
||||
state: UserEmailVerificationState::Valid,
|
||||
};
|
||||
|
||||
Ok(verification)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.find_verification_code",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_email.id,
|
||||
user.id = %user_email.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_verification_code(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
user_email: &UserEmail,
|
||||
code: &str,
|
||||
) -> Result<Option<UserEmailVerification>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailConfirmationCodeLookup,
|
||||
r#"
|
||||
SELECT user_email_confirmation_code_id
|
||||
, user_email_id
|
||||
, code
|
||||
, created_at
|
||||
, expires_at
|
||||
, consumed_at
|
||||
FROM user_email_confirmation_codes
|
||||
WHERE code = $1
|
||||
AND user_email_id = $2
|
||||
"#,
|
||||
code,
|
||||
Uuid::from(user_email.id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into_verification(clock)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_email.consume_verification_code",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_email_verification.id,
|
||||
user_email.id = %user_email_verification.user_email_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume_verification_code(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut user_email_verification: UserEmailVerification,
|
||||
) -> Result<UserEmailVerification, Self::Error> {
|
||||
if !matches!(
|
||||
user_email_verification.state,
|
||||
UserEmailVerificationState::Valid
|
||||
) {
|
||||
return Err(DatabaseError::invalid_operation());
|
||||
}
|
||||
|
||||
let consumed_at = clock.now();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE user_email_confirmation_codes
|
||||
SET consumed_at = $2
|
||||
WHERE user_email_confirmation_code_id = $1
|
||||
"#,
|
||||
Uuid::from(user_email_verification.id),
|
||||
consumed_at
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
user_email_verification.state =
|
||||
UserEmailVerificationState::AlreadyUsed { when: consumed_at };
|
||||
|
||||
Ok(user_email_verification)
|
||||
}
|
||||
}
|
||||
|
@ -13,26 +13,18 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::User;
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
use crate::Clock;
|
||||
|
||||
mod email;
|
||||
mod password;
|
||||
mod session;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use self::{
|
||||
email::{PgUserEmailRepository, UserEmailRepository},
|
||||
password::{PgUserPasswordRepository, UserPasswordRepository},
|
||||
session::{BrowserSessionRepository, PgBrowserSessionRepository},
|
||||
email::UserEmailRepository, password::UserPasswordRepository, session::BrowserSessionRepository,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
@ -49,170 +41,3 @@ pub trait UserRepository: Send + Sync {
|
||||
) -> Result<User, Self::Error>;
|
||||
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUserRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUserRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct UserLookup {
|
||||
user_id: Uuid,
|
||||
username: String,
|
||||
primary_user_email_id: Option<Uuid>,
|
||||
|
||||
#[allow(dead_code)]
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl From<UserLookup> for User {
|
||||
fn from(value: UserLookup) -> Self {
|
||||
let id = value.user_id.into();
|
||||
Self {
|
||||
id,
|
||||
username: value.username,
|
||||
sub: id.to_string(),
|
||||
primary_user_email_id: value.primary_user_email_id.map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UserRepository for PgUserRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
user.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserLookup,
|
||||
r#"
|
||||
SELECT user_id
|
||||
, username
|
||||
, primary_user_email_id
|
||||
, created_at
|
||||
FROM users
|
||||
WHERE user_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.find_by_username",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
user.username = username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserLookup,
|
||||
r#"
|
||||
SELECT user_id
|
||||
, username
|
||||
, primary_user_email_id
|
||||
, created_at
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
"#,
|
||||
username,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
user.username = username,
|
||||
user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
username: String,
|
||||
) -> Result<User, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("user.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO users (user_id, username, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
username,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(User {
|
||||
id,
|
||||
username,
|
||||
sub: id.to_string(),
|
||||
primary_user_email_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.exists",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
user.username = username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
|
||||
let exists = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM users WHERE username = $1
|
||||
) AS "exists!"
|
||||
"#,
|
||||
username
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(exists)
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,16 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{Password, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
use crate::Clock;
|
||||
|
||||
#[async_trait]
|
||||
pub trait UserPasswordRepository: Send + Sync {
|
||||
@ -39,134 +33,3 @@ pub trait UserPasswordRepository: Send + Sync {
|
||||
upgraded_from: Option<&Password>,
|
||||
) -> Result<Password, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgUserPasswordRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgUserPasswordRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct UserPasswordLookup {
|
||||
user_password_id: Uuid,
|
||||
hashed_password: String,
|
||||
version: i32,
|
||||
upgraded_from_id: Option<Uuid>,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> UserPasswordRepository for PgUserPasswordRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_password.active",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn active(&mut self, user: &User) -> Result<Option<Password>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserPasswordLookup,
|
||||
r#"
|
||||
SELECT up.user_password_id
|
||||
, up.hashed_password
|
||||
, up.version
|
||||
, up.upgraded_from_id
|
||||
, up.created_at
|
||||
FROM user_passwords up
|
||||
WHERE up.user_id = $1
|
||||
ORDER BY up.created_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let id = Ulid::from(res.user_password_id);
|
||||
|
||||
let version = res.version.try_into().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("user_passwords")
|
||||
.column("version")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let upgraded_from_id = res.upgraded_from_id.map(Ulid::from);
|
||||
let created_at = res.created_at;
|
||||
let hashed_password = res.hashed_password;
|
||||
|
||||
Ok(Some(Password {
|
||||
id,
|
||||
hashed_password,
|
||||
version,
|
||||
upgraded_from_id,
|
||||
created_at,
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user_password.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
user_password.id,
|
||||
user_password.version = version,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
version: u16,
|
||||
hashed_password: String,
|
||||
upgraded_from: Option<&Password>,
|
||||
) -> Result<Password, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("user_password.id", tracing::field::display(id));
|
||||
|
||||
let upgraded_from_id = upgraded_from.map(|p| p.id);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_passwords
|
||||
(user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user.id),
|
||||
hashed_password,
|
||||
i32::from(version),
|
||||
upgraded_from_id.map(Uuid::from),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(Password {
|
||||
id,
|
||||
hashed_password,
|
||||
version,
|
||||
upgraded_from_id,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,18 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
|
||||
use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||
};
|
||||
use crate::{pagination::Page, Clock, Pagination};
|
||||
|
||||
#[async_trait]
|
||||
pub trait BrowserSessionRepository: Send + Sync {
|
||||
@ -65,351 +58,3 @@ pub trait BrowserSessionRepository: Send + Sync {
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
) -> Result<BrowserSession, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgBrowserSessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgBrowserSessionRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct SessionLookup {
|
||||
user_session_id: Uuid,
|
||||
user_session_created_at: DateTime<Utc>,
|
||||
user_session_finished_at: Option<DateTime<Utc>>,
|
||||
user_id: Uuid,
|
||||
user_username: String,
|
||||
user_primary_user_email_id: Option<Uuid>,
|
||||
last_authentication_id: Option<Uuid>,
|
||||
last_authd_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<SessionLookup> for BrowserSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = Ulid::from(value.user_id);
|
||||
let user = User {
|
||||
id,
|
||||
username: value.user_username,
|
||||
sub: id.to_string(),
|
||||
primary_user_email_id: value.user_primary_user_email_id.map(Into::into),
|
||||
};
|
||||
|
||||
let last_authentication = match (value.last_authentication_id, value.last_authd_at) {
|
||||
(Some(id), Some(created_at)) => Some(Authentication {
|
||||
id: id.into(),
|
||||
created_at,
|
||||
}),
|
||||
(None, None) => None,
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError::on(
|
||||
"user_session_authentications",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(BrowserSession {
|
||||
id: value.user_session_id.into(),
|
||||
user,
|
||||
created_at: value.user_session_created_at,
|
||||
finished_at: value.user_session_finished_at,
|
||||
last_authentication,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
user_session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionLookup,
|
||||
r#"
|
||||
SELECT s.user_session_id
|
||||
, s.created_at AS "user_session_created_at"
|
||||
, s.finished_at AS "user_session_finished_at"
|
||||
, u.user_id
|
||||
, u.username AS "user_username"
|
||||
, u.primary_user_email_id AS "user_primary_user_email_id"
|
||||
, a.user_session_authentication_id AS "last_authentication_id?"
|
||||
, a.created_at AS "last_authd_at?"
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
USING (user_id)
|
||||
LEFT JOIN user_session_authentications a
|
||||
USING (user_session_id)
|
||||
WHERE s.user_session_id = $1
|
||||
ORDER BY a.created_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
user_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
) -> Result<BrowserSession, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("user_session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_sessions (user_session_id, user_id, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user.id),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let session = BrowserSession {
|
||||
id,
|
||||
// XXX
|
||||
user: user.clone(),
|
||||
created_at,
|
||||
finished_at: None,
|
||||
last_authentication: None,
|
||||
};
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.finish",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_session.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn finish(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut user_session: BrowserSession,
|
||||
) -> Result<BrowserSession, Self::Error> {
|
||||
let finished_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE user_sessions
|
||||
SET finished_at = $1
|
||||
WHERE user_session_id = $2
|
||||
"#,
|
||||
finished_at,
|
||||
Uuid::from(user_session.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
user_session.finished_at = Some(finished_at);
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(user_session)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.list_active_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_active_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<BrowserSession>, Self::Error> {
|
||||
// TODO: ordering of last authentication is wrong
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT DISTINCT ON (s.user_session_id)
|
||||
s.user_session_id,
|
||||
u.user_id,
|
||||
u.username,
|
||||
s.created_at,
|
||||
a.user_session_authentication_id AS "last_authentication_id",
|
||||
a.created_at AS "last_authd_at",
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
USING (user_id)
|
||||
LEFT JOIN user_session_authentications a
|
||||
USING (user_session_id)
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE s.finished_at IS NULL AND s.user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("s.user_session_id", pagination);
|
||||
|
||||
let edges: Vec<SessionLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination
|
||||
.process(edges)
|
||||
.try_map(BrowserSession::try_from)?;
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.count_active",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error> {
|
||||
let res = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT COUNT(*) as "count!"
|
||||
FROM user_sessions s
|
||||
WHERE s.user_id = $1 AND s.finished_at IS NULL
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
res.try_into().map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.authenticate_with_password",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_session.id,
|
||||
%user_password.id,
|
||||
user_session_authentication.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn authenticate_with_password(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
mut user_session: BrowserSession,
|
||||
user_password: &Password,
|
||||
) -> Result<BrowserSession, Self::Error> {
|
||||
let _user_password = user_password;
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record(
|
||||
"user_session_authentication.id",
|
||||
tracing::field::display(id),
|
||||
);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_session_authentications
|
||||
(user_session_authentication_id, user_session_id, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_session.id),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
user_session.last_authentication = Some(Authentication { id, created_at });
|
||||
|
||||
Ok(user_session)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.authenticate_with_upstream",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_session.id,
|
||||
%upstream_oauth_link.id,
|
||||
user_session_authentication.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn authenticate_with_upstream(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
mut user_session: BrowserSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
) -> Result<BrowserSession, Self::Error> {
|
||||
let _upstream_oauth_link = upstream_oauth_link;
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record(
|
||||
"user_session_authentication.id",
|
||||
tracing::field::display(id),
|
||||
);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_session_authentications
|
||||
(user_session_authentication_id, user_session_id, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_session.id),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
user_session.last_authentication = Some(Authentication { id, created_at });
|
||||
|
||||
Ok(user_session)
|
||||
}
|
||||
}
|
||||
|
@ -1,394 +0,0 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::Duration;
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::{
|
||||
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
Clock, PgRepository, Repository,
|
||||
};
|
||||
|
||||
/// Test the user repository, by adding and looking up a user
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_user_repo(pool: PgPool) {
|
||||
const USERNAME: &str = "john";
|
||||
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
|
||||
// Initially, the user shouldn't exist
|
||||
assert!(!repo.user().exists(USERNAME).await.unwrap());
|
||||
assert!(repo
|
||||
.user()
|
||||
.find_by_username(USERNAME)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
// Adding the user should work
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, USERNAME.to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// And now it should exist
|
||||
assert!(repo.user().exists(USERNAME).await.unwrap());
|
||||
assert!(repo
|
||||
.user()
|
||||
.find_by_username(USERNAME)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(repo.user().lookup(user.id).await.unwrap().is_some());
|
||||
|
||||
// Adding a second time should give a conflict
|
||||
assert!(repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, USERNAME.to_owned())
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
/// Test the user email repository, by trying out most of its methods
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_user_email_repo(pool: PgPool) {
|
||||
const USERNAME: &str = "john";
|
||||
const CODE: &str = "012345";
|
||||
const CODE2: &str = "543210";
|
||||
const EMAIL: &str = "john@example.com";
|
||||
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, USERNAME.to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// The user email should not exist yet
|
||||
assert!(repo
|
||||
.user_email()
|
||||
.find(&user, &EMAIL)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
assert_eq!(repo.user_email().count(&user).await.unwrap(), 0);
|
||||
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.add(&mut rng, &clock, &user, EMAIL.to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(user_email.user_id, user.id);
|
||||
assert_eq!(user_email.email, EMAIL);
|
||||
assert!(user_email.confirmed_at.is_none());
|
||||
|
||||
assert_eq!(repo.user_email().count(&user).await.unwrap(), 1);
|
||||
|
||||
assert!(repo
|
||||
.user_email()
|
||||
.find(&user, &EMAIL)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(user_email.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user email was not found");
|
||||
|
||||
assert_eq!(user_email.user_id, user.id);
|
||||
assert_eq!(user_email.email, EMAIL);
|
||||
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.add_verification_code(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user_email,
|
||||
Duration::hours(8),
|
||||
CODE.to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let verification_id = verification.id;
|
||||
assert_eq!(verification.user_email_id, user_email.id);
|
||||
assert_eq!(verification.code, CODE);
|
||||
|
||||
// A single user email can have multiple verification at the same time
|
||||
let _verification2 = repo
|
||||
.user_email()
|
||||
.add_verification_code(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user_email,
|
||||
Duration::hours(8),
|
||||
CODE2.to_owned(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.find_verification_code(&clock, &user_email, CODE)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user email verification was not found");
|
||||
|
||||
assert_eq!(verification.id, verification_id);
|
||||
assert_eq!(verification.user_email_id, user_email.id);
|
||||
assert_eq!(verification.code, CODE);
|
||||
|
||||
// Consuming the verification code
|
||||
repo.user_email()
|
||||
.consume_verification_code(&clock, verification)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Mark the email as verified
|
||||
repo.user_email()
|
||||
.mark_as_verified(&clock, user_email)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Reload the user_email
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.find(&user, &EMAIL)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user email was not found");
|
||||
|
||||
// The email should be marked as verified now
|
||||
assert!(user_email.confirmed_at.is_some());
|
||||
|
||||
// Reload the verification
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.find_verification_code(&clock, &user_email, CODE)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user email verification was not found");
|
||||
|
||||
// Consuming a second time should not work
|
||||
assert!(repo
|
||||
.user_email()
|
||||
.consume_verification_code(&clock, verification)
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
// The user shouldn't have a primary email yet
|
||||
assert!(repo
|
||||
.user_email()
|
||||
.get_primary(&user)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
repo.user_email().set_as_primary(&user_email).await.unwrap();
|
||||
|
||||
// Reload the user
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user was not found");
|
||||
|
||||
// Now it should have one
|
||||
assert!(repo
|
||||
.user_email()
|
||||
.get_primary(&user)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
|
||||
// Deleting the user email should work
|
||||
repo.user_email().remove(user_email).await.unwrap();
|
||||
assert_eq!(repo.user_email().count(&user).await.unwrap(), 0);
|
||||
|
||||
// Reload the user
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user was not found");
|
||||
|
||||
// The primary user email should be gone
|
||||
assert!(repo
|
||||
.user_email()
|
||||
.get_primary(&user)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_user_password_repo(pool: PgPool) {
|
||||
const USERNAME: &str = "john";
|
||||
const FIRST_PASSWORD_HASH: &str = "doesntmatter";
|
||||
const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter";
|
||||
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, USERNAME.to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// User should have no active password
|
||||
assert!(repo.user_password().active(&user).await.unwrap().is_none());
|
||||
|
||||
// Insert a first password
|
||||
let first_password = repo
|
||||
.user_password()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user,
|
||||
1,
|
||||
FIRST_PASSWORD_HASH.to_owned(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// User should now have an active password
|
||||
let first_password_lookup = repo
|
||||
.user_password()
|
||||
.active(&user)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user should have an active password");
|
||||
|
||||
assert_eq!(first_password.id, first_password_lookup.id);
|
||||
assert_eq!(first_password_lookup.hashed_password, FIRST_PASSWORD_HASH);
|
||||
assert_eq!(first_password_lookup.version, 1);
|
||||
assert_eq!(first_password_lookup.upgraded_from_id, None);
|
||||
|
||||
// Getting the last inserted password is based on the clock, so we need to
|
||||
// advance it
|
||||
clock.advance(Duration::seconds(10));
|
||||
|
||||
let second_password = repo
|
||||
.user_password()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user,
|
||||
2,
|
||||
SECOND_PASSWORD_HASH.to_owned(),
|
||||
Some(&first_password),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// User should now have an active password
|
||||
let second_password_lookup = repo
|
||||
.user_password()
|
||||
.active(&user)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user should have an active password");
|
||||
|
||||
assert_eq!(second_password.id, second_password_lookup.id);
|
||||
assert_eq!(second_password_lookup.hashed_password, SECOND_PASSWORD_HASH);
|
||||
assert_eq!(second_password_lookup.version, 2);
|
||||
assert_eq!(
|
||||
second_password_lookup.upgraded_from_id,
|
||||
Some(first_password.id)
|
||||
);
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_user_session(pool: PgPool) {
|
||||
const USERNAME: &str = "john";
|
||||
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = Clock::mock();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, USERNAME.to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0);
|
||||
|
||||
let session = repo
|
||||
.browser_session()
|
||||
.add(&mut rng, &clock, &user)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(session.user.id, user.id);
|
||||
assert!(session.finished_at.is_none());
|
||||
|
||||
assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 1);
|
||||
|
||||
let session_lookup = repo
|
||||
.browser_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user session not found");
|
||||
|
||||
assert_eq!(session_lookup.id, session.id);
|
||||
assert_eq!(session_lookup.user.id, user.id);
|
||||
assert!(session_lookup.finished_at.is_none());
|
||||
|
||||
// Finish the session
|
||||
repo.browser_session()
|
||||
.finish(&clock, session_lookup)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// The active session counter is back to 0
|
||||
assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0);
|
||||
|
||||
// Reload the session
|
||||
let session_lookup = repo
|
||||
.browser_session()
|
||||
.lookup(session.id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("user session not found");
|
||||
|
||||
assert_eq!(session_lookup.id, session.id);
|
||||
assert_eq!(session_lookup.user.id, user.id);
|
||||
// This time the session is finished
|
||||
assert!(session_lookup.finished_at.is_some());
|
||||
}
|
Reference in New Issue
Block a user