You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
storage: unify the compat login errors
This commit is contained in:
@ -224,7 +224,7 @@ pub enum ClientAuthorizationError {
|
|||||||
MissingCredentials,
|
MissingCredentials,
|
||||||
InvalidRequest,
|
InvalidRequest,
|
||||||
InvalidAssertion,
|
InvalidAssertion,
|
||||||
InternalError(Box<dyn std::error::Error>),
|
Internal(Box<dyn std::error::Error>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for ClientAuthorizationError {
|
impl IntoResponse for ClientAuthorizationError {
|
||||||
@ -289,7 +289,7 @@ where
|
|||||||
return Err(ClientAuthorizationError::BadForm(err))
|
return Err(ClientAuthorizationError::BadForm(err))
|
||||||
}
|
}
|
||||||
// Other errors (body read twice, byte stream broke) return an internal error
|
// Other errors (body read twice, byte stream broke) return an internal error
|
||||||
Err(e) => return Err(ClientAuthorizationError::InternalError(Box::new(e))),
|
Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
|
||||||
};
|
};
|
||||||
|
|
||||||
// And now, figure out the actual auth method
|
// And now, figure out the actual auth method
|
||||||
|
@ -104,7 +104,7 @@ pub enum UserAuthorizationError {
|
|||||||
InvalidHeader,
|
InvalidHeader,
|
||||||
TokenInFormAndHeader,
|
TokenInFormAndHeader,
|
||||||
BadForm(FailedToDeserializeForm),
|
BadForm(FailedToDeserializeForm),
|
||||||
InternalError(Box<dyn Error>),
|
Internal(Box<dyn Error>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
@ -119,7 +119,7 @@ pub enum AuthorizationVerificationError {
|
|||||||
MissingForm,
|
MissingForm,
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
InternalError(Box<dyn Error>),
|
Internal(Box<dyn Error>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<AccessTokenLookupError> for AuthorizationVerificationError {
|
impl From<AccessTokenLookupError> for AuthorizationVerificationError {
|
||||||
@ -127,7 +127,7 @@ impl From<AccessTokenLookupError> for AuthorizationVerificationError {
|
|||||||
if e.not_found() {
|
if e.not_found() {
|
||||||
Self::InvalidToken
|
Self::InvalidToken
|
||||||
} else {
|
} else {
|
||||||
Self::InternalError(Box::new(e))
|
Self::Internal(Box::new(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -232,9 +232,7 @@ impl IntoResponse for UserAuthorizationError {
|
|||||||
});
|
});
|
||||||
(StatusCode::BAD_REQUEST, headers).into_response()
|
(StatusCode::BAD_REQUEST, headers).into_response()
|
||||||
}
|
}
|
||||||
Self::InternalError(e) => {
|
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -262,9 +260,7 @@ impl IntoResponse for AuthorizationVerificationError {
|
|||||||
});
|
});
|
||||||
(StatusCode::BAD_REQUEST, headers).into_response()
|
(StatusCode::BAD_REQUEST, headers).into_response()
|
||||||
}
|
}
|
||||||
Self::InternalError(e) => {
|
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -309,7 +305,7 @@ where
|
|||||||
return Err(UserAuthorizationError::BadForm(err))
|
return Err(UserAuthorizationError::BadForm(err))
|
||||||
}
|
}
|
||||||
// Other errors (body read twice, byte stream broke) return an internal error
|
// Other errors (body read twice, byte stream broke) return an internal error
|
||||||
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))),
|
||||||
};
|
};
|
||||||
|
|
||||||
let access_token = match (token_from_header, token_from_form) {
|
let access_token = match (token_from_header, token_from_form) {
|
||||||
|
@ -20,9 +20,8 @@ use mas_storage::{
|
|||||||
compat::{
|
compat::{
|
||||||
add_compat_access_token, add_compat_refresh_token, compat_login,
|
add_compat_access_token, add_compat_refresh_token, compat_login,
|
||||||
get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged,
|
get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged,
|
||||||
CompatSsoLoginLookupError,
|
|
||||||
},
|
},
|
||||||
Clock, LookupError,
|
Clock,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
|
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
|
||||||
@ -30,6 +29,7 @@ use sqlx::{PgPool, Postgres, Transaction};
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use super::{MatrixError, MatrixHomeserver};
|
use super::{MatrixError, MatrixHomeserver};
|
||||||
|
use crate::impl_from_error_for_route;
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
@ -145,21 +145,8 @@ pub enum RouteError {
|
|||||||
InvalidLoginToken,
|
InvalidLoginToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<sqlx::Error> for RouteError {
|
impl_from_error_for_route!(sqlx::Error);
|
||||||
fn from(e: sqlx::Error) -> Self {
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<CompatSsoLoginLookupError> for RouteError {
|
|
||||||
fn from(e: CompatSsoLoginLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::InvalidLoginToken
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoResponse for RouteError {
|
impl IntoResponse for RouteError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
@ -268,7 +255,9 @@ async fn token_login(
|
|||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> Result<CompatSession, RouteError> {
|
) -> Result<CompatSession, RouteError> {
|
||||||
let login = get_compat_sso_login_by_token(&mut *txn, token).await?;
|
let login = get_compat_sso_login_by_token(&mut *txn, token)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::InvalidLoginToken)?;
|
||||||
|
|
||||||
let now = clock.now();
|
let now = clock.now();
|
||||||
match login.state {
|
match login.state {
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Form, Path, Query, State},
|
extract::{Form, Path, Query, State},
|
||||||
response::{Html, IntoResponse, Redirect, Response},
|
response::{Html, IntoResponse, Redirect, Response},
|
||||||
@ -92,7 +93,9 @@ pub async fn get(
|
|||||||
return Ok((cookie_jar, destination.go()).into_response());
|
return Ok((cookie_jar, destination.go()).into_response());
|
||||||
}
|
}
|
||||||
|
|
||||||
let login = get_compat_sso_login_by_id(&mut conn, id).await?;
|
let login = get_compat_sso_login_by_id(&mut conn, id)
|
||||||
|
.await?
|
||||||
|
.context("Could not find compat SSO login")?;
|
||||||
|
|
||||||
// Bail out if that login session is more than 30min old
|
// Bail out if that login session is more than 30min old
|
||||||
if clock.now() > login.created_at + Duration::minutes(30) {
|
if clock.now() > login.created_at + Duration::minutes(30) {
|
||||||
@ -158,7 +161,9 @@ pub async fn post(
|
|||||||
return Ok((cookie_jar, destination.go()).into_response());
|
return Ok((cookie_jar, destination.go()).into_response());
|
||||||
}
|
}
|
||||||
|
|
||||||
let login = get_compat_sso_login_by_id(&mut txn, id).await?;
|
let login = get_compat_sso_login_by_id(&mut txn, id)
|
||||||
|
.await?
|
||||||
|
.context("Could not find compat SSO login")?;
|
||||||
|
|
||||||
// Bail out if that login session is more than 30min old
|
// Bail out if that login session is more than 30min old
|
||||||
if clock.now() > login.created_at + Duration::minutes(30) {
|
if clock.now() > login.created_at + Duration::minutes(30) {
|
||||||
|
@ -16,13 +16,9 @@ use axum::{extract::State, response::IntoResponse, Json};
|
|||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_data_model::{TokenFormatError, TokenType};
|
use mas_data_model::{TokenFormatError, TokenType};
|
||||||
use mas_storage::{
|
use mas_storage::compat::{
|
||||||
compat::{
|
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
|
||||||
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
|
expire_compat_access_token, lookup_active_compat_refresh_token,
|
||||||
expire_compat_access_token, lookup_active_compat_refresh_token,
|
|
||||||
CompatRefreshTokenLookupError,
|
|
||||||
},
|
|
||||||
LookupError,
|
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_with::{serde_as, DurationMilliSeconds};
|
use serde_with::{serde_as, DurationMilliSeconds};
|
||||||
@ -30,6 +26,7 @@ use sqlx::PgPool;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use super::MatrixError;
|
use super::MatrixError;
|
||||||
|
use crate::impl_from_error_for_route;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct RequestBody {
|
pub struct RequestBody {
|
||||||
@ -66,11 +63,8 @@ impl IntoResponse for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<sqlx::Error> for RouteError {
|
impl_from_error_for_route!(sqlx::Error);
|
||||||
fn from(e: sqlx::Error) -> Self {
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<TokenFormatError> for RouteError {
|
impl From<TokenFormatError> for RouteError {
|
||||||
fn from(_e: TokenFormatError) -> Self {
|
fn from(_e: TokenFormatError) -> Self {
|
||||||
@ -78,16 +72,6 @@ impl From<TokenFormatError> for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<CompatRefreshTokenLookupError> for RouteError {
|
|
||||||
fn from(e: CompatRefreshTokenLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::InvalidToken
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[serde_as]
|
#[serde_as]
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct ResponseBody {
|
pub struct ResponseBody {
|
||||||
@ -111,7 +95,9 @@ pub(crate) async fn post(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let (refresh_token, access_token, session) =
|
let (refresh_token, access_token, session) =
|
||||||
lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?;
|
lookup_active_compat_refresh_token(&mut txn, &input.refresh_token)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::InvalidToken)?;
|
||||||
|
|
||||||
let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
|
let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
|
||||||
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
|
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
|
||||||
|
@ -62,7 +62,7 @@ macro_rules! impl_from_error_for_route {
|
|||||||
($error:ty) => {
|
($error:ty) => {
|
||||||
impl From<$error> for self::RouteError {
|
impl From<$error> for self::RouteError {
|
||||||
fn from(e: $error) -> Self {
|
fn from(e: $error) -> Self {
|
||||||
Self::InternalError(Box::new(e))
|
Self::Internal(Box::new(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -22,10 +22,7 @@ use mas_data_model::{TokenFormatError, TokenType};
|
|||||||
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::{
|
compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token},
|
||||||
lookup_active_compat_access_token, lookup_active_compat_refresh_token,
|
|
||||||
CompatAccessTokenLookupError, CompatRefreshTokenLookupError,
|
|
||||||
},
|
|
||||||
oauth2::{
|
oauth2::{
|
||||||
access_token::{lookup_active_access_token, AccessTokenLookupError},
|
access_token::{lookup_active_access_token, AccessTokenLookupError},
|
||||||
client::ClientFetchError,
|
client::ClientFetchError,
|
||||||
@ -37,6 +34,8 @@ use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
|
use crate::impl_from_error_for_route;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum RouteError {
|
pub enum RouteError {
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
@ -79,11 +78,8 @@ impl IntoResponse for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<sqlx::Error> for RouteError {
|
impl_from_error_for_route!(sqlx::Error);
|
||||||
fn from(e: sqlx::Error) -> Self {
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<TokenFormatError> for RouteError {
|
impl From<TokenFormatError> for RouteError {
|
||||||
fn from(_e: TokenFormatError) -> Self {
|
fn from(_e: TokenFormatError) -> Self {
|
||||||
@ -111,16 +107,6 @@ impl From<AccessTokenLookupError> for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<CompatAccessTokenLookupError> for RouteError {
|
|
||||||
fn from(e: CompatAccessTokenLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::UnknownToken
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<RefreshTokenLookupError> for RouteError {
|
impl From<RefreshTokenLookupError> for RouteError {
|
||||||
fn from(e: RefreshTokenLookupError) -> Self {
|
fn from(e: RefreshTokenLookupError) -> Self {
|
||||||
if e.not_found() {
|
if e.not_found() {
|
||||||
@ -131,16 +117,6 @@ impl From<RefreshTokenLookupError> for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<CompatRefreshTokenLookupError> for RouteError {
|
|
||||||
fn from(e: CompatRefreshTokenLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::UnknownToken
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
||||||
active: false,
|
active: false,
|
||||||
scope: None,
|
scope: None,
|
||||||
@ -232,8 +208,9 @@ pub(crate) async fn post(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
TokenType::CompatAccessToken => {
|
TokenType::CompatAccessToken => {
|
||||||
let (token, session) =
|
let (token, session) = lookup_active_compat_access_token(&mut conn, &clock, token)
|
||||||
lookup_active_compat_access_token(&mut conn, &clock, token).await?;
|
.await?
|
||||||
|
.ok_or(RouteError::UnknownToken)?;
|
||||||
|
|
||||||
let device_scope = session.device.to_scope_token();
|
let device_scope = session.device.to_scope_token();
|
||||||
let scope = [device_scope].into_iter().collect();
|
let scope = [device_scope].into_iter().collect();
|
||||||
@ -255,7 +232,9 @@ pub(crate) async fn post(
|
|||||||
}
|
}
|
||||||
TokenType::CompatRefreshToken => {
|
TokenType::CompatRefreshToken => {
|
||||||
let (refresh_token, _access_token, session) =
|
let (refresh_token, _access_token, session) =
|
||||||
lookup_active_compat_refresh_token(&mut conn, token).await?;
|
lookup_active_compat_refresh_token(&mut conn, token)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::UnknownToken)?;
|
||||||
|
|
||||||
let device_scope = session.device.to_scope_token();
|
let device_scope = session.device.to_scope_token();
|
||||||
let scope = [device_scope].into_iter().collect();
|
let scope = [device_scope].into_iter().collect();
|
||||||
|
@ -36,7 +36,7 @@ pub(crate) enum RouteError {
|
|||||||
ProviderNotFound,
|
ProviderNotFound,
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
InternalError(Box<dyn std::error::Error>),
|
Internal(Box<dyn std::error::Error>),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Anyhow(#[from] anyhow::Error),
|
Anyhow(#[from] anyhow::Error),
|
||||||
@ -52,9 +52,7 @@ impl IntoResponse for RouteError {
|
|||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
match self {
|
match self {
|
||||||
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||||
Self::InternalError(e) => {
|
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
|
||||||
}
|
|
||||||
Self::Anyhow(e) => {
|
Self::Anyhow(e) => {
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
||||||
}
|
}
|
||||||
|
@ -90,7 +90,7 @@ pub(crate) enum RouteError {
|
|||||||
MissingCookie,
|
MissingCookie,
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
InternalError(Box<dyn std::error::Error>),
|
Internal(Box<dyn std::error::Error>),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Anyhow(#[from] anyhow::Error),
|
Anyhow(#[from] anyhow::Error),
|
||||||
@ -111,9 +111,7 @@ impl IntoResponse for RouteError {
|
|||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
match self {
|
match self {
|
||||||
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
||||||
Self::InternalError(e) => {
|
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
|
||||||
}
|
|
||||||
Self::Anyhow(e) => {
|
Self::Anyhow(e) => {
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ pub(crate) enum RouteError {
|
|||||||
InvalidFormAction,
|
InvalidFormAction,
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
InternalError(Box<dyn std::error::Error>),
|
Internal(Box<dyn std::error::Error>),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Anyhow(#[from] anyhow::Error),
|
Anyhow(#[from] anyhow::Error),
|
||||||
@ -85,9 +85,7 @@ impl IntoResponse for RouteError {
|
|||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
match self {
|
match self {
|
||||||
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
|
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
|
||||||
Self::InternalError(e) => {
|
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
|
||||||
}
|
|
||||||
Self::Anyhow(e) => {
|
Self::Anyhow(e) => {
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,9 @@ impl OptionalPostAuthAction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PostAuthAction::ContinueCompatSsoLogin { data } => {
|
PostAuthAction::ContinueCompatSsoLogin { data } => {
|
||||||
let login = get_compat_sso_login_by_id(conn, data).await?;
|
let login = get_compat_sso_login_by_id(conn, data)
|
||||||
|
.await?
|
||||||
|
.context("Failed to load compat SSO login")?;
|
||||||
let login = Box::new(login);
|
let login = Box::new(login);
|
||||||
PostAuthContextInner::ContinueCompatSsoLogin { login }
|
PostAuthContextInner::ContinueCompatSsoLogin { login }
|
||||||
}
|
}
|
||||||
|
@ -1,103 +1,5 @@
|
|||||||
{
|
{
|
||||||
"db": "PostgreSQL",
|
"db": "PostgreSQL",
|
||||||
"0157f14a089d100bdfe245e51082526326b2f84b11da7901ca6c0aaae9e43efd": {
|
|
||||||
"describe": {
|
|
||||||
"columns": [
|
|
||||||
{
|
|
||||||
"name": "compat_access_token_id",
|
|
||||||
"ordinal": 0,
|
|
||||||
"type_info": "Uuid"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "compat_access_token",
|
|
||||||
"ordinal": 1,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "compat_access_token_created_at",
|
|
||||||
"ordinal": 2,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "compat_access_token_expires_at",
|
|
||||||
"ordinal": 3,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "compat_session_id",
|
|
||||||
"ordinal": 4,
|
|
||||||
"type_info": "Uuid"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "compat_session_created_at",
|
|
||||||
"ordinal": 5,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "compat_session_finished_at",
|
|
||||||
"ordinal": 6,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "compat_session_device_id",
|
|
||||||
"ordinal": 7,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "user_id!",
|
|
||||||
"ordinal": 8,
|
|
||||||
"type_info": "Uuid"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "user_username!",
|
|
||||||
"ordinal": 9,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "user_email_id?",
|
|
||||||
"ordinal": 10,
|
|
||||||
"type_info": "Uuid"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "user_email?",
|
|
||||||
"ordinal": 11,
|
|
||||||
"type_info": "Text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "user_email_created_at?",
|
|
||||||
"ordinal": 12,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "user_email_confirmed_at?",
|
|
||||||
"ordinal": 13,
|
|
||||||
"type_info": "Timestamptz"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nullable": [
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
true
|
|
||||||
],
|
|
||||||
"parameters": {
|
|
||||||
"Left": [
|
|
||||||
"Text"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE ct.access_token = $1 AND cs.finished_at IS NULL\n "
|
|
||||||
},
|
|
||||||
"05b50b7ae0109063c50fe70e83635a31920e44a7fbaa2b4f07552ba2f83a28d7": {
|
"05b50b7ae0109063c50fe70e83635a31920e44a7fbaa2b4f07552ba2f83a28d7": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
@ -2169,6 +2071,105 @@
|
|||||||
},
|
},
|
||||||
"query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n "
|
"query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n "
|
||||||
},
|
},
|
||||||
|
"a0ef64e3de97dc2d24efe235c289557018448957a4776197445eafec8b5fb7a9": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"name": "compat_access_token_id",
|
||||||
|
"ordinal": 0,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compat_access_token",
|
||||||
|
"ordinal": 1,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compat_access_token_created_at",
|
||||||
|
"ordinal": 2,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compat_access_token_expires_at",
|
||||||
|
"ordinal": 3,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compat_session_id",
|
||||||
|
"ordinal": 4,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compat_session_created_at",
|
||||||
|
"ordinal": 5,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compat_session_finished_at",
|
||||||
|
"ordinal": 6,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compat_session_device_id",
|
||||||
|
"ordinal": 7,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_id!",
|
||||||
|
"ordinal": 8,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_username!",
|
||||||
|
"ordinal": 9,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_email_id?",
|
||||||
|
"ordinal": 10,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_email?",
|
||||||
|
"ordinal": 11,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_email_created_at?",
|
||||||
|
"ordinal": 12,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_email_confirmed_at?",
|
||||||
|
"ordinal": 13,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": [
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Text",
|
||||||
|
"Timestamptz"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE ct.access_token = $1\n AND ct.expires_at < $2\n AND cs.finished_at IS NULL \n "
|
||||||
|
},
|
||||||
"a5a7dad633396e087239d5629092e4a305908ffce9c2610db07372f719070546": {
|
"a5a7dad633396e087239d5629092e4a305908ffce9c2610db07372f719070546": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [],
|
"columns": [],
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use anyhow::{bail, Context};
|
use anyhow::Context;
|
||||||
use argon2::{Argon2, PasswordHash};
|
use argon2::{Argon2, PasswordHash};
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use mas_data_model::{
|
use mas_data_model::{
|
||||||
@ -21,7 +21,6 @@ use mas_data_model::{
|
|||||||
};
|
};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
|
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::task;
|
use tokio::task;
|
||||||
use tracing::{info_span, Instrument};
|
use tracing::{info_span, Instrument};
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
@ -31,7 +30,7 @@ use uuid::Uuid;
|
|||||||
use crate::{
|
use crate::{
|
||||||
pagination::{process_page, QueryBuilderExt},
|
pagination::{process_page, QueryBuilderExt},
|
||||||
user::lookup_user_by_username,
|
user::lookup_user_by_username,
|
||||||
Clock, DatabaseInconsistencyError, LookupError,
|
Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CompatAccessTokenLookup {
|
struct CompatAccessTokenLookup {
|
||||||
@ -51,29 +50,12 @@ struct CompatAccessTokenLookup {
|
|||||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
#[error("failed to lookup compat access token")]
|
|
||||||
pub enum CompatAccessTokenLookupError {
|
|
||||||
Expired { when: DateTime<Utc> },
|
|
||||||
Database(#[from] sqlx::Error),
|
|
||||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for CompatAccessTokenLookupError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(
|
|
||||||
self,
|
|
||||||
Self::Database(sqlx::Error::RowNotFound) | Self::Expired { .. }
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
pub async fn lookup_active_compat_access_token(
|
pub async fn lookup_active_compat_access_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> Result<(CompatAccessToken, CompatSession), CompatAccessTokenLookupError> {
|
) -> Result<Option<(CompatAccessToken, CompatSession)>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
CompatAccessTokenLookup,
|
CompatAccessTokenLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -101,20 +83,19 @@ pub async fn lookup_active_compat_access_token(
|
|||||||
LEFT JOIN user_emails ue
|
LEFT JOIN user_emails ue
|
||||||
ON ue.user_email_id = u.primary_user_email_id
|
ON ue.user_email_id = u.primary_user_email_id
|
||||||
|
|
||||||
WHERE ct.access_token = $1 AND cs.finished_at IS NULL
|
WHERE ct.access_token = $1
|
||||||
|
AND ct.expires_at < $2
|
||||||
|
AND cs.finished_at IS NULL
|
||||||
"#,
|
"#,
|
||||||
token,
|
token,
|
||||||
|
clock.now(),
|
||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.instrument(info_span!("Fetch compat access token"))
|
.instrument(info_span!("Fetch compat access token"))
|
||||||
.await?;
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
// Check for token expiration
|
let Some(res) = res else { return Ok(None) };
|
||||||
if let Some(expires_at) = res.compat_access_token_expires_at {
|
|
||||||
if expires_at < clock.now() {
|
|
||||||
return Err(CompatAccessTokenLookupError::Expired { when: expires_at });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let token = CompatAccessToken {
|
let token = CompatAccessToken {
|
||||||
id: res.compat_access_token_id.into(),
|
id: res.compat_access_token_id.into(),
|
||||||
@ -123,6 +104,7 @@ pub async fn lookup_active_compat_access_token(
|
|||||||
expires_at: res.compat_access_token_expires_at,
|
expires_at: res.compat_access_token_expires_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let user_id = Ulid::from(res.user_id);
|
||||||
let primary_email = match (
|
let primary_email = match (
|
||||||
res.user_email_id,
|
res.user_email_id,
|
||||||
res.user_email,
|
res.user_email,
|
||||||
@ -136,28 +118,38 @@ pub async fn lookup_active_compat_access_token(
|
|||||||
confirmed_at,
|
confirmed_at,
|
||||||
}),
|
}),
|
||||||
(None, None, None, None) => None,
|
(None, None, None, None) => None,
|
||||||
_ => return Err(DatabaseInconsistencyError.into()),
|
_ => {
|
||||||
|
return Err(DatabaseInconsistencyError2::on("compat_sessions")
|
||||||
|
.column("user_id")
|
||||||
|
.row(user_id)
|
||||||
|
.into())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let id = Ulid::from(res.user_id);
|
|
||||||
let user = User {
|
let user = User {
|
||||||
id,
|
id: user_id,
|
||||||
username: res.user_username,
|
username: res.user_username,
|
||||||
sub: id.to_string(),
|
sub: user_id.to_string(),
|
||||||
primary_email,
|
primary_email,
|
||||||
};
|
};
|
||||||
|
|
||||||
let device = Device::try_from(res.compat_session_device_id).unwrap();
|
let id = res.compat_session_id.into();
|
||||||
|
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("compat_sessions")
|
||||||
|
.column("device_id")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
let session = CompatSession {
|
let session = CompatSession {
|
||||||
id: res.compat_session_id.into(),
|
id,
|
||||||
user,
|
user,
|
||||||
device,
|
device,
|
||||||
created_at: res.compat_session_created_at,
|
created_at: res.compat_session_created_at,
|
||||||
finished_at: res.compat_session_finished_at,
|
finished_at: res.compat_session_finished_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((token, session))
|
Ok(Some((token, session)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CompatRefreshTokenLookup {
|
pub struct CompatRefreshTokenLookup {
|
||||||
@ -180,25 +172,12 @@ pub struct CompatRefreshTokenLookup {
|
|||||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
#[error("failed to lookup compat refresh token")]
|
|
||||||
pub enum CompatRefreshTokenLookupError {
|
|
||||||
Database(#[from] sqlx::Error),
|
|
||||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for CompatRefreshTokenLookupError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
pub async fn lookup_active_compat_refresh_token(
|
pub async fn lookup_active_compat_refresh_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> Result<(CompatRefreshToken, CompatAccessToken, CompatSession), CompatRefreshTokenLookupError> {
|
) -> Result<Option<(CompatRefreshToken, CompatAccessToken, CompatSession)>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
CompatRefreshTokenLookup,
|
CompatRefreshTokenLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -239,7 +218,10 @@ pub async fn lookup_active_compat_refresh_token(
|
|||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.instrument(info_span!("Fetch compat refresh token"))
|
.instrument(info_span!("Fetch compat refresh token"))
|
||||||
.await?;
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
|
let Some(res) = res else { return Ok(None); };
|
||||||
|
|
||||||
let refresh_token = CompatRefreshToken {
|
let refresh_token = CompatRefreshToken {
|
||||||
id: res.compat_refresh_token_id.into(),
|
id: res.compat_refresh_token_id.into(),
|
||||||
@ -254,6 +236,7 @@ pub async fn lookup_active_compat_refresh_token(
|
|||||||
expires_at: res.compat_access_token_expires_at,
|
expires_at: res.compat_access_token_expires_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let user_id = Ulid::from(res.user_id);
|
||||||
let primary_email = match (
|
let primary_email = match (
|
||||||
res.user_email_id,
|
res.user_email_id,
|
||||||
res.user_email,
|
res.user_email,
|
||||||
@ -267,28 +250,38 @@ pub async fn lookup_active_compat_refresh_token(
|
|||||||
confirmed_at,
|
confirmed_at,
|
||||||
}),
|
}),
|
||||||
(None, None, None, None) => None,
|
(None, None, None, None) => None,
|
||||||
_ => return Err(DatabaseInconsistencyError.into()),
|
_ => {
|
||||||
|
return Err(DatabaseInconsistencyError2::on("users")
|
||||||
|
.column("primary_user_email_id")
|
||||||
|
.row(user_id)
|
||||||
|
.into())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let id = Ulid::from(res.user_id);
|
|
||||||
let user = User {
|
let user = User {
|
||||||
id,
|
id: user_id,
|
||||||
username: res.user_username,
|
username: res.user_username,
|
||||||
sub: id.to_string(),
|
sub: user_id.to_string(),
|
||||||
primary_email,
|
primary_email,
|
||||||
};
|
};
|
||||||
|
|
||||||
let device = Device::try_from(res.compat_session_device_id).unwrap();
|
let session_id = res.compat_session_id.into();
|
||||||
|
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("compat_sessions")
|
||||||
|
.column("device_id")
|
||||||
|
.row(session_id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
let session = CompatSession {
|
let session = CompatSession {
|
||||||
id: res.compat_session_id.into(),
|
id: session_id,
|
||||||
user,
|
user,
|
||||||
device,
|
device,
|
||||||
created_at: res.compat_session_created_at,
|
created_at: res.compat_session_created_at,
|
||||||
finished_at: res.compat_session_finished_at,
|
finished_at: res.compat_session_finished_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((refresh_token, access_token, session))
|
Ok(Some((refresh_token, access_token, session)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -299,7 +292,7 @@ pub async fn lookup_active_compat_refresh_token(
|
|||||||
compat_session.id,
|
compat_session.id,
|
||||||
compat_session.device.id = device.as_str(),
|
compat_session.device.id = device.as_str(),
|
||||||
),
|
),
|
||||||
err(Display),
|
err(Debug),
|
||||||
)]
|
)]
|
||||||
pub async fn compat_login(
|
pub async fn compat_login(
|
||||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||||
@ -309,6 +302,7 @@ pub async fn compat_login(
|
|||||||
password: &str,
|
password: &str,
|
||||||
device: Device,
|
device: Device,
|
||||||
) -> Result<CompatSession, anyhow::Error> {
|
) -> Result<CompatSession, anyhow::Error> {
|
||||||
|
// TODO: that should be split and not verify the password hash here
|
||||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||||
|
|
||||||
// First, lookup the user
|
// First, lookup the user
|
||||||
@ -381,7 +375,7 @@ pub async fn compat_login(
|
|||||||
compat_access_token.id,
|
compat_access_token.id,
|
||||||
user.id = %session.user.id,
|
user.id = %session.user.id,
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn add_compat_access_token(
|
pub async fn add_compat_access_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
@ -390,7 +384,7 @@ pub async fn add_compat_access_token(
|
|||||||
session: &CompatSession,
|
session: &CompatSession,
|
||||||
token: String,
|
token: String,
|
||||||
expires_after: Option<Duration>,
|
expires_after: Option<Duration>,
|
||||||
) -> Result<CompatAccessToken, anyhow::Error> {
|
) -> Result<CompatAccessToken, sqlx::Error> {
|
||||||
let created_at = clock.now();
|
let created_at = clock.now();
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
||||||
@ -411,8 +405,7 @@ pub async fn add_compat_access_token(
|
|||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.instrument(tracing::info_span!("Insert compat access token"))
|
.instrument(tracing::info_span!("Insert compat access token"))
|
||||||
.await
|
.await?;
|
||||||
.context("could not insert compat access token")?;
|
|
||||||
|
|
||||||
Ok(CompatAccessToken {
|
Ok(CompatAccessToken {
|
||||||
id,
|
id,
|
||||||
@ -427,13 +420,13 @@ pub async fn add_compat_access_token(
|
|||||||
fields(
|
fields(
|
||||||
compat_access_token.id = %access_token.id,
|
compat_access_token.id = %access_token.id,
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn expire_compat_access_token(
|
pub async fn expire_compat_access_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
access_token: CompatAccessToken,
|
access_token: CompatAccessToken,
|
||||||
) -> Result<(), anyhow::Error> {
|
) -> Result<(), DatabaseError> {
|
||||||
let expires_at = clock.now();
|
let expires_at = clock.now();
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
@ -445,16 +438,9 @@ pub async fn expire_compat_access_token(
|
|||||||
expires_at,
|
expires_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.await
|
.await?;
|
||||||
.context("failed to update compat access token")?;
|
|
||||||
|
|
||||||
if res.rows_affected() == 1 {
|
DatabaseError::ensure_affected_rows(&res, 1)
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(anyhow::anyhow!(
|
|
||||||
"no row were affected when updating access token"
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -466,7 +452,7 @@ pub async fn expire_compat_access_token(
|
|||||||
compat_refresh_token.id,
|
compat_refresh_token.id,
|
||||||
user.id = %session.user.id,
|
user.id = %session.user.id,
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn add_compat_refresh_token(
|
pub async fn add_compat_refresh_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
@ -475,7 +461,7 @@ pub async fn add_compat_refresh_token(
|
|||||||
session: &CompatSession,
|
session: &CompatSession,
|
||||||
access_token: &CompatAccessToken,
|
access_token: &CompatAccessToken,
|
||||||
token: String,
|
token: String,
|
||||||
) -> Result<CompatRefreshToken, anyhow::Error> {
|
) -> Result<CompatRefreshToken, sqlx::Error> {
|
||||||
let created_at = clock.now();
|
let created_at = clock.now();
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
||||||
@ -495,8 +481,7 @@ pub async fn add_compat_refresh_token(
|
|||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.instrument(tracing::info_span!("Insert compat refresh token"))
|
.instrument(tracing::info_span!("Insert compat refresh token"))
|
||||||
.await
|
.await?;
|
||||||
.context("could not insert compat refresh token")?;
|
|
||||||
|
|
||||||
Ok(CompatRefreshToken {
|
Ok(CompatRefreshToken {
|
||||||
id,
|
id,
|
||||||
@ -508,13 +493,13 @@ pub async fn add_compat_refresh_token(
|
|||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(compat_session.id),
|
fields(compat_session.id),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn compat_logout(
|
pub async fn compat_logout(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> Result<(), anyhow::Error> {
|
) -> Result<(), sqlx::Error> {
|
||||||
let finished_at = clock.now();
|
let finished_at = clock.now();
|
||||||
// TODO: this does not check for token expiration
|
// TODO: this does not check for token expiration
|
||||||
let compat_session_id = sqlx::query_scalar!(
|
let compat_session_id = sqlx::query_scalar!(
|
||||||
@ -531,8 +516,7 @@ pub async fn compat_logout(
|
|||||||
finished_at,
|
finished_at,
|
||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.await
|
.await?;
|
||||||
.context("could not update compat access token")?;
|
|
||||||
|
|
||||||
tracing::Span::current().record(
|
tracing::Span::current().record(
|
||||||
"compat_session.id",
|
"compat_session.id",
|
||||||
@ -547,13 +531,13 @@ pub async fn compat_logout(
|
|||||||
fields(
|
fields(
|
||||||
compat_refresh_token.id = %refresh_token.id,
|
compat_refresh_token.id = %refresh_token.id,
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn consume_compat_refresh_token(
|
pub async fn consume_compat_refresh_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
refresh_token: CompatRefreshToken,
|
refresh_token: CompatRefreshToken,
|
||||||
) -> Result<(), anyhow::Error> {
|
) -> Result<(), DatabaseError> {
|
||||||
let consumed_at = clock.now();
|
let consumed_at = clock.now();
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
@ -565,16 +549,9 @@ pub async fn consume_compat_refresh_token(
|
|||||||
consumed_at,
|
consumed_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.await
|
.await?;
|
||||||
.context("failed to update compat refresh token")?;
|
|
||||||
|
|
||||||
if res.rows_affected() == 1 {
|
DatabaseError::ensure_affected_rows(&res, 1)
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(anyhow::anyhow!(
|
|
||||||
"no row were affected when updating refresh token"
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -583,7 +560,7 @@ pub async fn consume_compat_refresh_token(
|
|||||||
compat_sso_login.id,
|
compat_sso_login.id,
|
||||||
compat_sso_login.redirect_uri = %redirect_uri,
|
compat_sso_login.redirect_uri = %redirect_uri,
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn insert_compat_sso_login(
|
pub async fn insert_compat_sso_login(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
@ -591,7 +568,7 @@ pub async fn insert_compat_sso_login(
|
|||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
login_token: String,
|
login_token: String,
|
||||||
redirect_uri: Url,
|
redirect_uri: Url,
|
||||||
) -> Result<CompatSsoLogin, anyhow::Error> {
|
) -> Result<CompatSsoLogin, sqlx::Error> {
|
||||||
let created_at = clock.now();
|
let created_at = clock.now();
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
||||||
@ -609,8 +586,7 @@ pub async fn insert_compat_sso_login(
|
|||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.instrument(tracing::info_span!("Insert compat SSO login"))
|
.instrument(tracing::info_span!("Insert compat SSO login"))
|
||||||
.await
|
.await?;
|
||||||
.context("could not insert compat SSO login")?;
|
|
||||||
|
|
||||||
Ok(CompatSsoLogin {
|
Ok(CompatSsoLogin {
|
||||||
id,
|
id,
|
||||||
@ -642,11 +618,16 @@ struct CompatSsoLoginLookup {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||||
type Error = DatabaseInconsistencyError;
|
type Error = DatabaseInconsistencyError2;
|
||||||
|
|
||||||
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
|
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
|
||||||
let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri)
|
let id = res.compat_sso_login_id.into();
|
||||||
.map_err(|_| DatabaseInconsistencyError)?;
|
let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("compat_sso_logins")
|
||||||
|
.column("redirect_uri")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
let primary_email = match (
|
let primary_email = match (
|
||||||
res.user_email_id,
|
res.user_email_id,
|
||||||
@ -661,7 +642,9 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
|||||||
confirmed_at,
|
confirmed_at,
|
||||||
}),
|
}),
|
||||||
(None, None, None, None) => None,
|
(None, None, None, None) => None,
|
||||||
_ => return Err(DatabaseInconsistencyError),
|
_ => {
|
||||||
|
return Err(DatabaseInconsistencyError2::on("users").column("primary_user_email_id"))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let user = match (res.user_id, res.user_username, primary_email) {
|
let user = match (res.user_id, res.user_username, primary_email) {
|
||||||
@ -676,7 +659,7 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
(None, None, None) => None,
|
(None, None, None) => None,
|
||||||
_ => return Err(DatabaseInconsistencyError),
|
_ => return Err(DatabaseInconsistencyError2::on("compat_sessions").column("user_id")),
|
||||||
};
|
};
|
||||||
|
|
||||||
let session = match (
|
let session = match (
|
||||||
@ -687,9 +670,15 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
|||||||
user,
|
user,
|
||||||
) {
|
) {
|
||||||
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => {
|
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => {
|
||||||
let device = Device::try_from(device_id).map_err(|_| DatabaseInconsistencyError)?;
|
let id = id.into();
|
||||||
|
let device = Device::try_from(device_id).map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("compat_sessions")
|
||||||
|
.column("device")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
Some(CompatSession {
|
Some(CompatSession {
|
||||||
id: id.into(),
|
id,
|
||||||
user,
|
user,
|
||||||
device,
|
device,
|
||||||
created_at,
|
created_at,
|
||||||
@ -697,7 +686,11 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
(None, None, None, None, None) => None,
|
(None, None, None, None, None) => None,
|
||||||
_ => return Err(DatabaseInconsistencyError),
|
_ => {
|
||||||
|
return Err(DatabaseInconsistencyError2::on("compat_sso_logins")
|
||||||
|
.column("compat_session_id")
|
||||||
|
.row(id))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let state = match (
|
let state = match (
|
||||||
@ -717,11 +710,11 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
|||||||
session,
|
session,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => return Err(DatabaseInconsistencyError),
|
_ => return Err(DatabaseInconsistencyError2::on("compat_sso_logins").row(id)),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(CompatSsoLogin {
|
Ok(CompatSsoLogin {
|
||||||
id: res.compat_sso_login_id.into(),
|
id,
|
||||||
login_token: res.compat_sso_login_token,
|
login_token: res.compat_sso_login_token,
|
||||||
redirect_uri,
|
redirect_uri,
|
||||||
created_at: res.compat_sso_login_created_at,
|
created_at: res.compat_sso_login_created_at,
|
||||||
@ -730,19 +723,6 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
#[error("failed to lookup compat SSO login")]
|
|
||||||
pub enum CompatSsoLoginLookupError {
|
|
||||||
Database(#[from] sqlx::Error),
|
|
||||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for CompatSsoLoginLookupError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
@ -753,7 +733,7 @@ impl LookupError for CompatSsoLoginLookupError {
|
|||||||
pub async fn get_compat_sso_login_by_id(
|
pub async fn get_compat_sso_login_by_id(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
id: Ulid,
|
id: Ulid,
|
||||||
) -> Result<CompatSsoLogin, CompatSsoLoginLookupError> {
|
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
CompatSsoLoginLookup,
|
CompatSsoLoginLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -787,9 +767,12 @@ pub async fn get_compat_sso_login_by_id(
|
|||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
||||||
.await?;
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
Ok(res.try_into()?)
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
|
Ok(Some(res.try_into()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -798,7 +781,7 @@ pub async fn get_compat_sso_login_by_id(
|
|||||||
%user.id,
|
%user.id,
|
||||||
%user.username,
|
%user.username,
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn get_paginated_user_compat_sso_logins(
|
pub async fn get_paginated_user_compat_sso_logins(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
@ -807,7 +790,7 @@ pub async fn get_paginated_user_compat_sso_logins(
|
|||||||
after: Option<Ulid>,
|
after: Option<Ulid>,
|
||||||
first: Option<usize>,
|
first: Option<usize>,
|
||||||
last: Option<usize>,
|
last: Option<usize>,
|
||||||
) -> Result<(bool, bool, Vec<CompatSsoLogin>), anyhow::Error> {
|
) -> Result<(bool, bool, Vec<CompatSsoLogin>), DatabaseError> {
|
||||||
// TODO: this queries too much (like user info) which we probably don't need
|
// TODO: this queries too much (like user info) which we probably don't need
|
||||||
// because we already have them
|
// because we already have them
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
@ -864,7 +847,7 @@ pub async fn get_paginated_user_compat_sso_logins(
|
|||||||
pub async fn get_compat_sso_login_by_token(
|
pub async fn get_compat_sso_login_by_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> Result<CompatSsoLogin, CompatSsoLoginLookupError> {
|
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
CompatSsoLoginLookup,
|
CompatSsoLoginLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -898,35 +881,38 @@ pub async fn get_compat_sso_login_by_token(
|
|||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
||||||
.await?;
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
Ok(res.try_into()?)
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
|
Ok(Some(res.try_into()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
%user.id,
|
%user.id,
|
||||||
compat_sso_login.id = %login.id,
|
%compat_sso_login.id,
|
||||||
compat_sso_login.redirect_uri = %login.redirect_uri,
|
%compat_sso_login.redirect_uri,
|
||||||
compat_session.id,
|
compat_session.id,
|
||||||
compat_session.device.id = device.as_str(),
|
compat_session.device.id = device.as_str(),
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn fullfill_compat_sso_login(
|
pub async fn fullfill_compat_sso_login(
|
||||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||||
mut rng: impl Rng + Send,
|
mut rng: impl Rng + Send,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
user: User,
|
user: User,
|
||||||
mut login: CompatSsoLogin,
|
mut compat_sso_login: CompatSsoLogin,
|
||||||
device: Device,
|
device: Device,
|
||||||
) -> Result<CompatSsoLogin, anyhow::Error> {
|
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||||
if !matches!(login.state, CompatSsoLoginState::Pending) {
|
if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) {
|
||||||
bail!("sso login in wrong state");
|
return Err(DatabaseError::InvalidOperation);
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
let mut txn = conn.begin().await?;
|
||||||
|
|
||||||
let created_at = clock.now();
|
let created_at = clock.now();
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
@ -944,8 +930,7 @@ pub async fn fullfill_compat_sso_login(
|
|||||||
)
|
)
|
||||||
.execute(&mut txn)
|
.execute(&mut txn)
|
||||||
.instrument(tracing::info_span!("Insert compat session"))
|
.instrument(tracing::info_span!("Insert compat session"))
|
||||||
.await
|
.await?;
|
||||||
.context("could not insert compat session")?;
|
|
||||||
|
|
||||||
let session = CompatSession {
|
let session = CompatSession {
|
||||||
id,
|
id,
|
||||||
@ -965,46 +950,41 @@ pub async fn fullfill_compat_sso_login(
|
|||||||
WHERE
|
WHERE
|
||||||
compat_sso_login_id = $1
|
compat_sso_login_id = $1
|
||||||
"#,
|
"#,
|
||||||
Uuid::from(login.id),
|
Uuid::from(compat_sso_login.id),
|
||||||
Uuid::from(session.id),
|
Uuid::from(session.id),
|
||||||
fulfilled_at,
|
fulfilled_at,
|
||||||
)
|
)
|
||||||
.execute(&mut txn)
|
.execute(&mut txn)
|
||||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||||
.await
|
.await?;
|
||||||
.context("could not update compat SSO login")?;
|
|
||||||
|
|
||||||
let state = CompatSsoLoginState::Fulfilled {
|
let state = CompatSsoLoginState::Fulfilled {
|
||||||
fulfilled_at,
|
fulfilled_at,
|
||||||
session,
|
session,
|
||||||
};
|
};
|
||||||
|
|
||||||
login.state = state;
|
compat_sso_login.state = state;
|
||||||
|
|
||||||
txn.commit().await?;
|
txn.commit().await?;
|
||||||
|
|
||||||
Ok(login)
|
Ok(compat_sso_login)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
compat_sso_login.id = %login.id,
|
%compat_sso_login.id,
|
||||||
compat_sso_login.redirect_uri = %login.redirect_uri,
|
%compat_sso_login.redirect_uri,
|
||||||
),
|
),
|
||||||
err(Display),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn mark_compat_sso_login_as_exchanged(
|
pub async fn mark_compat_sso_login_as_exchanged(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
mut login: CompatSsoLogin,
|
mut compat_sso_login: CompatSsoLogin,
|
||||||
) -> Result<CompatSsoLogin, anyhow::Error> {
|
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||||
let (fulfilled_at, session) = match login.state {
|
let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else {
|
||||||
CompatSsoLoginState::Fulfilled {
|
return Err(DatabaseError::InvalidOperation);
|
||||||
fulfilled_at,
|
|
||||||
session,
|
|
||||||
} => (fulfilled_at, session),
|
|
||||||
_ => bail!("sso login in wrong state"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let exchanged_at = clock.now();
|
let exchanged_at = clock.now();
|
||||||
@ -1016,19 +996,18 @@ pub async fn mark_compat_sso_login_as_exchanged(
|
|||||||
WHERE
|
WHERE
|
||||||
compat_sso_login_id = $1
|
compat_sso_login_id = $1
|
||||||
"#,
|
"#,
|
||||||
Uuid::from(login.id),
|
Uuid::from(compat_sso_login.id),
|
||||||
exchanged_at,
|
exchanged_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||||
.await
|
.await?;
|
||||||
.context("could not update compat SSO login")?;
|
|
||||||
|
|
||||||
let state = CompatSsoLoginState::Exchanged {
|
let state = CompatSsoLoginState::Exchanged {
|
||||||
fulfilled_at,
|
fulfilled_at,
|
||||||
exchanged_at,
|
exchanged_at,
|
||||||
session,
|
session,
|
||||||
};
|
};
|
||||||
login.state = state;
|
compat_sso_login.state = state;
|
||||||
Ok(login)
|
Ok(compat_sso_login)
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@
|
|||||||
|
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use pagination::InvalidPagination;
|
use pagination::InvalidPagination;
|
||||||
use sqlx::migrate::Migrator;
|
use sqlx::{migrate::Migrator, postgres::PgQueryResult};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
@ -100,6 +100,30 @@ pub enum DatabaseError {
|
|||||||
|
|
||||||
/// An error which occured while generating the paginated query
|
/// An error which occured while generating the paginated query
|
||||||
Pagination(#[from] InvalidPagination),
|
Pagination(#[from] InvalidPagination),
|
||||||
|
|
||||||
|
/// An error which happened because the requested database operation is
|
||||||
|
/// invalid
|
||||||
|
#[error("Invalid database operation")]
|
||||||
|
InvalidOperation,
|
||||||
|
|
||||||
|
/// An error which happens when an operation affects not enough or too many
|
||||||
|
/// rows
|
||||||
|
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
|
||||||
|
RowsAffected { expected: u64, actual: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DatabaseError {
|
||||||
|
pub(crate) fn ensure_affected_rows(
|
||||||
|
result: &PgQueryResult,
|
||||||
|
expected: u64,
|
||||||
|
) -> Result<(), DatabaseError> {
|
||||||
|
let actual = result.rows_affected();
|
||||||
|
if actual == expected {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(DatabaseError::RowsAffected { expected, actual })
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
Reference in New Issue
Block a user