You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-07 22:41:18 +03:00
storage: unify most oauth2 related errors
This commit is contained in:
@ -17,12 +17,11 @@ use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::{lookup_client, ClientFetchError};
|
||||
use crate::{Clock, DatabaseInconsistencyError, LookupError};
|
||||
use super::client::lookup_client;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -95,25 +94,11 @@ pub struct OAuth2AccessTokenLookup {
|
||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("failed to lookup access token")]
|
||||
pub enum AccessTokenLookupError {
|
||||
Database(#[from] sqlx::Error),
|
||||
ClientFetch(#[from] ClientFetchError),
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for AccessTokenLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_access_token(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<(AccessToken, Session), AccessTokenLookupError> {
|
||||
) -> Result<Option<(AccessToken, Session)>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
@ -160,17 +145,25 @@ pub async fn lookup_active_access_token(
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let id = Ulid::from(res.oauth2_access_token_id);
|
||||
let access_token_id = Ulid::from(res.oauth2_access_token_id);
|
||||
let access_token = AccessToken {
|
||||
id,
|
||||
jti: id.to_string(),
|
||||
id: access_token_id,
|
||||
jti: access_token_id.to_string(),
|
||||
access_token: res.oauth2_access_token,
|
||||
created_at: res.oauth2_access_token_created_at,
|
||||
expires_at: res.oauth2_access_token_expires_at,
|
||||
};
|
||||
|
||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?;
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||
.column("client_id")
|
||||
.row(session_id)
|
||||
})?;
|
||||
|
||||
let user_id = Ulid::from(res.user_id);
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
res.user_email,
|
||||
@ -184,14 +177,18 @@ pub async fn lookup_active_access_token(
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("users")
|
||||
.column("primary_user_email_id")
|
||||
.row(user_id)
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
let user = User {
|
||||
id,
|
||||
id: user_id,
|
||||
username: res.user_username,
|
||||
sub: id.to_string(),
|
||||
sub: user_id.to_string(),
|
||||
primary_email,
|
||||
};
|
||||
|
||||
@ -204,7 +201,7 @@ pub async fn lookup_active_access_token(
|
||||
id: id.into(),
|
||||
created_at,
|
||||
}),
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => return Err(DatabaseInconsistencyError2::on("user_session_authentications").into()),
|
||||
};
|
||||
|
||||
let browser_session = BrowserSession {
|
||||
@ -214,28 +211,33 @@ pub async fn lookup_active_access_token(
|
||||
last_authentication,
|
||||
};
|
||||
|
||||
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
|
||||
let scope = res.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(session_id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = Session {
|
||||
id: res.oauth2_session_id.into(),
|
||||
id: session_id,
|
||||
client,
|
||||
browser_session,
|
||||
scope,
|
||||
};
|
||||
|
||||
Ok((access_token, session))
|
||||
Ok(Some((access_token, session)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%access_token.id),
|
||||
err(Debug),
|
||||
err,
|
||||
)]
|
||||
pub async fn revoke_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
access_token: AccessToken,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), DatabaseError> {
|
||||
let revoked_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@ -247,17 +249,15 @@ pub async fn revoke_access_token(
|
||||
revoked_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("could not revoke access tokens")?;
|
||||
.await?;
|
||||
|
||||
if res.rows_affected() == 1 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("no row were affected when revoking token"))
|
||||
}
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result<u64> {
|
||||
pub async fn cleanup_expired(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
) -> Result<u64, sqlx::Error> {
|
||||
// Cleanup token which expired more than 15 minutes ago
|
||||
let threshold = clock.now() - Duration::minutes(15);
|
||||
let res = sqlx::query!(
|
||||
@ -268,8 +268,7 @@ pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> an
|
||||
threshold,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("could not cleanup expired access tokens")?;
|
||||
.await?;
|
||||
|
||||
Ok(res.rows_affected())
|
||||
}
|
||||
|
@ -180,6 +180,7 @@ impl GrantLookup {
|
||||
// TODO: don't unwrap
|
||||
let client = lookup_client(executor, self.oauth2_client_id.into())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let last_authentication = match (
|
||||
|
@ -23,12 +23,11 @@ use mas_jose::jwk::PublicJsonWebKeySet;
|
||||
use oauth2_types::requests::GrantType;
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, LookupError};
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt};
|
||||
|
||||
// XXX: response_types & contacts
|
||||
#[derive(Debug)]
|
||||
@ -54,52 +53,20 @@ pub struct OAuth2ClientLookup {
|
||||
initiate_login_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClientFetchError {
|
||||
#[error("invalid client ID")]
|
||||
InvalidClientId(#[from] ulid::DecodeError),
|
||||
|
||||
#[error("malformed jwks column")]
|
||||
MalformedJwks(#[source] serde_json::Error),
|
||||
|
||||
#[error("entry has both a jwks and a jwks_uri")]
|
||||
BothJwksAndJwksUri,
|
||||
|
||||
#[error("could not parse URL in field {field:?}")]
|
||||
ParseUrl {
|
||||
field: &'static str,
|
||||
source: url::ParseError,
|
||||
},
|
||||
|
||||
#[error("could not parse field {field:?}")]
|
||||
ParseField {
|
||||
field: &'static str,
|
||||
source: mas_iana::ParseError,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Database(#[from] sqlx::Error),
|
||||
}
|
||||
|
||||
impl LookupError for ClientFetchError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::Database(sqlx::Error::RowNotFound) | Self::InvalidClientId(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryInto<Client> for OAuth2ClientLookup {
|
||||
type Error = ClientFetchError;
|
||||
type Error = DatabaseInconsistencyError2;
|
||||
|
||||
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
|
||||
fn try_into(self) -> Result<Client, Self::Error> {
|
||||
let id = Ulid::from(self.oauth2_client_id);
|
||||
|
||||
let redirect_uris: Result<Vec<Url>, _> =
|
||||
self.redirect_uris.iter().map(|s| s.parse()).collect();
|
||||
let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "redirect_uris",
|
||||
source,
|
||||
let redirect_uris = redirect_uris.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("redirect_uris")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let response_types = vec![
|
||||
@ -124,107 +91,125 @@ impl TryInto<Client> for OAuth2ClientLookup {
|
||||
grant_types.push(GrantType::RefreshToken);
|
||||
}
|
||||
|
||||
let logo_uri = self
|
||||
.logo_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "logo_uri",
|
||||
source,
|
||||
})?;
|
||||
let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("logo_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let client_uri = self
|
||||
.client_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "client_uri",
|
||||
source,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("client_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let policy_uri = self
|
||||
.policy_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "policy_uri",
|
||||
source,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("policy_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let tos_uri = self
|
||||
.tos_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "tos_uri",
|
||||
source,
|
||||
})?;
|
||||
let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("tos_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let id_token_signed_response_alg = self
|
||||
.id_token_signed_response_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "id_token_signed_response_alg",
|
||||
source,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("id_token_signed_response_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let userinfo_signed_response_alg = self
|
||||
.userinfo_signed_response_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "userinfo_signed_response_alg",
|
||||
source,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("userinfo_signed_response_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_method = self
|
||||
.token_endpoint_auth_method
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "token_endpoint_auth_method",
|
||||
source,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("token_endpoint_auth_method")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_signing_alg = self
|
||||
.token_endpoint_auth_signing_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "token_endpoint_auth_signing_alg",
|
||||
source,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("token_endpoint_auth_signing_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let initiate_login_uri = self
|
||||
.initiate_login_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "initiate_login_uri",
|
||||
source,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("initiate_login_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let jwks = match (self.jwks, self.jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => {
|
||||
let jwks = serde_json::from_value(jwks).map_err(ClientFetchError::MalformedJwks)?;
|
||||
let jwks = serde_json::from_value(jwks).map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("jwks")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
Some(JwksOrJwksUri::Jwks(jwks))
|
||||
}
|
||||
(None, Some(jwks_uri)) => {
|
||||
let jwks_uri = jwks_uri
|
||||
.parse()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "jwks_uri",
|
||||
source,
|
||||
})?;
|
||||
let jwks_uri = jwks_uri.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("jwks_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Some(JwksOrJwksUri::JwksUri(jwks_uri))
|
||||
}
|
||||
_ => return Err(ClientFetchError::BothJwksAndJwksUri),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("oauth2_clients")
|
||||
.column("jwks(_uri)")
|
||||
.row(id))
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(self.oauth2_client_id);
|
||||
Ok(Client {
|
||||
id,
|
||||
client_id: id.to_string(),
|
||||
@ -253,7 +238,7 @@ impl TryInto<Client> for OAuth2ClientLookup {
|
||||
pub async fn lookup_clients(
|
||||
executor: impl PgExecutor<'_>,
|
||||
ids: impl IntoIterator<Item = Ulid> + Send,
|
||||
) -> Result<HashMap<Ulid, Client>, ClientFetchError> {
|
||||
) -> Result<HashMap<Ulid, Client>, DatabaseError> {
|
||||
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
@ -289,12 +274,13 @@ pub async fn lookup_clients(
|
||||
.fetch_all(executor)
|
||||
.await?;
|
||||
|
||||
let clients: Result<HashMap<Ulid, Client>, _> = res
|
||||
.into_iter()
|
||||
.map(|r| r.try_into().map(|c: Client| (c.id, c)))
|
||||
.collect();
|
||||
|
||||
clients
|
||||
res.into_iter()
|
||||
.map(|r| {
|
||||
r.try_into()
|
||||
.map(|c: Client| (c.id, c))
|
||||
.map_err(DatabaseError::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -305,7 +291,7 @@ pub async fn lookup_clients(
|
||||
pub async fn lookup_client(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<Client, ClientFetchError> {
|
||||
) -> Result<Option<Client>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
@ -338,11 +324,12 @@ pub async fn lookup_client(
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let client = res.try_into()?;
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(client)
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -353,8 +340,8 @@ pub async fn lookup_client(
|
||||
pub async fn lookup_client_by_client_id(
|
||||
executor: impl PgExecutor<'_>,
|
||||
client_id: &str,
|
||||
) -> Result<Client, ClientFetchError> {
|
||||
let id: Ulid = client_id.parse()?;
|
||||
) -> Result<Option<Client>, DatabaseError> {
|
||||
let Ok(id) = client_id.parse() else { return Ok(None) };
|
||||
lookup_client(executor, id).await
|
||||
}
|
||||
|
||||
|
@ -19,12 +19,11 @@ use mas_data_model::{
|
||||
};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::{lookup_client, ClientFetchError};
|
||||
use crate::{Clock, DatabaseInconsistencyError, LookupError};
|
||||
use super::client::lookup_client;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -98,26 +97,12 @@ struct OAuth2RefreshTokenLookup {
|
||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("could not lookup refresh token")]
|
||||
pub enum RefreshTokenLookupError {
|
||||
Fetch(#[from] sqlx::Error),
|
||||
ClientFetch(#[from] ClientFetchError),
|
||||
Conversion(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for RefreshTokenLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Fetch(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_refresh_token(
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<(RefreshToken, Session), RefreshTokenLookupError> {
|
||||
) -> Result<Option<(RefreshToken, Session)>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
@ -187,7 +172,7 @@ pub async fn lookup_active_refresh_token(
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => return Err(DatabaseInconsistencyError2::on("oauth2_access_tokens").into()),
|
||||
};
|
||||
|
||||
let refresh_token = RefreshToken {
|
||||
@ -197,8 +182,16 @@ pub async fn lookup_active_refresh_token(
|
||||
access_token,
|
||||
};
|
||||
|
||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?;
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||
.column("client_id")
|
||||
.row(session_id)
|
||||
})?;
|
||||
|
||||
let user_id = Ulid::from(res.user_id);
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
res.user_email,
|
||||
@ -212,14 +205,18 @@ pub async fn lookup_active_refresh_token(
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("users")
|
||||
.column("primary_user_email_id")
|
||||
.row(user_id)
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
let user = User {
|
||||
id,
|
||||
id: user_id,
|
||||
username: res.user_username,
|
||||
sub: id.to_string(),
|
||||
sub: user_id.to_string(),
|
||||
primary_email,
|
||||
};
|
||||
|
||||
@ -232,7 +229,7 @@ pub async fn lookup_active_refresh_token(
|
||||
id: id.into(),
|
||||
created_at,
|
||||
}),
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => return Err(DatabaseInconsistencyError2::on("user_session_authentications").into()),
|
||||
};
|
||||
|
||||
let browser_session = BrowserSession {
|
||||
@ -242,19 +239,21 @@ pub async fn lookup_active_refresh_token(
|
||||
last_authentication,
|
||||
};
|
||||
|
||||
let scope = res
|
||||
.oauth2_session_scope
|
||||
.parse()
|
||||
.map_err(|_e| DatabaseInconsistencyError)?;
|
||||
let scope = res.oauth2_session_scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(session_id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = Session {
|
||||
id: res.oauth2_session_id.into(),
|
||||
id: session_id,
|
||||
client,
|
||||
browser_session,
|
||||
scope,
|
||||
};
|
||||
|
||||
Ok((refresh_token, session))
|
||||
Ok(Some((refresh_token, session)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -268,7 +267,7 @@ pub async fn consume_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
refresh_token: &RefreshToken,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), DatabaseError> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@ -280,14 +279,7 @@ pub async fn consume_refresh_token(
|
||||
consumed_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("failed to update oauth2 refresh token")?;
|
||||
.await?;
|
||||
|
||||
if res.rows_affected() == 1 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"no row were affected when updating refresh token"
|
||||
))
|
||||
}
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
@ -24,6 +24,5 @@ pub use self::{
|
||||
provider::{add_provider, get_paginated_providers, get_providers, lookup_provider},
|
||||
session::{
|
||||
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
|
||||
SessionLookupError,
|
||||
},
|
||||
};
|
||||
|
@ -16,24 +16,12 @@ use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||
use rand::Rng;
|
||||
use sqlx::PgExecutor;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, DatabaseInconsistencyError, GenericLookupError, LookupError};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Failed to lookup upstream OAuth 2.0 authorization session")]
|
||||
pub enum SessionLookupError {
|
||||
Driver(#[from] sqlx::Error),
|
||||
Inconcistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for SessionLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Driver(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
use crate::{
|
||||
Clock, DatabaseError, DatabaseInconsistencyError2, GenericLookupError, LookupResultExt,
|
||||
};
|
||||
|
||||
struct SessionAndProviderLookup {
|
||||
upstream_oauth_authorization_session_id: Uuid,
|
||||
@ -64,7 +52,7 @@ struct SessionAndProviderLookup {
|
||||
pub async fn lookup_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> {
|
||||
) -> Result<Option<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession)>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionAndProviderLookup,
|
||||
r#"
|
||||
@ -94,29 +82,41 @@ pub async fn lookup_session(
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let id = res.upstream_oauth_provider_id.into();
|
||||
let provider = UpstreamOAuthProvider {
|
||||
id: res.upstream_oauth_provider_id.into(),
|
||||
issuer: res
|
||||
.provider_issuer
|
||||
.parse()
|
||||
.map_err(|_| DatabaseInconsistencyError)?,
|
||||
scope: res
|
||||
.provider_scope
|
||||
.parse()
|
||||
.map_err(|_| DatabaseInconsistencyError)?,
|
||||
id,
|
||||
issuer: res.provider_issuer,
|
||||
scope: res.provider_scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?,
|
||||
client_id: res.provider_client_id,
|
||||
encrypted_client_secret: res.provider_encrypted_client_secret,
|
||||
token_endpoint_auth_method: res
|
||||
.provider_token_endpoint_auth_method
|
||||
.parse()
|
||||
.map_err(|_| DatabaseInconsistencyError)?,
|
||||
token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err(
|
||||
|e| {
|
||||
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_auth_method")
|
||||
.row(id)
|
||||
.source(e)
|
||||
},
|
||||
)?,
|
||||
token_endpoint_signing_alg: res
|
||||
.provider_token_endpoint_signing_alg
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|_| DatabaseInconsistencyError)?,
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_signing_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?,
|
||||
created_at: res.provider_created_at,
|
||||
};
|
||||
|
||||
@ -133,7 +133,7 @@ pub async fn lookup_session(
|
||||
consumed_at: res.consumed_at,
|
||||
};
|
||||
|
||||
Ok((provider, session))
|
||||
Ok(Some((provider, session)))
|
||||
}
|
||||
|
||||
/// Add a session to the database
|
||||
|
Reference in New Issue
Block a user