1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Better OAuth 2.0 sessions pagination and filtering

This commit is contained in:
Quentin Gliech
2023-07-21 17:56:51 +02:00
parent 59c79276bc
commit 6767c93a75
11 changed files with 652 additions and 54 deletions

View File

@ -22,6 +22,8 @@ mod session;
pub use self::{
access_token::OAuth2AccessTokenRepository,
authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository,
refresh_token::OAuth2RefreshTokenRepository, session::OAuth2SessionRepository,
authorization_grant::OAuth2AuthorizationGrantRepository,
client::OAuth2ClientRepository,
refresh_token::OAuth2RefreshTokenRepository,
session::{OAuth2SessionFilter, OAuth2SessionRepository},
};

View File

@ -20,6 +20,90 @@ use ulid::Ulid;
use crate::{pagination::Page, repository_impl, Clock, Pagination};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OAuth2SessionState {
Active,
Finished,
}
impl OAuth2SessionState {
pub fn is_active(self) -> bool {
matches!(self, Self::Active)
}
pub fn is_finished(self) -> bool {
matches!(self, Self::Finished)
}
}
/// Filter parameters for listing OAuth 2.0 sessions
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct OAuth2SessionFilter<'a> {
user: Option<&'a User>,
client: Option<&'a Client>,
state: Option<OAuth2SessionState>,
}
impl<'a> OAuth2SessionFilter<'a> {
/// Create a new [`OAuth2SessionFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// List sessions for a specific user
#[must_use]
pub fn for_user(mut self, user: &'a User) -> Self {
self.user = Some(user);
self
}
/// Get the user filter
///
/// Returns [`None`] if no user filter was set
#[must_use]
pub fn user(&self) -> Option<&User> {
self.user
}
/// List sessions for a specific client
#[must_use]
pub fn for_client(mut self, client: &'a Client) -> Self {
self.client = Some(client);
self
}
/// Get the client filter
///
/// Returns [`None`] if no client filter was set
#[must_use]
pub fn client(&self) -> Option<&Client> {
self.client
}
/// Only return active sessions
#[must_use]
pub fn active_only(mut self) -> Self {
self.state = Some(OAuth2SessionState::Active);
self
}
/// Only return finished sessions
#[must_use]
pub fn finished_only(mut self) -> Self {
self.state = Some(OAuth2SessionState::Finished);
self
}
/// Get the state filter
///
/// Returns [`None`] if no state filter was set
#[must_use]
pub fn state(&self) -> Option<OAuth2SessionState> {
self.state
}
}
/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
/// saved in the storage backend
#[async_trait]
@ -80,21 +164,32 @@ pub trait OAuth2SessionRepository: Send + Sync {
async fn finish(&mut self, clock: &dyn Clock, session: Session)
-> Result<Session, Self::Error>;
/// Get a paginated list of [`Session`]s for a [`User`]
/// List [`Session`]s matching the given filter and pagination parameters
///
/// # Parameters
///
/// * `user`: The [`User`] to get the [`Session`]s for
/// * `filter`: The filter parameters
/// * `pagination`: The pagination parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn list_paginated(
async fn list(
&mut self,
user: &User,
filter: OAuth2SessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<Session>, Self::Error>;
/// Count [`Session`]s matching the given filter
///
/// # Parameters
///
/// * `filter`: The filter parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
}
repository_impl!(OAuth2SessionRepository:
@ -112,9 +207,11 @@ repository_impl!(OAuth2SessionRepository:
async fn finish(&mut self, clock: &dyn Clock, session: Session)
-> Result<Session, Self::Error>;
async fn list_paginated(
async fn list(
&mut self,
user: &User,
filter: OAuth2SessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<Session>, Self::Error>;
async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
);