diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d72e5f48..14848935 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -24,6 +24,7 @@ pub use self::{ #[cfg(test)] mod tests { + use chrono::Duration; use oauth2_types::scope::{Scope, OPENID}; use rand::SeedableRng; use sqlx::PgPool; @@ -32,13 +33,13 @@ mod tests { use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository}; #[sqlx::test(migrator = "crate::MIGRATOR")] - async fn test_repository(pool: PgPool) -> Result<(), Box> { + async fn test_repository(pool: PgPool) { let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); let clock = Clock::mock(); - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); // The provider list should be empty at the start - let all_providers = repo.upstream_oauth_provider().all().await?; + let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); assert!(all_providers.is_empty()); // Let's add a provider @@ -54,13 +55,15 @@ mod tests { "client-id".to_owned(), None, ) - .await?; + .await + .unwrap(); // Look it up in the database let provider = repo .upstream_oauth_provider() .lookup(provider.id) - .await? + .await + .unwrap() .expect("provider to be found in the database"); assert_eq!(provider.issuer, "https://example.com/"); assert_eq!(provider.client_id, "client-id"); @@ -76,13 +79,15 @@ mod tests { None, "some-nonce".to_owned(), ) - .await?; + .await + .unwrap(); // Look it up in the database let session = repo .upstream_oauth_session() .lookup(session.id) - .await? + .await + .unwrap() .expect("session to be found in the database"); assert_eq!(session.provider_id, provider.id); assert_eq!(session.link_id(), None); @@ -94,19 +99,22 @@ mod tests { let link = repo .upstream_oauth_link() .add(&mut rng, &clock, &provider, "a-subject".to_owned()) - .await?; + .await + .unwrap(); // We can look it up by its ID repo.upstream_oauth_link() .lookup(link.id) - .await? + .await + .unwrap() .expect("link to be found in database"); // or by its subject let link = repo .upstream_oauth_link() .find_by_subject(&provider, "a-subject") - .await? + .await + .unwrap() .expect("link to be found in database"); assert_eq!(link.subject, "a-subject"); assert_eq!(link.provider_id, provider.id); @@ -114,12 +122,14 @@ mod tests { let session = repo .upstream_oauth_session() .complete_with_link(&clock, session, &link, None) - .await?; + .await + .unwrap(); // Reload the session let session = repo .upstream_oauth_session() .lookup(session.id) - .await? + .await + .unwrap() .expect("session to be found in the database"); assert!(session.is_completed()); assert!(!session.is_consumed()); @@ -128,30 +138,128 @@ mod tests { let session = repo .upstream_oauth_session() .consume(&clock, session) - .await?; + .await + .unwrap(); // Reload the session let session = repo .upstream_oauth_session() .lookup(session.id) - .await? + .await + .unwrap() .expect("session to be found in the database"); assert!(session.is_consumed()); - let user = repo.user().add(&mut rng, &clock, "john".to_owned()).await?; + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); repo.upstream_oauth_link() .associate_to_user(&link, &user) - .await?; + .await + .unwrap(); let links = repo .upstream_oauth_link() .list_paginated(&user, &Pagination::first(10)) - .await?; + .await + .unwrap(); assert!(!links.has_previous_page); assert!(!links.has_next_page); assert_eq!(links.edges.len(), 1); assert_eq!(links.edges[0].id, link.id); assert_eq!(links.edges[0].user_id, Some(user.id)); + } - Ok(()) + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_provider_repository_pagination(pool: PgPool) { + const ISSUER: &str = "https://example.com/"; + let scope = Scope::from_iter([OPENID]); + + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + let mut ids = Vec::with_capacity(20); + // Create 20 providers + for idx in 0..20 { + let client_id = format!("client-{idx}"); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + ISSUER.to_owned(), + scope.clone(), + mas_iana::oauth::OAuthClientAuthenticationMethod::None, + None, + client_id, + None, + ) + .await + .unwrap(); + ids.push(provider.id); + clock.advance(Duration::seconds(10)); + } + + // Lookup the first 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::first(10)) + .await + .unwrap(); + + // It returned the first 10 items + assert!(page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup the next 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::first(10).after(ids[9])) + .await + .unwrap(); + + // It returned the next 10 items + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the last 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::last(10)) + .await + .unwrap(); + + // It returned the last 10 items + assert!(page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[10..]); + + // Lookup the previous 10 items + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::last(10).before(ids[10])) + .await + .unwrap(); + + // It returned the previous 10 items + assert!(!page.has_previous_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[..10]); + + // Lookup 10 items between two IDs + let page = repo + .upstream_oauth_provider() + .list_paginated(&Pagination::first(10).after(ids[5]).before(ids[8])) + .await + .unwrap(); + + // It returned the items in between + assert!(!page.has_next_page); + let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); + assert_eq!(&edge_ids, &ids[6..8]); } }