1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

storage: repository pattern for the compat layer

This commit is contained in:
Quentin Gliech
2023-01-12 15:41:26 +01:00
parent 9f0c9f1466
commit 36396c0b45
18 changed files with 1738 additions and 1191 deletions

View File

@@ -18,8 +18,8 @@ use hyper::StatusCode;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User};
use mas_storage::{
compat::{
add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token,
lookup_compat_session, mark_compat_sso_login_as_exchanged, start_compat_session,
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
user::{UserPasswordRepository, UserRepository},
Clock, Repository,
@@ -224,27 +224,17 @@ pub(crate) async fn post(
};
let access_token = TokenType::CompatAccessToken.generate(&mut rng);
let access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
access_token,
expires_in,
)
.await?;
let access_token = txn
.compat_access_token()
.add(&mut rng, &clock, &session, access_token, expires_in)
.await?;
let refresh_token = if input.refresh_token {
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
let refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&access_token,
refresh_token,
)
.await?;
let refresh_token = txn
.compat_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token)
.await?;
Some(refresh_token.token)
} else {
None
@@ -266,7 +256,9 @@ async fn token_login(
clock: &Clock,
token: &str,
) -> Result<(CompatSession, User), RouteError> {
let login = get_compat_sso_login_by_token(&mut *txn, token)
let login = txn
.compat_sso_login()
.find_by_token(token)
.await?
.ok_or(RouteError::InvalidLoginToken)?;
@@ -308,7 +300,9 @@ async fn token_login(
}
};
let session = lookup_compat_session(&mut *txn, session_id)
let session = txn
.compat_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
@@ -318,7 +312,7 @@ async fn token_login(
.await?
.ok_or(RouteError::UserNotFound)?;
mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
txn.compat_sso_login().exchange(clock, login).await?;
Ok((session, user))
}
@@ -374,7 +368,10 @@ async fn user_password_login(
// Now that the user credentials have been verified, start a new compat session
let device = Device::generate(&mut rng);
let session = start_compat_session(&mut *txn, &mut rng, &clock, &user, device).await?;
let session = txn
.compat_session()
.add(&mut rng, &clock, &user, device)
.await?;
Ok((session, user))
}

View File

@@ -29,7 +29,10 @@ use mas_axum_utils::{
use mas_data_model::Device;
use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository},
Repository,
};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
@@ -87,7 +90,9 @@ pub async fn get(
return Ok((cookie_jar, destination.go()).into_response());
}
let login = get_compat_sso_login_by_id(&mut conn, id)
let login = conn
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
@@ -149,7 +154,9 @@ pub async fn post(
return Ok((cookie_jar, destination.go()).into_response());
}
let login = get_compat_sso_login_by_id(&mut txn, id)
let login = txn
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
@@ -181,8 +188,14 @@ pub async fn post(
};
let device = Device::generate(&mut rng);
let _login =
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, &session.user, login, device).await?;
let compat_session = txn
.compat_session()
.add(&mut rng, &clock, &session.user, device)
.await?;
txn.compat_sso_login()
.fulfill(&clock, login, &compat_session)
.await?;
txn.commit().await?;

View File

@@ -19,7 +19,7 @@ use axum::{
};
use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::compat::insert_compat_sso_login;
use mas_storage::{compat::CompatSsoLoginRepository, Repository};
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use serde_with::serde;
@@ -49,6 +49,7 @@ pub enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@@ -80,7 +81,10 @@ pub async fn get(
let token = Alphanumeric.sample_string(&mut rng, 32);
let mut conn = pool.acquire().await?;
let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
let login = conn
.compat_sso_login()
.add(&mut rng, &clock, token, redirect_url)
.await?;
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action)))
}

View File

@@ -17,8 +17,8 @@ use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::TokenType;
use mas_storage::{
compat::{end_compat_session, find_compat_access_token, lookup_compat_session},
Clock,
compat::{CompatAccessTokenRepository, CompatSessionRepository},
Clock, Repository,
};
use sqlx::PgPool;
use thiserror::Error;
@@ -83,17 +83,21 @@ pub(crate) async fn post(
return Err(RouteError::InvalidAuthorization);
}
let token = find_compat_access_token(&mut txn, token)
let token = txn
.compat_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::InvalidAuthorization)?;
let session = lookup_compat_session(&mut txn, token.session_id)
let session = txn
.compat_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::InvalidAuthorization)?;
end_compat_session(&mut txn, &clock, session).await?;
txn.compat_session().finish(&clock, session).await?;
txn.commit().await?;

View File

@@ -16,10 +16,9 @@ use axum::{extract::State, response::IntoResponse, Json};
use chrono::Duration;
use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::compat::{
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
expire_compat_access_token, find_compat_refresh_token, lookup_compat_access_token,
lookup_compat_session,
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
Repository,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds};
@@ -101,7 +100,9 @@ pub(crate) async fn post(
return Err(RouteError::InvalidToken);
}
let refresh_token = find_compat_refresh_token(&mut txn, &input.refresh_token)
let refresh_token = txn
.compat_refresh_token()
.find_by_token(&input.refresh_token)
.await?
.ok_or(RouteError::InvalidToken)?;
@@ -109,7 +110,9 @@ pub(crate) async fn post(
return Err(RouteError::RefreshTokenConsumed);
}
let session = lookup_compat_session(&mut txn, refresh_token.session_id)
let session = txn
.compat_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::UnknownSession)?;
@@ -117,7 +120,9 @@ pub(crate) async fn post(
return Err(RouteError::InvalidSession);
}
let access_token = lookup_compat_access_token(&mut txn, refresh_token.access_token_id)
let access_token = txn
.compat_access_token()
.lookup(refresh_token.access_token_id)
.await?
.filter(|t| t.is_valid(clock.now()));
@@ -125,29 +130,35 @@ pub(crate) async fn post(
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
let expires_in = Duration::minutes(5);
let new_access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
let new_access_token = txn
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = txn
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?;
txn.compat_refresh_token()
.consume(&clock, refresh_token)
.await?;
if let Some(access_token) = access_token {
expire_compat_access_token(&mut txn, &clock, access_token).await?;
txn.compat_access_token()
.expire(&clock, access_token)
.await?;
}
txn.commit().await?;

View File

@@ -22,7 +22,7 @@ use mas_data_model::{TokenFormatError, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter;
use mas_storage::{
compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session},
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{
access_token::find_access_token, refresh_token::lookup_refresh_token,
OAuth2SessionRepository,
@@ -243,12 +243,16 @@ pub(crate) async fn post(
}
TokenType::CompatAccessToken => {
let token = find_compat_access_token(&mut conn, token)
let access_token = conn
.compat_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = lookup_compat_session(&mut conn, token.session_id)
let session = conn
.compat_session()
.lookup(access_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;
@@ -269,9 +273,9 @@ pub(crate) async fn post(
client_id: Some("legacy".into()),
username: Some(user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: token.expires_at,
iat: Some(token.created_at),
nbf: Some(token.created_at),
exp: access_token.expires_at,
iat: Some(access_token.created_at),
nbf: Some(access_token.created_at),
sub: Some(user.sub),
aud: None,
iss: None,
@@ -280,12 +284,16 @@ pub(crate) async fn post(
}
TokenType::CompatRefreshToken => {
let refresh_token = find_compat_refresh_token(&mut conn, token)
let refresh_token = conn
.compat_refresh_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = lookup_compat_session(&mut conn, refresh_token.session_id)
let session = conn
.compat_session()
.lookup(refresh_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;

View File

@@ -15,7 +15,7 @@
use anyhow::Context;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
compat::CompatSsoLoginRepository, oauth2::authorization_grant::get_grant_by_id,
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository,
};
use mas_templates::{PostAuthContext, PostAuthContextInner};
@@ -54,7 +54,9 @@ impl OptionalPostAuthAction {
}
PostAuthAction::ContinueCompatSsoLogin { id } => {
let login = get_compat_sso_login_by_id(conn, id)
let login = conn
.compat_sso_login()
.lookup(id)
.await?
.context("Failed to load compat SSO login")?;
let login = Box::new(login);