You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Better OAuth 2.0 sessions pagination and filtering
This commit is contained in:
@ -19,10 +19,6 @@ use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
|
||||
use super::{NodeType, User};
|
||||
use crate::state::ContextExt;
|
||||
|
||||
/// A browser session represents a logged in user in a browser.
|
||||
#[derive(Description)]
|
||||
pub struct BrowserSession(pub mas_data_model::BrowserSession);
|
||||
|
||||
/// The state of a browser session.
|
||||
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
|
||||
pub enum BrowserSessionState {
|
||||
@ -33,6 +29,10 @@ pub enum BrowserSessionState {
|
||||
Finished,
|
||||
}
|
||||
|
||||
/// A browser session represents a logged in user in a browser.
|
||||
#[derive(Description)]
|
||||
pub struct BrowserSession(pub mas_data_model::BrowserSession);
|
||||
|
||||
impl From<mas_data_model::BrowserSession> for BrowserSession {
|
||||
fn from(v: mas_data_model::BrowserSession) -> Self {
|
||||
Self(v)
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context as _;
|
||||
use async_graphql::{Context, Description, Object, ID};
|
||||
use async_graphql::{Context, Description, Enum, Object, ID};
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::SessionState;
|
||||
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository};
|
||||
@ -24,6 +24,16 @@ use url::Url;
|
||||
use super::{BrowserSession, NodeType, User};
|
||||
use crate::state::ContextExt;
|
||||
|
||||
/// The state of an OAuth 2.0 session.
|
||||
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
|
||||
pub enum OAuth2SessionState {
|
||||
/// The session is active.
|
||||
Active,
|
||||
|
||||
/// The session is no longer active.
|
||||
Finished,
|
||||
}
|
||||
|
||||
/// An OAuth 2.0 session represents a client session which used the OAuth APIs
|
||||
/// to login.
|
||||
#[derive(Description)]
|
||||
|
@ -19,7 +19,7 @@ use async_graphql::{
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_storage::{
|
||||
compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository},
|
||||
oauth2::OAuth2SessionRepository,
|
||||
oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
|
||||
upstream_oauth2::UpstreamOAuthLinkRepository,
|
||||
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository},
|
||||
Pagination, RepositoryAccess,
|
||||
@ -34,6 +34,7 @@ use crate::{
|
||||
browser_sessions::BrowserSessionState,
|
||||
compat_sessions::{CompatSessionState, CompatSessionType},
|
||||
matrix::MatrixUser,
|
||||
oauth::OAuth2SessionState,
|
||||
CompatSession,
|
||||
},
|
||||
state::ContextExt,
|
||||
@ -365,17 +366,23 @@ impl User {
|
||||
}
|
||||
|
||||
/// Get the list of OAuth 2.0 sessions, chronologically sorted
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn oauth2_sessions(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
|
||||
#[graphql(name = "state", desc = "List only sessions in the given state.")]
|
||||
state_param: Option<OAuth2SessionState>,
|
||||
|
||||
#[graphql(desc = "List only sessions for the given client.")] client: Option<ID>,
|
||||
|
||||
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
|
||||
after: Option<String>,
|
||||
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
|
||||
before: Option<String>,
|
||||
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
|
||||
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
|
||||
) -> Result<Connection<Cursor, OAuth2Session>, async_graphql::Error> {
|
||||
) -> Result<Connection<Cursor, OAuth2Session, PreloadedTotalCount>, async_graphql::Error> {
|
||||
let state = ctx.state();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
@ -393,14 +400,49 @@ impl User {
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let page = repo
|
||||
.oauth2_session()
|
||||
.list_paginated(&self.0, pagination)
|
||||
.await?;
|
||||
let client = if let Some(id) = client {
|
||||
// Load the client if we're filtering by it
|
||||
let id = NodeType::OAuth2Client.extract_ulid(&id)?;
|
||||
let client = repo
|
||||
.oauth2_client()
|
||||
.lookup(id)
|
||||
.await?
|
||||
.ok_or(async_graphql::Error::new("Unknown client ID"))?;
|
||||
|
||||
Some(client)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let filter = OAuth2SessionFilter::new().for_user(&self.0);
|
||||
|
||||
let filter = match state_param {
|
||||
Some(OAuth2SessionState::Active) => filter.active_only(),
|
||||
Some(OAuth2SessionState::Finished) => filter.finished_only(),
|
||||
None => filter,
|
||||
};
|
||||
|
||||
let filter = match client.as_ref() {
|
||||
Some(client) => filter.for_client(client),
|
||||
None => filter,
|
||||
};
|
||||
|
||||
let page = repo.oauth2_session().list(filter, pagination).await?;
|
||||
|
||||
let count = if ctx.look_ahead().field("totalCount").exists() {
|
||||
Some(repo.oauth2_session().count(filter).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
repo.cancel().await?;
|
||||
|
||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||
let mut connection = Connection::with_additional_fields(
|
||||
page.has_previous_page,
|
||||
page.has_next_page,
|
||||
PreloadedTotalCount(count),
|
||||
);
|
||||
|
||||
connection.edges.extend(page.edges.into_iter().map(|s| {
|
||||
Edge::new(
|
||||
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)),
|
||||
|
@ -63,3 +63,17 @@ pub enum CompatSsoLogins {
|
||||
FulfilledAt,
|
||||
ExchangedAt,
|
||||
}
|
||||
|
||||
#[derive(sea_query::Iden)]
|
||||
#[iden = "oauth2_sessions"]
|
||||
pub enum OAuth2Sessions {
|
||||
Table,
|
||||
#[iden = "oauth2_session_id"]
|
||||
OAuth2SessionId,
|
||||
UserSessionId,
|
||||
#[iden = "oauth2_client_id"]
|
||||
OAuth2ClientId,
|
||||
Scope,
|
||||
CreatedAt,
|
||||
FinishedAt,
|
||||
}
|
||||
|
@ -31,7 +31,11 @@ pub use self::{
|
||||
mod tests {
|
||||
use chrono::Duration;
|
||||
use mas_data_model::AuthorizationCode;
|
||||
use mas_storage::{clock::MockClock, Clock, Pagination, Repository};
|
||||
use mas_storage::{
|
||||
clock::MockClock,
|
||||
oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
|
||||
Clock, Pagination, Repository,
|
||||
};
|
||||
use oauth2_types::{
|
||||
requests::{GrantType, ResponseMode},
|
||||
scope::{Scope, OPENID},
|
||||
@ -364,14 +368,279 @@ mod tests {
|
||||
assert!(session.is_valid());
|
||||
let session = repo.oauth2_session().finish(&clock, session).await.unwrap();
|
||||
assert!(!session.is_valid());
|
||||
}
|
||||
|
||||
// The session should appear in the paginated list of sessions for the user
|
||||
let sessions = repo
|
||||
.oauth2_session()
|
||||
.list_paginated(&user, Pagination::first(10))
|
||||
/// Test the [`OAuth2SessionRepository::list`] and
|
||||
/// [`OAuth2SessionRepository::count`] methods.
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_list_sessions(pool: PgPool) {
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
|
||||
|
||||
// Create two users and their corresponding browser sessions
|
||||
let user1 = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "alice".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!sessions.has_next_page);
|
||||
assert_eq!(sessions.edges, vec![session]);
|
||||
let user1_session = repo
|
||||
.browser_session()
|
||||
.add(&mut rng, &clock, &user1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let user2 = repo
|
||||
.user()
|
||||
.add(&mut rng, &clock, "bob".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
let user2_session = repo
|
||||
.browser_session()
|
||||
.add(&mut rng, &clock, &user2)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create two clients
|
||||
let client1 = repo
|
||||
.oauth2_client()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
vec!["https://first.example.com/redirect".parse().unwrap()],
|
||||
None,
|
||||
vec![GrantType::AuthorizationCode],
|
||||
Vec::new(), // TODO: contacts are not yet saved
|
||||
// vec!["contact@first.example.com".to_owned()],
|
||||
Some("First client".to_owned()),
|
||||
Some("https://first.example.com/logo.png".parse().unwrap()),
|
||||
Some("https://first.example.com/".parse().unwrap()),
|
||||
Some("https://first.example.com/policy".parse().unwrap()),
|
||||
Some("https://first.example.com/tos".parse().unwrap()),
|
||||
Some("https://first.example.com/jwks.json".parse().unwrap()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some("https://first.example.com/login".parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let client2 = repo
|
||||
.oauth2_client()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
vec!["https://second.example.com/redirect".parse().unwrap()],
|
||||
None,
|
||||
vec![GrantType::AuthorizationCode],
|
||||
Vec::new(), // TODO: contacts are not yet saved
|
||||
// vec!["contact@second.example.com".to_owned()],
|
||||
Some("Second client".to_owned()),
|
||||
Some("https://second.example.com/logo.png".parse().unwrap()),
|
||||
Some("https://second.example.com/".parse().unwrap()),
|
||||
Some("https://second.example.com/policy".parse().unwrap()),
|
||||
Some("https://second.example.com/tos".parse().unwrap()),
|
||||
Some("https://second.example.com/jwks.json".parse().unwrap()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some("https://second.example.com/login".parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let scope = Scope::from_iter([OPENID]);
|
||||
|
||||
// Create two sessions for each user, one with each client
|
||||
// We're moving the clock forward by 1 minute between each session to ensure
|
||||
// we're getting consistent ordering in lists.
|
||||
let session11 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client1, &user1_session, scope.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
let session12 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client1, &user2_session, scope.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
let session21 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client2, &user1_session, scope.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
let session22 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client2, &user2_session, scope.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
// We're also finishing two of the sessions
|
||||
let session11 = repo
|
||||
.oauth2_session()
|
||||
.finish(&clock, session11)
|
||||
.await
|
||||
.unwrap();
|
||||
let session22 = repo
|
||||
.oauth2_session()
|
||||
.finish(&clock, session22)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let pagination = Pagination::first(10);
|
||||
|
||||
// First, list all the sessions
|
||||
let filter = OAuth2SessionFilter::new();
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 4);
|
||||
assert_eq!(list.edges[0], session11);
|
||||
assert_eq!(list.edges[1], session12);
|
||||
assert_eq!(list.edges[2], session21);
|
||||
assert_eq!(list.edges[3], session22);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
|
||||
|
||||
// Now filter for only one user
|
||||
let filter = OAuth2SessionFilter::new().for_user(&user1);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 2);
|
||||
assert_eq!(list.edges[0], session11);
|
||||
assert_eq!(list.edges[1], session21);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
|
||||
|
||||
// Filter for only one client
|
||||
let filter = OAuth2SessionFilter::new().for_client(&client1);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 2);
|
||||
assert_eq!(list.edges[0], session11);
|
||||
assert_eq!(list.edges[1], session12);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
|
||||
|
||||
// Filter for both a user and a client
|
||||
let filter = OAuth2SessionFilter::new()
|
||||
.for_user(&user2)
|
||||
.for_client(&client2);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0], session22);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||
|
||||
// Filter for active sessions
|
||||
let filter = OAuth2SessionFilter::new().active_only();
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 2);
|
||||
assert_eq!(list.edges[0], session12);
|
||||
assert_eq!(list.edges[1], session21);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
|
||||
|
||||
// Filter for finished sessions
|
||||
let filter = OAuth2SessionFilter::new().finished_only();
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 2);
|
||||
assert_eq!(list.edges[0], session11);
|
||||
assert_eq!(list.edges[1], session22);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
|
||||
|
||||
// Combine the finished filter with the user filter
|
||||
let filter = OAuth2SessionFilter::new().finished_only().for_user(&user2);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0], session22);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||
|
||||
// Combine the finished filter with the client filter
|
||||
let filter = OAuth2SessionFilter::new()
|
||||
.finished_only()
|
||||
.for_client(&client2);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0], session22);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||
|
||||
// Combine the active filter with the user filter
|
||||
let filter = OAuth2SessionFilter::new().active_only().for_user(&user2);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0], session12);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||
|
||||
// Combine the active filter with the client filter
|
||||
let filter = OAuth2SessionFilter::new()
|
||||
.active_only()
|
||||
.for_client(&client2);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0], session21);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||
}
|
||||
}
|
||||
|
@ -14,16 +14,24 @@
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
|
||||
use mas_storage::{oauth2::OAuth2SessionRepository, Clock, Page, Pagination};
|
||||
use mas_data_model::{BrowserSession, Client, Session, SessionState};
|
||||
use mas_storage::{
|
||||
oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
|
||||
Clock, Page, Pagination,
|
||||
};
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
|
||||
iden::{OAuth2Sessions, UserSessions},
|
||||
pagination::QueryBuilderExt,
|
||||
sea_query_sqlx::map_values,
|
||||
tracing::ExecuteExt,
|
||||
DatabaseError, DatabaseInconsistencyError,
|
||||
};
|
||||
|
||||
/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
|
||||
@ -40,6 +48,7 @@ impl<'c> PgOAuth2SessionRepository<'c> {
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
#[enum_def]
|
||||
struct OAuthSessionLookup {
|
||||
oauth2_session_id: Uuid,
|
||||
user_session_id: Uuid,
|
||||
@ -211,45 +220,143 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.list_paginated",
|
||||
name = "db.oauth2_session.list",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_paginated(
|
||||
async fn list(
|
||||
&mut self,
|
||||
user: &User,
|
||||
filter: OAuth2SessionFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<Session>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, os.created_at
|
||||
, os.finished_at
|
||||
FROM oauth2_sessions os
|
||||
INNER JOIN user_sessions USING (user_session_id)
|
||||
"#,
|
||||
);
|
||||
let (sql, values) = Query::select()
|
||||
.expr_as(
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
|
||||
OAuthSessionLookupIden::Oauth2SessionId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
|
||||
OAuthSessionLookupIden::UserSessionId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
|
||||
OAuthSessionLookupIden::Oauth2ClientId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::Scope)),
|
||||
OAuthSessionLookupIden::Scope,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
|
||||
OAuthSessionLookupIden::CreatedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
|
||||
OAuthSessionLookupIden::FinishedAt,
|
||||
)
|
||||
.from(OAuth2Sessions::Table)
|
||||
.and_where_option(filter.user().map(|user| {
|
||||
// Check for user ownership by querying the user_sessions table
|
||||
// The query plan is the same as if we were joining the tables instead
|
||||
Expr::exists(
|
||||
Query::select()
|
||||
.expr(Expr::cust("1"))
|
||||
.from(UserSessions::Table)
|
||||
.and_where(
|
||||
Expr::col((UserSessions::Table, UserSessions::UserId))
|
||||
.eq(Uuid::from(user.id)),
|
||||
)
|
||||
.and_where(
|
||||
Expr::col((UserSessions::Table, UserSessions::UserSessionId))
|
||||
.equals((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
|
||||
)
|
||||
.take(),
|
||||
)
|
||||
}))
|
||||
.and_where_option(filter.client().map(|client| {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
|
||||
.eq(Uuid::from(client.id))
|
||||
}))
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||
} else {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.generate_pagination(
|
||||
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId).into_column_ref(),
|
||||
pagination,
|
||||
)
|
||||
.build(PostgresQueryBuilder);
|
||||
|
||||
query
|
||||
.push(" WHERE user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("oauth2_session_id", pagination);
|
||||
let arguments = map_values(values);
|
||||
|
||||
let edges: Vec<OAuthSessionLookup> = query
|
||||
.build_query_as()
|
||||
let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).try_map(Session::try_from)?;
|
||||
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.count",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
|
||||
let (sql, values) = Query::select()
|
||||
.expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
|
||||
.from(OAuth2Sessions::Table)
|
||||
.and_where_option(filter.user().map(|user| {
|
||||
// Check for user ownership by querying the user_sessions table
|
||||
// The query plan is the same as if we were joining the tables instead
|
||||
Expr::exists(
|
||||
Query::select()
|
||||
.expr(Expr::cust("1"))
|
||||
.from(UserSessions::Table)
|
||||
.and_where(
|
||||
Expr::col((UserSessions::Table, UserSessions::UserId))
|
||||
.eq(Uuid::from(user.id)),
|
||||
)
|
||||
.and_where(
|
||||
Expr::col((UserSessions::Table, UserSessions::UserSessionId))
|
||||
.equals((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
|
||||
)
|
||||
.take(),
|
||||
)
|
||||
}))
|
||||
.and_where_option(filter.client().map(|client| {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
|
||||
.eq(Uuid::from(client.id))
|
||||
}))
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||
} else {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.build(PostgresQueryBuilder);
|
||||
|
||||
let arguments = map_values(values);
|
||||
|
||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
count
|
||||
.try_into()
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
}
|
||||
|
@ -22,6 +22,8 @@ mod session;
|
||||
|
||||
pub use self::{
|
||||
access_token::OAuth2AccessTokenRepository,
|
||||
authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository,
|
||||
refresh_token::OAuth2RefreshTokenRepository, session::OAuth2SessionRepository,
|
||||
authorization_grant::OAuth2AuthorizationGrantRepository,
|
||||
client::OAuth2ClientRepository,
|
||||
refresh_token::OAuth2RefreshTokenRepository,
|
||||
session::{OAuth2SessionFilter, OAuth2SessionRepository},
|
||||
};
|
||||
|
@ -20,6 +20,90 @@ use ulid::Ulid;
|
||||
|
||||
use crate::{pagination::Page, repository_impl, Clock, Pagination};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum OAuth2SessionState {
|
||||
Active,
|
||||
Finished,
|
||||
}
|
||||
|
||||
impl OAuth2SessionState {
|
||||
pub fn is_active(self) -> bool {
|
||||
matches!(self, Self::Active)
|
||||
}
|
||||
|
||||
pub fn is_finished(self) -> bool {
|
||||
matches!(self, Self::Finished)
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter parameters for listing OAuth 2.0 sessions
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub struct OAuth2SessionFilter<'a> {
|
||||
user: Option<&'a User>,
|
||||
client: Option<&'a Client>,
|
||||
state: Option<OAuth2SessionState>,
|
||||
}
|
||||
|
||||
impl<'a> OAuth2SessionFilter<'a> {
|
||||
/// Create a new [`OAuth2SessionFilter`] with default values
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// List sessions for a specific user
|
||||
#[must_use]
|
||||
pub fn for_user(mut self, user: &'a User) -> Self {
|
||||
self.user = Some(user);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the user filter
|
||||
///
|
||||
/// Returns [`None`] if no user filter was set
|
||||
#[must_use]
|
||||
pub fn user(&self) -> Option<&User> {
|
||||
self.user
|
||||
}
|
||||
|
||||
/// List sessions for a specific client
|
||||
#[must_use]
|
||||
pub fn for_client(mut self, client: &'a Client) -> Self {
|
||||
self.client = Some(client);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the client filter
|
||||
///
|
||||
/// Returns [`None`] if no client filter was set
|
||||
#[must_use]
|
||||
pub fn client(&self) -> Option<&Client> {
|
||||
self.client
|
||||
}
|
||||
|
||||
/// Only return active sessions
|
||||
#[must_use]
|
||||
pub fn active_only(mut self) -> Self {
|
||||
self.state = Some(OAuth2SessionState::Active);
|
||||
self
|
||||
}
|
||||
|
||||
/// Only return finished sessions
|
||||
#[must_use]
|
||||
pub fn finished_only(mut self) -> Self {
|
||||
self.state = Some(OAuth2SessionState::Finished);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the state filter
|
||||
///
|
||||
/// Returns [`None`] if no state filter was set
|
||||
#[must_use]
|
||||
pub fn state(&self) -> Option<OAuth2SessionState> {
|
||||
self.state
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
|
||||
/// saved in the storage backend
|
||||
#[async_trait]
|
||||
@ -80,21 +164,32 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
||||
async fn finish(&mut self, clock: &dyn Clock, session: Session)
|
||||
-> Result<Session, Self::Error>;
|
||||
|
||||
/// Get a paginated list of [`Session`]s for a [`User`]
|
||||
/// List [`Session`]s matching the given filter and pagination parameters
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `user`: The [`User`] to get the [`Session`]s for
|
||||
/// * `filter`: The filter parameters
|
||||
/// * `pagination`: The pagination parameters
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn list_paginated(
|
||||
async fn list(
|
||||
&mut self,
|
||||
user: &User,
|
||||
filter: OAuth2SessionFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<Session>, Self::Error>;
|
||||
|
||||
/// Count [`Session`]s matching the given filter
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `filter`: The filter parameters
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(OAuth2SessionRepository:
|
||||
@ -112,9 +207,11 @@ repository_impl!(OAuth2SessionRepository:
|
||||
async fn finish(&mut self, clock: &dyn Clock, session: Session)
|
||||
-> Result<Session, Self::Error>;
|
||||
|
||||
async fn list_paginated(
|
||||
async fn list(
|
||||
&mut self,
|
||||
user: &User,
|
||||
filter: OAuth2SessionFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<Session>, Self::Error>;
|
||||
|
||||
async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
|
||||
);
|
||||
|
Reference in New Issue
Block a user