1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: repository pattern for the compat layer

This commit is contained in:
Quentin Gliech
2023-01-12 15:41:26 +01:00
parent 9f0c9f1466
commit 36396c0b45
18 changed files with 1738 additions and 1191 deletions

View File

@ -1,757 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User,
};
use rand::Rng;
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
use tracing::{info_span, Instrument};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
pagination::{process_page, QueryBuilderExt},
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
struct CompatSessionLookup {
compat_session_id: Uuid,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
#[tracing::instrument(skip_all, err)]
pub async fn lookup_compat_session(
executor: impl PgExecutor<'_>,
session_id: Ulid,
) -> Result<Option<CompatSession>, DatabaseError> {
let res = sqlx::query_as!(
CompatSessionLookup,
r#"
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
Uuid::from(session_id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = res.compat_session_id.into();
let device = Device::try_from(res.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match res.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: res.user_id.into(),
device,
created_at: res.created_at,
};
Ok(Some(session))
}
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[tracing::instrument(skip_all, err)]
pub async fn find_compat_access_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<CompatAccessToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE access_token = $1
"#,
token,
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(
compat_access_token.id = %id,
),
err,
)]
pub async fn lookup_compat_access_token(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<CompatAccessToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
pub struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
#[tracing::instrument(skip_all, err)]
#[allow(clippy::type_complexity)]
pub async fn find_compat_refresh_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<CompatRefreshToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
token,
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None); };
let state = match res.consumed_at {
None => CompatRefreshTokenState::Valid,
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
};
let refresh_token = CompatRefreshToken {
id: res.compat_refresh_token_id.into(),
state,
session_id: res.compat_session_id.into(),
access_token_id: res.compat_access_token_id.into(),
token: res.refresh_token,
created_at: res.created_at,
};
Ok(Some(refresh_token))
}
#[tracing::instrument(
skip_all,
fields(
compat_session.id = %session.id,
compat_session.device.id = session.device.as_str(),
compat_access_token.id,
user.id = %session.user_id,
),
err,
)]
pub async fn add_compat_access_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
token,
created_at,
expires_at,
)
.execute(executor)
.instrument(tracing::info_span!("Insert compat access token"))
.await?;
Ok(CompatAccessToken {
id,
session_id: session.id,
token,
created_at,
expires_at,
})
}
#[tracing::instrument(
skip_all,
fields(
compat_access_token.id = %access_token.id,
),
err,
)]
pub async fn expire_compat_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: CompatAccessToken,
) -> Result<(), DatabaseError> {
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
Uuid::from(access_token.id),
expires_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
}
#[tracing::instrument(
skip_all,
fields(
compat_session.id = %session.id,
compat_session.device.id = session.device.as_str(),
compat_access_token.id = %access_token.id,
compat_refresh_token.id,
user.id = %session.user_id,
),
err,
)]
pub async fn add_compat_refresh_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession,
access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
token,
created_at,
)
.execute(executor)
.instrument(tracing::info_span!("Insert compat refresh token"))
.await?;
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: session.id,
access_token_id: access_token.id,
token,
created_at,
})
}
#[tracing::instrument(
skip_all,
fields(%compat_session.id),
err,
)]
pub async fn end_compat_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, DatabaseError> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
finished_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
#[tracing::instrument(
skip_all,
fields(
compat_refresh_token.id = %refresh_token.id,
),
err,
)]
pub async fn consume_compat_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: CompatRefreshToken,
) -> Result<(), DatabaseError> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
}
#[tracing::instrument(
skip_all,
fields(
compat_sso_login.id,
compat_sso_login.redirect_uri = %redirect_uri,
),
err,
)]
pub async fn insert_compat_sso_login(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.execute(executor)
.instrument(tracing::info_span!("Insert compat SSO login"))
.await?;
Ok(CompatSsoLogin {
id,
login_token,
redirect_uri,
created_at,
state: CompatSsoLoginState::Pending,
})
}
#[derive(sqlx::FromRow)]
struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid,
compat_sso_login_token: String,
compat_sso_login_redirect_uri: String,
compat_sso_login_created_at: DateTime<Utc>,
compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
let id = res.compat_sso_login_id.into();
let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| {
DatabaseInconsistencyError::on("compat_sso_logins")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let state = match (
res.compat_sso_login_fulfilled_at,
res.compat_sso_login_exchanged_at,
res.compat_session_id,
) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
};
Ok(CompatSsoLogin {
id,
login_token: res.compat_sso_login_token,
redirect_uri,
created_at: res.compat_sso_login_created_at,
state,
})
}
}
#[tracing::instrument(
skip_all,
fields(
compat_sso_login.id = %id,
),
err,
)]
pub async fn get_compat_sso_login_by_id(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT cl.compat_sso_login_id
, cl.login_token AS "compat_sso_login_token"
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
WHERE cl.compat_sso_login_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.instrument(tracing::info_span!("Lookup compat SSO login"))
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
%user.username,
),
err,
)]
pub async fn get_paginated_user_compat_sso_logins(
executor: impl PgExecutor<'_>,
user: &User,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<CompatSsoLogin>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
, cl.login_token AS "compat_sso_login_token"
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
"#,
);
query
.push(" WHERE cs.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", before, after, first, last)?;
let span = info_span!(
"Fetch paginated user compat SSO logins",
db.statement = query.sql()
);
let page: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.fetch_all(executor)
.instrument(span)
.await?;
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
let page: Result<Vec<_>, _> = page.into_iter().map(TryInto::try_into).collect();
Ok((has_previous_page, has_next_page, page?))
}
#[tracing::instrument(skip_all, err)]
pub async fn get_compat_sso_login_by_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT cl.compat_sso_login_id
, cl.login_token AS "compat_sso_login_token"
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
WHERE cl.login_token = $1
"#,
token,
)
.fetch_one(executor)
.instrument(tracing::info_span!("Lookup compat SSO login"))
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
compat_session.id,
compat_session.device.id = device.as_str(),
),
err,
)]
pub async fn start_compat_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, DatabaseError> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
device.as_str(),
created_at,
)
.execute(executor)
.await?;
Ok(CompatSession {
id,
state: CompatSessionState::default(),
user_id: user.id,
device,
created_at,
})
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
%compat_sso_login.id,
%compat_sso_login.redirect_uri,
compat_session.id,
compat_session.device.id = device.as_str(),
),
err,
)]
pub async fn fullfill_compat_sso_login(
conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User,
compat_sso_login: CompatSsoLogin,
device: Device,
) -> Result<CompatSsoLogin, DatabaseError> {
if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) {
return Err(DatabaseError::invalid_operation());
};
let mut txn = conn.begin().await?;
let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?;
let session_id = session.id;
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, &session)
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
compat_session_id = $2,
fulfilled_at = $3
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
Uuid::from(session_id),
fulfilled_at,
)
.execute(&mut txn)
.instrument(tracing::info_span!("Update compat SSO login"))
.await?;
txn.commit().await?;
Ok(compat_sso_login)
}
#[tracing::instrument(
skip_all,
fields(
%compat_sso_login.id,
%compat_sso_login.redirect_uri,
),
err,
)]
pub async fn mark_compat_sso_login_as_exchanged(
executor: impl PgExecutor<'_>,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, DatabaseError> {
let exchanged_at = clock.now();
let compat_sso_login = compat_sso_login
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = $2
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
exchanged_at,
)
.execute(executor)
.instrument(tracing::info_span!("Update compat SSO login"))
.await?;
Ok(compat_sso_login)
}

View File

@ -0,0 +1,246 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{CompatAccessToken, CompatSession};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[async_trait]
pub trait CompatAccessTokenRepository: Send + Sync {
type Error;
/// Lookup a compat access token by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error>;
/// Find a compat access token by its token
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<CompatAccessToken>, Self::Error>;
/// Add a new compat access token to the database
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, Self::Error>;
/// Set the expiration time of the compat access token to now
async fn expire(
&mut self,
clock: &Clock,
compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error>;
}
pub struct PgCompatAccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatAccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_access_token.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.add",
skip_all,
fields(
db.statement,
compat_access_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatAccessToken {
id,
session_id: compat_session.id,
token,
created_at,
expires_at,
})
}
#[tracing::instrument(
name = "db.compat_access_token.expire",
skip_all,
fields(
db.statement,
%compat_access_token.id,
compat_session.id = %compat_access_token.session_id,
),
err,
)]
async fn expire(
&mut self,
clock: &Clock,
mut compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error> {
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
Uuid::from(compat_access_token.id),
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
compat_access_token.expires_at = Some(expires_at);
Ok(compat_access_token)
}
}

View File

@ -0,0 +1,25 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod access_token;
mod refresh_token;
mod session;
mod sso_login;
pub use self::{
access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository},
refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository},
session::{CompatSessionRepository, PgCompatSessionRepository},
sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository},
};

View File

@ -0,0 +1,260 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[async_trait]
pub trait CompatRefreshTokenRepository: Send + Sync {
type Error;
/// Lookup a compat refresh token by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error>;
/// Find a compat refresh token by its token
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<CompatRefreshToken>, Self::Error>;
/// Add a new compat refresh token to the database
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, Self::Error>;
/// Consume a compat refresh token
async fn consume(
&mut self,
clock: &Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error>;
}
pub struct PgCompatRefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatRefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
fn from(value: CompatRefreshTokenLookup) -> Self {
let state = match value.consumed_at {
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
None => CompatRefreshTokenState::Valid,
};
Self {
id: value.compat_refresh_token_id.into(),
state,
session_id: value.compat_session_id.into(),
token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.compat_access_token_id.into(),
}
}
}
#[async_trait]
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_refresh_token.lookup",
skip_all,
fields(
db.statement,
compat_refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.add",
skip_all,
fields(
db.statement,
compat_refresh_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
Uuid::from(compat_access_token.id),
token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: compat_session.id,
access_token_id: compat_access_token.id,
token,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_refresh_token.consume",
skip_all,
fields(
db.statement,
%compat_refresh_token.id,
compat_session.id = %compat_refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(compat_refresh_token.id),
consumed_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_refresh_token = compat_refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_refresh_token)
}
}

View File

@ -0,0 +1,220 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[async_trait]
pub trait CompatSessionRepository: Send + Sync {
type Error;
/// Lookup a compat session by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error>;
/// Start a new compat session
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error>;
/// End a compat session
async fn finish(
&mut self,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error>;
}
pub struct PgCompatSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatSessionLookup {
compat_session_id: Uuid,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<CompatSessionLookup> for CompatSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
let id = value.compat_session_id.into();
let device = Device::try_from(value.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: value.user_id.into(),
device,
created_at: value.created_at,
};
Ok(session)
}
}
#[async_trait]
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_session.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
let res = sqlx::query_as!(
CompatSessionLookup,
r#"
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_session.add",
skip_all,
fields(
db.statement,
compat_session.id,
%user.id,
%user.username,
compat_session.device.id = device.as_str(),
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
device.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSession {
id,
state: CompatSessionState::default(),
user_id: user.id,
device,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_session.finish",
skip_all,
fields(
db.statement,
%compat_session.id,
user.id = %compat_session.user_id,
compat_session.device.id = compat_session.device.as_str(),
),
err,
)]
async fn finish(
&mut self,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
}

View File

@ -0,0 +1,397 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
pagination::{process_page, Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[async_trait]
pub trait CompatSsoLoginRepository: Send + Sync {
type Error;
/// Lookup a compat SSO login by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error>;
/// Find a compat SSO login by its login token
async fn find_by_token(
&mut self,
login_token: &str,
) -> Result<Option<CompatSsoLogin>, Self::Error>;
/// Start a new compat SSO login token
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error>;
/// Fulfill a compat SSO login by providing a compat session
async fn fulfill(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error>;
/// Mark a compat SSO login as exchanged
async fn exchange(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error>;
/// Get a paginated list of compat SSO logins for a user
async fn list_paginated(
&mut self,
user: &User,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Page<CompatSsoLogin>, Self::Error>;
}
pub struct PgCompatSsoLoginRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSsoLoginRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid,
login_token: String,
redirect_uri: String,
created_at: DateTime<Utc>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
let id = res.compat_sso_login_id.into();
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
DatabaseInconsistencyError::on("compat_sso_logins")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
};
Ok(CompatSsoLogin {
id,
login_token: res.login_token,
redirect_uri,
created_at: res.created_at,
state,
})
}
}
#[async_trait]
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_sso_login.lookup",
skip_all,
fields(
db.statement,
compat_sso_login.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE compat_sso_login_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
login_token: &str,
) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE login_token = $1
"#,
login_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.add",
skip_all,
fields(
db.statement,
compat_sso_login.id,
compat_sso_login.redirect_uri = %redirect_uri,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSsoLogin {
id,
login_token,
redirect_uri,
created_at,
state: CompatSsoLoginState::default(),
})
}
#[tracing::instrument(
name = "db.compat_sso_login.fulfill",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
%compat_session.id,
compat_session.device.id = compat_session.device.as_str(),
user.id = %compat_session.user_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error> {
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, compat_session)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
compat_session_id = $2,
fulfilled_at = $3
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
Uuid::from(compat_session.id),
fulfilled_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.exchange",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error> {
let exchanged_at = clock.now();
let compat_sso_login = compat_sso_login
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = $2
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
exchanged_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Page<CompatSsoLogin>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
, cl.login_token
, cl.redirect_uri
, cl.created_at
, cl.fulfilled_at
, cl.exchanged_at
, cl.compat_session_id
FROM compat_sso_logins cl
INNER JOIN compat_sessions ON compat_session_id
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", before, after, first, last)?;
let page: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?;
let edges: Result<Vec<_>, _> = edges.into_iter().map(TryInto::try_into).collect();
Ok(Page {
has_next_page,
has_previous_page,
edges: edges?,
})
}
}

View File

@ -15,6 +15,10 @@
use sqlx::{PgConnection, Postgres, Transaction};
use crate::{
compat::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
PgCompatSsoLoginRepository,
},
oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository},
upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
@ -63,6 +67,22 @@ pub trait Repository {
where
Self: 'c;
type CompatSessionRepository<'c>
where
Self: 'c;
type CompatSsoLoginRepository<'c>
where
Self: 'c;
type CompatAccessTokenRepository<'c>
where
Self: 'c;
type CompatRefreshTokenRepository<'c>
where
Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>;
@ -72,6 +92,10 @@ pub trait Repository {
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>;
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>;
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>;
}
impl Repository for PgConnection {
@ -84,6 +108,10 @@ impl Repository for PgConnection {
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self)
@ -120,6 +148,22 @@ impl Repository for PgConnection {
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self)
}
}
impl<'t> Repository for Transaction<'t, Postgres> {
@ -132,6 +176,10 @@ impl<'t> Repository for Transaction<'t, Postgres> {
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self)
@ -168,4 +216,20 @@ impl<'t> Repository for Transaction<'t, Postgres> {
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self)
}
}