1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00
Files
authentication-service/crates/storage-pg/src/compat/mod.rs
Quentin Gliech ed5893eb20 Save which user session created a compat session
This also exposes the user session in the GraphQL API, and allow
filtering on browser session ID on the app session list.
2024-02-21 11:55:58 +01:00

689 lines
23 KiB
Rust

// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A module containing PostgreSQL implementation of repositories for the
//! compatibility layer
mod access_token;
mod refresh_token;
mod session;
mod sso_login;
pub use self::{
access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
};
#[cfg(test)]
mod tests {
use chrono::Duration;
use mas_data_model::Device;
use mas_storage::{
clock::MockClock,
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
CompatSessionRepository, CompatSsoLoginFilter,
},
user::UserRepository,
Clock, Pagination, Repository, RepositoryAccess,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use ulid::Ulid;
use crate::PgRepository;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_session_repository(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
let all = CompatSessionFilter::new().for_user(&user);
let active = all.active_only();
let finished = all.finished_only();
let pagination = Pagination::first(10);
assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert!(full_list.edges.is_empty());
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert!(finished_list.edges.is_empty());
// Start a compat session for that user
let device = Device::generate(&mut rng);
let device_str = device.as_str().to_owned();
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device.clone(), None, false)
.await
.unwrap();
assert_eq!(session.user_id, user.id);
assert_eq!(session.device.as_str(), device_str);
assert!(session.is_valid());
assert!(!session.is_finished());
assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(full_list.edges[0].0.id, session.id);
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert_eq!(active_list.edges.len(), 1);
assert_eq!(active_list.edges[0].0.id, session.id);
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert!(finished_list.edges.is_empty());
// Lookup the session and check it didn't change
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user_id, user.id);
assert_eq!(session_lookup.device.as_str(), device_str);
assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished());
// Look up the session by device
let list = repo
.compat_session()
.list(
CompatSessionFilter::new()
.for_user(&user)
.for_device(&device),
pagination,
)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
let session_lookup = &list.edges[0].0;
assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user_id, user.id);
assert_eq!(session_lookup.device.as_str(), device_str);
assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished());
// Finish the session
let session = repo.compat_session().finish(&clock, session).await.unwrap();
assert!(!session.is_valid());
assert!(session.is_finished());
assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(full_list.edges[0].0.id, session.id);
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert_eq!(finished_list.edges.len(), 1);
assert_eq!(finished_list.edges[0].0.id, session.id);
// Reload the session and check again
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert!(!session_lookup.is_valid());
assert!(session_lookup.is_finished());
// Now add another session, with an SSO login this time
let unknown_session = session;
// Start a new SSO login
let login = repo
.compat_sso_login()
.add(
&mut rng,
&clock,
"login-token".to_owned(),
"https://example.com/callback".parse().unwrap(),
)
.await
.unwrap();
assert!(login.is_pending());
// Start a compat session for that user
let device = Device::generate(&mut rng);
let sso_login_session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
// Associate the login with the session
let login = repo
.compat_sso_login()
.fulfill(&clock, login, &sso_login_session)
.await
.unwrap();
assert!(login.is_fulfilled());
// Now query the session list with both the unknown and SSO login session type
// filter
let all = CompatSessionFilter::new().for_user(&user);
let sso_login = all.sso_login_only();
let unknown = all.unknown_only();
assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
let list = repo
.compat_session()
.list(sso_login, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].0.id, sso_login_session.id);
let list = repo
.compat_session()
.list(unknown, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].0.id, unknown_session.id);
// Check that combining the two filters works
// At this point, there is one active SSO login session and one finished unknown
// session
assert_eq!(
repo.compat_session()
.count(all.sso_login_only().active_only())
.await
.unwrap(),
1
);
assert_eq!(
repo.compat_session()
.count(all.sso_login_only().finished_only())
.await
.unwrap(),
0
);
assert_eq!(
repo.compat_session()
.count(all.unknown_only().active_only())
.await
.unwrap(),
0
);
assert_eq!(
repo.compat_session()
.count(all.unknown_only().finished_only())
.await
.unwrap(),
1
);
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_access_token_repository(pool: PgPool) {
const FIRST_TOKEN: &str = "first_access_token";
const SECOND_TOKEN: &str = "second_access_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
// Add an access token to that session
let token = repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, FIRST_TOKEN);
// Commit the txn and grab a new transaction, to test a conflict
repo.save().await.unwrap();
{
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Adding the same token a second time should conflict
assert!(repo
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
FIRST_TOKEN.to_owned(),
Some(Duration::minutes(1)),
)
.await
.is_err());
repo.cancel().await.unwrap();
}
// Grab a new repo
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Looking up via ID works
let token_lookup = repo
.compat_access_token()
.lookup(token.id)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Looking up via the token value works
let token_lookup = repo
.compat_access_token()
.find_by_token(FIRST_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
assert_eq!(token.id, token_lookup.id);
assert_eq!(token_lookup.session_id, session.id);
// Token is currently valid
assert!(token.is_valid(clock.now()));
clock.advance(Duration::minutes(1));
// Token should have expired
assert!(!token.is_valid(clock.now()));
// Add a second access token, this time without expiration
let token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
.await
.unwrap();
assert_eq!(token.session_id, session.id);
assert_eq!(token.token, SECOND_TOKEN);
// Token is currently valid
assert!(token.is_valid(clock.now()));
// Make it expire
repo.compat_access_token()
.expire(&clock, token)
.await
.unwrap();
// Reload it
let token = repo
.compat_access_token()
.find_by_token(SECOND_TOKEN)
.await
.unwrap()
.expect("compat access token not found");
// Token is not valid anymore
assert!(!token.is_valid(clock.now()));
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_refresh_token_repository(pool: PgPool) {
const ACCESS_TOKEN: &str = "access_token";
const REFRESH_TOKEN: &str = "refresh_token";
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
// Add an access token to that session
let access_token = repo
.compat_access_token()
.add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
.await
.unwrap();
let refresh_token = repo
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&access_token,
REFRESH_TOKEN.to_owned(),
)
.await
.unwrap();
assert_eq!(refresh_token.session_id, session.id);
assert_eq!(refresh_token.access_token_id, access_token.id);
assert_eq!(refresh_token.token, REFRESH_TOKEN);
assert!(refresh_token.is_valid());
assert!(!refresh_token.is_consumed());
// Look it up by ID and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.lookup(refresh_token.id)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Look it up by token and check everything matches
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert_eq!(refresh_token_lookup.id, refresh_token.id);
assert_eq!(refresh_token_lookup.session_id, session.id);
assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
assert!(refresh_token_lookup.is_valid());
assert!(!refresh_token_lookup.is_consumed());
// Consume it
let refresh_token = repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.unwrap();
assert!(!refresh_token.is_valid());
assert!(refresh_token.is_consumed());
// Reload it and check again
let refresh_token_lookup = repo
.compat_refresh_token()
.find_by_token(REFRESH_TOKEN)
.await
.unwrap()
.expect("refresh token not found");
assert!(!refresh_token_lookup.is_valid());
assert!(refresh_token_lookup.is_consumed());
// Consuming it again should not work
assert!(repo
.compat_refresh_token()
.consume(&clock, refresh_token)
.await
.is_err());
repo.save().await.unwrap();
}
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_compat_sso_login_repository(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
// Create a user
let user = repo
.user()
.add(&mut rng, &clock, "john".to_owned())
.await
.unwrap();
// Lookup an unknown SSO login
let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
assert_eq!(login, None);
let all = CompatSsoLoginFilter::new();
let for_user = all.for_user(&user);
let pending = all.pending_only();
let fulfilled = all.fulfilled_only();
let exchanged = all.exchanged_only();
// Check the initial counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
// Lookup an unknown login token
let login = repo
.compat_sso_login()
.find_by_token("login-token")
.await
.unwrap();
assert_eq!(login, None);
// Start a new SSO login
let login = repo
.compat_sso_login()
.add(
&mut rng,
&clock,
"login-token".to_owned(),
"https://example.com/callback".parse().unwrap(),
)
.await
.unwrap();
assert!(login.is_pending());
// Check the counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
// Lookup the login by ID
let login_lookup = repo
.compat_sso_login()
.lookup(login.id)
.await
.unwrap()
.expect("login not found");
assert_eq!(login_lookup, login);
// Find the login by token
let login_lookup = repo
.compat_sso_login()
.find_by_token("login-token")
.await
.unwrap()
.expect("login not found");
assert_eq!(login_lookup, login);
// Exchanging before fulfilling should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.exchange(&clock, login.clone())
.await;
assert!(res.is_err());
// Start a compat session for that user
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, None, false)
.await
.unwrap();
// Associate the login with the session
let login = repo
.compat_sso_login()
.fulfill(&clock, login, &session)
.await
.unwrap();
assert!(login.is_fulfilled());
// Check the counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
// Fulfilling again should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.fulfill(&clock, login.clone(), &session)
.await;
assert!(res.is_err());
// Exchange that login
let login = repo
.compat_sso_login()
.exchange(&clock, login)
.await
.unwrap();
assert!(login.is_exchanged());
// Check the counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
// Exchange again should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.exchange(&clock, login.clone())
.await;
assert!(res.is_err());
// Fulfilling after exchanging should not work
// Note: It should also not poison the SQL transaction
let res = repo
.compat_sso_login()
.fulfill(&clock, login.clone(), &session)
.await;
assert!(res.is_err());
let pagination = Pagination::first(10);
// List all logins
let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login.clone()]);
// List the logins for the user
let logins = repo
.compat_sso_login()
.list(for_user, pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login.clone()]);
// List only the pending logins for the user
let logins = repo
.compat_sso_login()
.list(for_user.pending_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert!(logins.edges.is_empty());
// List only the fulfilled logins for the user
let logins = repo
.compat_sso_login()
.list(for_user.fulfilled_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert!(logins.edges.is_empty());
// List only the exchanged logins for the user
let logins = repo
.compat_sso_login()
.list(for_user.exchanged_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login]);
}
}