1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-01 20:26:56 +03:00

storage: Load with less joins

This is done to simplify some queries, to avoid loading more data than
necessary, and in preparation of a proper cache layer
This commit is contained in:
Quentin Gliech
2023-01-04 18:06:17 +01:00
parent a7883618be
commit e26f75246d
16 changed files with 824 additions and 1209 deletions

View File

@ -23,8 +23,6 @@ use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use crate::User;
static DEVICE_ID_LENGTH: usize = 10; static DEVICE_ID_LENGTH: usize = 10;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
@ -85,7 +83,7 @@ impl TryFrom<String> for Device {
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct CompatSession { pub struct CompatSession {
pub id: Ulid, pub id: Ulid,
pub user: User, pub user_id: Ulid,
pub device: Device, pub device: Device,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub finished_at: Option<DateTime<Utc>>, pub finished_at: Option<DateTime<Utc>>,

View File

@ -16,13 +16,10 @@ use oauth2_types::scope::Scope;
use serde::Serialize; use serde::Serialize;
use ulid::Ulid; use ulid::Ulid;
use super::client::Client;
use crate::users::BrowserSession;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Session { pub struct Session {
pub id: Ulid, pub id: Ulid,
pub browser_session: BrowserSession, pub user_session_id: Ulid,
pub client: Client, pub client_id: Ulid,
pub scope: Scope, pub scope: Scope,
} }

View File

@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use async_graphql::{Description, Object, ID}; use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::CompatSsoLoginState; use mas_data_model::CompatSsoLoginState;
use mas_storage::{user::UserRepository, Repository};
use sqlx::PgPool;
use url::Url; use url::Url;
use super::{NodeType, User}; use super::{NodeType, User};
@ -32,8 +35,14 @@ impl CompatSession {
} }
/// The user authorized for this session. /// The user authorized for this session.
async fn user(&self) -> User { async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
User(self.0.user.clone()) let mut conn = ctx.data::<PgPool>()?.acquire().await?;
let user = conn
.user()
.lookup(self.0.user_id)
.await?
.context("Could not load user")?;
Ok(User(user))
} }
/// The Matrix Device ID of this session. /// The Matrix Device ID of this session.

View File

@ -14,7 +14,9 @@
use anyhow::Context as _; use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID}; use async_graphql::{Context, Description, Object, ID};
use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; use mas_storage::{
oauth2::client::OAuth2ClientRepository, user::BrowserSessionRepository, Repository,
};
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use sqlx::PgPool; use sqlx::PgPool;
use ulid::Ulid; use ulid::Ulid;
@ -35,8 +37,15 @@ impl OAuth2Session {
} }
/// OAuth 2.0 client used by this session. /// OAuth 2.0 client used by this session.
pub async fn client(&self) -> OAuth2Client { pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
OAuth2Client(self.0.client.clone()) let mut conn = ctx.data::<PgPool>()?.acquire().await?;
let client = conn
.oauth2_client()
.lookup(self.0.client_id)
.await?
.context("Could not load client")?;
Ok(OAuth2Client(client))
} }
/// Scope granted for this session. /// Scope granted for this session.
@ -45,13 +54,30 @@ impl OAuth2Session {
} }
/// The browser session which started this OAuth 2.0 session. /// The browser session which started this OAuth 2.0 session.
pub async fn browser_session(&self) -> BrowserSession { pub async fn browser_session(
BrowserSession(self.0.browser_session.clone()) &self,
ctx: &Context<'_>,
) -> Result<BrowserSession, async_graphql::Error> {
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
let browser_session = conn
.browser_session()
.lookup(self.0.user_session_id)
.await?
.context("Could not load browser session")?;
Ok(BrowserSession(browser_session))
} }
/// User authorized for this session. /// User authorized for this session.
pub async fn user(&self) -> User { pub async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
User(self.0.browser_session.user.clone()) let mut conn = ctx.data::<PgPool>()?.acquire().await?;
let browser_session = conn
.browser_session()
.lookup(self.0.user_session_id)
.await?
.context("Could not load browser session")?;
Ok(User(browser_session.user))
} }
} }

View File

@ -15,7 +15,7 @@
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use chrono::Duration; use chrono::Duration;
use hyper::StatusCode; use hyper::StatusCode;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType}; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User};
use mas_storage::{ use mas_storage::{
compat::{ compat::{
add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token,
@ -197,7 +197,7 @@ pub(crate) async fn post(
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let session = match input.credentials { let (session, user) = match input.credentials {
Credentials::Password { Credentials::Password {
identifier: Identifier::User { user }, identifier: Identifier::User { user },
password, password,
@ -210,7 +210,7 @@ pub(crate) async fn post(
} }
}; };
let user_id = format!("@{username}:{homeserver}", username = session.user.username); let user_id = format!("@{username}:{homeserver}", username = user.username);
// If the client asked for a refreshable token, make it expire // If the client asked for a refreshable token, make it expire
let expires_in = if input.refresh_token { let expires_in = if input.refresh_token {
@ -262,13 +262,13 @@ async fn token_login(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
clock: &Clock, clock: &Clock,
token: &str, token: &str,
) -> Result<CompatSession, RouteError> { ) -> Result<(CompatSession, User), RouteError> {
let login = get_compat_sso_login_by_token(&mut *txn, token) let login = get_compat_sso_login_by_token(&mut *txn, token)
.await? .await?
.ok_or(RouteError::InvalidLoginToken)?; .ok_or(RouteError::InvalidLoginToken)?;
let now = clock.now(); let now = clock.now();
match login.state { let user_id = match login.state {
CompatSsoLoginState::Pending => { CompatSsoLoginState::Pending => {
tracing::error!( tracing::error!(
compat_sso_login.id = %login.id, compat_sso_login.id = %login.id,
@ -278,11 +278,14 @@ async fn token_login(
} }
CompatSsoLoginState::Fulfilled { CompatSsoLoginState::Fulfilled {
fulfilled_at: fullfilled_at, fulfilled_at: fullfilled_at,
ref session,
.. ..
} => { } => {
if now > fullfilled_at + Duration::seconds(30) { if now > fullfilled_at + Duration::seconds(30) {
return Err(RouteError::LoginTookTooLong); return Err(RouteError::LoginTookTooLong);
} }
session.user_id
} }
CompatSsoLoginState::Exchanged { exchanged_at, .. } => { CompatSsoLoginState::Exchanged { exchanged_at, .. } => {
if now > exchanged_at + Duration::seconds(30) { if now > exchanged_at + Duration::seconds(30) {
@ -295,12 +298,18 @@ async fn token_login(
return Err(RouteError::InvalidLoginToken); return Err(RouteError::InvalidLoginToken);
} }
} };
let user = txn
.user()
.lookup(user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
match login.state { match login.state {
CompatSsoLoginState::Exchanged { session, .. } => Ok(session), CompatSsoLoginState::Exchanged { session, .. } => Ok((session, user)),
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -310,7 +319,7 @@ async fn user_password_login(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
username: String, username: String,
password: String, password: String,
) -> Result<CompatSession, RouteError> { ) -> Result<(CompatSession, User), RouteError> {
let (clock, mut rng) = crate::clock_and_rng(); let (clock, mut rng) = crate::clock_and_rng();
// Find the user // Find the user
@ -356,7 +365,7 @@ async fn user_password_login(
// Now that the user credentials have been verified, start a new compat session // Now that the user credentials have been verified, start a new compat session
let device = Device::generate(&mut rng); let device = Device::generate(&mut rng);
let session = start_compat_session(&mut *txn, &mut rng, &clock, user, device).await?; let session = start_compat_session(&mut *txn, &mut rng, &clock, &user, device).await?;
Ok(session) Ok((session, user))
} }

View File

@ -182,7 +182,7 @@ pub async fn post(
let device = Device::generate(&mut rng); let device = Device::generate(&mut rng);
let _login = let _login =
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?; fullfill_compat_sso_login(&mut txn, &mut rng, &clock, &session.user, login, device).await?;
txn.commit().await?; txn.commit().await?;

View File

@ -26,7 +26,8 @@ use mas_storage::{
oauth2::{ oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
}, },
Clock, user::{BrowserSessionRepository, UserRepository},
Clock, Repository,
}; };
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
@ -171,16 +172,23 @@ pub(crate) async fn post(
.await? .await?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let browser_session = conn
.browser_session()
.lookup(session.user_session_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
IntrospectionResponse { IntrospectionResponse {
active: true, active: true,
scope: Some(session.scope), scope: Some(session.scope),
client_id: Some(session.client.client_id), client_id: Some(session.client_id.to_string()),
username: Some(session.browser_session.user.username), username: Some(browser_session.user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken), token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: Some(token.expires_at), exp: Some(token.expires_at),
iat: Some(token.created_at), iat: Some(token.created_at),
nbf: Some(token.created_at), nbf: Some(token.created_at),
sub: Some(session.browser_session.user.sub), sub: Some(browser_session.user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: None,
@ -191,16 +199,23 @@ pub(crate) async fn post(
.await? .await?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let browser_session = conn
.browser_session()
.lookup(session.user_session_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
IntrospectionResponse { IntrospectionResponse {
active: true, active: true,
scope: Some(session.scope), scope: Some(session.scope),
client_id: Some(session.client.client_id), client_id: Some(session.client_id.to_string()),
username: Some(session.browser_session.user.username), username: Some(browser_session.user.username),
token_type: Some(OAuthTokenTypeHint::RefreshToken), token_type: Some(OAuthTokenTypeHint::RefreshToken),
exp: None, exp: None,
iat: Some(token.created_at), iat: Some(token.created_at),
nbf: Some(token.created_at), nbf: Some(token.created_at),
sub: Some(session.browser_session.user.sub), sub: Some(browser_session.user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: None,
@ -211,6 +226,13 @@ pub(crate) async fn post(
.await? .await?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let user = conn
.user()
.lookup(session.user_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let device_scope = session.device.to_scope_token(); let device_scope = session.device.to_scope_token();
let scope = [API_SCOPE, device_scope].into_iter().collect(); let scope = [API_SCOPE, device_scope].into_iter().collect();
@ -218,12 +240,12 @@ pub(crate) async fn post(
active: true, active: true,
scope: Some(scope), scope: Some(scope),
client_id: Some("legacy".into()), client_id: Some("legacy".into()),
username: Some(session.user.username), username: Some(user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken), token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: token.expires_at, exp: token.expires_at,
iat: Some(token.created_at), iat: Some(token.created_at),
nbf: Some(token.created_at), nbf: Some(token.created_at),
sub: Some(session.user.sub), sub: Some(user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: None,
@ -235,6 +257,13 @@ pub(crate) async fn post(
.await? .await?
.ok_or(RouteError::UnknownToken)?; .ok_or(RouteError::UnknownToken)?;
let user = conn
.user()
.lookup(session.user_id)
.await?
// XXX: is that the right error to bubble up?
.ok_or(RouteError::UnknownToken)?;
let device_scope = session.device.to_scope_token(); let device_scope = session.device.to_scope_token();
let scope = [API_SCOPE, device_scope].into_iter().collect(); let scope = [API_SCOPE, device_scope].into_iter().collect();
@ -242,12 +271,12 @@ pub(crate) async fn post(
active: true, active: true,
scope: Some(scope), scope: Some(scope),
client_id: Some("legacy".into()), client_id: Some("legacy".into()),
username: Some(session.user.username), username: Some(user.username),
token_type: Some(OAuthTokenTypeHint::RefreshToken), token_type: Some(OAuthTokenTypeHint::RefreshToken),
exp: None, exp: None,
iat: Some(refresh_token.created_at), iat: Some(refresh_token.created_at),
nbf: Some(refresh_token.created_at), nbf: Some(refresh_token.created_at),
sub: Some(session.user.sub), sub: Some(user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: None,

View File

@ -31,11 +31,15 @@ use mas_jose::{
}; };
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::oauth2::{ use mas_storage::{
access_token::{add_access_token, revoke_access_token}, oauth2::{
authorization_grant::{exchange_grant, lookup_grant_by_code}, access_token::{add_access_token, revoke_access_token},
end_oauth_session, authorization_grant::{exchange_grant, lookup_grant_by_code},
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, end_oauth_session,
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token},
},
user::BrowserSessionRepository,
Repository,
}; };
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
@ -102,12 +106,15 @@ pub(crate) enum RouteError {
#[error("no suitable key found for signing")] #[error("no suitable key found for signing")]
InvalidSigningKey, InvalidSigningKey,
#[error("failed to load browser session")]
NoSuchBrowserSession,
} }
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
Self::Internal(_) | Self::InvalidSigningKey => ( Self::Internal(_) | Self::InvalidSigningKey | Self::NoSuchBrowserSession => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ClientError::from(ClientErrorCode::ServerError)), Json(ClientError::from(ClientErrorCode::ServerError)),
), ),
@ -253,7 +260,7 @@ async fn authorization_code_grant(
// This should never happen, since we looked up in the database using the code // This should never happen, since we looked up in the database using the code
let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?; let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
if client.client_id != session.client.client_id { if client.id != session.client_id {
return Err(RouteError::UnauthorizedClient); return Err(RouteError::UnauthorizedClient);
} }
@ -267,7 +274,11 @@ async fn authorization_code_grant(
} }
}; };
let browser_session = &session.browser_session; let browser_session = txn
.browser_session()
.lookup(session.user_session_id)
.await?
.ok_or(RouteError::NoSuchBrowserSession)?;
let ttl = Duration::minutes(5); let ttl = Duration::minutes(5);
let access_token_str = TokenType::AccessToken.generate(&mut rng); let access_token_str = TokenType::AccessToken.generate(&mut rng);
@ -357,7 +368,7 @@ async fn refresh_token_grant(
.await? .await?
.ok_or(RouteError::InvalidGrant)?; .ok_or(RouteError::InvalidGrant)?;
if client.client_id != session.client.client_id { if client.id != session.client_id {
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
return Err(RouteError::InvalidGrant); return Err(RouteError::InvalidGrant);
} }

View File

@ -28,7 +28,11 @@ use mas_jose::{
}; };
use mas_keystore::Keystore; use mas_keystore::Keystore;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{user::UserEmailRepository, Repository}; use mas_storage::{
oauth2::client::OAuth2ClientRepository,
user::{BrowserSessionRepository, UserEmailRepository},
Repository,
};
use oauth2_types::scope; use oauth2_types::scope;
use serde::Serialize; use serde::Serialize;
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
@ -64,6 +68,12 @@ pub enum RouteError {
#[error("no suitable key found for signing")] #[error("no suitable key found for signing")]
InvalidSigningKey, InvalidSigningKey,
#[error("failed to load client")]
NoSuchClient,
#[error("failed to load browser session")]
NoSuchBrowserSession,
} }
impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(sqlx::Error);
@ -74,7 +84,10 @@ impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
Self::Internal(_) | Self::InvalidSigningKey => { Self::Internal(_)
| Self::InvalidSigningKey
| Self::NoSuchClient
| Self::NoSuchBrowserSession => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
} }
Self::AuthorizationVerificationError(_e) => StatusCode::UNAUTHORIZED.into_response(), Self::AuthorizationVerificationError(_e) => StatusCode::UNAUTHORIZED.into_response(),
@ -93,7 +106,13 @@ pub async fn get(
let session = user_authorization.protected(&mut conn).await?; let session = user_authorization.protected(&mut conn).await?;
let user = session.browser_session.user; let browser_session = conn
.browser_session()
.lookup(session.user_session_id)
.await?
.ok_or(RouteError::NoSuchBrowserSession)?;
let user = browser_session.user;
let user_email = if session.scope.contains(&scope::EMAIL) { let user_email = if session.scope.contains(&scope::EMAIL) {
conn.user_email().get_primary(&user).await? conn.user_email().get_primary(&user).await?
@ -108,7 +127,13 @@ pub async fn get(
email: user_email.map(|u| u.email), email: user_email.map(|u| u.email),
}; };
if let Some(alg) = session.client.userinfo_signed_response_alg { let client = conn
.oauth2_client()
.lookup(session.client_id)
.await?
.ok_or(RouteError::NoSuchClient)?;
if let Some(alg) = client.userinfo_signed_response_alg {
let key = key_store let key = key_store
.signing_key_for_algorithm(&alg) .signing_key_for_algorithm(&alg)
.ok_or(RouteError::InvalidSigningKey)?; .ok_or(RouteError::InvalidSigningKey)?;
@ -119,7 +144,7 @@ pub async fn get(
let user_info = SignedUserInfo { let user_info = SignedUserInfo {
iss: url_builder.oidc_issuer().to_string(), iss: url_builder.oidc_issuer().to_string(),
aud: session.client.client_id, aud: client.client_id,
user_info, user_info,
}; };

File diff suppressed because it is too large Load Diff

View File

@ -39,8 +39,6 @@ struct CompatAccessTokenLookup {
compat_session_finished_at: Option<DateTime<Utc>>, compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: String, compat_session_device_id: String,
user_id: Uuid, user_id: Uuid,
user_username: String,
user_primary_user_email_id: Option<Uuid>,
} }
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
@ -52,24 +50,19 @@ pub async fn lookup_active_compat_access_token(
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatAccessTokenLookup, CompatAccessTokenLookup,
r#" r#"
SELECT SELECT ct.compat_access_token_id
ct.compat_access_token_id, , ct.access_token AS "compat_access_token"
ct.access_token AS "compat_access_token", , ct.created_at AS "compat_access_token_created_at"
ct.created_at AS "compat_access_token_created_at", , ct.expires_at AS "compat_access_token_expires_at"
ct.expires_at AS "compat_access_token_expires_at", , cs.compat_session_id
cs.compat_session_id, , cs.created_at AS "compat_session_created_at"
cs.created_at AS "compat_session_created_at", , cs.finished_at AS "compat_session_finished_at"
cs.finished_at AS "compat_session_finished_at", , cs.device_id AS "compat_session_device_id"
cs.device_id AS "compat_session_device_id", , cs.user_id AS "user_id!"
u.user_id AS "user_id!",
u.username AS "user_username!",
u.primary_user_email_id AS "user_primary_user_email_id"
FROM compat_access_tokens ct FROM compat_access_tokens ct
INNER JOIN compat_sessions cs INNER JOIN compat_sessions cs
USING (compat_session_id) USING (compat_session_id)
INNER JOIN users u
USING (user_id)
WHERE ct.access_token = $1 WHERE ct.access_token = $1
AND (ct.expires_at < $2 OR ct.expires_at IS NULL) AND (ct.expires_at < $2 OR ct.expires_at IS NULL)
@ -92,14 +85,6 @@ pub async fn lookup_active_compat_access_token(
expires_at: res.compat_access_token_expires_at, expires_at: res.compat_access_token_expires_at,
}; };
let user_id = Ulid::from(res.user_id);
let user = User {
id: user_id,
username: res.user_username,
sub: user_id.to_string(),
primary_user_email_id: res.user_primary_user_email_id.map(Into::into),
};
let id = res.compat_session_id.into(); let id = res.compat_session_id.into();
let device = Device::try_from(res.compat_session_device_id).map_err(|e| { let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions") DatabaseInconsistencyError::on("compat_sessions")
@ -110,7 +95,7 @@ pub async fn lookup_active_compat_access_token(
let session = CompatSession { let session = CompatSession {
id, id,
user, user_id: res.user_id.into(),
device, device,
created_at: res.compat_session_created_at, created_at: res.compat_session_created_at,
finished_at: res.compat_session_finished_at, finished_at: res.compat_session_finished_at,
@ -132,8 +117,6 @@ pub struct CompatRefreshTokenLookup {
compat_session_finished_at: Option<DateTime<Utc>>, compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: String, compat_session_device_id: String,
user_id: Uuid, user_id: Uuid,
user_username: String,
user_primary_user_email_id: Option<Uuid>,
} }
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
@ -145,29 +128,24 @@ pub async fn lookup_active_compat_refresh_token(
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatRefreshTokenLookup, CompatRefreshTokenLookup,
r#" r#"
SELECT SELECT cr.compat_refresh_token_id
cr.compat_refresh_token_id, , cr.refresh_token AS "compat_refresh_token"
cr.refresh_token AS "compat_refresh_token", , cr.created_at AS "compat_refresh_token_created_at"
cr.created_at AS "compat_refresh_token_created_at", , ct.compat_access_token_id
ct.compat_access_token_id, , ct.access_token AS "compat_access_token"
ct.access_token AS "compat_access_token", , ct.created_at AS "compat_access_token_created_at"
ct.created_at AS "compat_access_token_created_at", , ct.expires_at AS "compat_access_token_expires_at"
ct.expires_at AS "compat_access_token_expires_at", , cs.compat_session_id
cs.compat_session_id, , cs.created_at AS "compat_session_created_at"
cs.created_at AS "compat_session_created_at", , cs.finished_at AS "compat_session_finished_at"
cs.finished_at AS "compat_session_finished_at", , cs.device_id AS "compat_session_device_id"
cs.device_id AS "compat_session_device_id", , cs.user_id
u.user_id,
u.username AS "user_username!",
u.primary_user_email_id AS "user_primary_user_email_id"
FROM compat_refresh_tokens cr FROM compat_refresh_tokens cr
INNER JOIN compat_sessions cs INNER JOIN compat_sessions cs
USING (compat_session_id) USING (compat_session_id)
INNER JOIN compat_access_tokens ct INNER JOIN compat_access_tokens ct
USING (compat_access_token_id) USING (compat_access_token_id)
INNER JOIN users u
USING (user_id)
WHERE cr.refresh_token = $1 WHERE cr.refresh_token = $1
AND cr.consumed_at IS NULL AND cr.consumed_at IS NULL
@ -195,25 +173,17 @@ pub async fn lookup_active_compat_refresh_token(
expires_at: res.compat_access_token_expires_at, expires_at: res.compat_access_token_expires_at,
}; };
let user_id = Ulid::from(res.user_id); let id = res.compat_session_id.into();
let user = User {
id: user_id,
username: res.user_username,
sub: user_id.to_string(),
primary_user_email_id: res.user_primary_user_email_id.map(Into::into),
};
let session_id = res.compat_session_id.into();
let device = Device::try_from(res.compat_session_device_id).map_err(|e| { let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions") DatabaseInconsistencyError::on("compat_sessions")
.column("device_id") .column("device_id")
.row(session_id) .row(id)
.source(e) .source(e)
})?; })?;
let session = CompatSession { let session = CompatSession {
id: session_id, id,
user, user_id: res.user_id.into(),
device, device,
created_at: res.compat_session_created_at, created_at: res.compat_session_created_at,
finished_at: res.compat_session_finished_at, finished_at: res.compat_session_finished_at,
@ -228,7 +198,7 @@ pub async fn lookup_active_compat_refresh_token(
compat_session.id = %session.id, compat_session.id = %session.id,
compat_session.device.id = session.device.as_str(), compat_session.device.id = session.device.as_str(),
compat_access_token.id, compat_access_token.id,
user.id = %session.user.id, user.id = %session.user_id,
), ),
err, err,
)] )]
@ -305,7 +275,7 @@ pub async fn expire_compat_access_token(
compat_session.device.id = session.device.as_str(), compat_session.device.id = session.device.as_str(),
compat_access_token.id = %access_token.id, compat_access_token.id = %access_token.id,
compat_refresh_token.id, compat_refresh_token.id,
user.id = %session.user.id, user.id = %session.user_id,
), ),
err, err,
)] )]
@ -469,8 +439,6 @@ struct CompatSsoLoginLookup {
compat_session_finished_at: Option<DateTime<Utc>>, compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: Option<String>, compat_session_device_id: Option<String>,
user_id: Option<Uuid>, user_id: Option<Uuid>,
user_username: Option<String>,
user_primary_user_email_id: Option<Uuid>,
} }
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin { impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
@ -485,33 +453,14 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
.source(e) .source(e)
})?; })?;
let user = match (
res.user_id,
res.user_username,
res.user_primary_user_email_id,
) {
(Some(id), Some(username), primary_email_id) => {
let id = Ulid::from(id);
Some(User {
id,
username,
sub: id.to_string(),
primary_user_email_id: primary_email_id.map(Into::into),
})
}
(None, None, None) => None,
_ => return Err(DatabaseInconsistencyError::on("compat_sessions").column("user_id")),
};
let session = match ( let session = match (
res.compat_session_id, res.compat_session_id,
res.compat_session_device_id, res.compat_session_device_id,
res.compat_session_created_at, res.compat_session_created_at,
res.compat_session_finished_at, res.compat_session_finished_at,
user, res.user_id,
) { ) {
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => { (Some(id), Some(device_id), Some(created_at), finished_at, Some(user_id)) => {
let id = id.into(); let id = id.into();
let device = Device::try_from(device_id).map_err(|e| { let device = Device::try_from(device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions") DatabaseInconsistencyError::on("compat_sessions")
@ -521,7 +470,7 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
})?; })?;
Some(CompatSession { Some(CompatSession {
id, id,
user, user_id: user_id.into(),
device, device,
created_at, created_at,
finished_at, finished_at,
@ -579,25 +528,21 @@ pub async fn get_compat_sso_login_by_id(
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatSsoLoginLookup, CompatSsoLoginLookup,
r#" r#"
SELECT SELECT cl.compat_sso_login_id
cl.compat_sso_login_id, , cl.login_token AS "compat_sso_login_token"
cl.login_token AS "compat_sso_login_token", , cl.redirect_uri AS "compat_sso_login_redirect_uri"
cl.redirect_uri AS "compat_sso_login_redirect_uri", , cl.created_at AS "compat_sso_login_created_at"
cl.created_at AS "compat_sso_login_created_at", , cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
cl.fulfilled_at AS "compat_sso_login_fulfilled_at", , cl.exchanged_at AS "compat_sso_login_exchanged_at"
cl.exchanged_at AS "compat_sso_login_exchanged_at", , cs.compat_session_id AS "compat_session_id?"
cs.compat_session_id AS "compat_session_id?", , cs.created_at AS "compat_session_created_at?"
cs.created_at AS "compat_session_created_at?", , cs.finished_at AS "compat_session_finished_at?"
cs.finished_at AS "compat_session_finished_at?", , cs.device_id AS "compat_session_device_id?"
cs.device_id AS "compat_session_device_id?", , cs.user_id AS "user_id?"
u.user_id AS "user_id?",
u.username AS "user_username?",
u.primary_user_email_id AS "user_primary_user_email_id?"
FROM compat_sso_logins cl FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs LEFT JOIN compat_sessions cs
USING (compat_session_id) USING (compat_session_id)
LEFT JOIN users u
USING (user_id)
WHERE cl.compat_sso_login_id = $1 WHERE cl.compat_sso_login_id = $1
"#, "#,
Uuid::from(id), Uuid::from(id),
@ -632,25 +577,20 @@ pub async fn get_paginated_user_compat_sso_logins(
// because we already have them // because we already have them
let mut query = QueryBuilder::new( let mut query = QueryBuilder::new(
r#" r#"
SELECT SELECT cl.compat_sso_login_id
cl.compat_sso_login_id, , cl.login_token AS "compat_sso_login_token"
cl.login_token AS "compat_sso_login_token", , cl.redirect_uri AS "compat_sso_login_redirect_uri"
cl.redirect_uri AS "compat_sso_login_redirect_uri", , cl.created_at AS "compat_sso_login_created_at"
cl.created_at AS "compat_sso_login_created_at", , cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
cl.fulfilled_at AS "compat_sso_login_fulfilled_at", , cl.exchanged_at AS "compat_sso_login_exchanged_at"
cl.exchanged_at AS "compat_sso_login_exchanged_at", , cs.compat_session_id AS "compat_session_id"
cs.compat_session_id AS "compat_session_id", , cs.created_at AS "compat_session_created_at"
cs.created_at AS "compat_session_created_at", , cs.finished_at AS "compat_session_finished_at"
cs.finished_at AS "compat_session_finished_at", , cs.device_id AS "compat_session_device_id"
cs.device_id AS "compat_session_device_id", , cs.user_id
u.user_id AS "user_id",
u.username AS "user_username",
u.primary_user_email_id AS "user_primary_user_email_id?"
FROM compat_sso_logins cl FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs LEFT JOIN compat_sessions cs
USING (compat_session_id) USING (compat_session_id)
LEFT JOIN users u
USING (user_id)
"#, "#,
); );
@ -683,25 +623,20 @@ pub async fn get_compat_sso_login_by_token(
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatSsoLoginLookup, CompatSsoLoginLookup,
r#" r#"
SELECT SELECT cl.compat_sso_login_id
cl.compat_sso_login_id, , cl.login_token AS "compat_sso_login_token"
cl.login_token AS "compat_sso_login_token", , cl.redirect_uri AS "compat_sso_login_redirect_uri"
cl.redirect_uri AS "compat_sso_login_redirect_uri", , cl.created_at AS "compat_sso_login_created_at"
cl.created_at AS "compat_sso_login_created_at", , cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
cl.fulfilled_at AS "compat_sso_login_fulfilled_at", , cl.exchanged_at AS "compat_sso_login_exchanged_at"
cl.exchanged_at AS "compat_sso_login_exchanged_at", , cs.compat_session_id AS "compat_session_id?"
cs.compat_session_id AS "compat_session_id?", , cs.created_at AS "compat_session_created_at?"
cs.created_at AS "compat_session_created_at?", , cs.finished_at AS "compat_session_finished_at?"
cs.finished_at AS "compat_session_finished_at?", , cs.device_id AS "compat_session_device_id?"
cs.device_id AS "compat_session_device_id?", , cs.user_id AS "user_id?"
u.user_id AS "user_id?",
u.username AS "user_username?",
u.primary_user_email_id AS "user_primary_user_email_id?"
FROM compat_sso_logins cl FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs LEFT JOIN compat_sessions cs
USING (compat_session_id) USING (compat_session_id)
LEFT JOIN users u
USING (user_id)
WHERE cl.login_token = $1 WHERE cl.login_token = $1
"#, "#,
token, token,
@ -729,7 +664,7 @@ pub async fn start_compat_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
user: User, user: &User,
device: Device, device: Device,
) -> Result<CompatSession, DatabaseError> { ) -> Result<CompatSession, DatabaseError> {
let created_at = clock.now(); let created_at = clock.now();
@ -751,7 +686,7 @@ pub async fn start_compat_session(
Ok(CompatSession { Ok(CompatSession {
id, id,
user, user_id: user.id,
device, device,
created_at, created_at,
finished_at: None, finished_at: None,
@ -773,7 +708,7 @@ pub async fn fullfill_compat_sso_login(
conn: impl Acquire<'_, Database = Postgres> + Send, conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
user: User, user: &User,
mut compat_sso_login: CompatSsoLogin, mut compat_sso_login: CompatSsoLogin,
device: Device, device: Device,
) -> Result<CompatSsoLogin, DatabaseError> { ) -> Result<CompatSsoLogin, DatabaseError> {

View File

@ -13,21 +13,20 @@
// limitations under the License. // limitations under the License.
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User}; use mas_data_model::{AccessToken, Session};
use rand::Rng; use rand::Rng;
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::client::OAuth2ClientRepository; use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
%session.id, %session.id,
client.id = %session.client.id, user_session.id = %session.user_session_id,
user.id = %session.browser_session.user.id, client.id = %session.client_id,
access_token.id, access_token.id,
), ),
err, err,
@ -81,12 +80,6 @@ pub struct OAuth2AccessTokenLookup {
oauth2_client_id: Uuid, oauth2_client_id: Uuid,
scope: String, scope: String,
user_session_id: Uuid, user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
user_id: Uuid,
user_username: String,
user_primary_user_email_id: Option<Uuid>,
user_session_last_authentication_id: Option<Uuid>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
} }
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
@ -104,30 +97,15 @@ pub async fn lookup_active_access_token(
, os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_session_id AS "oauth2_session_id!"
, os.oauth2_client_id AS "oauth2_client_id!" , os.oauth2_client_id AS "oauth2_client_id!"
, os.scope AS "scope!" , os.scope AS "scope!"
, us.user_session_id AS "user_session_id!" , os.user_session_id AS "user_session_id!"
, us.created_at AS "user_session_created_at!"
, u.user_id AS "user_id!"
, u.username AS "user_username!"
, u.primary_user_email_id AS "user_primary_user_email_id"
, usa.user_session_authentication_id AS "user_session_last_authentication_id?"
, usa.created_at AS "user_session_last_authentication_created_at?"
FROM oauth2_access_tokens at FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os INNER JOIN oauth2_sessions os
USING (oauth2_session_id) USING (oauth2_session_id)
INNER JOIN user_sessions us
USING (user_session_id)
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications usa
USING (user_session_id)
WHERE at.access_token = $1 WHERE at.access_token = $1
AND at.revoked_at IS NULL AND at.revoked_at IS NULL
AND os.finished_at IS NULL AND os.finished_at IS NULL
ORDER BY usa.created_at DESC
LIMIT 1
"#, "#,
token, token,
) )
@ -144,44 +122,6 @@ pub async fn lookup_active_access_token(
}; };
let session_id = res.oauth2_session_id.into(); let session_id = res.oauth2_session_id.into();
let client = conn
.oauth2_client()
.lookup(res.oauth2_client_id.into())
.await?
.ok_or_else(|| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("client_id")
.row(session_id)
})?;
let user_id = Ulid::from(res.user_id);
let user = User {
id: user_id,
username: res.user_username,
sub: user_id.to_string(),
primary_user_email_id: res.user_primary_user_email_id.map(Into::into),
};
let last_authentication = match (
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
id: id.into(),
created_at,
}),
_ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()),
};
let browser_session = BrowserSession {
id: res.user_session_id.into(),
created_at: res.user_session_created_at,
finished_at: None,
user,
last_authentication,
};
let scope = res.scope.parse().map_err(|e| { let scope = res.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions") DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope") .column("scope")
@ -191,8 +131,8 @@ pub async fn lookup_active_access_token(
let session = Session { let session = Session {
id: session_id, id: session_id,
client, client_id: res.oauth2_client_id.into(),
browser_session, user_session_id: res.user_session_id.into(),
scope, scope,
}; };

View File

@ -16,8 +16,8 @@ use std::num::NonZeroU32;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{ use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Client, Pkce,
Client, Pkce, Session, User, Session,
}; };
use mas_iana::oauth::PkceCodeChallengeMethod; use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope}; use oauth2_types::{requests::ResponseMode, scope::Scope};
@ -151,12 +151,6 @@ struct GrantLookup {
oauth2_client_id: Uuid, oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>, oauth2_session_id: Option<Uuid>,
user_session_id: Option<Uuid>, user_session_id: Option<Uuid>,
user_session_created_at: Option<DateTime<Utc>>,
user_id: Option<Uuid>,
user_username: Option<String>,
user_primary_user_email_id: Option<Uuid>,
user_session_last_authentication_id: Option<Uuid>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
} }
impl GrantLookup { impl GrantLookup {
@ -183,65 +177,20 @@ impl GrantLookup {
.row(id) .row(id)
})?; })?;
let last_authentication = match ( let session = match (self.oauth2_session_id, self.user_session_id) {
self.user_session_last_authentication_id, (Some(session_id), Some(user_session_id)) => {
self.user_session_last_authentication_created_at,
) {
(Some(id), Some(created_at)) => Some(Authentication {
id: id.into(),
created_at,
}),
(None, None) => None,
_ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()),
};
let session = match (
self.oauth2_session_id,
self.user_session_id,
self.user_session_created_at,
self.user_id,
self.user_username,
self.user_primary_user_email_id,
last_authentication,
) {
(
Some(session_id),
Some(user_session_id),
Some(user_session_created_at),
Some(user_id),
Some(user_username),
user_primary_user_email_id,
last_authentication,
) => {
let user_id = Ulid::from(user_id);
let user = User {
id: user_id,
username: user_username,
sub: user_id.to_string(),
primary_user_email_id: user_primary_user_email_id.map(Into::into),
};
let browser_session = BrowserSession {
id: user_session_id.into(),
user,
created_at: user_session_created_at,
finished_at: None,
last_authentication,
};
let client = client.clone();
let scope = scope.clone(); let scope = scope.clone();
let session = Session { let session = Session {
id: session_id.into(), id: session_id.into(),
client, client_id: client.id,
browser_session, user_session_id: user_session_id.into(),
scope, scope,
}; };
Some(session) Some(session)
} }
(None, None, None, None, None, None, None) => None, (None, None) => None,
_ => { _ => {
return Err( return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants") DatabaseInconsistencyError::on("oauth2_authorization_grants")
@ -394,48 +343,32 @@ pub async fn get_grant_by_id(
let res = sqlx::query_as!( let res = sqlx::query_as!(
GrantLookup, GrantLookup,
r#" r#"
SELECT SELECT og.oauth2_authorization_grant_id
og.oauth2_authorization_grant_id, , og.created_at AS oauth2_authorization_grant_created_at
og.created_at AS oauth2_authorization_grant_created_at, , og.cancelled_at AS oauth2_authorization_grant_cancelled_at
og.cancelled_at AS oauth2_authorization_grant_cancelled_at, , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at
og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at, , og.exchanged_at AS oauth2_authorization_grant_exchanged_at
og.exchanged_at AS oauth2_authorization_grant_exchanged_at, , og.scope AS oauth2_authorization_grant_scope
og.scope AS oauth2_authorization_grant_scope, , og.state AS oauth2_authorization_grant_state
og.state AS oauth2_authorization_grant_state, , og.redirect_uri AS oauth2_authorization_grant_redirect_uri
og.redirect_uri AS oauth2_authorization_grant_redirect_uri, , og.response_mode AS oauth2_authorization_grant_response_mode
og.response_mode AS oauth2_authorization_grant_response_mode, , og.nonce AS oauth2_authorization_grant_nonce
og.nonce AS oauth2_authorization_grant_nonce, , og.max_age AS oauth2_authorization_grant_max_age
og.max_age AS oauth2_authorization_grant_max_age, , og.oauth2_client_id AS oauth2_client_id
og.oauth2_client_id AS oauth2_client_id, , og.authorization_code AS oauth2_authorization_grant_code
og.authorization_code AS oauth2_authorization_grant_code, , og.response_type_code AS oauth2_authorization_grant_response_type_code
og.response_type_code AS oauth2_authorization_grant_response_type_code, , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token
og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token, , og.code_challenge AS oauth2_authorization_grant_code_challenge
og.code_challenge AS oauth2_authorization_grant_code_challenge, , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method
og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method, , og.requires_consent AS oauth2_authorization_grant_requires_consent
og.requires_consent AS oauth2_authorization_grant_requires_consent, , os.oauth2_session_id AS "oauth2_session_id?"
os.oauth2_session_id AS "oauth2_session_id?", , os.user_session_id AS "user_session_id?"
us.user_session_id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.user_id AS "user_id?",
u.username AS "user_username?",
u.primary_user_email_id AS "user_primary_user_email_id?",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM FROM
oauth2_authorization_grants og oauth2_authorization_grants og
LEFT JOIN oauth2_sessions os LEFT JOIN oauth2_sessions os
USING (oauth2_session_id) USING (oauth2_session_id)
LEFT JOIN user_sessions us
USING (user_session_id)
LEFT JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications usa
USING (user_session_id)
WHERE og.oauth2_authorization_grant_id = $1 WHERE og.oauth2_authorization_grant_id = $1
ORDER BY usa.created_at DESC
LIMIT 1
"#, "#,
Uuid::from(id), Uuid::from(id),
) )
@ -458,48 +391,32 @@ pub async fn lookup_grant_by_code(
let res = sqlx::query_as!( let res = sqlx::query_as!(
GrantLookup, GrantLookup,
r#" r#"
SELECT SELECT og.oauth2_authorization_grant_id
og.oauth2_authorization_grant_id, , og.created_at AS oauth2_authorization_grant_created_at
og.created_at AS oauth2_authorization_grant_created_at, , og.cancelled_at AS oauth2_authorization_grant_cancelled_at
og.cancelled_at AS oauth2_authorization_grant_cancelled_at, , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at
og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at, , og.exchanged_at AS oauth2_authorization_grant_exchanged_at
og.exchanged_at AS oauth2_authorization_grant_exchanged_at, , og.scope AS oauth2_authorization_grant_scope
og.scope AS oauth2_authorization_grant_scope, , og.state AS oauth2_authorization_grant_state
og.state AS oauth2_authorization_grant_state, , og.redirect_uri AS oauth2_authorization_grant_redirect_uri
og.redirect_uri AS oauth2_authorization_grant_redirect_uri, , og.response_mode AS oauth2_authorization_grant_response_mode
og.response_mode AS oauth2_authorization_grant_response_mode, , og.nonce AS oauth2_authorization_grant_nonce
og.nonce AS oauth2_authorization_grant_nonce, , og.max_age AS oauth2_authorization_grant_max_age
og.max_age AS oauth2_authorization_grant_max_age, , og.oauth2_client_id AS oauth2_client_id
og.oauth2_client_id AS oauth2_client_id, , og.authorization_code AS oauth2_authorization_grant_code
og.authorization_code AS oauth2_authorization_grant_code, , og.response_type_code AS oauth2_authorization_grant_response_type_code
og.response_type_code AS oauth2_authorization_grant_response_type_code, , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token
og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token, , og.code_challenge AS oauth2_authorization_grant_code_challenge
og.code_challenge AS oauth2_authorization_grant_code_challenge, , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method
og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method, , og.requires_consent AS oauth2_authorization_grant_requires_consent
og.requires_consent AS oauth2_authorization_grant_requires_consent, , os.oauth2_session_id AS "oauth2_session_id?"
os.oauth2_session_id AS "oauth2_session_id?", , os.user_session_id AS "user_session_id?"
us.user_session_id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.user_id AS "user_id?",
u.username AS "user_username?",
u.primary_user_email_id AS "user_primary_user_email_id?",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM FROM
oauth2_authorization_grants og oauth2_authorization_grants og
LEFT JOIN oauth2_sessions os LEFT JOIN oauth2_sessions os
USING (oauth2_session_id) USING (oauth2_session_id)
LEFT JOIN user_sessions us
USING (user_session_id)
LEFT JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications usa
USING (user_session_id)
WHERE og.authorization_code = $1 WHERE og.authorization_code = $1
ORDER BY usa.created_at DESC
LIMIT 1
"#, "#,
code, code,
) )
@ -561,8 +478,8 @@ pub async fn derive_session(
Ok(Session { Ok(Session {
id, id,
browser_session, user_session_id: browser_session.id,
client: grant.client.clone(), client_id: grant.client.id,
scope: grant.scope.clone(), scope: grant.scope.clone(),
}) })
} }
@ -573,8 +490,7 @@ pub async fn derive_session(
%grant.id, %grant.id,
client.id = %grant.client.id, client.id = %grant.client.id,
%session.id, %session.id,
user_session.id = %session.browser_session.id, user_session.id = %session.user_session_id,
user.id = %session.browser_session.user.id,
), ),
err, err,
)] )]

View File

@ -12,19 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::{BTreeSet, HashMap}; use mas_data_model::{Session, User};
use mas_data_model::{BrowserSession, Session, User};
use sqlx::{PgConnection, PgExecutor, QueryBuilder}; use sqlx::{PgConnection, PgExecutor, QueryBuilder};
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use self::client::OAuth2ClientRepository;
use crate::{ use crate::{
pagination::{process_page, QueryBuilderExt}, pagination::{process_page, QueryBuilderExt},
user::BrowserSessionRepository, Clock, DatabaseError, DatabaseInconsistencyError,
Clock, DatabaseError, DatabaseInconsistencyError, Repository,
}; };
pub mod access_token; pub mod access_token;
@ -32,14 +28,14 @@ pub mod authorization_grant;
pub mod client; pub mod client;
pub mod consent; pub mod consent;
pub mod refresh_token; pub mod refresh_token;
pub mod session;
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
%session.id, %session.id,
user.id = %session.browser_session.user.id, user_session.id = %session.user_session_id,
user_session.id = %session.browser_session.id, client.id = %session.client_id,
client.id = %session.client.id,
), ),
err, err,
)] )]
@ -120,49 +116,10 @@ pub async fn get_paginated_user_oauth_sessions(
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
let client_ids: BTreeSet<Ulid> = page
.iter()
.map(|i| Ulid::from(i.oauth2_client_id))
.collect();
let browser_session_ids: BTreeSet<Ulid> =
page.iter().map(|i| Ulid::from(i.user_session_id)).collect();
let clients = conn.oauth2_client().load_batch(client_ids).await?;
// TODO: this can generate N queries instead of batching. This is less than
// ideal
let mut browser_sessions: HashMap<Ulid, BrowserSession> = HashMap::new();
for id in browser_session_ids {
let v = conn.browser_session().lookup(id).await?.ok_or_else(|| {
DatabaseInconsistencyError::on("oauth2_sessions").column("user_session_id")
})?;
browser_sessions.insert(id, v);
}
let page: Result<Vec<_>, DatabaseInconsistencyError> = page let page: Result<Vec<_>, DatabaseInconsistencyError> = page
.into_iter() .into_iter()
.map(|item| { .map(|item| {
let id = Ulid::from(item.oauth2_session_id); let id = Ulid::from(item.oauth2_session_id);
let client = clients
.get(&Ulid::from(item.oauth2_client_id))
.ok_or_else(|| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("oauth2_client_id")
.row(id)
})?
.clone();
let browser_session = browser_sessions
.get(&Ulid::from(item.user_session_id))
.ok_or_else(|| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("user_session_id")
.row(id)
})?
.clone();
let scope = item.scope.parse().map_err(|e| { let scope = item.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions") DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope") .column("scope")
@ -172,8 +129,8 @@ pub async fn get_paginated_user_oauth_sessions(
Ok(Session { Ok(Session {
id: Ulid::from(item.oauth2_session_id), id: Ulid::from(item.oauth2_session_id),
client, client_id: item.oauth2_client_id.into(),
browser_session, user_session_id: item.user_session_id.into(),
scope, scope,
}) })
}) })

View File

@ -13,22 +13,20 @@
// limitations under the License. // limitations under the License.
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, RefreshToken, Session, User}; use mas_data_model::{AccessToken, RefreshToken, Session};
use rand::Rng; use rand::Rng;
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::client::OAuth2ClientRepository; use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
%session.id, %session.id,
user.id = %session.browser_session.user.id, user_session.id = %session.user_session_id,
user_session.id = %session.browser_session.id, client.id = %session.client_id,
client.id = %session.client.id,
refresh_token.id, refresh_token.id,
), ),
err, err,
@ -82,12 +80,6 @@ struct OAuth2RefreshTokenLookup {
oauth2_client_id: Uuid, oauth2_client_id: Uuid,
oauth2_session_scope: String, oauth2_session_scope: String,
user_session_id: Uuid, user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
user_id: Uuid,
user_username: String,
user_primary_user_email_id: Option<Uuid>,
user_session_last_authentication_id: Option<Uuid>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
} }
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
@ -99,46 +91,27 @@ pub async fn lookup_active_refresh_token(
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2RefreshTokenLookup, OAuth2RefreshTokenLookup,
r#" r#"
SELECT SELECT rt.oauth2_refresh_token_id
rt.oauth2_refresh_token_id, , rt.refresh_token AS oauth2_refresh_token
rt.refresh_token AS oauth2_refresh_token, , rt.created_at AS oauth2_refresh_token_created_at
rt.created_at AS oauth2_refresh_token_created_at, , at.oauth2_access_token_id AS "oauth2_access_token_id?"
at.oauth2_access_token_id AS "oauth2_access_token_id?", , at.access_token AS "oauth2_access_token?"
at.access_token AS "oauth2_access_token?", , at.created_at AS "oauth2_access_token_created_at?"
at.created_at AS "oauth2_access_token_created_at?", , at.expires_at AS "oauth2_access_token_expires_at?"
at.expires_at AS "oauth2_access_token_expires_at?", , os.oauth2_session_id AS "oauth2_session_id!"
os.oauth2_session_id AS "oauth2_session_id!", , os.oauth2_client_id AS "oauth2_client_id!"
os.oauth2_client_id AS "oauth2_client_id!", , os.scope AS "oauth2_session_scope!"
os.scope AS "oauth2_session_scope!", , os.user_session_id AS "user_session_id!"
us.user_session_id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.user_id AS "user_id!",
u.username AS "user_username!",
u.primary_user_email_id AS "user_primary_user_email_id",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM oauth2_refresh_tokens rt FROM oauth2_refresh_tokens rt
INNER JOIN oauth2_sessions os INNER JOIN oauth2_sessions os
USING (oauth2_session_id) USING (oauth2_session_id)
LEFT JOIN oauth2_access_tokens at LEFT JOIN oauth2_access_tokens at
USING (oauth2_access_token_id) USING (oauth2_access_token_id)
INNER JOIN user_sessions us
USING (user_session_id)
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications usa
USING (user_session_id)
LEFT JOIN user_emails ue
ON ue.user_email_id = u.primary_user_email_id
WHERE rt.refresh_token = $1 WHERE rt.refresh_token = $1
AND rt.consumed_at IS NULL AND rt.consumed_at IS NULL
AND rt.revoked_at IS NULL AND rt.revoked_at IS NULL
AND us.finished_at IS NULL
AND os.finished_at IS NULL AND os.finished_at IS NULL
ORDER BY usa.created_at DESC
LIMIT 1
"#, "#,
token, token,
) )
@ -173,44 +146,6 @@ pub async fn lookup_active_refresh_token(
}; };
let session_id = res.oauth2_session_id.into(); let session_id = res.oauth2_session_id.into();
let client = conn
.oauth2_client()
.lookup(res.oauth2_client_id.into())
.await?
.ok_or_else(|| {
DatabaseInconsistencyError::on("oauth2_sessions")
.column("client_id")
.row(session_id)
})?;
let user_id = Ulid::from(res.user_id);
let user = User {
id: user_id,
username: res.user_username,
sub: user_id.to_string(),
primary_user_email_id: res.user_primary_user_email_id.map(Into::into),
};
let last_authentication = match (
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
id: id.into(),
created_at,
}),
_ => return Err(DatabaseInconsistencyError::on("user_session_authentications").into()),
};
let browser_session = BrowserSession {
id: res.user_session_id.into(),
created_at: res.user_session_created_at,
finished_at: None,
user,
last_authentication,
};
let scope = res.oauth2_session_scope.parse().map_err(|e| { let scope = res.oauth2_session_scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_sessions") DatabaseInconsistencyError::on("oauth2_sessions")
.column("scope") .column("scope")
@ -220,8 +155,8 @@ pub async fn lookup_active_refresh_token(
let session = Session { let session = Session {
id: session_id, id: session_id,
client, client_id: res.oauth2_client_id.into(),
browser_session, user_session_id: res.user_session_id.into(),
scope, scope,
}; };

View File

@ -0,0 +1,20 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
#[async_trait]
pub trait OAuth2SessionRepository {
type Error;
}