diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index d104642e..c5f4539c 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -23,12 +23,18 @@ clippy::type_repetition_in_bounds )] +use thiserror::Error; + pub(crate) mod compat; pub(crate) mod oauth2; pub(crate) mod tokens; pub(crate) mod upstream_oauth2; pub(crate) mod users; +#[derive(Debug, Error)] +#[error("invalid state transition")] +pub struct InvalidTransitionError; + pub use self::{ compat::{ CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index a7222cda..10f619c7 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -21,11 +21,11 @@ use oauth2_types::{ requests::ResponseMode, }; use serde::Serialize; -use thiserror::Error; use ulid::Ulid; use url::Url; -use super::{client::Client, session::Session}; +use super::session::Session; +use crate::InvalidTransitionError; #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Pkce { @@ -53,10 +53,6 @@ pub struct AuthorizationCode { pub pkce: Option, } -#[derive(Debug, Error)] -#[error("invalid state transition")] -pub struct InvalidTransitionError; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)] #[serde(tag = "stage", rename_all = "lowercase")] pub enum AuthorizationGrantStage { @@ -132,7 +128,7 @@ pub struct AuthorizationGrant { #[serde(flatten)] pub stage: AuthorizationGrantStage, pub code: Option, - pub client: Client, + pub client_id: Ulid, pub redirect_uri: Url, pub scope: oauth2_types::scope::Scope, pub state: Option, diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index aec48ac1..bbadd3a7 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -17,6 +17,8 @@ use oauth2_types::scope::Scope; use serde::Serialize; use ulid::Ulid; +use crate::InvalidTransitionError; + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Session { pub id: Ulid, @@ -25,3 +27,14 @@ pub struct Session { pub scope: Scope, pub finished_at: Option>, } + +impl Session { + pub fn finish(mut self, finished_at: DateTime) -> Result { + if self.finished_at.is_some() { + return Err(InvalidTransitionError); + } + + self.finished_at = Some(finished_at); + Ok(self) + } +} diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 01c89ff3..b5cfd6ff 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -29,7 +29,7 @@ use mas_storage::{ oauth2::{ authorization_grant::{fulfill_grant, get_grant_by_id}, consent::fetch_client_consent, - OAuth2SessionRepository, + OAuth2ClientRepository, OAuth2SessionRepository, }, Repository, }; @@ -125,6 +125,7 @@ pub(crate) async fn get( } Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending), Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)), + Err(e) => Err(RouteError::Internal(e.into())), } } @@ -144,6 +145,9 @@ pub enum GrantCompletionError { #[error("denied by the policy")] PolicyViolation, + + #[error("failed to load client")] + NoSuchClient, } impl_from_error_for_route!(GrantCompletionError: sqlx::Error); @@ -182,8 +186,13 @@ pub(crate) async fn complete( return Err(GrantCompletionError::PolicyViolation); } - let current_consent = - fetch_client_consent(&mut txn, &browser_session.user, &grant.client).await?; + let client = txn + .oauth2_client() + .lookup(grant.client_id) + .await? + .ok_or(GrantCompletionError::NoSuchClient)?; + + let current_consent = fetch_client_consent(&mut txn, &browser_session.user, &client).await?; let lacks_consent = grant .scope diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index cfcd936e..faf7015d 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -360,7 +360,10 @@ pub(crate) async fn get( Err(GrantCompletionError::Internal(e)) => { return Err(RouteError::Internal(e)) } - Err(e @ GrantCompletionError::NotPending) => { + Err( + e @ (GrantCompletionError::NotPending + | GrantCompletionError::NoSuchClient), + ) => { // This should never happen return Err(RouteError::Internal(Box::new(e))); } @@ -390,7 +393,10 @@ pub(crate) async fn get( Err(GrantCompletionError::Internal(e)) => { return Err(RouteError::Internal(e)) } - Err(e @ GrantCompletionError::NotPending) => { + Err( + e @ (GrantCompletionError::NotPending + | GrantCompletionError::NoSuchClient), + ) => { // This should never happen return Err(RouteError::Internal(Box::new(e))); } diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 45107783..64b72d58 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -28,9 +28,13 @@ use mas_data_model::AuthorizationGrantStage; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::{get_grant_by_id, give_consent_to_grant}, - consent::insert_client_consent, +use mas_storage::{ + oauth2::{ + authorization_grant::{get_grant_by_id, give_consent_to_grant}, + consent::insert_client_consent, + OAuth2ClientRepository, + }, + Repository, }; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -55,6 +59,9 @@ pub enum RouteError { #[error("Policy violation")] PolicyViolation, + + #[error("Failed to load client")] + NoSuchClient, } impl_from_error_for_route!(sqlx::Error); @@ -160,6 +167,12 @@ pub(crate) async fn post( return Err(RouteError::PolicyViolation); } + let client = txn + .oauth2_client() + .lookup(grant.client_id) + .await? + .ok_or(RouteError::NoSuchClient)?; + // Do not consent for the "urn:matrix:org.matrix.msc2967.client:device:*" scope let scope_without_device = grant .scope @@ -172,7 +185,7 @@ pub(crate) async fn post( &mut rng, &clock, &session.user, - &grant.client, + &client, &scope_without_device, ) .await?; diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 33bd8b5d..c5d96976 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -26,8 +26,7 @@ use ulid::Ulid; use url::Url; use uuid::Uuid; -use super::OAuth2ClientRepository; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository}; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; #[tracing::instrument( skip_all, @@ -116,7 +115,7 @@ pub async fn new_authorization_grant( stage: AuthorizationGrantStage::Pending, code, redirect_uri, - client, + client_id: client.id, scope, state, nonce, @@ -151,35 +150,27 @@ struct GrantLookup { oauth2_session_id: Option, } -impl GrantLookup { - #[allow(clippy::too_many_lines)] - async fn into_authorization_grant( - self, - conn: &mut PgConnection, - ) -> Result { - let id = self.oauth2_authorization_grant_id.into(); - let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| { - DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("scope") - .row(id) - .source(e) - })?; +impl TryFrom for AuthorizationGrant { + type Error = DatabaseInconsistencyError; - let client = conn - .oauth2_client() - .lookup(self.oauth2_client_id.into()) - .await? - .ok_or_else(|| { + #[allow(clippy::too_many_lines)] + fn try_from(value: GrantLookup) -> Result { + let id = value.oauth2_authorization_grant_id.into(); + let scope: Scope = value + .oauth2_authorization_grant_scope + .parse() + .map_err(|e| { DatabaseInconsistencyError::on("oauth2_authorization_grants") - .column("client_id") + .column("scope") .row(id) + .source(e) })?; let stage = match ( - self.oauth2_authorization_grant_fulfilled_at, - self.oauth2_authorization_grant_exchanged_at, - self.oauth2_authorization_grant_cancelled_at, - self.oauth2_session_id, + value.oauth2_authorization_grant_fulfilled_at, + value.oauth2_authorization_grant_exchanged_at, + value.oauth2_authorization_grant_cancelled_at, + value.oauth2_session_id, ) { (None, None, None, None) => AuthorizationGrantStage::Pending, (Some(fulfilled_at), None, None, Some(session_id)) => { @@ -202,15 +193,14 @@ impl GrantLookup { return Err( DatabaseInconsistencyError::on("oauth2_authorization_grants") .column("stage") - .row(id) - .into(), + .row(id), ); } }; let pkce = match ( - self.oauth2_authorization_grant_code_challenge, - self.oauth2_authorization_grant_code_challenge_method, + value.oauth2_authorization_grant_code_challenge, + value.oauth2_authorization_grant_code_challenge_method, ) { (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { Some(Pkce { @@ -227,15 +217,14 @@ impl GrantLookup { return Err( DatabaseInconsistencyError::on("oauth2_authorization_grants") .column("code_challenge_method") - .row(id) - .into(), + .row(id), ); } }; let code: Option = match ( - self.oauth2_authorization_grant_response_type_code, - self.oauth2_authorization_grant_code, + value.oauth2_authorization_grant_response_type_code, + value.oauth2_authorization_grant_code, pkce, ) { (false, None, None) => None, @@ -244,13 +233,12 @@ impl GrantLookup { return Err( DatabaseInconsistencyError::on("oauth2_authorization_grants") .column("authorization_code") - .row(id) - .into(), + .row(id), ); } }; - let redirect_uri = self + let redirect_uri = value .oauth2_authorization_grant_redirect_uri .parse() .map_err(|e| { @@ -260,7 +248,7 @@ impl GrantLookup { .source(e) })?; - let response_mode = self + let response_mode = value .oauth2_authorization_grant_response_mode .parse() .map_err(|e| { @@ -270,7 +258,7 @@ impl GrantLookup { .source(e) })?; - let max_age = self + let max_age = value .oauth2_authorization_grant_max_age .map(u32::try_from) .transpose() @@ -292,17 +280,17 @@ impl GrantLookup { Ok(AuthorizationGrant { id, stage, - client, + client_id: value.oauth2_client_id.into(), code, scope, - state: self.oauth2_authorization_grant_state, - nonce: self.oauth2_authorization_grant_nonce, + state: value.oauth2_authorization_grant_state, + nonce: value.oauth2_authorization_grant_nonce, max_age, response_mode, redirect_uri, - created_at: self.oauth2_authorization_grant_created_at, - response_type_id_token: self.oauth2_authorization_grant_response_type_id_token, - requires_consent: self.oauth2_authorization_grant_requires_consent, + created_at: value.oauth2_authorization_grant_created_at, + response_type_id_token: value.oauth2_authorization_grant_response_type_id_token, + requires_consent: value.oauth2_authorization_grant_requires_consent, }) } } @@ -351,9 +339,7 @@ pub async fn get_grant_by_id( let Some(res) = res else { return Ok(None) }; - let grant = res.into_authorization_grant(&mut *conn).await?; - - Ok(Some(grant)) + Ok(Some(res.try_into()?)) } #[tracing::instrument(skip_all, err)] @@ -396,16 +382,14 @@ pub async fn lookup_grant_by_code( let Some(res) = res else { return Ok(None) }; - let grant = res.into_authorization_grant(&mut *conn).await?; - - Ok(Some(grant)) + Ok(Some(res.try_into()?)) } #[tracing::instrument( skip_all, fields( %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, %session.id, user_session.id = %session.user_session_id, ), @@ -446,7 +430,7 @@ pub async fn fulfill_grant( skip_all, fields( %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, ), err, )] @@ -476,7 +460,7 @@ pub async fn give_consent_to_grant( skip_all, fields( %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, ), err, )] diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 7acaf843..3a681a84 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -142,7 +142,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { %user_session.id, user.id = %user_session.user.id, %grant.id, - client.id = %grant.client.id, + client.id = %grant.client_id, session.id, session.scope = %grant.scope, ), @@ -172,7 +172,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { "#, Uuid::from(id), Uuid::from(user_session.id), - Uuid::from(grant.client.id), + Uuid::from(grant.client_id), grant.scope.to_string(), created_at, ) @@ -183,7 +183,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { Ok(Session { id, user_session_id: user_session.id, - client_id: grant.client.id, + client_id: grant.client_id, scope: grant.scope.clone(), finished_at: None, }) @@ -201,11 +201,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { ), err, )] - async fn finish( - &mut self, - clock: &Clock, - mut session: Session, - ) -> Result { + async fn finish(&mut self, clock: &Clock, session: Session) -> Result { let finished_at = clock.now(); let res = sqlx::query!( r#" @@ -222,9 +218,9 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { DatabaseError::ensure_affected_rows(&res, 1)?; - session.finished_at = Some(finished_at); - - Ok(session) + session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation) } #[tracing::instrument(