diff --git a/crates/data-model/src/compat.rs b/crates/data-model/src/compat/device.rs similarity index 64% rename from crates/data-model/src/compat.rs rename to crates/data-model/src/compat/device.rs index 07ff9aaa..84bdd067 100644 --- a/crates/data-model/src/compat.rs +++ b/crates/data-model/src/compat/device.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; use oauth2_types::scope::ScopeToken; use rand::{ distributions::{Alphanumeric, DistString}, @@ -20,8 +19,6 @@ use rand::{ }; use serde::Serialize; use thiserror::Error; -use ulid::Ulid; -use url::Url; static DEVICE_ID_LENGTH: usize = 10; @@ -79,50 +76,3 @@ impl TryFrom for Device { Ok(Self { id }) } } - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct CompatSession { - pub id: Ulid, - pub user_id: Ulid, - pub device: Device, - pub created_at: DateTime, - pub finished_at: Option>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CompatAccessToken { - pub id: Ulid, - pub token: String, - pub created_at: DateTime, - pub expires_at: Option>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CompatRefreshToken { - pub id: Ulid, - pub token: String, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub enum CompatSsoLoginState { - Pending, - Fulfilled { - fulfilled_at: DateTime, - session: CompatSession, - }, - Exchanged { - fulfilled_at: DateTime, - exchanged_at: DateTime, - session: CompatSession, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct CompatSsoLogin { - pub id: Ulid, - pub redirect_uri: Url, - pub login_token: String, - pub created_at: DateTime, - pub state: CompatSsoLoginState, -} diff --git a/crates/data-model/src/compat/mod.rs b/crates/data-model/src/compat/mod.rs new file mode 100644 index 00000000..f6e19bd8 --- /dev/null +++ b/crates/data-model/src/compat/mod.rs @@ -0,0 +1,41 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use ulid::Ulid; + +mod device; +mod session; +mod sso_login; + +pub use self::{ + device::Device, + session::{CompatSession, CompatSessionState}, + sso_login::{CompatSsoLogin, CompatSsoLoginState}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompatAccessToken { + pub id: Ulid, + pub token: String, + pub created_at: DateTime, + pub expires_at: Option>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CompatRefreshToken { + pub id: Ulid, + pub token: String, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/compat/session.rs b/crates/data-model/src/compat/session.rs new file mode 100644 index 00000000..2c4cdf2d --- /dev/null +++ b/crates/data-model/src/compat/session.rs @@ -0,0 +1,79 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +use super::Device; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum CompatSessionState { + #[default] + Valid, + Finished { + finished_at: DateTime, + }, +} + +impl CompatSessionState { + /// Returns `true` if the compta session state is [`Valid`]. + /// + /// [`Valid`]: ComptaSessionState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the compta session state is [`Finished`]. + /// + /// [`Finished`]: ComptaSessionState::Finished + #[must_use] + pub fn is_finished(&self) -> bool { + matches!(self, Self::Finished { .. }) + } + + pub fn finish(self, finished_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Finished { finished_at }), + Self::Finished { .. } => Err(InvalidTransitionError), + } + } + + #[must_use] + pub fn finished_at(&self) -> Option> { + match self { + CompatSessionState::Valid => None, + CompatSessionState::Finished { finished_at } => Some(*finished_at), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct CompatSession { + pub id: Ulid, + pub state: CompatSessionState, + pub user_id: Ulid, + pub device: Device, + pub created_at: DateTime, +} + +impl std::ops::Deref for CompatSession { + type Target = CompatSessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} diff --git a/crates/data-model/src/compat/sso_login.rs b/crates/data-model/src/compat/sso_login.rs new file mode 100644 index 00000000..7e494f82 --- /dev/null +++ b/crates/data-model/src/compat/sso_login.rs @@ -0,0 +1,148 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; +use url::Url; + +use super::CompatSession; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub enum CompatSsoLoginState { + Pending, + Fulfilled { + fulfilled_at: DateTime, + session: CompatSession, + }, + Exchanged { + fulfilled_at: DateTime, + exchanged_at: DateTime, + session: CompatSession, + }, +} + +impl CompatSsoLoginState { + /// Returns `true` if the compat sso login state is [`Pending`]. + /// + /// [`Pending`]: CompatSsoLoginState::Pending + #[must_use] + pub fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns `true` if the compat sso login state is [`Fulfilled`]. + /// + /// [`Fulfilled`]: CompatSsoLoginState::Fulfilled + #[must_use] + pub fn is_fulfilled(&self) -> bool { + matches!(self, Self::Fulfilled { .. }) + } + + /// Returns `true` if the compat sso login state is [`Exchanged`]. + /// + /// [`Exchanged`]: CompatSsoLoginState::Exchanged + #[must_use] + pub fn is_exchanged(&self) -> bool { + matches!(self, Self::Exchanged { .. }) + } + + #[must_use] + pub fn fulfilled_at(&self) -> Option> { + match self { + Self::Pending => None, + Self::Fulfilled { fulfilled_at, .. } | Self::Exchanged { fulfilled_at, .. } => { + Some(*fulfilled_at) + } + } + } + + #[must_use] + pub fn exchanged_at(&self) -> Option> { + match self { + Self::Pending | Self::Fulfilled { .. } => None, + Self::Exchanged { exchanged_at, .. } => Some(*exchanged_at), + } + } + + #[must_use] + pub fn session(&self) -> Option<&CompatSession> { + match self { + Self::Pending => None, + Self::Fulfilled { session, .. } | Self::Exchanged { session, .. } => Some(session), + } + } + + pub fn fulfill( + self, + fulfilled_at: DateTime, + session: CompatSession, + ) -> Result { + match self { + Self::Pending => Ok(Self::Fulfilled { + fulfilled_at, + session, + }), + Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), + } + } + + pub fn exchange(self, exchanged_at: DateTime) -> Result { + match self { + Self::Fulfilled { + fulfilled_at, + session, + } => Ok(Self::Exchanged { + fulfilled_at, + exchanged_at, + session, + }), + Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct CompatSsoLogin { + pub id: Ulid, + pub redirect_uri: Url, + pub login_token: String, + pub created_at: DateTime, + pub state: CompatSsoLoginState, +} + +impl std::ops::Deref for CompatSsoLogin { + type Target = CompatSsoLoginState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl CompatSsoLogin { + pub fn fulfill( + mut self, + fulfilled_at: DateTime, + session: CompatSession, + ) -> Result { + self.state = self.state.fulfill(fulfilled_at, session)?; + Ok(self) + } + + pub fn exchange(mut self, exchanged_at: DateTime) -> Result { + self.state = self.state.exchange(exchanged_at)?; + Ok(self) + } +} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index c5f4539c..879dc641 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,16 +37,17 @@ pub struct InvalidTransitionError; pub use self::{ compat::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, - Device, + CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin, + CompatSsoLoginState, Device, }, oauth2::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, - InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, + InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, SessionState, }, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, upstream_oauth2::{ - UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider, + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, + UpstreamOAuthLink, UpstreamOAuthProvider, }, users::{ Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification, diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index ef512260..bc76b091 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,5 +19,5 @@ pub(self) mod session; pub use self::{ authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, - session::Session, + session::{Session, SessionState}, }; diff --git a/crates/data-model/src/oauth2/session.rs b/crates/data-model/src/oauth2/session.rs index bbadd3a7..68ac821c 100644 --- a/crates/data-model/src/oauth2/session.rs +++ b/crates/data-model/src/oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,22 +19,69 @@ use ulid::Ulid; use crate::InvalidTransitionError; +trait T { + type State; +} + +impl T for Session { + type State = SessionState; +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum SessionState { + #[default] + Valid, + Finished { + finished_at: DateTime, + }, +} + +impl SessionState { + /// Returns `true` if the session state is [`Valid`]. + /// + /// [`Valid`]: SessionState::Valid + #[must_use] + pub fn is_valid(&self) -> bool { + matches!(self, Self::Valid) + } + + /// Returns `true` if the session state is [`Finished`]. + /// + /// [`Finished`]: SessionState::Finished + #[must_use] + pub fn is_finished(&self) -> bool { + matches!(self, Self::Finished { .. }) + } + + pub fn finish(self, finished_at: DateTime) -> Result { + match self { + Self::Valid => Ok(Self::Finished { finished_at }), + Self::Finished { .. } => Err(InvalidTransitionError), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Session { pub id: Ulid, + pub state: SessionState, + pub created_at: DateTime, pub user_session_id: Ulid, pub client_id: Ulid, pub scope: Scope, - pub finished_at: Option>, +} + +impl std::ops::Deref for Session { + type Target = SessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } } impl Session { pub fn finish(mut self, finished_at: DateTime) -> Result { - if self.finished_at.is_some() { - return Err(InvalidTransitionError); - } - - self.finished_at = Some(finished_at); + self.state = self.state.finish(finished_at)?; Ok(self) } } diff --git a/crates/data-model/src/upstream_oauth2/link.rs b/crates/data-model/src/upstream_oauth2/link.rs new file mode 100644 index 00000000..c0699173 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/link.rs @@ -0,0 +1,26 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthLink { + pub id: Ulid, + pub provider_id: Ulid, + pub user_id: Option, + pub subject: String, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index 08fbf6c0..90780a8b 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,55 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::{DateTime, Utc}; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; -use oauth2_types::scope::Scope; -use serde::Serialize; -use ulid::Ulid; +mod link; +mod provider; +mod session; -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthProvider { - pub id: Ulid, - pub issuer: String, - pub scope: Scope, - pub client_id: String, - pub encrypted_client_secret: Option, - pub token_endpoint_signing_alg: Option, - pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthLink { - pub id: Ulid, - pub provider_id: Ulid, - pub user_id: Option, - pub subject: String, - pub created_at: DateTime, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct UpstreamOAuthAuthorizationSession { - pub id: Ulid, - pub provider_id: Ulid, - pub link_id: Option, - pub state: String, - pub code_challenge_verifier: Option, - pub nonce: String, - pub created_at: DateTime, - pub completed_at: Option>, - pub consumed_at: Option>, - pub id_token: Option, -} - -impl UpstreamOAuthAuthorizationSession { - #[must_use] - pub const fn completed(&self) -> bool { - self.completed_at.is_some() - } - - #[must_use] - pub const fn consumed(&self) -> bool { - self.consumed_at.is_some() - } -} +pub use self::{ + link::UpstreamOAuthLink, + provider::UpstreamOAuthProvider, + session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState}, +}; diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs new file mode 100644 index 00000000..919b2221 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -0,0 +1,31 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use oauth2_types::scope::Scope; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthProvider { + pub id: Ulid, + pub issuer: String, + pub scope: Scope, + pub client_id: String, + pub encrypted_client_secret: Option, + pub token_endpoint_signing_alg: Option, + pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, + pub created_at: DateTime, +} diff --git a/crates/data-model/src/upstream_oauth2/session.rs b/crates/data-model/src/upstream_oauth2/session.rs new file mode 100644 index 00000000..9ce61266 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/session.rs @@ -0,0 +1,170 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use ulid::Ulid; + +use super::UpstreamOAuthLink; +use crate::InvalidTransitionError; + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub enum UpstreamOAuthAuthorizationSessionState { + #[default] + Pending, + Completed { + completed_at: DateTime, + link_id: Ulid, + id_token: Option, + }, + Consumed { + completed_at: DateTime, + consumed_at: DateTime, + link_id: Ulid, + id_token: Option, + }, +} + +impl UpstreamOAuthAuthorizationSessionState { + pub fn complete( + self, + completed_at: DateTime, + link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + match self { + Self::Pending => Ok(Self::Completed { + completed_at, + link_id: link.id, + id_token, + }), + Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + pub fn consume(self, consumed_at: DateTime) -> Result { + match self { + Self::Completed { + completed_at, + link_id, + id_token, + } => Ok(Self::Consumed { + completed_at, + link_id, + consumed_at, + id_token, + }), + Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError), + } + } + + #[must_use] + pub fn link_id(&self) -> Option { + match self { + Self::Pending => None, + Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id), + } + } + + #[must_use] + pub fn completed_at(&self) -> Option> { + match self { + Self::Pending => None, + Self::Completed { completed_at, .. } | Self::Consumed { completed_at, .. } => { + Some(*completed_at) + } + } + } + + #[must_use] + pub fn id_token(&self) -> Option<&str> { + match self { + Self::Pending => None, + Self::Completed { id_token, .. } | Self::Consumed { id_token, .. } => { + id_token.as_deref() + } + } + } + + #[must_use] + pub fn consumed_at(&self) -> Option> { + match self { + Self::Pending | Self::Completed { .. } => None, + Self::Consumed { consumed_at, .. } => Some(*consumed_at), + } + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Pending`]. + /// + /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending + #[must_use] + pub fn is_pending(&self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Completed`]. + /// + /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed + #[must_use] + pub fn is_completed(&self) -> bool { + matches!(self, Self::Completed { .. }) + } + + /// Returns `true` if the upstream oauth authorization session state is + /// [`Consumed`]. + /// + /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed + #[must_use] + pub fn is_consumed(&self) -> bool { + matches!(self, Self::Consumed { .. }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthAuthorizationSession { + pub id: Ulid, + pub state: UpstreamOAuthAuthorizationSessionState, + pub provider_id: Ulid, + pub state_str: String, + pub code_challenge_verifier: Option, + pub nonce: String, + pub created_at: DateTime, +} + +impl std::ops::Deref for UpstreamOAuthAuthorizationSession { + type Target = UpstreamOAuthAuthorizationSessionState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl UpstreamOAuthAuthorizationSession { + pub fn complete( + mut self, + completed_at: DateTime, + link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result { + self.state = self.state.complete(completed_at, link, id_token)?; + Ok(self) + } + + pub fn consume(mut self, consumed_at: DateTime) -> Result { + self.state = self.state.consume(consumed_at)?; + Ok(self) + } +} diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index f3610233..38fe866c 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,7 +15,6 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_data_model::CompatSsoLoginState; use mas_storage::{user::UserRepository, Repository}; use sqlx::PgPool; use url::Url; @@ -57,7 +56,7 @@ impl CompatSession { /// When the session ended. pub async fn finished_at(&self) -> Option> { - self.0.finished_at + self.0.finished_at() } } @@ -86,29 +85,16 @@ impl CompatSsoLogin { /// When the login was fulfilled, and the user was redirected back to the /// client. async fn fulfilled_at(&self) -> Option> { - match &self.0.state { - CompatSsoLoginState::Pending => None, - CompatSsoLoginState::Fulfilled { fulfilled_at, .. } - | CompatSsoLoginState::Exchanged { fulfilled_at, .. } => Some(*fulfilled_at), - } + self.0.fulfilled_at() } /// When the client exchanged the login token sent during the redirection. async fn exchanged_at(&self) -> Option> { - match &self.0.state { - CompatSsoLoginState::Pending | CompatSsoLoginState::Fulfilled { .. } => None, - CompatSsoLoginState::Exchanged { exchanged_at, .. } => Some(*exchanged_at), - } + self.0.exchanged_at() } /// The compat session which was started by this login. async fn session(&self) -> Option { - match &self.0.state { - CompatSsoLoginState::Pending => None, - CompatSsoLoginState::Fulfilled { session, .. } - | CompatSsoLoginState::Exchanged { session, .. } => { - Some(CompatSession(session.clone())) - } - } + self.0.session().cloned().map(CompatSession) } } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 295f7307..8cb9a605 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -153,12 +153,12 @@ pub(crate) async fn get( return Err(RouteError::ProviderMismatch); } - if params.state != session.state { + if params.state != session.state_str { // The state in the session cookie should match the one from the params return Err(RouteError::StateMismatch); } - if session.completed() { + if !session.is_pending() { // The session was already completed return Err(RouteError::AlreadyCompleted); } @@ -207,7 +207,7 @@ pub(crate) async fn get( // TODO: all that should be borrowed let validation_data = AuthorizationValidationData { - state: session.state.clone(), + state: session.state_str.clone(), nonce: session.nonce.clone(), code_challenge_verifier: session.code_challenge_verifier.clone(), redirect_uri, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 80fa04f7..10e1f80e 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -121,11 +121,11 @@ pub(crate) async fn get( // This checks that we're in a browser session which is allowed to consume this // link: the upstream auth session should have been started in this browser. - if upstream_session.link_id != Some(link.id) { + if upstream_session.link_id() != Some(link.id) { return Err(RouteError::SessionNotFound); } - if upstream_session.consumed() { + if upstream_session.is_consumed() { return Err(RouteError::SessionConsumed); } @@ -243,11 +243,11 @@ pub(crate) async fn post( // This checks that we're in a browser session which is allowed to consume this // link: the upstream auth session should have been started in this browser. - if upstream_session.link_id != Some(link.id) { + if upstream_session.link_id() != Some(link.id) { return Err(RouteError::SessionNotFound); } - if upstream_session.consumed() { + if upstream_session.is_consumed() { return Err(RouteError::SessionConsumed); } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 740e9a3a..e8a33bd2 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -657,6 +657,74 @@ }, "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " }, + "5f0e2aec0d7766d3674af3e68417921fec7068e83845e218a4a00d86487557f9": { + "describe": { + "columns": [ + { + "name": "oauth2_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "oauth2_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id!", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id!", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "scope!", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_session_id!", + "ordinal": 8, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n " + }, "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "describe": { "columns": [], @@ -1397,6 +1465,134 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " }, + "aa2fd69c595f94d8598715766a79671dba8f87b9d7af6ac30e3fa1fbc8cce28a": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id?", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n " + }, "aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": { "describe": { "columns": [ @@ -1442,6 +1638,134 @@ }, "query": "\n SELECT user_email_id\n , user_id\n , email\n , created_at\n , confirmed_at\n FROM user_emails\n\n WHERE user_id = $1 AND email = $2\n " }, + "b12f7ba71ad522261f54ffbb6739a7a06214b4f01e3ed6f7fdaa2033d249f3fb": { + "describe": { + "columns": [ + { + "name": "oauth2_authorization_grant_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_cancelled_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_fulfilled_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_exchanged_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_authorization_grant_scope", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_state", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_redirect_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_mode", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_nonce", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_max_age", + "ordinal": 10, + "type_info": "Int4" + }, + { + "name": "oauth2_client_id", + "ordinal": 11, + "type_info": "Uuid" + }, + { + "name": "oauth2_authorization_grant_code", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_response_type_code", + "ordinal": 13, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_response_type_id_token", + "ordinal": 14, + "type_info": "Bool" + }, + { + "name": "oauth2_authorization_grant_code_challenge", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_code_challenge_method", + "ordinal": 16, + "type_info": "Text" + }, + { + "name": "oauth2_authorization_grant_requires_consent", + "ordinal": 17, + "type_info": "Bool" + }, + { + "name": "oauth2_session_id?", + "ordinal": 18, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true, + false, + false, + true, + true, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n " + }, "b26ae7dd28f8a756b55a76e80cdedd7be9ba26435ea4a914421483f8ed832537": { "describe": { "columns": [], @@ -1544,140 +1868,6 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, - "c467144ae98322e3ed6d34df6626d63c15bdfc7137e12097cfb6f9398f7029ca": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT og.oauth2_authorization_grant_id\n , og.created_at AS oauth2_authorization_grant_created_at\n , og.cancelled_at AS oauth2_authorization_grant_cancelled_at\n , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , og.exchanged_at AS oauth2_authorization_grant_exchanged_at\n , og.scope AS oauth2_authorization_grant_scope\n , og.state AS oauth2_authorization_grant_state\n , og.redirect_uri AS oauth2_authorization_grant_redirect_uri\n , og.response_mode AS oauth2_authorization_grant_response_mode\n , og.nonce AS oauth2_authorization_grant_nonce\n , og.max_age AS oauth2_authorization_grant_max_age\n , og.oauth2_client_id AS oauth2_client_id\n , og.authorization_code AS oauth2_authorization_grant_code\n , og.response_type_code AS oauth2_authorization_grant_response_type_code\n , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , og.code_challenge AS oauth2_authorization_grant_code_challenge\n , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , og.requires_consent AS oauth2_authorization_grant_requires_consent\n , os.oauth2_session_id AS \"oauth2_session_id?\"\n , os.user_session_id AS \"user_session_id?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE og.authorization_code = $1\n " - }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { "columns": [], @@ -1704,86 +1894,6 @@ }, "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " }, - "cad4d47709278a9ddbebfc91642967b465bafa596827d9b86a336841b2cfbf0c": { - "describe": { - "columns": [ - { - "name": "oauth2_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_id?", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token?", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at?", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at?", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 7, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 8, - "type_info": "Uuid" - }, - { - "name": "oauth2_session_scope!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 10, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , at.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , at.access_token AS \"oauth2_access_token?\"\n , at.created_at AS \"oauth2_access_token_created_at?\"\n , at.expires_at AS \"oauth2_access_token_expires_at?\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n LEFT JOIN oauth2_access_tokens at\n USING (oauth2_access_token_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n " - }, "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { "describe": { "columns": [], @@ -1799,140 +1909,6 @@ }, "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " }, - "d08b787fc422b6699ffc0a491ecf92fb993db0aca51534b315bcfa4891baca84": { - "describe": { - "columns": [ - { - "name": "oauth2_authorization_grant_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_cancelled_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_fulfilled_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_exchanged_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_authorization_grant_scope", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_state", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_redirect_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_mode", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_nonce", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_max_age", - "ordinal": 10, - "type_info": "Int4" - }, - { - "name": "oauth2_client_id", - "ordinal": 11, - "type_info": "Uuid" - }, - { - "name": "oauth2_authorization_grant_code", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_response_type_code", - "ordinal": 13, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_response_type_id_token", - "ordinal": 14, - "type_info": "Bool" - }, - { - "name": "oauth2_authorization_grant_code_challenge", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_code_challenge_method", - "ordinal": 16, - "type_info": "Text" - }, - { - "name": "oauth2_authorization_grant_requires_consent", - "ordinal": 17, - "type_info": "Bool" - }, - { - "name": "oauth2_session_id?", - "ordinal": 18, - "type_info": "Uuid" - }, - { - "name": "user_session_id?", - "ordinal": 19, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - true, - true, - true, - false, - true, - false, - false, - true, - true, - false, - true, - false, - false, - true, - true, - false, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT og.oauth2_authorization_grant_id\n , og.created_at AS oauth2_authorization_grant_created_at\n , og.cancelled_at AS oauth2_authorization_grant_cancelled_at\n , og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , og.exchanged_at AS oauth2_authorization_grant_exchanged_at\n , og.scope AS oauth2_authorization_grant_scope\n , og.state AS oauth2_authorization_grant_state\n , og.redirect_uri AS oauth2_authorization_grant_redirect_uri\n , og.response_mode AS oauth2_authorization_grant_response_mode\n , og.nonce AS oauth2_authorization_grant_nonce\n , og.max_age AS oauth2_authorization_grant_max_age\n , og.oauth2_client_id AS oauth2_client_id\n , og.authorization_code AS oauth2_authorization_grant_code\n , og.response_type_code AS oauth2_authorization_grant_response_type_code\n , og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , og.code_challenge AS oauth2_authorization_grant_code_challenge\n , og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , og.requires_consent AS oauth2_authorization_grant_requires_consent\n , os.oauth2_session_id AS \"oauth2_session_id?\"\n , os.user_session_id AS \"user_session_id?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE og.oauth2_authorization_grant_id = $1\n " - }, "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { "describe": { "columns": [ @@ -2182,6 +2158,74 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, + "e25b8071b59075c4be9fac283410ec4acf771fdf06076ef7bbb11bf086c4bc03": { + "describe": { + "columns": [ + { + "name": "oauth2_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "oauth2_refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "oauth2_refresh_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_access_token_id?", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "oauth2_session_id!", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id!", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "oauth2_session_scope!", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_session_id!", + "ordinal": 8, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + false, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT rt.oauth2_refresh_token_id\n , rt.refresh_token AS oauth2_refresh_token\n , rt.created_at AS oauth2_refresh_token_created_at\n , rt.oauth2_access_token_id AS \"oauth2_access_token_id?\"\n , os.created_at AS \"oauth2_session_created_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"oauth2_session_scope!\"\n , os.user_session_id AS \"user_session_id!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE rt.refresh_token = $1\n AND rt.consumed_at IS NULL\n AND rt.revoked_at IS NULL\n AND os.finished_at IS NULL\n " + }, "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ @@ -2227,6 +2271,56 @@ }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " }, + "f0ace1af3775192a555c4ebb59b81183f359771f9f77e5fad759d38d872541d1": { + "describe": { + "columns": [ + { + "name": "oauth2_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "oauth2_client_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "scope", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "finished_at", + "ordinal": 5, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n " + }, "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { "describe": { "columns": [], @@ -2253,67 +2347,5 @@ } }, "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " - }, - "fba88894ee24cd181f50412571a19ee658f77012d330e7dab43a3c18d549355a": { - "describe": { - "columns": [ - { - "name": "oauth2_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "oauth2_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "oauth2_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "oauth2_session_id!", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "oauth2_client_id!", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "scope!", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "user_session_id!", - "ordinal": 7, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , os.user_session_id AS \"user_session_id!\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n " } } \ No newline at end of file diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index ba47990d..c328f9e1 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -14,8 +14,8 @@ use chrono::{DateTime, Duration, Utc}; use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, - Device, User, + CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin, + CompatSsoLoginState, Device, User, }; use rand::Rng; use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; @@ -93,12 +93,17 @@ pub async fn lookup_active_compat_access_token( .source(e) })?; + let state = match res.compat_session_finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + let session = CompatSession { id, + state, user_id: res.user_id.into(), device, created_at: res.compat_session_created_at, - finished_at: res.compat_session_finished_at, }; Ok(Some((token, session))) @@ -181,12 +186,17 @@ pub async fn lookup_active_compat_refresh_token( .source(e) })?; + let state = match res.compat_session_finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + let session = CompatSession { id, + state, user_id: res.user_id.into(), device, created_at: res.compat_session_created_at, - finished_at: res.compat_session_finished_at, }; Ok(Some((refresh_token, access_token, session))) @@ -468,12 +478,18 @@ impl TryFrom for CompatSsoLogin { .row(id) .source(e) })?; + + let state = match finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + Some(CompatSession { id, + state, user_id: user_id.into(), device, created_at, - finished_at, }) } (None, None, None, None, None) => None, @@ -686,10 +702,10 @@ pub async fn start_compat_session( Ok(CompatSession { id, + state: CompatSessionState::default(), user_id: user.id, device, created_at, - finished_at: None, }) } @@ -709,7 +725,7 @@ pub async fn fullfill_compat_sso_login( mut rng: impl Rng + Send, clock: &Clock, user: &User, - mut compat_sso_login: CompatSsoLogin, + compat_sso_login: CompatSsoLogin, device: Device, ) -> Result { if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) { @@ -719,8 +735,12 @@ pub async fn fullfill_compat_sso_login( let mut txn = conn.begin().await?; let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?; + let session_id = session.id; let fulfilled_at = clock.now(); + let compat_sso_login = compat_sso_login + .fulfill(fulfilled_at, session) + .map_err(DatabaseError::to_invalid_operation)?; sqlx::query!( r#" UPDATE compat_sso_logins @@ -731,20 +751,13 @@ pub async fn fullfill_compat_sso_login( compat_sso_login_id = $1 "#, Uuid::from(compat_sso_login.id), - Uuid::from(session.id), + Uuid::from(session_id), fulfilled_at, ) .execute(&mut txn) .instrument(tracing::info_span!("Update compat SSO login")) .await?; - let state = CompatSsoLoginState::Fulfilled { - fulfilled_at, - session, - }; - - compat_sso_login.state = state; - txn.commit().await?; Ok(compat_sso_login) @@ -761,13 +774,13 @@ pub async fn fullfill_compat_sso_login( pub async fn mark_compat_sso_login_as_exchanged( executor: impl PgExecutor<'_>, clock: &Clock, - mut compat_sso_login: CompatSsoLogin, + compat_sso_login: CompatSsoLogin, ) -> Result { - let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else { - return Err(DatabaseError::invalid_operation()); - }; - let exchanged_at = clock.now(); + let compat_sso_login = compat_sso_login + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + sqlx::query!( r#" UPDATE compat_sso_logins @@ -783,11 +796,5 @@ pub async fn mark_compat_sso_login_as_exchanged( .instrument(tracing::info_span!("Update compat SSO login")) .await?; - let state = CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session, - }; - compat_sso_login.state = state; Ok(compat_sso_login) } diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index cd4cafbf..58c13b19 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -13,7 +13,7 @@ // limitations under the License. use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{AccessToken, Session}; +use mas_data_model::{AccessToken, Session, SessionState}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; @@ -76,6 +76,7 @@ pub struct OAuth2AccessTokenLookup { oauth2_access_token: String, oauth2_access_token_created_at: DateTime, oauth2_access_token_expires_at: DateTime, + oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, oauth2_client_id: Uuid, scope: String, @@ -94,6 +95,7 @@ pub async fn lookup_active_access_token( , at.access_token AS "oauth2_access_token" , at.created_at AS "oauth2_access_token_created_at" , at.expires_at AS "oauth2_access_token_expires_at" + , os.created_at AS "oauth2_session_created_at" , os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_client_id AS "oauth2_client_id!" , os.scope AS "scope!" @@ -131,10 +133,11 @@ pub async fn lookup_active_access_token( let session = Session { id: session_id, + state: SessionState::Valid, + created_at: res.oauth2_session_created_at, client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, - finished_at: None, }; Ok(Some((access_token, session))) diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index e4c35c71..f49b38e8 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -13,7 +13,7 @@ // limitations under the License. use chrono::{DateTime, Utc}; -use mas_data_model::{AccessToken, RefreshToken, Session}; +use mas_data_model::{AccessToken, RefreshToken, Session, SessionState}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; @@ -73,6 +73,7 @@ struct OAuth2RefreshTokenLookup { oauth2_refresh_token: String, oauth2_refresh_token_created_at: DateTime, oauth2_access_token_id: Option, + oauth2_session_created_at: DateTime, oauth2_session_id: Uuid, oauth2_client_id: Uuid, oauth2_session_scope: String, @@ -92,6 +93,7 @@ pub async fn lookup_active_refresh_token( , rt.refresh_token AS oauth2_refresh_token , rt.created_at AS oauth2_refresh_token_created_at , rt.oauth2_access_token_id AS "oauth2_access_token_id?" + , os.created_at AS "oauth2_session_created_at" , os.oauth2_session_id AS "oauth2_session_id!" , os.oauth2_client_id AS "oauth2_client_id!" , os.scope AS "oauth2_session_scope!" @@ -127,10 +129,11 @@ pub async fn lookup_active_refresh_token( let session = Session { id: session_id, + state: SessionState::Valid, + created_at: res.oauth2_session_created_at, client_id: res.oauth2_client_id.into(), user_session_id: res.user_session_id.into(), scope, - finished_at: None, }; Ok(Some((refresh_token, session))) diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 3a681a84..9df2f61d 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; +use mas_data_model::{AuthorizationGrant, BrowserSession, Session, SessionState, User}; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; use ulid::Ulid; @@ -85,12 +85,18 @@ impl TryFrom for Session { .source(e) })?; + let state = match value.finished_at { + None => SessionState::Valid, + Some(finished_at) => SessionState::Finished { finished_at }, + }; + Ok(Session { id, + state, + created_at: value.created_at, client_id: value.oauth2_client_id.into(), user_session_id: value.user_session_id.into(), scope, - finished_at: value.finished_at, }) } } @@ -182,10 +188,11 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { Ok(Session { id, + state: SessionState::Valid, + created_at, user_session_id: user_session.id, client_id: grant.client_id, scope: grant.scope.clone(), - finished_at: None, }) } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 1abcd1d0..e195056c 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -85,9 +85,10 @@ mod tests { .await? .expect("session to be found in the database"); assert_eq!(session.provider_id, provider.id); - assert_eq!(session.link_id, None); - assert!(!session.completed()); - assert!(!session.consumed()); + assert_eq!(session.link_id(), None); + assert!(session.is_pending()); + assert!(!session.is_completed()); + assert!(!session.is_consumed()); // Create a link let link = conn @@ -114,15 +115,15 @@ mod tests { .upstream_oauth_session() .complete_with_link(&clock, session, &link, None) .await?; - assert!(session.completed()); - assert!(!session.consumed()); - assert_eq!(session.link_id, Some(link.id)); + assert!(session.is_completed()); + assert!(!session.is_consumed()); + assert_eq!(session.link_id(), Some(link.id)); let session = conn .upstream_oauth_session() .consume(&clock, session) .await?; - assert!(session.consumed()); + assert!(session.is_consumed()); Ok(()) } diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index f13c6ec8..d5da6ef8 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -14,13 +14,18 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; +use mas_data_model::{ + UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, + UpstreamOAuthProvider, +}; use rand::RngCore; use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -83,6 +88,52 @@ struct SessionLookup { consumed_at: Option>, } +impl TryFrom for UpstreamOAuthAuthorizationSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: SessionLookup) -> Result { + let id = value.upstream_oauth_authorization_session_id.into(); + let state = match ( + value.upstream_oauth_link_id, + value.id_token, + value.completed_at, + value.consumed_at, + ) { + (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, + (Some(link_id), id_token, Some(completed_at), None) => { + UpstreamOAuthAuthorizationSessionState::Completed { + completed_at, + link_id: link_id.into(), + id_token, + } + } + (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { + UpstreamOAuthAuthorizationSessionState::Consumed { + completed_at, + link_id: link_id.into(), + id_token, + consumed_at, + } + } + _ => { + return Err( + DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), + ) + } + }; + + Ok(Self { + id, + provider_id: value.upstream_oauth_provider_id.into(), + state_str: value.state, + nonce: value.nonce, + code_challenge_verifier: value.code_challenge_verifier, + created_at: value.created_at, + state, + }) + } +} + #[async_trait] impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> { type Error = DatabaseError; @@ -126,20 +177,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> let Some(res) = res else { return Ok(None) }; - let session = UpstreamOAuthAuthorizationSession { - id: res.upstream_oauth_authorization_session_id.into(), - provider_id: res.upstream_oauth_provider_id.into(), - link_id: res.upstream_oauth_link_id.map(Ulid::from), - state: res.state, - code_challenge_verifier: res.code_challenge_verifier, - nonce: res.nonce, - id_token: res.id_token, - created_at: res.created_at, - completed_at: res.completed_at, - consumed_at: res.consumed_at, - }; - - Ok(Some(session)) + Ok(Some(res.try_into()?)) } #[tracing::instrument( @@ -159,7 +197,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> rng: &mut (dyn RngCore + Send), clock: &Clock, upstream_oauth_provider: &UpstreamOAuthProvider, - state: String, + state_str: String, code_challenge_verifier: Option, nonce: String, ) -> Result { @@ -186,7 +224,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> "#, Uuid::from(id), Uuid::from(upstream_oauth_provider.id), - &state, + &state_str, code_challenge_verifier.as_deref(), nonce, created_at, @@ -197,15 +235,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> Ok(UpstreamOAuthAuthorizationSession { id, + state: UpstreamOAuthAuthorizationSessionState::default(), provider_id: upstream_oauth_provider.id, - link_id: None, - state, + state_str, code_challenge_verifier, nonce, - id_token: None, created_at, - completed_at: None, - consumed_at: None, }) } @@ -222,11 +257,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> async fn complete_with_link( &mut self, clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, ) -> Result { let completed_at = clock.now(); + sqlx::query!( r#" UPDATE upstream_oauth_authorization_sessions @@ -244,9 +280,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> .execute(&mut *self.conn) .await?; - upstream_oauth_authorization_session.completed_at = Some(completed_at); - upstream_oauth_authorization_session.id_token = id_token; - upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id); + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .complete(completed_at, upstream_oauth_link, id_token) + .map_err(DatabaseError::to_invalid_operation)?; Ok(upstream_oauth_authorization_session) } @@ -264,7 +300,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> async fn consume( &mut self, clock: &Clock, - mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result { let consumed_at = clock.now(); sqlx::query!( @@ -280,7 +316,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> .execute(&mut *self.conn) .await?; - upstream_oauth_authorization_session.consumed_at = Some(consumed_at); + let upstream_oauth_authorization_session = upstream_oauth_authorization_session + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; Ok(upstream_oauth_authorization_session) }