You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
storage: repository pattern for the compat layer
This commit is contained in:
@ -1,757 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
||||
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User,
|
||||
};
|
||||
use rand::Rng;
|
||||
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{process_page, QueryBuilderExt},
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
|
||||
struct CompatSessionLookup {
|
||||
compat_session_id: Uuid,
|
||||
device_id: String,
|
||||
user_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn lookup_compat_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
session_id: Ulid,
|
||||
) -> Result<Option<CompatSession>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSessionLookup,
|
||||
r#"
|
||||
SELECT compat_session_id
|
||||
, device_id
|
||||
, user_id
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM compat_sessions
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(session_id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let id = res.compat_session_id.into();
|
||||
let device = Device::try_from(res.device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match res.finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: res.user_id.into(),
|
||||
device,
|
||||
created_at: res.created_at,
|
||||
};
|
||||
|
||||
Ok(Some(session))
|
||||
}
|
||||
|
||||
struct CompatAccessTokenLookup {
|
||||
compat_access_token_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatAccessTokenLookup> for CompatAccessToken {
|
||||
fn from(value: CompatAccessTokenLookup) -> Self {
|
||||
Self {
|
||||
id: value.compat_access_token_id.into(),
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn find_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
token: &str,
|
||||
) -> Result<Option<CompatAccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_access_token.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<Option<CompatAccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
pub struct CompatRefreshTokenLookup {
|
||||
compat_refresh_token_id: Uuid,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
compat_access_token_id: Uuid,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn find_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
token: &str,
|
||||
) -> Result<Option<CompatRefreshToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None); };
|
||||
|
||||
let state = match res.consumed_at {
|
||||
None => CompatRefreshTokenState::Valid,
|
||||
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
|
||||
};
|
||||
|
||||
let refresh_token = CompatRefreshToken {
|
||||
id: res.compat_refresh_token_id.into(),
|
||||
state,
|
||||
session_id: res.compat_session_id.into(),
|
||||
access_token_id: res.compat_access_token_id.into(),
|
||||
token: res.refresh_token,
|
||||
created_at: res.created_at,
|
||||
};
|
||||
|
||||
Ok(Some(refresh_token))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_session.id = %session.id,
|
||||
compat_session.device.id = session.device.as_str(),
|
||||
compat_access_token.id,
|
||||
user.id = %session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn add_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
session: &CompatSession,
|
||||
token: String,
|
||||
expires_after: Option<Duration>,
|
||||
) -> Result<CompatAccessToken, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
||||
|
||||
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_access_tokens
|
||||
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(session.id),
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Insert compat access token"))
|
||||
.await?;
|
||||
|
||||
Ok(CompatAccessToken {
|
||||
id,
|
||||
session_id: session.id,
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_access_token.id = %access_token.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn expire_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
access_token: CompatAccessToken,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let expires_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_access_tokens
|
||||
SET expires_at = $2
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(access_token.id),
|
||||
expires_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_session.id = %session.id,
|
||||
compat_session.device.id = session.device.as_str(),
|
||||
compat_access_token.id = %access_token.id,
|
||||
compat_refresh_token.id,
|
||||
user.id = %session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn add_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
session: &CompatSession,
|
||||
access_token: &CompatAccessToken,
|
||||
token: String,
|
||||
) -> Result<CompatRefreshToken, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_refresh_tokens
|
||||
(compat_refresh_token_id, compat_session_id,
|
||||
compat_access_token_id, refresh_token, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(session.id),
|
||||
Uuid::from(access_token.id),
|
||||
token,
|
||||
created_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Insert compat refresh token"))
|
||||
.await?;
|
||||
|
||||
Ok(CompatRefreshToken {
|
||||
id,
|
||||
state: CompatRefreshTokenState::default(),
|
||||
session_id: session.id,
|
||||
access_token_id: access_token.id,
|
||||
token,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%compat_session.id),
|
||||
err,
|
||||
)]
|
||||
pub async fn end_compat_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
compat_session: CompatSession,
|
||||
) -> Result<CompatSession, DatabaseError> {
|
||||
let finished_at = clock.now();
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sessions cs
|
||||
SET finished_at = $2
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_session.id),
|
||||
finished_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_session = compat_session
|
||||
.finish(finished_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_session)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_refresh_token.id = %refresh_token.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn consume_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
refresh_token: CompatRefreshToken,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_refresh_tokens
|
||||
SET consumed_at = $2
|
||||
WHERE compat_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(refresh_token.id),
|
||||
consumed_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_sso_login.id,
|
||||
compat_sso_login.redirect_uri = %redirect_uri,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn insert_compat_sso_login(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
login_token: String,
|
||||
redirect_uri: Url,
|
||||
) -> Result<CompatSsoLogin, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sso_logins
|
||||
(compat_sso_login_id, login_token, redirect_uri, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&login_token,
|
||||
redirect_uri.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Insert compat SSO login"))
|
||||
.await?;
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token,
|
||||
redirect_uri,
|
||||
created_at,
|
||||
state: CompatSsoLoginState::Pending,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct CompatSsoLoginLookup {
|
||||
compat_sso_login_id: Uuid,
|
||||
compat_sso_login_token: String,
|
||||
compat_sso_login_redirect_uri: String,
|
||||
compat_sso_login_created_at: DateTime<Utc>,
|
||||
compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
|
||||
compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
|
||||
let id = res.compat_sso_login_id.into();
|
||||
let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sso_logins")
|
||||
.column("redirect_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match (
|
||||
res.compat_sso_login_fulfilled_at,
|
||||
res.compat_sso_login_exchanged_at,
|
||||
res.compat_session_id,
|
||||
) {
|
||||
(None, None, None) => CompatSsoLoginState::Pending,
|
||||
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session_id: session_id.into(),
|
||||
},
|
||||
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
|
||||
CompatSsoLoginState::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session_id: session_id.into(),
|
||||
}
|
||||
}
|
||||
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
|
||||
};
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token: res.compat_sso_login_token,
|
||||
redirect_uri,
|
||||
created_at: res.compat_sso_login_created_at,
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_sso_login.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_compat_sso_login_by_id(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT cl.compat_sso_login_id
|
||||
, cl.login_token AS "compat_sso_login_token"
|
||||
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
|
||||
, cl.created_at AS "compat_sso_login_created_at"
|
||||
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
|
||||
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
|
||||
, cl.compat_session_id AS "compat_session_id"
|
||||
|
||||
FROM compat_sso_logins cl
|
||||
WHERE cl.compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_paginated_user_compat_sso_logins(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user: &User,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(bool, bool, Vec<CompatSsoLogin>), DatabaseError> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT cl.compat_sso_login_id
|
||||
, cl.login_token AS "compat_sso_login_token"
|
||||
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
|
||||
, cl.created_at AS "compat_sso_login_created_at"
|
||||
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
|
||||
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
|
||||
, cl.compat_session_id AS "compat_session_id"
|
||||
FROM compat_sso_logins cl
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE cs.user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("cl.compat_sso_login_id", before, after, first, last)?;
|
||||
|
||||
let span = info_span!(
|
||||
"Fetch paginated user compat SSO logins",
|
||||
db.statement = query.sql()
|
||||
);
|
||||
let page: Vec<CompatSsoLoginLookup> = query
|
||||
.build_query_as()
|
||||
.fetch_all(executor)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
|
||||
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
|
||||
|
||||
let page: Result<Vec<_>, _> = page.into_iter().map(TryInto::try_into).collect();
|
||||
Ok((has_previous_page, has_next_page, page?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn get_compat_sso_login_by_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
token: &str,
|
||||
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT cl.compat_sso_login_id
|
||||
, cl.login_token AS "compat_sso_login_token"
|
||||
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
|
||||
, cl.created_at AS "compat_sso_login_created_at"
|
||||
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
|
||||
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
|
||||
, cl.compat_session_id AS "compat_session_id"
|
||||
FROM compat_sso_logins cl
|
||||
WHERE cl.login_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%user.id,
|
||||
compat_session.id,
|
||||
compat_session.device.id = device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn start_compat_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
device: Device,
|
||||
) -> Result<CompatSession, DatabaseError> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user.id),
|
||||
device.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
Ok(CompatSession {
|
||||
id,
|
||||
state: CompatSessionState::default(),
|
||||
user_id: user.id,
|
||||
device,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%user.id,
|
||||
%compat_sso_login.id,
|
||||
%compat_sso_login.redirect_uri,
|
||||
compat_session.id,
|
||||
compat_session.device.id = device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn fullfill_compat_sso_login(
|
||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
device: Device,
|
||||
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||
if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) {
|
||||
return Err(DatabaseError::invalid_operation());
|
||||
};
|
||||
|
||||
let mut txn = conn.begin().await?;
|
||||
|
||||
let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?;
|
||||
let session_id = session.id;
|
||||
|
||||
let fulfilled_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.fulfill(fulfilled_at, &session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
compat_session_id = $2,
|
||||
fulfilled_at = $3
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
Uuid::from(session_id),
|
||||
fulfilled_at,
|
||||
)
|
||||
.execute(&mut txn)
|
||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%compat_sso_login.id,
|
||||
%compat_sso_login.redirect_uri,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn mark_compat_sso_login_as_exchanged(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||
let exchanged_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.exchange(exchanged_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
exchanged_at = $2
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
exchanged_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||
.await?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
246
crates/storage/src/compat/access_token.rs
Normal file
246
crates/storage/src/compat/access_token.rs
Normal file
@ -0,0 +1,246 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{CompatAccessToken, CompatSession};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatAccessTokenRepository: Send + Sync {
|
||||
type Error;
|
||||
|
||||
/// Lookup a compat access token by its ID
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error>;
|
||||
|
||||
/// Find a compat access token by its token
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
access_token: &str,
|
||||
) -> Result<Option<CompatAccessToken>, Self::Error>;
|
||||
|
||||
/// Add a new compat access token to the database
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
compat_session: &CompatSession,
|
||||
token: String,
|
||||
expires_after: Option<Duration>,
|
||||
) -> Result<CompatAccessToken, Self::Error>;
|
||||
|
||||
/// Set the expiration time of the compat access token to now
|
||||
async fn expire(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_access_token: CompatAccessToken,
|
||||
) -> Result<CompatAccessToken, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatAccessTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatAccessTokenRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatAccessTokenLookup {
|
||||
compat_access_token_id: Uuid,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatAccessTokenLookup> for CompatAccessToken {
|
||||
fn from(value: CompatAccessTokenLookup) -> Self {
|
||||
Self {
|
||||
id: value.compat_access_token_id.into(),
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
access_token: &str,
|
||||
) -> Result<Option<CompatAccessToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
access_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_access_token.id,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
compat_session: &CompatSession,
|
||||
token: String,
|
||||
expires_after: Option<Duration>,
|
||||
) -> Result<CompatAccessToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
||||
|
||||
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_access_tokens
|
||||
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(compat_session.id),
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatAccessToken {
|
||||
id,
|
||||
session_id: compat_session.id,
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_access_token.expire",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_access_token.id,
|
||||
compat_session.id = %compat_access_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn expire(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut compat_access_token: CompatAccessToken,
|
||||
) -> Result<CompatAccessToken, Self::Error> {
|
||||
let expires_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_access_tokens
|
||||
SET expires_at = $2
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_access_token.id),
|
||||
expires_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
compat_access_token.expires_at = Some(expires_at);
|
||||
Ok(compat_access_token)
|
||||
}
|
||||
}
|
25
crates/storage/src/compat/mod.rs
Normal file
25
crates/storage/src/compat/mod.rs
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod access_token;
|
||||
mod refresh_token;
|
||||
mod session;
|
||||
mod sso_login;
|
||||
|
||||
pub use self::{
|
||||
access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository},
|
||||
refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository},
|
||||
session::{CompatSessionRepository, PgCompatSessionRepository},
|
||||
sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository},
|
||||
};
|
260
crates/storage/src/compat/refresh_token.rs
Normal file
260
crates/storage/src/compat/refresh_token.rs
Normal file
@ -0,0 +1,260 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatRefreshTokenRepository: Send + Sync {
|
||||
type Error;
|
||||
|
||||
/// Lookup a compat refresh token by its ID
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error>;
|
||||
|
||||
/// Find a compat refresh token by its token
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
refresh_token: &str,
|
||||
) -> Result<Option<CompatRefreshToken>, Self::Error>;
|
||||
|
||||
/// Add a new compat refresh token to the database
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
compat_session: &CompatSession,
|
||||
compat_access_token: &CompatAccessToken,
|
||||
token: String,
|
||||
) -> Result<CompatRefreshToken, Self::Error>;
|
||||
|
||||
/// Consume a compat refresh token
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_refresh_token: CompatRefreshToken,
|
||||
) -> Result<CompatRefreshToken, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatRefreshTokenRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatRefreshTokenRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatRefreshTokenLookup {
|
||||
compat_refresh_token_id: Uuid,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
compat_access_token_id: Uuid,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
|
||||
fn from(value: CompatRefreshTokenLookup) -> Self {
|
||||
let state = match value.consumed_at {
|
||||
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
|
||||
None => CompatRefreshTokenState::Valid,
|
||||
};
|
||||
|
||||
Self {
|
||||
id: value.compat_refresh_token_id.into(),
|
||||
state,
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.refresh_token,
|
||||
created_at: value.created_at,
|
||||
access_token_id: value.compat_access_token_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_refresh_token.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE compat_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
refresh_token: &str,
|
||||
) -> Result<Option<CompatRefreshToken>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
refresh_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_refresh_token.id,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
compat_session: &CompatSession,
|
||||
compat_access_token: &CompatAccessToken,
|
||||
token: String,
|
||||
) -> Result<CompatRefreshToken, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_refresh_tokens
|
||||
(compat_refresh_token_id, compat_session_id,
|
||||
compat_access_token_id, refresh_token, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(compat_session.id),
|
||||
Uuid::from(compat_access_token.id),
|
||||
token,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatRefreshToken {
|
||||
id,
|
||||
state: CompatRefreshTokenState::default(),
|
||||
session_id: compat_session.id,
|
||||
access_token_id: compat_access_token.id,
|
||||
token,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_refresh_token.consume",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_refresh_token.id,
|
||||
compat_session.id = %compat_refresh_token.session_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn consume(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_refresh_token: CompatRefreshToken,
|
||||
) -> Result<CompatRefreshToken, Self::Error> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_refresh_tokens
|
||||
SET consumed_at = $2
|
||||
WHERE compat_refresh_token_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_refresh_token.id),
|
||||
consumed_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_refresh_token = compat_refresh_token
|
||||
.consume(consumed_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_refresh_token)
|
||||
}
|
||||
}
|
220
crates/storage/src/compat/session.rs
Normal file
220
crates/storage/src/compat/session.rs
Normal file
@ -0,0 +1,220 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatSessionRepository: Send + Sync {
|
||||
type Error;
|
||||
|
||||
/// Lookup a compat session by its ID
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error>;
|
||||
|
||||
/// Start a new compat session
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
device: Device,
|
||||
) -> Result<CompatSession, Self::Error>;
|
||||
|
||||
/// End a compat session
|
||||
async fn finish(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_session: CompatSession,
|
||||
) -> Result<CompatSession, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatSessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatSessionRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
struct CompatSessionLookup {
|
||||
compat_session_id: Uuid,
|
||||
device_id: String,
|
||||
user_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSessionLookup> for CompatSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.compat_session_id.into();
|
||||
let device = Device::try_from(value.device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match value.finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: value.user_id.into(),
|
||||
device,
|
||||
created_at: value.created_at,
|
||||
};
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSessionLookup,
|
||||
r#"
|
||||
SELECT compat_session_id
|
||||
, device_id
|
||||
, user_id
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM compat_sessions
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_session.id,
|
||||
%user.id,
|
||||
%user.username,
|
||||
compat_session.device.id = device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
device: Device,
|
||||
) -> Result<CompatSession, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user.id),
|
||||
device.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatSession {
|
||||
id,
|
||||
state: CompatSessionState::default(),
|
||||
user_id: user.id,
|
||||
device,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_session.finish",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_session.id,
|
||||
user.id = %compat_session.user_id,
|
||||
compat_session.device.id = compat_session.device.as_str(),
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn finish(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_session: CompatSession,
|
||||
) -> Result<CompatSession, Self::Error> {
|
||||
let finished_at = clock.now();
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sessions cs
|
||||
SET finished_at = $2
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_session.id),
|
||||
finished_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_session = compat_session
|
||||
.finish(finished_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_session)
|
||||
}
|
||||
}
|
397
crates/storage/src/compat/sso_login.rs
Normal file
397
crates/storage/src/compat/sso_login.rs
Normal file
@ -0,0 +1,397 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{process_page, Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CompatSsoLoginRepository: Send + Sync {
|
||||
type Error;
|
||||
|
||||
/// Lookup a compat SSO login by its ID
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error>;
|
||||
|
||||
/// Find a compat SSO login by its login token
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
login_token: &str,
|
||||
) -> Result<Option<CompatSsoLogin>, Self::Error>;
|
||||
|
||||
/// Start a new compat SSO login token
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
login_token: String,
|
||||
redirect_uri: Url,
|
||||
) -> Result<CompatSsoLogin, Self::Error>;
|
||||
|
||||
/// Fulfill a compat SSO login by providing a compat session
|
||||
async fn fulfill(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
compat_session: &CompatSession,
|
||||
) -> Result<CompatSsoLogin, Self::Error>;
|
||||
|
||||
/// Mark a compat SSO login as exchanged
|
||||
async fn exchange(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, Self::Error>;
|
||||
|
||||
/// Get a paginated list of compat SSO logins for a user
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<Page<CompatSsoLogin>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgCompatSsoLoginRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgCompatSsoLoginRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct CompatSsoLoginLookup {
|
||||
compat_sso_login_id: Uuid,
|
||||
login_token: String,
|
||||
redirect_uri: String,
|
||||
created_at: DateTime<Utc>,
|
||||
fulfilled_at: Option<DateTime<Utc>>,
|
||||
exchanged_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
|
||||
let id = res.compat_sso_login_id.into();
|
||||
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sso_logins")
|
||||
.column("redirect_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
|
||||
(None, None, None) => CompatSsoLoginState::Pending,
|
||||
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session_id: session_id.into(),
|
||||
},
|
||||
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
|
||||
CompatSsoLoginState::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session_id: session_id.into(),
|
||||
}
|
||||
}
|
||||
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
|
||||
};
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token: res.login_token,
|
||||
redirect_uri,
|
||||
created_at: res.created_at,
|
||||
state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_sso_login.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT compat_sso_login_id
|
||||
, login_token
|
||||
, redirect_uri
|
||||
, created_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_sso_logins
|
||||
WHERE compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.find_by_token",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn find_by_token(
|
||||
&mut self,
|
||||
login_token: &str,
|
||||
) -> Result<Option<CompatSsoLogin>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
SELECT compat_sso_login_id
|
||||
, login_token
|
||||
, redirect_uri
|
||||
, created_at
|
||||
, fulfilled_at
|
||||
, exchanged_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_sso_logins
|
||||
WHERE login_token = $1
|
||||
"#,
|
||||
login_token,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
compat_sso_login.id,
|
||||
compat_sso_login.redirect_uri = %redirect_uri,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
login_token: String,
|
||||
redirect_uri: Url,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO compat_sso_logins
|
||||
(compat_sso_login_id, login_token, redirect_uri, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&login_token,
|
||||
redirect_uri.as_str(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
login_token,
|
||||
redirect_uri,
|
||||
created_at,
|
||||
state: CompatSsoLoginState::default(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.fulfill",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_sso_login.id,
|
||||
%compat_session.id,
|
||||
compat_session.device.id = compat_session.device.as_str(),
|
||||
user.id = %compat_session.user_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn fulfill(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
compat_session: &CompatSession,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let fulfilled_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.fulfill(fulfilled_at, compat_session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
compat_session_id = $2,
|
||||
fulfilled_at = $3
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
Uuid::from(compat_session.id),
|
||||
fulfilled_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.exchange",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%compat_sso_login.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn exchange(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
compat_sso_login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, Self::Error> {
|
||||
let exchanged_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.exchange(exchanged_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
SET
|
||||
exchanged_at = $2
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(compat_sso_login.id),
|
||||
exchanged_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.compat_sso_login.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<Page<CompatSsoLogin>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT cl.compat_sso_login_id
|
||||
, cl.login_token
|
||||
, cl.redirect_uri
|
||||
, cl.created_at
|
||||
, cl.fulfilled_at
|
||||
, cl.exchanged_at
|
||||
, cl.compat_session_id
|
||||
|
||||
FROM compat_sso_logins cl
|
||||
INNER JOIN compat_sessions ON compat_session_id
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("cl.compat_sso_login_id", before, after, first, last)?;
|
||||
|
||||
let page: Vec<CompatSsoLoginLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?;
|
||||
|
||||
let edges: Result<Vec<_>, _> = edges.into_iter().map(TryInto::try_into).collect();
|
||||
Ok(Page {
|
||||
has_next_page,
|
||||
has_previous_page,
|
||||
edges: edges?,
|
||||
})
|
||||
}
|
||||
}
|
@ -15,6 +15,10 @@
|
||||
use sqlx::{PgConnection, Postgres, Transaction};
|
||||
|
||||
use crate::{
|
||||
compat::{
|
||||
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
|
||||
PgCompatSsoLoginRepository,
|
||||
},
|
||||
oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository},
|
||||
upstream_oauth2::{
|
||||
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
|
||||
@ -63,6 +67,22 @@ pub trait Repository {
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
type CompatSessionRepository<'c>
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
type CompatSsoLoginRepository<'c>
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
type CompatAccessTokenRepository<'c>
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
type CompatRefreshTokenRepository<'c>
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
|
||||
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
|
||||
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>;
|
||||
@ -72,6 +92,10 @@ pub trait Repository {
|
||||
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
|
||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
|
||||
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>;
|
||||
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>;
|
||||
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
|
||||
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>;
|
||||
}
|
||||
|
||||
impl Repository for PgConnection {
|
||||
@ -84,6 +108,10 @@ impl Repository for PgConnection {
|
||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
||||
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
|
||||
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
|
||||
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
|
||||
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||
PgUpstreamOAuthLinkRepository::new(self)
|
||||
@ -120,6 +148,22 @@ impl Repository for PgConnection {
|
||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
||||
PgOAuth2SessionRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
|
||||
PgCompatSessionRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
|
||||
PgCompatSsoLoginRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
|
||||
PgCompatAccessTokenRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
|
||||
PgCompatRefreshTokenRepository::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
@ -132,6 +176,10 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
||||
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
|
||||
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
|
||||
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
|
||||
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||
PgUpstreamOAuthLinkRepository::new(self)
|
||||
@ -168,4 +216,20 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
||||
PgOAuth2SessionRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
|
||||
PgCompatSessionRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
|
||||
PgCompatSsoLoginRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
|
||||
PgCompatAccessTokenRepository::new(self)
|
||||
}
|
||||
|
||||
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
|
||||
PgCompatRefreshTokenRepository::new(self)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user