You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Storage layer for a unified session list
This commit is contained in:
629
crates/storage-pg/src/app_session.rs
Normal file
629
crates/storage-pg/src/app_session.rs
Normal file
@@ -0,0 +1,629 @@
|
|||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! A module containing PostgreSQL implementation of repositories for sessions
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState};
|
||||||
|
use mas_storage::{
|
||||||
|
app_session::{AppSession, AppSessionFilter, AppSessionRepository},
|
||||||
|
Page, Pagination,
|
||||||
|
};
|
||||||
|
use oauth2_types::scope::{Scope, ScopeToken};
|
||||||
|
use sea_query::{
|
||||||
|
Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
|
||||||
|
};
|
||||||
|
use sea_query_binder::SqlxBinder;
|
||||||
|
use sqlx::PgConnection;
|
||||||
|
use ulid::Ulid;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
errors::DatabaseInconsistencyError,
|
||||||
|
iden::{CompatSessions, OAuth2Sessions},
|
||||||
|
pagination::QueryBuilderExt,
|
||||||
|
DatabaseError, ExecuteExt,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An implementation of [`AppSessionRepository`] for a PostgreSQL connection
|
||||||
|
pub struct PgAppSessionRepository<'c> {
|
||||||
|
conn: &'c mut PgConnection,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'c> PgAppSessionRepository<'c> {
|
||||||
|
/// Create a new [`PgAppSessionRepository`] from an active PostgreSQL
|
||||||
|
/// connection
|
||||||
|
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||||
|
Self { conn }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mod priv_ {
|
||||||
|
// The enum_def macro generates a public enum, which we don't want, because it
|
||||||
|
// triggers the missing docs warning
|
||||||
|
|
||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use sea_query::enum_def;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(sqlx::FromRow)]
|
||||||
|
#[enum_def]
|
||||||
|
pub(super) struct AppSessionLookup {
|
||||||
|
pub(super) cursor: Uuid,
|
||||||
|
pub(super) compat_session_id: Option<Uuid>,
|
||||||
|
pub(super) oauth2_session_id: Option<Uuid>,
|
||||||
|
pub(super) oauth2_client_id: Option<Uuid>,
|
||||||
|
pub(super) user_session_id: Option<Uuid>,
|
||||||
|
pub(super) user_id: Option<Uuid>,
|
||||||
|
pub(super) scope_list: Option<Vec<String>>,
|
||||||
|
pub(super) device_id: Option<String>,
|
||||||
|
pub(super) created_at: DateTime<Utc>,
|
||||||
|
pub(super) finished_at: Option<DateTime<Utc>>,
|
||||||
|
pub(super) is_synapse_admin: Option<bool>,
|
||||||
|
pub(super) last_active_at: Option<DateTime<Utc>>,
|
||||||
|
pub(super) last_active_ip: Option<IpAddr>,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use priv_::{AppSessionLookup, AppSessionLookupIden};
|
||||||
|
|
||||||
|
impl TryFrom<AppSessionLookup> for AppSession {
|
||||||
|
type Error = DatabaseError;
|
||||||
|
|
||||||
|
fn try_from(value: AppSessionLookup) -> Result<Self, Self::Error> {
|
||||||
|
// This is annoying to do, but we have to match on all the fields to determine
|
||||||
|
// whether it's a compat session or an oauth2 session
|
||||||
|
let AppSessionLookup {
|
||||||
|
cursor,
|
||||||
|
compat_session_id,
|
||||||
|
oauth2_session_id,
|
||||||
|
oauth2_client_id,
|
||||||
|
user_session_id,
|
||||||
|
user_id,
|
||||||
|
scope_list,
|
||||||
|
device_id,
|
||||||
|
created_at,
|
||||||
|
finished_at,
|
||||||
|
is_synapse_admin,
|
||||||
|
last_active_at,
|
||||||
|
last_active_ip,
|
||||||
|
} = value;
|
||||||
|
|
||||||
|
match (
|
||||||
|
compat_session_id,
|
||||||
|
oauth2_session_id,
|
||||||
|
oauth2_client_id,
|
||||||
|
user_session_id,
|
||||||
|
user_id,
|
||||||
|
scope_list,
|
||||||
|
device_id,
|
||||||
|
is_synapse_admin,
|
||||||
|
) {
|
||||||
|
(
|
||||||
|
Some(compat_session_id),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
Some(user_id),
|
||||||
|
None,
|
||||||
|
Some(device_id),
|
||||||
|
Some(is_synapse_admin),
|
||||||
|
) => {
|
||||||
|
let id = compat_session_id.into();
|
||||||
|
let device = Device::try_from(device_id).map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("compat_sessions")
|
||||||
|
.column("device_id")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let state = match finished_at {
|
||||||
|
None => CompatSessionState::Valid,
|
||||||
|
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = CompatSession {
|
||||||
|
id,
|
||||||
|
state,
|
||||||
|
user_id: user_id.into(),
|
||||||
|
device,
|
||||||
|
created_at,
|
||||||
|
is_synapse_admin,
|
||||||
|
last_active_at,
|
||||||
|
last_active_ip,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(AppSession::Compat(Box::new(session)))
|
||||||
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
Some(oauth2_session_id),
|
||||||
|
Some(oauth2_client_id),
|
||||||
|
user_session_id,
|
||||||
|
user_id,
|
||||||
|
Some(scope_list),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
) => {
|
||||||
|
let id = oauth2_session_id.into();
|
||||||
|
let scope: Result<Scope, _> =
|
||||||
|
scope_list.iter().map(|s| s.parse::<ScopeToken>()).collect();
|
||||||
|
let scope = scope.map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||||
|
.column("scope")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let state = match value.finished_at {
|
||||||
|
None => SessionState::Valid,
|
||||||
|
Some(finished_at) => SessionState::Finished { finished_at },
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = Session {
|
||||||
|
id,
|
||||||
|
state,
|
||||||
|
created_at,
|
||||||
|
client_id: oauth2_client_id.into(),
|
||||||
|
user_id: user_id.map(Ulid::from),
|
||||||
|
user_session_id: user_session_id.map(Ulid::from),
|
||||||
|
scope,
|
||||||
|
last_active_at,
|
||||||
|
last_active_ip,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(AppSession::OAuth2(Box::new(session)))
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => Err(DatabaseInconsistencyError::on("sessions")
|
||||||
|
.row(cursor.into())
|
||||||
|
.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
||||||
|
type Error = DatabaseError;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
async fn list(
|
||||||
|
&mut self,
|
||||||
|
filter: AppSessionFilter<'_>,
|
||||||
|
pagination: Pagination,
|
||||||
|
) -> Result<Page<AppSession>, Self::Error> {
|
||||||
|
let mut oauth2_session_select = Query::select()
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
|
||||||
|
AppSessionLookupIden::Cursor,
|
||||||
|
)
|
||||||
|
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::CompatSessionId)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
|
||||||
|
AppSessionLookupIden::Oauth2SessionId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
|
||||||
|
AppSessionLookupIden::Oauth2ClientId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
|
||||||
|
AppSessionLookupIden::UserSessionId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
|
||||||
|
AppSessionLookupIden::UserId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
|
||||||
|
AppSessionLookupIden::ScopeList,
|
||||||
|
)
|
||||||
|
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::DeviceId)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
|
||||||
|
AppSessionLookupIden::CreatedAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
|
||||||
|
AppSessionLookupIden::FinishedAt,
|
||||||
|
)
|
||||||
|
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::IsSynapseAdmin)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
|
||||||
|
AppSessionLookupIden::LastActiveAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
|
||||||
|
AppSessionLookupIden::LastActiveIp,
|
||||||
|
)
|
||||||
|
.from(OAuth2Sessions::Table)
|
||||||
|
.and_where_option(filter.user().map(|user| {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.state().map(|state| {
|
||||||
|
if state.is_active() {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let compat_session_select = Query::select()
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
|
||||||
|
AppSessionLookupIden::Cursor,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
|
||||||
|
AppSessionLookupIden::CompatSessionId,
|
||||||
|
)
|
||||||
|
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2SessionId)
|
||||||
|
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2ClientId)
|
||||||
|
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::UserSessionId)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::UserId)),
|
||||||
|
AppSessionLookupIden::UserId,
|
||||||
|
)
|
||||||
|
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::ScopeList)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
|
||||||
|
AppSessionLookupIden::DeviceId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
|
||||||
|
AppSessionLookupIden::CreatedAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
|
||||||
|
AppSessionLookupIden::FinishedAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
|
||||||
|
AppSessionLookupIden::IsSynapseAdmin,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
|
||||||
|
AppSessionLookupIden::LastActiveAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
|
||||||
|
AppSessionLookupIden::LastActiveIp,
|
||||||
|
)
|
||||||
|
.from(CompatSessions::Table)
|
||||||
|
.and_where_option(filter.user().map(|user| {
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.state().map(|state| {
|
||||||
|
if state.is_active() {
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let common_table_expression = CommonTableExpression::new()
|
||||||
|
.query(
|
||||||
|
oauth2_session_select
|
||||||
|
.union(UnionType::All, compat_session_select)
|
||||||
|
.clone(),
|
||||||
|
)
|
||||||
|
.table_name(Alias::new("sessions"))
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let with_clause = Query::with().cte(common_table_expression).clone();
|
||||||
|
|
||||||
|
let select = Query::select()
|
||||||
|
.column(ColumnRef::Asterisk)
|
||||||
|
.from(Alias::new("sessions"))
|
||||||
|
.generate_pagination(AppSessionLookupIden::Cursor, pagination)
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
let edges: Vec<AppSessionLookup> = sqlx::query_as_with(&sql, arguments)
|
||||||
|
.traced()
|
||||||
|
.fetch_all(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let page = pagination.process(edges).try_map(TryFrom::try_from)?;
|
||||||
|
|
||||||
|
Ok(page)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
|
||||||
|
let mut oauth2_session_select = Query::select()
|
||||||
|
.expr(Expr::cust("1"))
|
||||||
|
.from(OAuth2Sessions::Table)
|
||||||
|
.and_where_option(filter.user().map(|user| {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.state().map(|state| {
|
||||||
|
if state.is_active() {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let compat_session_select = Query::select()
|
||||||
|
.expr(Expr::cust("1"))
|
||||||
|
.from(CompatSessions::Table)
|
||||||
|
.and_where_option(filter.user().map(|user| {
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.state().map(|state| {
|
||||||
|
if state.is_active() {
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let common_table_expression = CommonTableExpression::new()
|
||||||
|
.query(
|
||||||
|
oauth2_session_select
|
||||||
|
.union(UnionType::All, compat_session_select)
|
||||||
|
.clone(),
|
||||||
|
)
|
||||||
|
.table_name(Alias::new("sessions"))
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let with_clause = Query::with().cte(common_table_expression).clone();
|
||||||
|
|
||||||
|
let select = Query::select()
|
||||||
|
.expr(Expr::cust("COUNT(*)"))
|
||||||
|
.from(Alias::new("sessions"))
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||||
|
.traced()
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
count
|
||||||
|
.try_into()
|
||||||
|
.map_err(DatabaseError::to_invalid_operation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use chrono::Duration;
|
||||||
|
use mas_data_model::Device;
|
||||||
|
use mas_storage::{
|
||||||
|
app_session::{AppSession, AppSessionFilter},
|
||||||
|
clock::MockClock,
|
||||||
|
oauth2::OAuth2SessionRepository,
|
||||||
|
Pagination, RepositoryAccess,
|
||||||
|
};
|
||||||
|
use oauth2_types::{
|
||||||
|
requests::GrantType,
|
||||||
|
scope::{Scope, OPENID},
|
||||||
|
};
|
||||||
|
use rand::SeedableRng;
|
||||||
|
use rand_chacha::ChaChaRng;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
use crate::PgRepository;
|
||||||
|
|
||||||
|
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||||
|
async fn test_app_repo(pool: PgPool) {
|
||||||
|
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||||
|
let clock = MockClock::default();
|
||||||
|
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||||
|
|
||||||
|
// Create a user
|
||||||
|
let user = repo
|
||||||
|
.user()
|
||||||
|
.add(&mut rng, &clock, "john".to_owned())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let all = AppSessionFilter::new().for_user(&user);
|
||||||
|
let active = all.active_only();
|
||||||
|
let finished = all.finished_only();
|
||||||
|
let pagination = Pagination::first(10);
|
||||||
|
|
||||||
|
assert_eq!(repo.app_session().count(all).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
|
||||||
|
|
||||||
|
let full_list = repo.app_session().list(all, pagination).await.unwrap();
|
||||||
|
assert!(full_list.edges.is_empty());
|
||||||
|
let active_list = repo.app_session().list(active, pagination).await.unwrap();
|
||||||
|
assert!(active_list.edges.is_empty());
|
||||||
|
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
|
||||||
|
assert!(finished_list.edges.is_empty());
|
||||||
|
|
||||||
|
// Start a compat session for that user
|
||||||
|
let device = Device::generate(&mut rng);
|
||||||
|
let compat_session = repo
|
||||||
|
.compat_session()
|
||||||
|
.add(&mut rng, &clock, &user, device, false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
|
||||||
|
|
||||||
|
let full_list = repo.app_session().list(all, pagination).await.unwrap();
|
||||||
|
assert_eq!(full_list.edges.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
full_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
let active_list = repo.app_session().list(active, pagination).await.unwrap();
|
||||||
|
assert_eq!(active_list.edges.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
active_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
|
||||||
|
assert!(finished_list.edges.is_empty());
|
||||||
|
|
||||||
|
// Finish the session
|
||||||
|
let compat_session = repo
|
||||||
|
.compat_session()
|
||||||
|
.finish(&clock, compat_session)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
|
||||||
|
|
||||||
|
let full_list = repo.app_session().list(all, pagination).await.unwrap();
|
||||||
|
assert_eq!(full_list.edges.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
full_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
let active_list = repo.app_session().list(active, pagination).await.unwrap();
|
||||||
|
assert!(active_list.edges.is_empty());
|
||||||
|
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
|
||||||
|
assert_eq!(finished_list.edges.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
finished_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
|
||||||
|
// Start an OAuth2 session
|
||||||
|
let client = repo
|
||||||
|
.oauth2_client()
|
||||||
|
.add(
|
||||||
|
&mut rng,
|
||||||
|
&clock,
|
||||||
|
vec!["https://example.com/redirect".parse().unwrap()],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
vec![GrantType::AuthorizationCode],
|
||||||
|
Vec::new(), // TODO: contacts are not yet saved
|
||||||
|
// vec!["contact@example.com".to_owned()],
|
||||||
|
Some("First client".to_owned()),
|
||||||
|
Some("https://example.com/logo.png".parse().unwrap()),
|
||||||
|
Some("https://example.com/".parse().unwrap()),
|
||||||
|
Some("https://example.com/policy".parse().unwrap()),
|
||||||
|
Some("https://example.com/tos".parse().unwrap()),
|
||||||
|
Some("https://example.com/jwks.json".parse().unwrap()),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
Some("https://example.com/login".parse().unwrap()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let scope = Scope::from_iter([OPENID]);
|
||||||
|
|
||||||
|
// We're moving the clock forward by 1 minute between each session to ensure
|
||||||
|
// we're getting consistent ordering in lists.
|
||||||
|
clock.advance(Duration::minutes(1));
|
||||||
|
|
||||||
|
let oauth_session = repo
|
||||||
|
.oauth2_session()
|
||||||
|
.add(&mut rng, &clock, &client, Some(&user), None, scope)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
|
||||||
|
assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
|
||||||
|
|
||||||
|
let full_list = repo.app_session().list(all, pagination).await.unwrap();
|
||||||
|
assert_eq!(full_list.edges.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
full_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
full_list.edges[1],
|
||||||
|
AppSession::OAuth2(Box::new(oauth_session.clone()))
|
||||||
|
);
|
||||||
|
|
||||||
|
let active_list = repo.app_session().list(active, pagination).await.unwrap();
|
||||||
|
assert_eq!(active_list.edges.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
active_list.edges[0],
|
||||||
|
AppSession::OAuth2(Box::new(oauth_session.clone()))
|
||||||
|
);
|
||||||
|
|
||||||
|
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
|
||||||
|
assert_eq!(finished_list.edges.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
finished_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
|
||||||
|
// Finish the session
|
||||||
|
let oauth_session = repo
|
||||||
|
.oauth2_session()
|
||||||
|
.finish(&clock, oauth_session)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
|
||||||
|
assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.app_session().count(finished).await.unwrap(), 2);
|
||||||
|
|
||||||
|
let full_list = repo.app_session().list(all, pagination).await.unwrap();
|
||||||
|
assert_eq!(full_list.edges.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
full_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
full_list.edges[1],
|
||||||
|
AppSession::OAuth2(Box::new(oauth_session.clone()))
|
||||||
|
);
|
||||||
|
|
||||||
|
let active_list = repo.app_session().list(active, pagination).await.unwrap();
|
||||||
|
assert!(active_list.edges.is_empty());
|
||||||
|
|
||||||
|
let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
|
||||||
|
assert_eq!(finished_list.edges.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
finished_list.edges[0],
|
||||||
|
AppSession::Compat(Box::new(compat_session.clone()))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
full_list.edges[1],
|
||||||
|
AppSession::OAuth2(Box::new(oauth_session.clone()))
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create a second user
|
||||||
|
let user2 = repo
|
||||||
|
.user()
|
||||||
|
.add(&mut rng, &clock, "alice".to_owned())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// If we list/count for this user, we should get nothing
|
||||||
|
let filter = AppSessionFilter::new().for_user(&user2);
|
||||||
|
assert_eq!(repo.app_session().count(filter).await.unwrap(), 0);
|
||||||
|
let list = repo.app_session().list(filter, pagination).await.unwrap();
|
||||||
|
assert!(list.edges.is_empty());
|
||||||
|
}
|
||||||
|
}
|
@@ -177,6 +177,7 @@
|
|||||||
|
|
||||||
use sqlx::migrate::Migrator;
|
use sqlx::migrate::Migrator;
|
||||||
|
|
||||||
|
pub mod app_session;
|
||||||
pub mod compat;
|
pub mod compat;
|
||||||
pub mod job;
|
pub mod job;
|
||||||
pub mod oauth2;
|
pub mod oauth2;
|
||||||
|
@@ -16,6 +16,7 @@ use std::ops::{Deref, DerefMut};
|
|||||||
|
|
||||||
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
|
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
|
app_session::AppSessionRepository,
|
||||||
compat::{
|
compat::{
|
||||||
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
||||||
CompatSsoLoginRepository,
|
CompatSsoLoginRepository,
|
||||||
@@ -36,6 +37,7 @@ use sqlx::{PgConnection, PgPool, Postgres, Transaction};
|
|||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
app_session::PgAppSessionRepository,
|
||||||
compat::{
|
compat::{
|
||||||
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
|
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
|
||||||
PgCompatSsoLoginRepository,
|
PgCompatSsoLoginRepository,
|
||||||
@@ -182,6 +184,10 @@ where
|
|||||||
Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
|
Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
|
||||||
|
Box::new(PgAppSessionRepository::new(self.conn.as_mut()))
|
||||||
|
}
|
||||||
|
|
||||||
fn oauth2_client<'c>(
|
fn oauth2_client<'c>(
|
||||||
&'c mut self,
|
&'c mut self,
|
||||||
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
|
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
|
||||||
|
148
crates/storage/src/app_session.rs
Normal file
148
crates/storage/src/app_session.rs
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! Repositories to interact with all kinds of sessions
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use mas_data_model::{CompatSession, Session, User};
|
||||||
|
|
||||||
|
use crate::{repository_impl, Page, Pagination};
|
||||||
|
|
||||||
|
/// The state of a session
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub enum AppSessionState {
|
||||||
|
/// The session is active
|
||||||
|
Active,
|
||||||
|
/// The session is finished
|
||||||
|
Finished,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppSessionState {
|
||||||
|
/// Returns [`true`] if we're looking for active sessions
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_active(self) -> bool {
|
||||||
|
matches!(self, Self::Active)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns [`true`] if we're looking for finished sessions
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_finished(self) -> bool {
|
||||||
|
matches!(self, Self::Finished)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An [`AppSession`] is either a [`CompatSession`] or an OAuth 2.0 [`Session`]
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum AppSession {
|
||||||
|
/// A compatibility layer session
|
||||||
|
Compat(Box<CompatSession>),
|
||||||
|
|
||||||
|
/// An OAuth 2.0 session
|
||||||
|
OAuth2(Box<Session>),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filtering parameters for application sessions
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||||
|
pub struct AppSessionFilter<'a> {
|
||||||
|
user: Option<&'a User>,
|
||||||
|
state: Option<AppSessionState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> AppSessionFilter<'a> {
|
||||||
|
/// Create a new [`AppSessionFilter`] with default values
|
||||||
|
#[must_use]
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the user who owns the compatibility sessions
|
||||||
|
#[must_use]
|
||||||
|
pub fn for_user(mut self, user: &'a User) -> Self {
|
||||||
|
self.user = Some(user);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the user filter
|
||||||
|
#[must_use]
|
||||||
|
pub fn user(&self) -> Option<&User> {
|
||||||
|
self.user
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Only return active compatibility sessions
|
||||||
|
#[must_use]
|
||||||
|
pub fn active_only(mut self) -> Self {
|
||||||
|
self.state = Some(AppSessionState::Active);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Only return finished compatibility sessions
|
||||||
|
#[must_use]
|
||||||
|
pub fn finished_only(mut self) -> Self {
|
||||||
|
self.state = Some(AppSessionState::Finished);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the state filter
|
||||||
|
#[must_use]
|
||||||
|
pub fn state(&self) -> Option<AppSessionState> {
|
||||||
|
self.state
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`AppSessionRepository`] helps interacting with both [`CompatSession`] and
|
||||||
|
/// OAuth 2.0 [`Session`] at the same time saved in the storage backend
|
||||||
|
#[async_trait]
|
||||||
|
pub trait AppSessionRepository: Send + Sync {
|
||||||
|
/// The error type returned by the repository
|
||||||
|
type Error;
|
||||||
|
|
||||||
|
/// List [`AppSession`] with the given filter and pagination
|
||||||
|
///
|
||||||
|
/// Returns a page of [`AppSession`] matching the given filter
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// * `filter`: The filter to apply
|
||||||
|
/// * `pagination`: The pagination parameters
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
async fn list(
|
||||||
|
&mut self,
|
||||||
|
filter: AppSessionFilter<'_>,
|
||||||
|
pagination: Pagination,
|
||||||
|
) -> Result<Page<AppSession>, Self::Error>;
|
||||||
|
|
||||||
|
/// Count the number of [`AppSession`] with the given filter
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// * `filter`: The filter to apply
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
repository_impl!(AppSessionRepository:
|
||||||
|
async fn list(
|
||||||
|
&mut self,
|
||||||
|
filter: AppSessionFilter<'_>,
|
||||||
|
pagination: Pagination,
|
||||||
|
) -> Result<Page<AppSession>, Self::Error>;
|
||||||
|
|
||||||
|
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error>;
|
||||||
|
);
|
@@ -132,7 +132,7 @@ impl<'a> CompatSessionFilter<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A [`CompatSessionRepository`] helps interacting with
|
/// A [`CompatSessionRepository`] helps interacting with
|
||||||
/// [`CompatSessionRepository`] saved in the storage backend
|
/// [`CompatSession`] saved in the storage backend
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait CompatSessionRepository: Send + Sync {
|
pub trait CompatSessionRepository: Send + Sync {
|
||||||
/// The error type returned by the repository
|
/// The error type returned by the repository
|
||||||
|
@@ -149,6 +149,7 @@ pub mod pagination;
|
|||||||
pub(crate) mod repository;
|
pub(crate) mod repository;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
|
pub mod app_session;
|
||||||
pub mod compat;
|
pub mod compat;
|
||||||
pub mod job;
|
pub mod job;
|
||||||
pub mod oauth2;
|
pub mod oauth2;
|
||||||
|
@@ -16,6 +16,7 @@ use futures_util::future::BoxFuture;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
app_session::AppSessionRepository,
|
||||||
compat::{
|
compat::{
|
||||||
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
||||||
CompatSsoLoginRepository,
|
CompatSsoLoginRepository,
|
||||||
@@ -150,6 +151,9 @@ pub trait RepositoryAccess: Send {
|
|||||||
&'c mut self,
|
&'c mut self,
|
||||||
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
|
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
|
||||||
|
|
||||||
|
/// Get a [`AppSessionRepository`]
|
||||||
|
fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c>;
|
||||||
|
|
||||||
/// Get an [`OAuth2ClientRepository`]
|
/// Get an [`OAuth2ClientRepository`]
|
||||||
fn oauth2_client<'c>(&'c mut self)
|
fn oauth2_client<'c>(&'c mut self)
|
||||||
-> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
|
-> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
|
||||||
@@ -205,6 +209,7 @@ mod impls {
|
|||||||
|
|
||||||
use super::RepositoryAccess;
|
use super::RepositoryAccess;
|
||||||
use crate::{
|
use crate::{
|
||||||
|
app_session::AppSessionRepository,
|
||||||
compat::{
|
compat::{
|
||||||
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
|
||||||
CompatSsoLoginRepository,
|
CompatSsoLoginRepository,
|
||||||
@@ -310,6 +315,12 @@ mod impls {
|
|||||||
Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper))
|
Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn app_session<'c>(
|
||||||
|
&'c mut self,
|
||||||
|
) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
|
||||||
|
Box::new(MapErr::new(self.inner.app_session(), &mut self.mapper))
|
||||||
|
}
|
||||||
|
|
||||||
fn oauth2_client<'c>(
|
fn oauth2_client<'c>(
|
||||||
&'c mut self,
|
&'c mut self,
|
||||||
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
|
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
|
||||||
@@ -425,6 +436,12 @@ mod impls {
|
|||||||
(**self).browser_session()
|
(**self).browser_session()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn app_session<'c>(
|
||||||
|
&'c mut self,
|
||||||
|
) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
|
||||||
|
(**self).app_session()
|
||||||
|
}
|
||||||
|
|
||||||
fn oauth2_client<'c>(
|
fn oauth2_client<'c>(
|
||||||
&'c mut self,
|
&'c mut self,
|
||||||
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
|
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
|
||||||
|
Reference in New Issue
Block a user