diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index fbc7dbfe..49366e17 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -67,6 +67,49 @@ pub enum Requester { OAuth2Session(Session, User), } +trait OwnerId { + fn owner_id(&self) -> Option; +} + +impl OwnerId for User { + fn owner_id(&self) -> Option { + Some(self.id) + } +} + +impl OwnerId for BrowserSession { + fn owner_id(&self) -> Option { + Some(self.user.id) + } +} + +impl OwnerId for mas_data_model::UserEmail { + fn owner_id(&self) -> Option { + Some(self.user_id) + } +} + +impl OwnerId for mas_data_model::CompatSession { + fn owner_id(&self) -> Option { + Some(self.user_id) + } +} + +impl OwnerId for mas_data_model::UpstreamOAuthLink { + fn owner_id(&self) -> Option { + self.user_id + } +} + +/// A dumb wrapper around a `Ulid` to implement `OwnerId` for it. +pub struct UserId(Ulid); + +impl OwnerId for UserId { + fn owner_id(&self) -> Option { + Some(self.0) + } +} + impl Requester { fn browser_session(&self) -> Option<&BrowserSession> { match self { @@ -83,19 +126,23 @@ impl Requester { } } - fn ensure_owner_or_admin(&self, user_id: Ulid) -> Result<(), async_graphql::Error> { + /// Returns true if the requester can access the resource. + fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool { // If the requester is an admin, they can do anything. if self.is_admin() { - return Ok(()); + return true; } - // Else check that they are the owner. - let user = self.user().context("Unauthorized")?; - if user.id == user_id { - Ok(()) - } else { - Err(async_graphql::Error::new("Unauthorized")) - } + // Otherwise, they must be the owner of the resource. + let Some(owner_id) = resource.owner_id() else { + return false; + }; + + let Some(user) = self.user() else { + return false; + }; + + user.id == owner_id } fn is_admin(&self) -> bool { diff --git a/crates/graphql/src/query/mod.rs b/crates/graphql/src/query/mod.rs index 0a645bc2..12d6f0a6 100644 --- a/crates/graphql/src/query/mod.rs +++ b/crates/graphql/src/query/mod.rs @@ -17,6 +17,7 @@ use async_graphql::{Context, MergedObject, Object, ID}; use crate::{ model::{Anonymous, BrowserSession, Node, NodeType, OAuth2Client, User, UserEmail}, state::ContextExt, + UserId, }; mod upstream_oauth; @@ -80,24 +81,21 @@ impl BaseQuery { /// Fetch a user by its ID. async fn user(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { let id = NodeType::User.extract_ulid(&id)?; + let requester = ctx.requester(); - - let Some(current_user) = requester.user() else { + if !requester.is_owner_or_admin(&UserId(id)) { return Ok(None); - }; - - if current_user.id == id { - Ok(Some(User(current_user.clone()))) - } else if requester.is_admin() { - // An admin can fetch any user, not just themselves - let state = ctx.state(); - let mut repo = state.repository().await?; - let user = repo.user().lookup(id).await?; - repo.cancel().await?; - Ok(user.map(User)) - } else { - Ok(None) } + + // We could avoid the database lookup if the requester is the user we're looking + // for but that would make the code more complex and we're not very + // concerned about performance yet + let state = ctx.state(); + let mut repo = state.repository().await?; + let user = repo.user().lookup(id).await?; + repo.cancel().await?; + + Ok(user.map(User)) } /// Fetch a browser session by its ID. @@ -110,24 +108,19 @@ impl BaseQuery { let id = NodeType::BrowserSession.extract_ulid(&id)?; let requester = ctx.requester(); - let Some(current_user) = requester.user() else { - return Ok(None); - }; let mut repo = state.repository().await?; - let browser_session = repo.browser_session().lookup(id).await?; - repo.cancel().await?; - let ret = browser_session.and_then(|browser_session| { - if browser_session.user.id == current_user.id || requester.is_admin() { - Some(BrowserSession(browser_session)) - } else { - None - } - }); + let Some(browser_session) = browser_session else { + return Ok(None); + }; - Ok(ret) + if !requester.is_owner_or_admin(&browser_session) { + return Ok(None); + } + + Ok(Some(BrowserSession(browser_session))) } /// Fetch a user email by its ID. @@ -140,20 +133,19 @@ impl BaseQuery { let id = NodeType::UserEmail.extract_ulid(&id)?; let requester = ctx.requester(); - let Some(current_user) = requester.user() else { - return Ok(None); - }; let mut repo = state.repository().await?; - - let user_email = repo - .user_email() - .lookup(id) - .await? - .filter(|e| e.user_id == current_user.id || requester.is_admin()); - + let user_email = repo.user_email().lookup(id).await?; repo.cancel().await?; - Ok(user_email.map(UserEmail)) + let Some(user_email) = user_email else { + return Ok(None); + }; + + if !requester.is_owner_or_admin(&user_email) { + return Ok(None); + } + + Ok(Some(UserEmail(user_email))) } /// Fetches an object given its ID. diff --git a/crates/graphql/src/query/upstream_oauth.rs b/crates/graphql/src/query/upstream_oauth.rs index 43e48238..520be2fd 100644 --- a/crates/graphql/src/query/upstream_oauth.rs +++ b/crates/graphql/src/query/upstream_oauth.rs @@ -41,18 +41,19 @@ impl UpstreamOAuthQuery { let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let requester = ctx.requester(); - let Some(current_user) = requester.user() else { + let mut repo = state.repository().await?; + let link = repo.upstream_oauth_link().lookup(id).await?; + repo.cancel().await?; + + let Some(link) = link else { return Ok(None); }; - let mut repo = state.repository().await?; - let link = repo.upstream_oauth_link().lookup(id).await?; + if !requester.is_owner_or_admin(&link) { + return Ok(None); + } - // Ensure that the link belongs to the current user - let link = - link.filter(|link| link.user_id == Some(current_user.id) || requester.is_admin()); - - Ok(link.map(UpstreamOAuth2Link::new)) + Ok(Some(UpstreamOAuth2Link::new(link))) } /// Fetch an upstream OAuth 2.0 provider by its ID. @@ -68,7 +69,11 @@ impl UpstreamOAuthQuery { let provider = repo.upstream_oauth_provider().lookup(id).await?; repo.cancel().await?; - Ok(provider.map(UpstreamOAuth2Provider::new)) + let Some(provider) = provider else { + return Ok(None); + }; + + Ok(Some(UpstreamOAuth2Provider::new(provider))) } /// Get a list of upstream OAuth 2.0 providers.