diff --git a/crates/data-model/src/oauth2/device_code_grant.rs b/crates/data-model/src/oauth2/device_code_grant.rs index a22b095f..203d7387 100644 --- a/crates/data-model/src/oauth2/device_code_grant.rs +++ b/crates/data-model/src/oauth2/device_code_grant.rs @@ -17,7 +17,7 @@ use oauth2_types::scope::Scope; use serde::Serialize; use ulid::Ulid; -use crate::{BrowserSession, InvalidTransitionError}; +use crate::{BrowserSession, InvalidTransitionError, Session}; #[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[serde(rename_all = "snake_case", tag = "state")] @@ -117,7 +117,7 @@ impl DeviceCodeGrantState { /// [`Fulfilled`]: DeviceCodeGrantState::Fulfilled pub fn exchange( self, - session_id: Ulid, + session: &Session, exchanged_at: DateTime, ) -> Result { match self { @@ -129,7 +129,7 @@ impl DeviceCodeGrantState { browser_session_id, fulfilled_at, exchanged_at, - session_id, + session_id: session.id, }), _ => Err(InvalidTransitionError), } @@ -251,11 +251,11 @@ impl DeviceCodeGrant { /// [`Fulfilled`]: DeviceCodeGrantState::Fulfilled pub fn exchange( self, - session_id: Ulid, + session: &Session, exchanged_at: DateTime, ) -> Result { Ok(Self { - state: self.state.exchange(session_id, exchanged_at)?, + state: self.state.exchange(session, exchanged_at)?, ..self }) } diff --git a/crates/storage-pg/migrations/20231207090532_oauth_device_code_grant.sql b/crates/storage-pg/migrations/20231207090532_oauth_device_code_grant.sql new file mode 100644 index 00000000..7f471f7b --- /dev/null +++ b/crates/storage-pg/migrations/20231207090532_oauth_device_code_grant.sql @@ -0,0 +1,76 @@ +-- Copyright 2023 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +--- Adds a table to store device codes for OAuth 2.0 device code flows +-- +-- +-- This has 4 possible states, only going in one direction: +-- +-- [[ Pending ]] +-- | | +-- | [ Rejected ] -- The `rejected_at` and `user_session_id` fields are set +-- | +-- [ Fulfilled ] -- The `fulfilled_at` and `user_session_id` fields are set +-- | +-- [ Exchanged ] -- The `exchanged_at` and `oauth2_session_id` fields are also set +-- +CREATE TABLE "oauth2_device_code_grant" ( + "oauth2_device_code_grant_id" UUID NOT NULL + PRIMARY KEY, + + -- The client who initiated the device code grant + "oauth2_client_id" UUID NOT NULL + REFERENCES "oauth2_clients" ("oauth2_client_id") + ON DELETE CASCADE, + + -- The scope requested + "scope" TEXT NOT NULL, + + -- The random code that is displayed to the user + "user_code" TEXT NOT NULL + UNIQUE, + + -- The random code that the client uses to poll for the access token + "device_code" TEXT NOT NULL + UNIQUE, + + -- Timestamp when the device code was created + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL, + + -- Timestamp when the device code expires + "expires_at" TIMESTAMP WITH TIME ZONE NOT NULL, + + -- When the device code was fulfilled, i.e. the user has granted access + -- This is mutually exclusive with rejected_at + "fulfilled_at" TIMESTAMP WITH TIME ZONE, + + -- When the device code was rejected, i.e. the user has denied access + -- This is mutually exclusive with fulfilled_at + "rejected_at" TIMESTAMP WITH TIME ZONE, + + -- When the device code was exchanged + -- This means "fulfilled_at" has also been set + "exchanged_at" TIMESTAMP WITH TIME ZONE, + + -- The OAuth 2.0 session generated for this device code + -- This means "exchanged_at" has also been set + "oauth2_session_id" UUID + REFERENCES "oauth2_sessions" ("oauth2_session_id") + ON DELETE CASCADE, + + -- The browser session ID that the user used to authenticate + -- This means "fulfilled_at" or "rejected_at" has also been set + "user_session_id" UUID + REFERENCES "user_sessions" ("user_session_id") +); diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs index 6fe56010..ce4d25cf 100644 --- a/crates/storage-pg/src/oauth2/authorization_grant.rs +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -281,6 +281,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi requires_consent, created_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -340,6 +341,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi "#, Uuid::from(id), ) + .traced() .fetch_optional(&mut *self.conn) .await?; @@ -427,6 +429,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi fulfilled_at, Uuid::from(session.id), ) + .traced() .execute(&mut *self.conn) .await?; @@ -465,6 +468,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi Uuid::from(grant.id), exchanged_at, ) + .traced() .execute(&mut *self.conn) .await?; @@ -501,6 +505,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi "#, Uuid::from(grant.id), ) + .traced() .execute(&mut *self.conn) .await?; diff --git a/crates/storage-pg/src/oauth2/device_code_grant.rs b/crates/storage-pg/src/oauth2/device_code_grant.rs new file mode 100644 index 00000000..0b5094b1 --- /dev/null +++ b/crates/storage-pg/src/oauth2/device_code_grant.rs @@ -0,0 +1,463 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{BrowserSession, DeviceCodeGrant, DeviceCodeGrantState, Session}; +use mas_storage::{ + oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository}, + Clock, +}; +use oauth2_types::scope::Scope; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{errors::DatabaseInconsistencyError, DatabaseError, ExecuteExt}; + +/// An implementation of [`OAuth2DeviceCodeGrantRepository`] for a PostgreSQL +/// connection +pub struct PgOAuth2DeviceCodeGrantRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> { + /// Create a new [`PgOAuth2DeviceCodeGrantRepository`] from an active + /// PostgreSQL connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct OAuth2DeviceGrantLookup { + oauth2_device_code_grant_id: Uuid, + oauth2_client_id: Uuid, + scope: String, + device_code: String, + user_code: String, + created_at: DateTime, + expires_at: DateTime, + fulfilled_at: Option>, + rejected_at: Option>, + exchanged_at: Option>, + user_session_id: Option, + oauth2_session_id: Option, +} + +impl TryFrom for DeviceCodeGrant { + type Error = DatabaseInconsistencyError; + + fn try_from( + OAuth2DeviceGrantLookup { + oauth2_device_code_grant_id, + oauth2_client_id, + scope, + device_code, + user_code, + created_at, + expires_at, + fulfilled_at, + rejected_at, + exchanged_at, + user_session_id, + oauth2_session_id, + }: OAuth2DeviceGrantLookup, + ) -> Result { + let id = Ulid::from(oauth2_device_code_grant_id); + let client_id = Ulid::from(oauth2_client_id); + + let scope: Scope = scope.parse().map_err(|e| { + DatabaseInconsistencyError::on("oauth2_authorization_grants") + .column("scope") + .row(id) + .source(e) + })?; + + let state = match ( + fulfilled_at, + rejected_at, + exchanged_at, + user_session_id, + oauth2_session_id, + ) { + (None, None, None, None, None) => DeviceCodeGrantState::Pending, + + (Some(fulfilled_at), None, None, Some(user_session_id), None) => { + DeviceCodeGrantState::Fulfilled { + browser_session_id: Ulid::from(user_session_id), + fulfilled_at, + } + } + + (None, Some(rejected_at), None, Some(user_session_id), None) => { + DeviceCodeGrantState::Rejected { + browser_session_id: Ulid::from(user_session_id), + rejected_at, + } + } + + ( + Some(fulfilled_at), + None, + Some(exchanged_at), + Some(user_session_id), + Some(oauth2_session_id), + ) => DeviceCodeGrantState::Exchanged { + browser_session_id: Ulid::from(user_session_id), + session_id: Ulid::from(oauth2_session_id), + fulfilled_at, + exchanged_at, + }, + + _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)), + }; + + Ok(DeviceCodeGrant { + id, + state, + client_id, + scope, + user_code, + device_code, + created_at, + expires_at, + }) + } +} + +#[async_trait] +impl<'c> OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.oauth2_device_code_grant.add", + skip_all, + fields( + db.statement, + oauth2_device_code.id, + oauth2_device_code.scope = %params.scope, + oauth2_client.id = %params.client.id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + params: OAuth2DeviceCodeGrantParams<'_>, + ) -> Result { + let now = clock.now(); + let id = Ulid::from_datetime_with_source(now.into(), rng); + tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id)); + + let created_at = now; + let expires_at = now + params.expires_in; + let client_id = params.client.id; + + sqlx::query!( + r#" + INSERT INTO "oauth2_device_code_grant" + ( oauth2_device_code_grant_id + , oauth2_client_id + , scope + , device_code + , user_code + , created_at + , expires_at + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7) + "#, + Uuid::from(id), + Uuid::from(client_id), + params.scope.to_string(), + ¶ms.device_code, + ¶ms.user_code, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(DeviceCodeGrant { + id, + state: DeviceCodeGrantState::Pending, + client_id, + scope: params.scope, + user_code: params.user_code, + device_code: params.device_code, + created_at, + expires_at, + }) + } + + #[tracing::instrument( + name = "db.oauth2_device_code_grant.lookup", + skip_all, + fields( + db.statement, + oauth2_device_code.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2DeviceGrantLookup, + r#" + SELECT oauth2_device_code_grant_id + , oauth2_client_id + , scope + , device_code + , user_code + , created_at + , expires_at + , fulfilled_at + , rejected_at + , exchanged_at + , user_session_id + , oauth2_session_id + FROM + oauth2_device_code_grant + + WHERE oauth2_device_code_grant_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_device_code_grant.find_by_user_code", + skip_all, + fields( + db.statement, + oauth2_device_code.user_code = %user_code, + ), + err, + )] + async fn find_by_user_code( + &mut self, + user_code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2DeviceGrantLookup, + r#" + SELECT oauth2_device_code_grant_id + , oauth2_client_id + , scope + , device_code + , user_code + , created_at + , expires_at + , fulfilled_at + , rejected_at + , exchanged_at + , user_session_id + , oauth2_session_id + FROM + oauth2_device_code_grant + + WHERE user_code = $1 + "#, + user_code, + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_device_code_grant.find_by_device_code", + skip_all, + fields( + db.statement, + oauth2_device_code.device_code = %device_code, + ), + err, + )] + async fn find_by_device_code( + &mut self, + device_code: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2DeviceGrantLookup, + r#" + SELECT oauth2_device_code_grant_id + , oauth2_client_id + , scope + , device_code + , user_code + , created_at + , expires_at + , fulfilled_at + , rejected_at + , exchanged_at + , user_session_id + , oauth2_session_id + FROM + oauth2_device_code_grant + + WHERE device_code = $1 + "#, + device_code, + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.oauth2_device_code_grant.fulfill", + skip_all, + fields( + db.statement, + oauth2_device_code.id = %device_code_grant.id, + oauth2_client.id = %device_code_grant.client_id, + browser_session.id = %browser_session.id, + user.id = %browser_session.user.id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + browser_session: &BrowserSession, + ) -> Result { + let fulfilled_at = clock.now(); + let device_code_grant = device_code_grant + .fulfill(&browser_session, fulfilled_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE oauth2_device_code_grant + SET fulfilled_at = $1 + , user_session_id = $2 + WHERE oauth2_device_code_grant_id = $3 + "#, + fulfilled_at, + Uuid::from(browser_session.id), + Uuid::from(device_code_grant.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(device_code_grant) + } + + #[tracing::instrument( + name = "db.oauth2_device_code_grant.reject", + skip_all, + fields( + db.statement, + oauth2_device_code.id = %device_code_grant.id, + oauth2_client.id = %device_code_grant.client_id, + browser_session.id = %browser_session.id, + user.id = %browser_session.user.id, + ), + err, + )] + async fn reject( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + browser_session: &BrowserSession, + ) -> Result { + let fulfilled_at = clock.now(); + let device_code_grant = device_code_grant + .reject(&browser_session, fulfilled_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE oauth2_device_code_grant + SET rejected_at = $1 + , user_session_id = $2 + WHERE oauth2_device_code_grant_id = $3 + "#, + fulfilled_at, + Uuid::from(browser_session.id), + Uuid::from(device_code_grant.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(device_code_grant) + } + + #[tracing::instrument( + name = "db.oauth2_device_code_grant.exchange", + skip_all, + fields( + db.statement, + oauth2_device_code.id = %device_code_grant.id, + oauth2_client.id = %device_code_grant.client_id, + oauth2_session.id = %session.id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + session: &Session, + ) -> Result { + let exchanged_at = clock.now(); + let device_code_grant = device_code_grant + .exchange(session, exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE oauth2_device_code_grant + SET exchanged_at = $1 + , oauth2_session_id = $2 + WHERE oauth2_device_code_grant_id = $3 + "#, + exchanged_at, + Uuid::from(session.id), + Uuid::from(device_code_grant.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(device_code_grant) + } +} diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index 4e90a6eb..4e0a5d40 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -18,12 +18,14 @@ mod access_token; mod authorization_grant; mod client; +mod device_code_grant; mod refresh_token; mod session; pub use self::{ access_token::PgOAuth2AccessTokenRepository, authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository, + device_code_grant::PgOAuth2DeviceCodeGrantRepository, refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository, }; @@ -33,7 +35,7 @@ mod tests { use mas_data_model::AuthorizationCode; use mas_storage::{ clock::MockClock, - oauth2::{OAuth2SessionFilter, OAuth2SessionRepository}, + oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository}, Clock, Pagination, Repository, }; use oauth2_types::{ @@ -690,4 +692,226 @@ mod tests { assert_eq!(list.edges[0], session11); assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1); } + + /// Test the [`OAuth2DeviceCodeGrantRepository`] implementation + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_device_code_grant_repository(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + // Provision a client + let client = repo + .oauth2_client() + .add( + &mut rng, + &clock, + vec!["https://example.com/redirect".parse().unwrap()], + None, + None, + vec![GrantType::AuthorizationCode], + Vec::new(), // TODO: contacts are not yet saved + // vec!["contact@example.com".to_owned()], + Some("Example".to_owned()), + Some("https://example.com/logo.png".parse().unwrap()), + Some("https://example.com/".parse().unwrap()), + Some("https://example.com/policy".parse().unwrap()), + Some("https://example.com/tos".parse().unwrap()), + Some("https://example.com/jwks.json".parse().unwrap()), + None, + None, + None, + None, + None, + Some("https://example.com/login".parse().unwrap()), + ) + .await + .unwrap(); + + // Provision a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + // Provision a browser session + let browser_session = repo + .browser_session() + .add(&mut rng, &clock, &user, None) + .await + .unwrap(); + + let user_code = "usercode"; + let device_code = "devicecode"; + let scope = Scope::from_iter([OPENID, EMAIL]); + + // Create a device code grant + let grant = repo + .oauth2_device_code_grant() + .add( + &mut rng, + &clock, + OAuth2DeviceCodeGrantParams { + client: &client, + scope: scope.clone(), + device_code: device_code.to_owned(), + user_code: user_code.to_owned(), + expires_in: Duration::minutes(5), + }, + ) + .await + .unwrap(); + + assert!(grant.is_pending()); + + // Check that we can find the grant by ID + let id = grant.id; + let lookup = repo.oauth2_device_code_grant().lookup(id).await.unwrap(); + assert_eq!(lookup.as_ref(), Some(&grant)); + + // Check that we can find the grant by device code + let lookup = repo + .oauth2_device_code_grant() + .find_by_device_code(device_code) + .await + .unwrap(); + assert_eq!(lookup.as_ref(), Some(&grant)); + + // Check that we can find the grant by user code + let lookup = repo + .oauth2_device_code_grant() + .find_by_user_code(user_code) + .await + .unwrap(); + assert_eq!(lookup.as_ref(), Some(&grant)); + + // Let's mark it as fulfilled + let grant = repo + .oauth2_device_code_grant() + .fulfill(&clock, grant, &browser_session) + .await + .unwrap(); + assert!(!grant.is_pending()); + assert!(grant.is_fulfilled()); + + // Check that we can't mark it as rejected now + let res = repo + .oauth2_device_code_grant() + .reject(&clock, grant, &browser_session) + .await; + assert!(res.is_err()); + + // Look it up again + let grant = repo + .oauth2_device_code_grant() + .lookup(id) + .await + .unwrap() + .unwrap(); + + // We can't mark it as fulfilled again + let res = repo + .oauth2_device_code_grant() + .fulfill(&clock, grant, &browser_session) + .await; + assert!(res.is_err()); + + // Look it up again + let grant = repo + .oauth2_device_code_grant() + .lookup(id) + .await + .unwrap() + .unwrap(); + + // Create an OAuth 2.0 session + let session = repo + .oauth2_session() + .add_from_browser_session(&mut rng, &clock, &client, &browser_session, scope.clone()) + .await + .unwrap(); + + // We can mark it as exchanged + let grant = repo + .oauth2_device_code_grant() + .exchange(&clock, grant, &session) + .await + .unwrap(); + assert!(!grant.is_pending()); + assert!(!grant.is_fulfilled()); + assert!(grant.is_exchanged()); + + // We can't mark it as exchanged again + let res = repo + .oauth2_device_code_grant() + .exchange(&clock, grant, &session) + .await; + assert!(res.is_err()); + + // Do a new grant to reject it + let grant = repo + .oauth2_device_code_grant() + .add( + &mut rng, + &clock, + OAuth2DeviceCodeGrantParams { + client: &client, + scope: scope.clone(), + device_code: "second_devicecode".to_owned(), + user_code: "second_usercode".to_owned(), + expires_in: Duration::minutes(5), + }, + ) + .await + .unwrap(); + + let id = grant.id; + + // We can mark it as rejected + let grant = repo + .oauth2_device_code_grant() + .reject(&clock, grant, &browser_session) + .await + .unwrap(); + assert!(!grant.is_pending()); + assert!(grant.is_rejected()); + + // We can't mark it as rejected again + let res = repo + .oauth2_device_code_grant() + .reject(&clock, grant, &browser_session) + .await; + assert!(res.is_err()); + + // Look it up again + let grant = repo + .oauth2_device_code_grant() + .lookup(id) + .await + .unwrap() + .unwrap(); + + // We can't mark it as fulfilled + let res = repo + .oauth2_device_code_grant() + .fulfill(&clock, grant, &browser_session) + .await; + assert!(res.is_err()); + + // Look it up again + let grant = repo + .oauth2_device_code_grant() + .lookup(id) + .await + .unwrap() + .unwrap(); + + // We can't mark it as exchanged + let res = repo + .oauth2_device_code_grant() + .exchange(&clock, grant, &session) + .await; + assert!(res.is_err()); + } } diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 5b4926d5..0ac94e10 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -24,7 +24,7 @@ use mas_storage::{ job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, - OAuth2RefreshTokenRepository, OAuth2SessionRepository, + OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, @@ -45,7 +45,8 @@ use crate::{ job::PgJobRepository, oauth2::{ PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, - PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository, + PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, }, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, @@ -220,6 +221,12 @@ where Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut())) } + fn oauth2_device_code_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut())) + } + fn compat_session<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage/src/oauth2/device_code_grant.rs b/crates/storage/src/oauth2/device_code_grant.rs new file mode 100644 index 00000000..9b432ef2 --- /dev/null +++ b/crates/storage/src/oauth2/device_code_grant.rs @@ -0,0 +1,228 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::Duration; +use mas_data_model::{BrowserSession, Client, DeviceCodeGrant, Session}; +use oauth2_types::scope::Scope; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{repository_impl, Clock}; + +/// Parameters used to create a new [`DeviceCodeGrant`] +pub struct OAuth2DeviceCodeGrantParams<'a> { + /// The client which requested the device code grant + pub client: &'a Client, + + /// The scope requested by the client + pub scope: Scope, + + /// The device code which the client uses to poll for authorisation + pub device_code: String, + + /// The user code which the client uses to display to the user + pub user_code: String, + + /// After how long the device code expires + pub expires_in: Duration, +} + +/// An [`OAuth2DeviceCodeGrantRepository`] helps interacting with +/// [`DeviceCodeGrant`] saved in the storage backend. +#[async_trait] +pub trait OAuth2DeviceCodeGrantRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Create a new device code grant + /// + /// Returns the newly created device code grant + /// + /// # Parameters + /// + /// * `rng`: A random number generator + /// * `clock`: The clock used to generate timestamps + /// * `params`: The parameters used to create the device code grant. See the + /// fields of [`DeviceCodeGrantParams`] + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + params: OAuth2DeviceCodeGrantParams<'_>, + ) -> Result; + + /// Lookup a device code grant by its ID + /// + /// Returns the device code grant if found, [`None`] otherwise + /// + /// # Parameters + /// + /// * `id`: The ID of the device code grant + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Lookup a device code grant by its device code + /// + /// Returns the device code grant if found, [`None`] otherwise + /// + /// # Parameters + /// + /// * `device_code`: The device code of the device code grant + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_device_code( + &mut self, + device_code: &str, + ) -> Result, Self::Error>; + + /// Lookup a device code grant by its user code + /// + /// Returns the device code grant if found, [`None`] otherwise + /// + /// # Parameters + /// + /// * `user_code`: The user code of the device code grant + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_user_code( + &mut self, + user_code: &str, + ) -> Result, Self::Error>; + + /// Mark the device code grant as fulfilled with the given browser session + /// + /// Returns the updated device code grant + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `device_code_grant`: The device code grant to fulfill + /// * `browser_session`: The browser session which was used to fulfill the + /// device code grant + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails or if the + /// device code grant is not in the [`Pending`] state + /// + /// [`Pending`]: DeviceCodeGrantState::Pending + async fn fulfill( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + browser_session: &BrowserSession, + ) -> Result; + + /// Mark the device code grant as rejected with the given browser session + /// + /// Returns the updated device code grant + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `device_code_grant`: The device code grant to reject + /// * `browser_session`: The browser session which was used to reject the + /// device code grant + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails or if the + /// device code grant is not in the [`Pending`] state + /// + /// [`Pending`]: DeviceCodeGrantState::Pending + async fn reject( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + browser_session: &BrowserSession, + ) -> Result; + + /// Mark the device code grant as exchanged and store the session which was + /// created + /// + /// Returns the updated device code grant + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `device_code_grant`: The device code grant to exchange + /// * `session`: The OAuth 2.0 session which was created + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails or if the + /// device code grant is not in the [`Fulfilled`] state + /// + /// [`Fulfilled`]: DeviceCodeGrantState::Fulfilled + async fn exchange( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + session: &Session, + ) -> Result; +} + +repository_impl!(OAuth2DeviceCodeGrantRepository: + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + params: OAuth2DeviceCodeGrantParams<'_>, + ) -> Result; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_device_code( + &mut self, + device_code: &str, + ) -> Result, Self::Error>; + + async fn find_by_user_code( + &mut self, + user_code: &str, + ) -> Result, Self::Error>; + + async fn fulfill( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + browser_session: &BrowserSession, + ) -> Result; + + async fn reject( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + browser_session: &BrowserSession, + ) -> Result; + + async fn exchange( + &mut self, + clock: &dyn Clock, + device_code_grant: DeviceCodeGrant, + session: &Session, + ) -> Result; +); diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 3ed528fd..eae6e454 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -17,6 +17,7 @@ mod access_token; mod authorization_grant; mod client; +mod device_code_grant; mod refresh_token; mod session; @@ -24,6 +25,7 @@ pub use self::{ access_token::OAuth2AccessTokenRepository, authorization_grant::OAuth2AuthorizationGrantRepository, client::OAuth2ClientRepository, + device_code_grant::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository}, refresh_token::OAuth2RefreshTokenRepository, session::{OAuth2SessionFilter, OAuth2SessionRepository}, }; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 7f71ab96..5328178b 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -24,7 +24,7 @@ use crate::{ job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, - OAuth2RefreshTokenRepository, OAuth2SessionRepository, + OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, @@ -178,6 +178,11 @@ pub trait RepositoryAccess: Send { &'c mut self, ) -> Box + 'c>; + /// Get an [`OAuth2DeviceCodeGrantRepository`] + fn oauth2_device_code_grant<'c>( + &'c mut self, + ) -> Box + 'c>; + /// Get a [`CompatSessionRepository`] fn compat_session<'c>( &'c mut self, @@ -217,7 +222,8 @@ mod impls { job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, - OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, + OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, + OAuth2SessionRepository, }, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, @@ -360,6 +366,15 @@ mod impls { )) } + fn oauth2_device_code_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_device_code_grant(), + &mut self.mapper, + )) + } + fn compat_session<'c>( &'c mut self, ) -> Box + 'c> { @@ -472,6 +487,12 @@ mod impls { (**self).oauth2_refresh_token() } + fn oauth2_device_code_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_device_code_grant() + } + fn compat_session<'c>( &'c mut self, ) -> Box + 'c> {