1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

storage: wrap the postgres repository in a struct

This commit is contained in:
Quentin Gliech
2023-01-13 18:03:37 +01:00
parent 488a666a8d
commit 195203823a
44 changed files with 505 additions and 548 deletions

View File

@@ -27,11 +27,11 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
Repository,
PgRepository, Repository,
};
use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
use sqlx::{PgPool, Postgres, Transaction};
use sqlx::PgPool;
use thiserror::Error;
use ulid::Ulid;
@@ -84,13 +84,13 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>,
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let mut txn = pool.begin().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let grant = txn
let grant = repo
.oauth2_authorization_grant()
.lookup(grant_id)
.await?
@@ -107,7 +107,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response());
};
match complete(grant, session, &policy_factory, txn).await {
match complete(grant, session, &policy_factory, repo).await {
Ok(params) => {
let res = callback_destination.go(&templates, params).await?;
Ok((cookie_jar, res).into_response())
@@ -159,7 +159,7 @@ pub(crate) async fn complete(
grant: AuthorizationGrant,
browser_session: BrowserSession,
policy_factory: &PolicyFactory,
mut txn: Transaction<'_, Postgres>,
mut repo: PgRepository,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
let (clock, mut rng) = crate::clock_and_rng();
@@ -170,7 +170,7 @@ pub(crate) async fn complete(
// Check if the authentication is fresh enough
if !browser_session.was_authenticated_after(grant.max_auth_time()) {
txn.commit().await?;
repo.save().await?;
return Err(GrantCompletionError::RequiresReauth);
}
@@ -184,13 +184,13 @@ pub(crate) async fn complete(
return Err(GrantCompletionError::PolicyViolation);
}
let client = txn
let client = repo
.oauth2_client()
.lookup(grant.client_id)
.await?
.ok_or(GrantCompletionError::NoSuchClient)?;
let current_consent = txn
let current_consent = repo
.oauth2_client()
.get_consent_for_user(&client, &browser_session.user)
.await?;
@@ -202,17 +202,17 @@ pub(crate) async fn complete(
// Check if the client lacks consent *or* if consent was explicitely asked
if lacks_consent || grant.requires_consent {
txn.commit().await?;
repo.save().await?;
return Err(GrantCompletionError::RequiresConsent);
}
// All good, let's start the session
let session = txn
let session = repo
.oauth2_session()
.create_from_grant(&mut rng, &clock, &grant, &browser_session)
.await?;
let grant = txn
let grant = repo
.oauth2_authorization_grant()
.fulfill(&clock, &session, grant)
.await?;
@@ -233,6 +233,6 @@ pub(crate) async fn complete(
));
}
txn.commit().await?;
repo.save().await?;
Ok(params)
}

View File

@@ -27,7 +27,7 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
Repository,
PgRepository, Repository,
};
use mas_templates::Templates;
use oauth2_types::{
@@ -139,10 +139,10 @@ pub(crate) async fn get(
Form(params): Form<Params>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
// First, figure out what client it is
let client = txn
let client = repo
.oauth2_client()
.find_by_client_id(&params.auth.client_id)
.await?
@@ -170,7 +170,7 @@ pub(crate) async fn get(
let templates = templates.clone();
let callback_destination = callback_destination.clone();
async move {
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
// Check if the request/request_uri/registration params are used. If so, reply
@@ -275,7 +275,7 @@ pub(crate) async fn get(
let requires_consent = prompt.contains(&Prompt::Consent);
let grant = txn
let grant = repo
.oauth2_authorization_grant()
.add(
&mut rng,
@@ -302,7 +302,7 @@ pub(crate) async fn get(
}
None if prompt.contains(&Prompt::Create) => {
// Client asked for a registration, show the registration prompt
txn.commit().await?;
repo.save().await?;
mas_router::Register::and_then(continue_grant)
.go()
@@ -310,7 +310,7 @@ pub(crate) async fn get(
}
None => {
// Other cases where we don't have a session, ask for a login
txn.commit().await?;
repo.save().await?;
mas_router::Login::and_then(continue_grant)
.go()
@@ -323,7 +323,7 @@ pub(crate) async fn get(
|| prompt.contains(&Prompt::SelectAccount) =>
{
// TODO: better pages here
txn.commit().await?;
repo.save().await?;
mas_router::Reauth::and_then(continue_grant)
.go()
@@ -333,7 +333,7 @@ pub(crate) async fn get(
// Else, we immediately try to complete the authorization grant
Some(user_session) if prompt.contains(&Prompt::None) => {
// With prompt=none, we should get back to the client immediately
match self::complete::complete(grant, user_session, &policy_factory, txn).await
match self::complete::complete(grant, user_session, &policy_factory, repo).await
{
Ok(params) => callback_destination.go(&templates, params).await?,
Err(GrantCompletionError::RequiresConsent) => {
@@ -372,7 +372,7 @@ pub(crate) async fn get(
Some(user_session) => {
let grant_id = grant.id;
// Else, we show the relevant reauth/consent page if necessary
match self::complete::complete(grant, user_session, &policy_factory, txn).await
match self::complete::complete(grant, user_session, &policy_factory, repo).await
{
Ok(params) => callback_destination.go(&templates, params).await?,
Err(

View File

@@ -30,7 +30,7 @@ use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
Repository,
PgRepository, Repository,
};
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
use sqlx::PgPool;
@@ -81,13 +81,13 @@ pub(crate) async fn get(
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut conn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let grant = conn
let grant = repo
.oauth2_authorization_grant()
.lookup(grant_id)
.await?
@@ -136,15 +136,15 @@ pub(crate) async fn post(
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
cookie_jar.verify_form(clock.now(), form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut txn).await?;
let maybe_session = session_info.load_session(&mut repo).await?;
let grant = txn
let grant = repo
.oauth2_authorization_grant()
.lookup(grant_id)
.await?
@@ -167,7 +167,7 @@ pub(crate) async fn post(
return Err(RouteError::PolicyViolation);
}
let client = txn
let client = repo
.oauth2_client()
.lookup(grant.client_id)
.await?
@@ -180,7 +180,7 @@ pub(crate) async fn post(
.filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:"))
.cloned()
.collect();
txn.oauth2_client()
repo.oauth2_client()
.give_consent_for_user(
&mut rng,
&clock,
@@ -190,9 +190,11 @@ pub(crate) async fn post(
)
.await?;
txn.oauth2_authorization_grant().give_consent(grant).await?;
repo.oauth2_authorization_grant()
.give_consent(grant)
.await?;
txn.commit().await?;
repo.save().await?;
Ok((cookie_jar, next.go_next()).into_response())
}

View File

@@ -25,7 +25,7 @@ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
user::{BrowserSessionRepository, UserRepository},
Clock, Repository,
Clock, PgRepository, Repository,
};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
@@ -130,12 +130,13 @@ pub(crate) async fn post(
client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let client = client_authorization
.credentials
.fetch(&mut conn)
.await?
.fetch(&mut repo)
.await
.unwrap()
.ok_or(RouteError::ClientNotFound)?;
let method = match &client.token_endpoint_auth_method {
@@ -166,14 +167,14 @@ pub(crate) async fn post(
let reply = match token_type {
TokenType::AccessToken => {
let token = conn
let token = repo
.oauth2_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = conn
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
@@ -181,7 +182,7 @@ pub(crate) async fn post(
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let browser_session = conn
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
@@ -205,14 +206,14 @@ pub(crate) async fn post(
}
TokenType::RefreshToken => {
let token = conn
let token = repo
.oauth2_refresh_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = conn
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
@@ -220,7 +221,7 @@ pub(crate) async fn post(
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let browser_session = conn
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
@@ -244,21 +245,21 @@ pub(crate) async fn post(
}
TokenType::CompatAccessToken => {
let access_token = conn
let access_token = repo
.compat_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = conn
let session = repo
.compat_session()
.lookup(access_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;
let user = conn
let user = repo
.user()
.lookup(session.user_id)
.await?
@@ -285,21 +286,21 @@ pub(crate) async fn post(
}
TokenType::CompatRefreshToken => {
let refresh_token = conn
let refresh_token = repo
.compat_refresh_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = conn
let session = repo
.compat_session()
.lookup(refresh_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;
let user = conn
let user = repo
.user()
.lookup(session.user_id)
.await?

View File

@@ -19,7 +19,7 @@ use hyper::StatusCode;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::Encrypter;
use mas_policy::{PolicyFactory, Violation};
use mas_storage::{oauth2::OAuth2ClientRepository, Repository};
use mas_storage::{oauth2::OAuth2ClientRepository, PgRepository, Repository};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::{
@@ -124,8 +124,7 @@ pub(crate) async fn post(
return Err(RouteError::PolicyDenied(res.violations));
}
// Grab a txn
let mut txn = pool.begin().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
Some(
@@ -141,7 +140,7 @@ pub(crate) async fn post(
_ => (None, None),
};
let client = txn
let client = repo
.oauth2_client()
.add(
&mut rng,
@@ -170,7 +169,7 @@ pub(crate) async fn post(
)
.await?;
txn.commit().await?;
repo.save().await?;
let response = ClientRegistrationResponse {
client_id: client.client_id,

View File

@@ -37,7 +37,7 @@ use mas_storage::{
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
user::BrowserSessionRepository,
Repository,
PgRepository, Repository,
};
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
@@ -49,7 +49,7 @@ use oauth2_types::{
};
use serde::Serialize;
use serde_with::{serde_as, skip_serializing_none};
use sqlx::{PgPool, Postgres, Transaction};
use sqlx::PgPool;
use thiserror::Error;
use tracing::debug;
use url::Url;
@@ -166,11 +166,11 @@ pub(crate) async fn post(
State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<AccessTokenRequest>,
) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let client = client_authorization
.credentials
.fetch(&mut txn)
.fetch(&mut repo)
.await?
.ok_or(RouteError::ClientNotFound)?;
@@ -188,10 +188,10 @@ pub(crate) async fn post(
let reply = match form {
AccessTokenRequest::AuthorizationCode(grant) => {
authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await?
authorization_code_grant(&grant, &client, &key_store, &url_builder, repo).await?
}
AccessTokenRequest::RefreshToken(grant) => {
refresh_token_grant(&grant, &client, txn).await?
refresh_token_grant(&grant, &client, repo).await?
}
_ => {
return Err(RouteError::InvalidGrant);
@@ -211,11 +211,11 @@ async fn authorization_code_grant(
client: &Client,
key_store: &Keystore,
url_builder: &UrlBuilder,
mut txn: Transaction<'_, Postgres>,
mut repo: PgRepository,
) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let authz_grant = txn
let authz_grant = repo
.oauth2_authorization_grant()
.find_by_code(&grant.code)
.await?
@@ -238,13 +238,13 @@ async fn authorization_code_grant(
// Ending the session if the token was already exchanged more than 20s ago
if now - exchanged_at > Duration::seconds(20) {
debug!("Ending potentially compromised session");
let session = txn
let session = repo
.oauth2_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NoSuchOAuthSession)?;
txn.oauth2_session().finish(&clock, session).await?;
txn.commit().await?;
repo.oauth2_session().finish(&clock, session).await?;
repo.save().await?;
}
return Err(RouteError::InvalidGrant);
@@ -266,7 +266,7 @@ async fn authorization_code_grant(
}
};
let session = txn
let session = repo
.oauth2_session()
.lookup(session_id)
.await?
@@ -289,7 +289,7 @@ async fn authorization_code_grant(
}
};
let browser_session = txn
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
@@ -299,12 +299,12 @@ async fn authorization_code_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let access_token = txn
let access_token = repo
.oauth2_access_token()
.add(&mut rng, &clock, &session, access_token_str, ttl)
.await?;
let refresh_token = txn
let refresh_token = repo
.oauth2_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token_str)
.await?;
@@ -355,11 +355,11 @@ async fn authorization_code_grant(
params = params.with_id_token(id_token);
}
txn.oauth2_authorization_grant()
repo.oauth2_authorization_grant()
.exchange(&clock, authz_grant)
.await?;
txn.commit().await?;
repo.save().await?;
Ok(params)
}
@@ -367,17 +367,17 @@ async fn authorization_code_grant(
async fn refresh_token_grant(
grant: &RefreshTokenGrant,
client: &Client,
mut txn: Transaction<'_, Postgres>,
mut repo: PgRepository,
) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let refresh_token = txn
let refresh_token = repo
.oauth2_refresh_token()
.find_by_token(&grant.refresh_token)
.await?
.ok_or(RouteError::InvalidGrant)?;
let session = txn
let session = repo
.oauth2_session()
.lookup(refresh_token.session_id)
.await?
@@ -396,12 +396,12 @@ async fn refresh_token_grant(
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let new_access_token = txn
let new_access_token = repo
.oauth2_access_token()
.add(&mut rng, &clock, &session, access_token_str.clone(), ttl)
.await?;
let new_refresh_token = txn
let new_refresh_token = repo
.oauth2_refresh_token()
.add(
&mut rng,
@@ -412,14 +412,14 @@ async fn refresh_token_grant(
)
.await?;
let refresh_token = txn
let refresh_token = repo
.oauth2_refresh_token()
.consume(&clock, refresh_token)
.await?;
if let Some(access_token_id) = refresh_token.access_token_id {
if let Some(access_token) = txn.oauth2_access_token().lookup(access_token_id).await? {
txn.oauth2_access_token()
if let Some(access_token) = repo.oauth2_access_token().lookup(access_token_id).await? {
repo.oauth2_access_token()
.revoke(&clock, access_token)
.await?;
}
@@ -430,7 +430,7 @@ async fn refresh_token_grant(
.with_refresh_token(new_refresh_token.refresh_token)
.with_scope(session.scope);
txn.commit().await?;
repo.save().await?;
Ok(params)
}

View File

@@ -31,7 +31,7 @@ use mas_router::UrlBuilder;
use mas_storage::{
oauth2::OAuth2ClientRepository,
user::{BrowserSessionRepository, UserEmailRepository},
Repository,
DatabaseError, PgRepository, Repository,
};
use oauth2_types::scope;
use serde::Serialize;
@@ -64,7 +64,7 @@ pub enum RouteError {
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("failed to authenticate")]
AuthorizationVerificationError(#[from] AuthorizationVerificationError),
AuthorizationVerificationError(#[from] AuthorizationVerificationError<DatabaseError>),
#[error("no suitable key found for signing")]
InvalidSigningKey,
@@ -102,11 +102,11 @@ pub async fn get(
user_authorization: UserAuthorization,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::clock_and_rng();
let mut conn = pool.acquire().await?;
let mut repo = PgRepository::from_pool(&pool).await?;
let session = user_authorization.protected(&mut conn, clock.now()).await?;
let session = user_authorization.protected(&mut repo, clock.now()).await?;
let browser_session = conn
let browser_session = repo
.browser_session()
.lookup(session.user_session_id)
.await?
@@ -115,7 +115,7 @@ pub async fn get(
let user = browser_session.user;
let user_email = if session.scope.contains(&scope::EMAIL) {
conn.user_email().get_primary(&user).await?
repo.user_email().get_primary(&user).await?
} else {
None
};
@@ -127,7 +127,7 @@ pub async fn get(
email: user_email.map(|u| u.email),
};
let client = conn
let client = repo
.oauth2_client()
.lookup(session.client_id)
.await?