You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-07 22:41:18 +03:00
storage: unify most oauth2 related errors
This commit is contained in:
@ -31,7 +31,7 @@ use mas_http::HttpServiceExt;
|
|||||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||||
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
|
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
|
||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_storage::oauth2::client::{lookup_client_by_client_id, ClientFetchError};
|
use mas_storage::{oauth2::client::lookup_client_by_client_id, DatabaseError};
|
||||||
use serde::{de::DeserializeOwned, Deserialize};
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use sqlx::PgExecutor;
|
use sqlx::PgExecutor;
|
||||||
@ -73,7 +73,10 @@ pub enum Credentials {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Credentials {
|
impl Credentials {
|
||||||
pub async fn fetch(&self, executor: impl PgExecutor<'_>) -> Result<Client, ClientFetchError> {
|
pub async fn fetch(
|
||||||
|
&self,
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
) -> Result<Option<Client>, DatabaseError> {
|
||||||
let client_id = match self {
|
let client_id = match self {
|
||||||
Credentials::None { client_id }
|
Credentials::None { client_id }
|
||||||
| Credentials::ClientSecretBasic { client_id, .. }
|
| Credentials::ClientSecretBasic { client_id, .. }
|
||||||
|
@ -27,10 +27,7 @@ use axum::{
|
|||||||
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
|
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
|
||||||
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
||||||
use mas_data_model::Session;
|
use mas_data_model::Session;
|
||||||
use mas_storage::{
|
use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError};
|
||||||
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
|
|
||||||
LookupError,
|
|
||||||
};
|
|
||||||
use serde::{de::DeserializeOwned, Deserialize};
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
use sqlx::PgConnection;
|
use sqlx::PgConnection;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -61,7 +58,9 @@ impl AccessToken {
|
|||||||
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
|
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (token, session) = lookup_active_access_token(conn, token.as_str()).await?;
|
let (token, session) = lookup_active_access_token(conn, token.as_str())
|
||||||
|
.await?
|
||||||
|
.ok_or(AuthorizationVerificationError::InvalidToken)?;
|
||||||
|
|
||||||
Ok((token, session))
|
Ok((token, session))
|
||||||
}
|
}
|
||||||
@ -119,17 +118,7 @@ pub enum AuthorizationVerificationError {
|
|||||||
MissingForm,
|
MissingForm,
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Internal(Box<dyn Error>),
|
Internal(#[from] DatabaseError),
|
||||||
}
|
|
||||||
|
|
||||||
impl From<AccessTokenLookupError> for AuthorizationVerificationError {
|
|
||||||
fn from(e: AccessTokenLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::InvalidToken
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum BearerError {
|
enum BearerError {
|
||||||
|
@ -24,7 +24,7 @@ use mas_storage::{
|
|||||||
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
|
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
|
||||||
set_password,
|
set_password,
|
||||||
},
|
},
|
||||||
Clock, LookupError,
|
Clock,
|
||||||
};
|
};
|
||||||
use oauth2_types::scope::Scope;
|
use oauth2_types::scope::Scope;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
@ -259,15 +259,11 @@ impl Options {
|
|||||||
|
|
||||||
for client in config.clients.iter() {
|
for client in config.clients.iter() {
|
||||||
let client_id = client.client_id;
|
let client_id = client.client_id;
|
||||||
let res = lookup_client(&mut txn, client_id).await;
|
let res = lookup_client(&mut txn, client_id).await?;
|
||||||
match res {
|
if res.is_some() {
|
||||||
Ok(_) => {
|
|
||||||
warn!(%client_id, "Skipping already imported client");
|
warn!(%client_id, "Skipping already imported client");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Err(e) if e.not_found() => {}
|
|
||||||
Err(e) => anyhow::bail!(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
info!(%client_id, "Importing client");
|
info!(%client_id, "Importing client");
|
||||||
let client_secret = client.client_secret();
|
let client_secret = client.client_secret();
|
||||||
|
@ -27,7 +27,6 @@ use async_graphql::{
|
|||||||
Context, Description, EmptyMutation, EmptySubscription, ID,
|
Context, Description, EmptyMutation, EmptySubscription, ID,
|
||||||
};
|
};
|
||||||
use mas_axum_utils::SessionInfo;
|
use mas_axum_utils::SessionInfo;
|
||||||
use mas_storage::LookupResultExt;
|
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
use self::model::{
|
use self::model::{
|
||||||
@ -96,9 +95,7 @@ impl RootQuery {
|
|||||||
let database = ctx.data::<PgPool>()?;
|
let database = ctx.data::<PgPool>()?;
|
||||||
let mut conn = database.acquire().await?;
|
let mut conn = database.acquire().await?;
|
||||||
|
|
||||||
let client = mas_storage::oauth2::client::lookup_client(&mut conn, id)
|
let client = mas_storage::oauth2::client::lookup_client(&mut conn, id).await?;
|
||||||
.await
|
|
||||||
.to_option()?;
|
|
||||||
|
|
||||||
Ok(client.map(OAuth2Client))
|
Ok(client.map(OAuth2Client))
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use anyhow::Context as _;
|
||||||
use async_graphql::{Context, Description, Object, ID};
|
use async_graphql::{Context, Description, Object, ID};
|
||||||
use mas_storage::oauth2::client::lookup_client;
|
use mas_storage::oauth2::client::lookup_client;
|
||||||
use oauth2_types::scope::Scope;
|
use oauth2_types::scope::Scope;
|
||||||
@ -114,7 +115,9 @@ impl OAuth2Consent {
|
|||||||
/// OAuth 2.0 client for which the user granted access.
|
/// OAuth 2.0 client for which the user granted access.
|
||||||
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
|
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
|
||||||
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
|
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
|
||||||
let client = lookup_client(&mut conn, self.client_id).await?;
|
let client = lookup_client(&mut conn, self.client_id)
|
||||||
|
.await?
|
||||||
|
.context("Could not load client")?;
|
||||||
Ok(OAuth2Client(client))
|
Ok(OAuth2Client(client))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,12 +26,8 @@ use mas_data_model::{AuthorizationCode, Pkce};
|
|||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_policy::PolicyFactory;
|
use mas_policy::PolicyFactory;
|
||||||
use mas_router::{PostAuthAction, Route};
|
use mas_router::{PostAuthAction, Route};
|
||||||
use mas_storage::{
|
use mas_storage::oauth2::{
|
||||||
oauth2::{
|
authorization_grant::new_authorization_grant, client::lookup_client_by_client_id,
|
||||||
authorization_grant::new_authorization_grant,
|
|
||||||
client::{lookup_client_by_client_id, ClientFetchError},
|
|
||||||
},
|
|
||||||
LookupError,
|
|
||||||
};
|
};
|
||||||
use mas_templates::Templates;
|
use mas_templates::Templates;
|
||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
@ -46,6 +42,7 @@ use sqlx::PgPool;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use self::{callback::CallbackDestination, complete::GrantCompletionError};
|
use self::{callback::CallbackDestination, complete::GrantCompletionError};
|
||||||
|
use crate::impl_from_error_for_route;
|
||||||
|
|
||||||
mod callback;
|
mod callback;
|
||||||
pub mod complete;
|
pub mod complete;
|
||||||
@ -56,7 +53,7 @@ pub enum RouteError {
|
|||||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Anyhow(anyhow::Error),
|
Anyhow(#[from] anyhow::Error),
|
||||||
|
|
||||||
#[error("could not find client")]
|
#[error("could not find client")]
|
||||||
ClientNotFound,
|
ClientNotFound,
|
||||||
@ -93,33 +90,9 @@ impl IntoResponse for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<sqlx::Error> for RouteError {
|
impl_from_error_for_route!(sqlx::Error);
|
||||||
fn from(e: sqlx::Error) -> Self {
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
Self::Internal(Box::new(e))
|
impl_from_error_for_route!(self::callback::CallbackDestinationError);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<self::callback::CallbackDestinationError> for RouteError {
|
|
||||||
fn from(e: self::callback::CallbackDestinationError) -> Self {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ClientFetchError> for RouteError {
|
|
||||||
fn from(e: ClientFetchError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::ClientNotFound
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<anyhow::Error> for RouteError {
|
|
||||||
fn from(e: anyhow::Error) -> Self {
|
|
||||||
Self::Anyhow(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub(crate) struct Params {
|
pub(crate) struct Params {
|
||||||
@ -166,7 +139,9 @@ pub(crate) async fn get(
|
|||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
// First, figure out what client it is
|
// First, figure out what client it is
|
||||||
let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id).await?;
|
let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::ClientNotFound)?;
|
||||||
|
|
||||||
// And resolve the redirect_uri and response_mode
|
// And resolve the redirect_uri and response_mode
|
||||||
let redirect_uri = client
|
let redirect_uri = client
|
||||||
|
@ -24,11 +24,9 @@ use mas_keystore::Encrypter;
|
|||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token},
|
compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token},
|
||||||
oauth2::{
|
oauth2::{
|
||||||
access_token::{lookup_active_access_token, AccessTokenLookupError},
|
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
|
||||||
client::ClientFetchError,
|
|
||||||
refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError},
|
|
||||||
},
|
},
|
||||||
Clock, LookupError,
|
Clock,
|
||||||
};
|
};
|
||||||
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
|
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
@ -87,36 +85,6 @@ impl From<TokenFormatError> for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ClientFetchError> for RouteError {
|
|
||||||
fn from(e: ClientFetchError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::ClientNotFound
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<AccessTokenLookupError> for RouteError {
|
|
||||||
fn from(e: AccessTokenLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::UnknownToken
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<RefreshTokenLookupError> for RouteError {
|
|
||||||
fn from(e: RefreshTokenLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::UnknownToken
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
||||||
active: false,
|
active: false,
|
||||||
scope: None,
|
scope: None,
|
||||||
@ -142,7 +110,11 @@ pub(crate) async fn post(
|
|||||||
let clock = Clock::default();
|
let clock = Clock::default();
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
|
||||||
let client = client_authorization.credentials.fetch(&mut conn).await?;
|
let client = client_authorization
|
||||||
|
.credentials
|
||||||
|
.fetch(&mut conn)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::ClientNotFound)?;
|
||||||
|
|
||||||
let method = match &client.token_endpoint_auth_method {
|
let method = match &client.token_endpoint_auth_method {
|
||||||
None | Some(OAuthClientAuthenticationMethod::None) => {
|
None | Some(OAuthClientAuthenticationMethod::None) => {
|
||||||
@ -172,7 +144,9 @@ pub(crate) async fn post(
|
|||||||
|
|
||||||
let reply = match token_type {
|
let reply = match token_type {
|
||||||
TokenType::AccessToken => {
|
TokenType::AccessToken => {
|
||||||
let (token, session) = lookup_active_access_token(&mut conn, token).await?;
|
let (token, session) = lookup_active_access_token(&mut conn, token)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::UnknownToken)?;
|
||||||
|
|
||||||
IntrospectionResponse {
|
IntrospectionResponse {
|
||||||
active: true,
|
active: true,
|
||||||
@ -190,7 +164,9 @@ pub(crate) async fn post(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
TokenType::RefreshToken => {
|
TokenType::RefreshToken => {
|
||||||
let (token, session) = lookup_active_refresh_token(&mut conn, token).await?;
|
let (token, session) = lookup_active_refresh_token(&mut conn, token)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::UnknownToken)?;
|
||||||
|
|
||||||
IntrospectionResponse {
|
IntrospectionResponse {
|
||||||
active: true,
|
active: true,
|
||||||
|
@ -26,9 +26,9 @@ use mas_axum_utils::{
|
|||||||
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
|
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
|
||||||
use mas_iana::jose::JsonWebSignatureAlg;
|
use mas_iana::jose::JsonWebSignatureAlg;
|
||||||
use mas_jose::{
|
use mas_jose::{
|
||||||
claims::{self, hash_token, ClaimError, TokenHashError},
|
claims::{self, hash_token},
|
||||||
constraints::Constrainable,
|
constraints::Constrainable,
|
||||||
jwt::{JsonWebSignatureHeader, Jwt, JwtSignatureError},
|
jwt::{JsonWebSignatureHeader, Jwt},
|
||||||
};
|
};
|
||||||
use mas_keystore::{Encrypter, Keystore};
|
use mas_keystore::{Encrypter, Keystore};
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
@ -36,14 +36,10 @@ use mas_storage::{
|
|||||||
oauth2::{
|
oauth2::{
|
||||||
access_token::{add_access_token, revoke_access_token},
|
access_token::{add_access_token, revoke_access_token},
|
||||||
authorization_grant::{exchange_grant, lookup_grant_by_code},
|
authorization_grant::{exchange_grant, lookup_grant_by_code},
|
||||||
client::ClientFetchError,
|
|
||||||
end_oauth_session,
|
end_oauth_session,
|
||||||
refresh_token::{
|
refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token},
|
||||||
add_refresh_token, consume_refresh_token, lookup_active_refresh_token,
|
|
||||||
RefreshTokenLookupError,
|
|
||||||
},
|
},
|
||||||
},
|
DatabaseInconsistencyError,
|
||||||
DatabaseInconsistencyError, LookupError,
|
|
||||||
};
|
};
|
||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
errors::{ClientError, ClientErrorCode},
|
errors::{ClientError, ClientErrorCode},
|
||||||
@ -60,6 +56,8 @@ use thiserror::Error;
|
|||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::impl_from_error_for_route;
|
||||||
|
|
||||||
#[serde_as]
|
#[serde_as]
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Serialize, Debug)]
|
#[derive(Serialize, Debug)]
|
||||||
@ -107,26 +105,6 @@ pub(crate) enum RouteError {
|
|||||||
UnauthorizedClient,
|
UnauthorizedClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ClientFetchError> for RouteError {
|
|
||||||
fn from(e: ClientFetchError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::ClientNotFound
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<RefreshTokenLookupError> for RouteError {
|
|
||||||
fn from(e: RefreshTokenLookupError) -> Self {
|
|
||||||
if e.not_found() {
|
|
||||||
Self::InvalidGrant
|
|
||||||
} else {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoResponse for RouteError {
|
impl IntoResponse for RouteError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
match self {
|
match self {
|
||||||
@ -162,35 +140,12 @@ impl IntoResponse for RouteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<mas_keystore::WrongAlgorithmError> for RouteError {
|
impl_from_error_for_route!(sqlx::Error);
|
||||||
fn from(e: mas_keystore::WrongAlgorithmError) -> Self {
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
Self::Internal(Box::new(e))
|
impl_from_error_for_route!(mas_keystore::WrongAlgorithmError);
|
||||||
}
|
impl_from_error_for_route!(mas_jose::claims::ClaimError);
|
||||||
}
|
impl_from_error_for_route!(mas_jose::claims::TokenHashError);
|
||||||
|
impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError);
|
||||||
impl From<sqlx::Error> for RouteError {
|
|
||||||
fn from(e: sqlx::Error) -> Self {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ClaimError> for RouteError {
|
|
||||||
fn from(e: ClaimError) -> Self {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<TokenHashError> for RouteError {
|
|
||||||
fn from(e: TokenHashError) -> Self {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<JwtSignatureError> for RouteError {
|
|
||||||
fn from(e: JwtSignatureError) -> Self {
|
|
||||||
Self::Internal(Box::new(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
@ -203,7 +158,11 @@ pub(crate) async fn post(
|
|||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
let client = client_authorization.credentials.fetch(&mut txn).await?;
|
let client = client_authorization
|
||||||
|
.credentials
|
||||||
|
.fetch(&mut txn)
|
||||||
|
.await?
|
||||||
|
.ok_or(RouteError::ClientNotFound)?;
|
||||||
|
|
||||||
let method = client
|
let method = client
|
||||||
.token_endpoint_auth_method
|
.token_endpoint_auth_method
|
||||||
@ -396,8 +355,9 @@ async fn refresh_token_grant(
|
|||||||
) -> Result<AccessTokenResponse, RouteError> {
|
) -> Result<AccessTokenResponse, RouteError> {
|
||||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||||
|
|
||||||
let (refresh_token, session) =
|
let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token)
|
||||||
lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?;
|
.await?
|
||||||
|
.ok_or(RouteError::InvalidGrant)?;
|
||||||
|
|
||||||
if client.client_id != session.client.client_id {
|
if client.client_id != session.client.client_id {
|
||||||
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||||
|
@ -25,9 +25,8 @@ use mas_oidc_client::requests::{
|
|||||||
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
|
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
|
||||||
};
|
};
|
||||||
use mas_router::{Route, UrlBuilder};
|
use mas_router::{Route, UrlBuilder};
|
||||||
use mas_storage::{
|
use mas_storage::upstream_oauth2::{
|
||||||
upstream_oauth2::{add_link, complete_session, lookup_link_by_subject, lookup_session},
|
add_link, complete_session, lookup_link_by_subject, lookup_session,
|
||||||
LookupResultExt,
|
|
||||||
};
|
};
|
||||||
use oauth2_types::errors::ClientErrorCode;
|
use oauth2_types::errors::ClientErrorCode;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -97,8 +96,6 @@ pub(crate) enum RouteError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
impl_from_error_for_route!(mas_storage::GenericLookupError);
|
|
||||||
impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError);
|
|
||||||
impl_from_error_for_route!(mas_http::ClientInitError);
|
impl_from_error_for_route!(mas_http::ClientInitError);
|
||||||
impl_from_error_for_route!(sqlx::Error);
|
impl_from_error_for_route!(sqlx::Error);
|
||||||
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
||||||
@ -141,8 +138,7 @@ pub(crate) async fn get(
|
|||||||
.map_err(|_| RouteError::MissingCookie)?;
|
.map_err(|_| RouteError::MissingCookie)?;
|
||||||
|
|
||||||
let (provider, session) = lookup_session(&mut txn, session_id)
|
let (provider, session) = lookup_session(&mut txn, session_id)
|
||||||
.await
|
.await?
|
||||||
.to_option()?
|
|
||||||
.ok_or(RouteError::SessionNotFound)?;
|
.ok_or(RouteError::SessionNotFound)?;
|
||||||
|
|
||||||
if provider.id != provider_id {
|
if provider.id != provider_id {
|
||||||
|
@ -17,12 +17,11 @@ use chrono::{DateTime, Duration, Utc};
|
|||||||
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
|
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use sqlx::{PgConnection, PgExecutor};
|
use sqlx::{PgConnection, PgExecutor};
|
||||||
use thiserror::Error;
|
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::client::{lookup_client, ClientFetchError};
|
use super::client::lookup_client;
|
||||||
use crate::{Clock, DatabaseInconsistencyError, LookupError};
|
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2};
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
@ -95,25 +94,11 @@ pub struct OAuth2AccessTokenLookup {
|
|||||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
#[error("failed to lookup access token")]
|
|
||||||
pub enum AccessTokenLookupError {
|
|
||||||
Database(#[from] sqlx::Error),
|
|
||||||
ClientFetch(#[from] ClientFetchError),
|
|
||||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for AccessTokenLookupError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub async fn lookup_active_access_token(
|
pub async fn lookup_active_access_token(
|
||||||
conn: &mut PgConnection,
|
conn: &mut PgConnection,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> Result<(AccessToken, Session), AccessTokenLookupError> {
|
) -> Result<Option<(AccessToken, Session)>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
OAuth2AccessTokenLookup,
|
OAuth2AccessTokenLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -160,17 +145,25 @@ pub async fn lookup_active_access_token(
|
|||||||
.fetch_one(&mut *conn)
|
.fetch_one(&mut *conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let id = Ulid::from(res.oauth2_access_token_id);
|
let access_token_id = Ulid::from(res.oauth2_access_token_id);
|
||||||
let access_token = AccessToken {
|
let access_token = AccessToken {
|
||||||
id,
|
id: access_token_id,
|
||||||
jti: id.to_string(),
|
jti: access_token_id.to_string(),
|
||||||
access_token: res.oauth2_access_token,
|
access_token: res.oauth2_access_token,
|
||||||
created_at: res.oauth2_access_token_created_at,
|
created_at: res.oauth2_access_token_created_at,
|
||||||
expires_at: res.oauth2_access_token_expires_at,
|
expires_at: res.oauth2_access_token_expires_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?;
|
let session_id = res.oauth2_session_id.into();
|
||||||
|
let client = lookup_client(&mut *conn, res.oauth2_client_id.into())
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| {
|
||||||
|
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||||
|
.column("client_id")
|
||||||
|
.row(session_id)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let user_id = Ulid::from(res.user_id);
|
||||||
let primary_email = match (
|
let primary_email = match (
|
||||||
res.user_email_id,
|
res.user_email_id,
|
||||||
res.user_email,
|
res.user_email,
|
||||||
@ -184,14 +177,18 @@ pub async fn lookup_active_access_token(
|
|||||||
confirmed_at,
|
confirmed_at,
|
||||||
}),
|
}),
|
||||||
(None, None, None, None) => None,
|
(None, None, None, None) => None,
|
||||||
_ => return Err(DatabaseInconsistencyError.into()),
|
_ => {
|
||||||
|
return Err(DatabaseInconsistencyError2::on("users")
|
||||||
|
.column("primary_user_email_id")
|
||||||
|
.row(user_id)
|
||||||
|
.into())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let id = Ulid::from(res.user_id);
|
|
||||||
let user = User {
|
let user = User {
|
||||||
id,
|
id: user_id,
|
||||||
username: res.user_username,
|
username: res.user_username,
|
||||||
sub: id.to_string(),
|
sub: user_id.to_string(),
|
||||||
primary_email,
|
primary_email,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -204,7 +201,7 @@ pub async fn lookup_active_access_token(
|
|||||||
id: id.into(),
|
id: id.into(),
|
||||||
created_at,
|
created_at,
|
||||||
}),
|
}),
|
||||||
_ => return Err(DatabaseInconsistencyError.into()),
|
_ => return Err(DatabaseInconsistencyError2::on("user_session_authentications").into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let browser_session = BrowserSession {
|
let browser_session = BrowserSession {
|
||||||
@ -214,28 +211,33 @@ pub async fn lookup_active_access_token(
|
|||||||
last_authentication,
|
last_authentication,
|
||||||
};
|
};
|
||||||
|
|
||||||
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
|
let scope = res.scope.parse().map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||||
|
.column("scope")
|
||||||
|
.row(session_id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
let session = Session {
|
let session = Session {
|
||||||
id: res.oauth2_session_id.into(),
|
id: session_id,
|
||||||
client,
|
client,
|
||||||
browser_session,
|
browser_session,
|
||||||
scope,
|
scope,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((access_token, session))
|
Ok(Some((access_token, session)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(%access_token.id),
|
fields(%access_token.id),
|
||||||
err(Debug),
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn revoke_access_token(
|
pub async fn revoke_access_token(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
access_token: AccessToken,
|
access_token: AccessToken,
|
||||||
) -> anyhow::Result<()> {
|
) -> Result<(), DatabaseError> {
|
||||||
let revoked_at = clock.now();
|
let revoked_at = clock.now();
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
@ -247,17 +249,15 @@ pub async fn revoke_access_token(
|
|||||||
revoked_at,
|
revoked_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.await
|
.await?;
|
||||||
.context("could not revoke access tokens")?;
|
|
||||||
|
|
||||||
if res.rows_affected() == 1 {
|
DatabaseError::ensure_affected_rows(&res, 1)
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(anyhow::anyhow!("no row were affected when revoking token"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result<u64> {
|
pub async fn cleanup_expired(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
clock: &Clock,
|
||||||
|
) -> Result<u64, sqlx::Error> {
|
||||||
// Cleanup token which expired more than 15 minutes ago
|
// Cleanup token which expired more than 15 minutes ago
|
||||||
let threshold = clock.now() - Duration::minutes(15);
|
let threshold = clock.now() - Duration::minutes(15);
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
@ -268,8 +268,7 @@ pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> an
|
|||||||
threshold,
|
threshold,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.await
|
.await?;
|
||||||
.context("could not cleanup expired access tokens")?;
|
|
||||||
|
|
||||||
Ok(res.rows_affected())
|
Ok(res.rows_affected())
|
||||||
}
|
}
|
||||||
|
@ -180,6 +180,7 @@ impl GrantLookup {
|
|||||||
// TODO: don't unwrap
|
// TODO: don't unwrap
|
||||||
let client = lookup_client(executor, self.oauth2_client_id.into())
|
let client = lookup_client(executor, self.oauth2_client_id.into())
|
||||||
.await
|
.await
|
||||||
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let last_authentication = match (
|
let last_authentication = match (
|
||||||
|
@ -23,12 +23,11 @@ use mas_jose::jwk::PublicJsonWebKeySet;
|
|||||||
use oauth2_types::requests::GrantType;
|
use oauth2_types::requests::GrantType;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use sqlx::{PgConnection, PgExecutor};
|
use sqlx::{PgConnection, PgExecutor};
|
||||||
use thiserror::Error;
|
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{Clock, LookupError};
|
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt};
|
||||||
|
|
||||||
// XXX: response_types & contacts
|
// XXX: response_types & contacts
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -54,52 +53,20 @@ pub struct OAuth2ClientLookup {
|
|||||||
initiate_login_uri: Option<String>,
|
initiate_login_uri: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum ClientFetchError {
|
|
||||||
#[error("invalid client ID")]
|
|
||||||
InvalidClientId(#[from] ulid::DecodeError),
|
|
||||||
|
|
||||||
#[error("malformed jwks column")]
|
|
||||||
MalformedJwks(#[source] serde_json::Error),
|
|
||||||
|
|
||||||
#[error("entry has both a jwks and a jwks_uri")]
|
|
||||||
BothJwksAndJwksUri,
|
|
||||||
|
|
||||||
#[error("could not parse URL in field {field:?}")]
|
|
||||||
ParseUrl {
|
|
||||||
field: &'static str,
|
|
||||||
source: url::ParseError,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("could not parse field {field:?}")]
|
|
||||||
ParseField {
|
|
||||||
field: &'static str,
|
|
||||||
source: mas_iana::ParseError,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Database(#[from] sqlx::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for ClientFetchError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(
|
|
||||||
self,
|
|
||||||
Self::Database(sqlx::Error::RowNotFound) | Self::InvalidClientId(_)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryInto<Client> for OAuth2ClientLookup {
|
impl TryInto<Client> for OAuth2ClientLookup {
|
||||||
type Error = ClientFetchError;
|
type Error = DatabaseInconsistencyError2;
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
|
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
|
||||||
fn try_into(self) -> Result<Client, Self::Error> {
|
fn try_into(self) -> Result<Client, Self::Error> {
|
||||||
|
let id = Ulid::from(self.oauth2_client_id);
|
||||||
|
|
||||||
let redirect_uris: Result<Vec<Url>, _> =
|
let redirect_uris: Result<Vec<Url>, _> =
|
||||||
self.redirect_uris.iter().map(|s| s.parse()).collect();
|
self.redirect_uris.iter().map(|s| s.parse()).collect();
|
||||||
let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl {
|
let redirect_uris = redirect_uris.map_err(|e| {
|
||||||
field: "redirect_uris",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("redirect_uris")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let response_types = vec![
|
let response_types = vec![
|
||||||
@ -124,107 +91,125 @@ impl TryInto<Client> for OAuth2ClientLookup {
|
|||||||
grant_types.push(GrantType::RefreshToken);
|
grant_types.push(GrantType::RefreshToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
let logo_uri = self
|
let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||||
.logo_uri
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
.map(|s| s.parse())
|
.column("logo_uri")
|
||||||
.transpose()
|
.row(id)
|
||||||
.map_err(|source| ClientFetchError::ParseUrl {
|
.source(e)
|
||||||
field: "logo_uri",
|
|
||||||
source,
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let client_uri = self
|
let client_uri = self
|
||||||
.client_uri
|
.client_uri
|
||||||
.map(|s| s.parse())
|
.map(|s| s.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|source| ClientFetchError::ParseUrl {
|
.map_err(|e| {
|
||||||
field: "client_uri",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("client_uri")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let policy_uri = self
|
let policy_uri = self
|
||||||
.policy_uri
|
.policy_uri
|
||||||
.map(|s| s.parse())
|
.map(|s| s.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|source| ClientFetchError::ParseUrl {
|
.map_err(|e| {
|
||||||
field: "policy_uri",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("policy_uri")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let tos_uri = self
|
let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
|
||||||
.tos_uri
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
.map(|s| s.parse())
|
.column("tos_uri")
|
||||||
.transpose()
|
.row(id)
|
||||||
.map_err(|source| ClientFetchError::ParseUrl {
|
.source(e)
|
||||||
field: "tos_uri",
|
|
||||||
source,
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let id_token_signed_response_alg = self
|
let id_token_signed_response_alg = self
|
||||||
.id_token_signed_response_alg
|
.id_token_signed_response_alg
|
||||||
.map(|s| s.parse())
|
.map(|s| s.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|source| ClientFetchError::ParseField {
|
.map_err(|e| {
|
||||||
field: "id_token_signed_response_alg",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("id_token_signed_response_alg")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let userinfo_signed_response_alg = self
|
let userinfo_signed_response_alg = self
|
||||||
.userinfo_signed_response_alg
|
.userinfo_signed_response_alg
|
||||||
.map(|s| s.parse())
|
.map(|s| s.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|source| ClientFetchError::ParseField {
|
.map_err(|e| {
|
||||||
field: "userinfo_signed_response_alg",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("userinfo_signed_response_alg")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let token_endpoint_auth_method = self
|
let token_endpoint_auth_method = self
|
||||||
.token_endpoint_auth_method
|
.token_endpoint_auth_method
|
||||||
.map(|s| s.parse())
|
.map(|s| s.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|source| ClientFetchError::ParseField {
|
.map_err(|e| {
|
||||||
field: "token_endpoint_auth_method",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("token_endpoint_auth_method")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let token_endpoint_auth_signing_alg = self
|
let token_endpoint_auth_signing_alg = self
|
||||||
.token_endpoint_auth_signing_alg
|
.token_endpoint_auth_signing_alg
|
||||||
.map(|s| s.parse())
|
.map(|s| s.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|source| ClientFetchError::ParseField {
|
.map_err(|e| {
|
||||||
field: "token_endpoint_auth_signing_alg",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("token_endpoint_auth_signing_alg")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let initiate_login_uri = self
|
let initiate_login_uri = self
|
||||||
.initiate_login_uri
|
.initiate_login_uri
|
||||||
.map(|s| s.parse())
|
.map(|s| s.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|source| ClientFetchError::ParseUrl {
|
.map_err(|e| {
|
||||||
field: "initiate_login_uri",
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
source,
|
.column("initiate_login_uri")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let jwks = match (self.jwks, self.jwks_uri) {
|
let jwks = match (self.jwks, self.jwks_uri) {
|
||||||
(None, None) => None,
|
(None, None) => None,
|
||||||
(Some(jwks), None) => {
|
(Some(jwks), None) => {
|
||||||
let jwks = serde_json::from_value(jwks).map_err(ClientFetchError::MalformedJwks)?;
|
let jwks = serde_json::from_value(jwks).map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
|
.column("jwks")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
Some(JwksOrJwksUri::Jwks(jwks))
|
Some(JwksOrJwksUri::Jwks(jwks))
|
||||||
}
|
}
|
||||||
(None, Some(jwks_uri)) => {
|
(None, Some(jwks_uri)) => {
|
||||||
let jwks_uri = jwks_uri
|
let jwks_uri = jwks_uri.parse().map_err(|e| {
|
||||||
.parse()
|
DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
.map_err(|source| ClientFetchError::ParseUrl {
|
.column("jwks_uri")
|
||||||
field: "jwks_uri",
|
.row(id)
|
||||||
source,
|
.source(e)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Some(JwksOrJwksUri::JwksUri(jwks_uri))
|
Some(JwksOrJwksUri::JwksUri(jwks_uri))
|
||||||
}
|
}
|
||||||
_ => return Err(ClientFetchError::BothJwksAndJwksUri),
|
_ => {
|
||||||
|
return Err(DatabaseInconsistencyError2::on("oauth2_clients")
|
||||||
|
.column("jwks(_uri)")
|
||||||
|
.row(id))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let id = Ulid::from(self.oauth2_client_id);
|
|
||||||
Ok(Client {
|
Ok(Client {
|
||||||
id,
|
id,
|
||||||
client_id: id.to_string(),
|
client_id: id.to_string(),
|
||||||
@ -253,7 +238,7 @@ impl TryInto<Client> for OAuth2ClientLookup {
|
|||||||
pub async fn lookup_clients(
|
pub async fn lookup_clients(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
ids: impl IntoIterator<Item = Ulid> + Send,
|
ids: impl IntoIterator<Item = Ulid> + Send,
|
||||||
) -> Result<HashMap<Ulid, Client>, ClientFetchError> {
|
) -> Result<HashMap<Ulid, Client>, DatabaseError> {
|
||||||
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
|
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
OAuth2ClientLookup,
|
OAuth2ClientLookup,
|
||||||
@ -289,12 +274,13 @@ pub async fn lookup_clients(
|
|||||||
.fetch_all(executor)
|
.fetch_all(executor)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let clients: Result<HashMap<Ulid, Client>, _> = res
|
res.into_iter()
|
||||||
.into_iter()
|
.map(|r| {
|
||||||
.map(|r| r.try_into().map(|c: Client| (c.id, c)))
|
r.try_into()
|
||||||
.collect();
|
.map(|c: Client| (c.id, c))
|
||||||
|
.map_err(DatabaseError::from)
|
||||||
clients
|
})
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -305,7 +291,7 @@ pub async fn lookup_clients(
|
|||||||
pub async fn lookup_client(
|
pub async fn lookup_client(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
id: Ulid,
|
id: Ulid,
|
||||||
) -> Result<Client, ClientFetchError> {
|
) -> Result<Option<Client>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
OAuth2ClientLookup,
|
OAuth2ClientLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -338,11 +324,12 @@ pub async fn lookup_client(
|
|||||||
Uuid::from(id),
|
Uuid::from(id),
|
||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.await?;
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
let client = res.try_into()?;
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
Ok(client)
|
Ok(Some(res.try_into()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -353,8 +340,8 @@ pub async fn lookup_client(
|
|||||||
pub async fn lookup_client_by_client_id(
|
pub async fn lookup_client_by_client_id(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
client_id: &str,
|
client_id: &str,
|
||||||
) -> Result<Client, ClientFetchError> {
|
) -> Result<Option<Client>, DatabaseError> {
|
||||||
let id: Ulid = client_id.parse()?;
|
let Ok(id) = client_id.parse() else { return Ok(None) };
|
||||||
lookup_client(executor, id).await
|
lookup_client(executor, id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,12 +19,11 @@ use mas_data_model::{
|
|||||||
};
|
};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use sqlx::{PgConnection, PgExecutor};
|
use sqlx::{PgConnection, PgExecutor};
|
||||||
use thiserror::Error;
|
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::client::{lookup_client, ClientFetchError};
|
use super::client::lookup_client;
|
||||||
use crate::{Clock, DatabaseInconsistencyError, LookupError};
|
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2};
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
@ -98,26 +97,12 @@ struct OAuth2RefreshTokenLookup {
|
|||||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
#[error("could not lookup refresh token")]
|
|
||||||
pub enum RefreshTokenLookupError {
|
|
||||||
Fetch(#[from] sqlx::Error),
|
|
||||||
ClientFetch(#[from] ClientFetchError),
|
|
||||||
Conversion(#[from] DatabaseInconsistencyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for RefreshTokenLookupError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(self, Self::Fetch(sqlx::Error::RowNotFound))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub async fn lookup_active_refresh_token(
|
pub async fn lookup_active_refresh_token(
|
||||||
conn: &mut PgConnection,
|
conn: &mut PgConnection,
|
||||||
token: &str,
|
token: &str,
|
||||||
) -> Result<(RefreshToken, Session), RefreshTokenLookupError> {
|
) -> Result<Option<(RefreshToken, Session)>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
OAuth2RefreshTokenLookup,
|
OAuth2RefreshTokenLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -187,7 +172,7 @@ pub async fn lookup_active_refresh_token(
|
|||||||
expires_at,
|
expires_at,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
_ => return Err(DatabaseInconsistencyError.into()),
|
_ => return Err(DatabaseInconsistencyError2::on("oauth2_access_tokens").into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let refresh_token = RefreshToken {
|
let refresh_token = RefreshToken {
|
||||||
@ -197,8 +182,16 @@ pub async fn lookup_active_refresh_token(
|
|||||||
access_token,
|
access_token,
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?;
|
let session_id = res.oauth2_session_id.into();
|
||||||
|
let client = lookup_client(&mut *conn, res.oauth2_client_id.into())
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| {
|
||||||
|
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||||
|
.column("client_id")
|
||||||
|
.row(session_id)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let user_id = Ulid::from(res.user_id);
|
||||||
let primary_email = match (
|
let primary_email = match (
|
||||||
res.user_email_id,
|
res.user_email_id,
|
||||||
res.user_email,
|
res.user_email,
|
||||||
@ -212,14 +205,18 @@ pub async fn lookup_active_refresh_token(
|
|||||||
confirmed_at,
|
confirmed_at,
|
||||||
}),
|
}),
|
||||||
(None, None, None, None) => None,
|
(None, None, None, None) => None,
|
||||||
_ => return Err(DatabaseInconsistencyError.into()),
|
_ => {
|
||||||
|
return Err(DatabaseInconsistencyError2::on("users")
|
||||||
|
.column("primary_user_email_id")
|
||||||
|
.row(user_id)
|
||||||
|
.into())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let id = Ulid::from(res.user_id);
|
|
||||||
let user = User {
|
let user = User {
|
||||||
id,
|
id: user_id,
|
||||||
username: res.user_username,
|
username: res.user_username,
|
||||||
sub: id.to_string(),
|
sub: user_id.to_string(),
|
||||||
primary_email,
|
primary_email,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -232,7 +229,7 @@ pub async fn lookup_active_refresh_token(
|
|||||||
id: id.into(),
|
id: id.into(),
|
||||||
created_at,
|
created_at,
|
||||||
}),
|
}),
|
||||||
_ => return Err(DatabaseInconsistencyError.into()),
|
_ => return Err(DatabaseInconsistencyError2::on("user_session_authentications").into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let browser_session = BrowserSession {
|
let browser_session = BrowserSession {
|
||||||
@ -242,19 +239,21 @@ pub async fn lookup_active_refresh_token(
|
|||||||
last_authentication,
|
last_authentication,
|
||||||
};
|
};
|
||||||
|
|
||||||
let scope = res
|
let scope = res.oauth2_session_scope.parse().map_err(|e| {
|
||||||
.oauth2_session_scope
|
DatabaseInconsistencyError2::on("oauth2_sessions")
|
||||||
.parse()
|
.column("scope")
|
||||||
.map_err(|_e| DatabaseInconsistencyError)?;
|
.row(session_id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
let session = Session {
|
let session = Session {
|
||||||
id: res.oauth2_session_id.into(),
|
id: session_id,
|
||||||
client,
|
client,
|
||||||
browser_session,
|
browser_session,
|
||||||
scope,
|
scope,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((refresh_token, session))
|
Ok(Some((refresh_token, session)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -268,7 +267,7 @@ pub async fn consume_refresh_token(
|
|||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
refresh_token: &RefreshToken,
|
refresh_token: &RefreshToken,
|
||||||
) -> Result<(), anyhow::Error> {
|
) -> Result<(), DatabaseError> {
|
||||||
let consumed_at = clock.now();
|
let consumed_at = clock.now();
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
@ -280,14 +279,7 @@ pub async fn consume_refresh_token(
|
|||||||
consumed_at,
|
consumed_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(executor)
|
||||||
.await
|
.await?;
|
||||||
.context("failed to update oauth2 refresh token")?;
|
|
||||||
|
|
||||||
if res.rows_affected() == 1 {
|
DatabaseError::ensure_affected_rows(&res, 1)
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(anyhow::anyhow!(
|
|
||||||
"no row were affected when updating refresh token"
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,5 @@ pub use self::{
|
|||||||
provider::{add_provider, get_paginated_providers, get_providers, lookup_provider},
|
provider::{add_provider, get_paginated_providers, get_providers, lookup_provider},
|
||||||
session::{
|
session::{
|
||||||
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
|
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
|
||||||
SessionLookupError,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -16,24 +16,12 @@ use chrono::{DateTime, Utc};
|
|||||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use sqlx::PgExecutor;
|
use sqlx::PgExecutor;
|
||||||
use thiserror::Error;
|
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{Clock, DatabaseInconsistencyError, GenericLookupError, LookupError};
|
use crate::{
|
||||||
|
Clock, DatabaseError, DatabaseInconsistencyError2, GenericLookupError, LookupResultExt,
|
||||||
#[derive(Debug, Error)]
|
};
|
||||||
#[error("Failed to lookup upstream OAuth 2.0 authorization session")]
|
|
||||||
pub enum SessionLookupError {
|
|
||||||
Driver(#[from] sqlx::Error),
|
|
||||||
Inconcistency(#[from] DatabaseInconsistencyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for SessionLookupError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(self, Self::Driver(sqlx::Error::RowNotFound))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct SessionAndProviderLookup {
|
struct SessionAndProviderLookup {
|
||||||
upstream_oauth_authorization_session_id: Uuid,
|
upstream_oauth_authorization_session_id: Uuid,
|
||||||
@ -64,7 +52,7 @@ struct SessionAndProviderLookup {
|
|||||||
pub async fn lookup_session(
|
pub async fn lookup_session(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
id: Ulid,
|
id: Ulid,
|
||||||
) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> {
|
) -> Result<Option<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession)>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
SessionAndProviderLookup,
|
SessionAndProviderLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -94,29 +82,41 @@ pub async fn lookup_session(
|
|||||||
Uuid::from(id),
|
Uuid::from(id),
|
||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.await?;
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
|
let id = res.upstream_oauth_provider_id.into();
|
||||||
let provider = UpstreamOAuthProvider {
|
let provider = UpstreamOAuthProvider {
|
||||||
id: res.upstream_oauth_provider_id.into(),
|
id,
|
||||||
issuer: res
|
issuer: res.provider_issuer,
|
||||||
.provider_issuer
|
scope: res.provider_scope.parse().map_err(|e| {
|
||||||
.parse()
|
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||||
.map_err(|_| DatabaseInconsistencyError)?,
|
.column("scope")
|
||||||
scope: res
|
.row(id)
|
||||||
.provider_scope
|
.source(e)
|
||||||
.parse()
|
})?,
|
||||||
.map_err(|_| DatabaseInconsistencyError)?,
|
|
||||||
client_id: res.provider_client_id,
|
client_id: res.provider_client_id,
|
||||||
encrypted_client_secret: res.provider_encrypted_client_secret,
|
encrypted_client_secret: res.provider_encrypted_client_secret,
|
||||||
token_endpoint_auth_method: res
|
token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err(
|
||||||
.provider_token_endpoint_auth_method
|
|e| {
|
||||||
.parse()
|
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||||
.map_err(|_| DatabaseInconsistencyError)?,
|
.column("token_endpoint_auth_method")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
},
|
||||||
|
)?,
|
||||||
token_endpoint_signing_alg: res
|
token_endpoint_signing_alg: res
|
||||||
.provider_token_endpoint_signing_alg
|
.provider_token_endpoint_signing_alg
|
||||||
.map(|x| x.parse())
|
.map(|x| x.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|_| DatabaseInconsistencyError)?,
|
.map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||||
|
.column("token_endpoint_signing_alg")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?,
|
||||||
created_at: res.provider_created_at,
|
created_at: res.provider_created_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ pub async fn lookup_session(
|
|||||||
consumed_at: res.consumed_at,
|
consumed_at: res.consumed_at,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((provider, session))
|
Ok(Some((provider, session)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a session to the database
|
/// Add a session to the database
|
||||||
|
Reference in New Issue
Block a user