diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 6699b95f..162e1d0f 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -27,6 +27,7 @@ )] use async_graphql::EmptySubscription; +use mas_data_model::{BrowserSession, User}; mod model; mod mutations; @@ -49,3 +50,42 @@ pub fn schema_builder() -> SchemaBuilder { .register_output_type::() .register_output_type::() } + +/// The identity of the requester. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum Requester { + /// The requester presented no authentication information. + #[default] + Anonymous, + + /// The requester is a browser session, stored in a cookie. + BrowserSession(BrowserSession), +} + +impl Requester { + fn browser_session(&self) -> Option<&BrowserSession> { + match self { + Self::BrowserSession(session) => Some(session), + Self::Anonymous => None, + } + } + + fn user(&self) -> Option<&User> { + self.browser_session().map(|session| &session.user) + } +} + +impl From for Requester { + fn from(session: BrowserSession) -> Self { + Self::BrowserSession(session) + } +} + +impl From> for Requester +where + T: Into, +{ + fn from(session: Option) -> Self { + session.map(Into::into).unwrap_or_default() + } +} diff --git a/crates/graphql/src/mutations.rs b/crates/graphql/src/mutations.rs index 24041502..7d610d6b 100644 --- a/crates/graphql/src/mutations.rs +++ b/crates/graphql/src/mutations.rs @@ -50,13 +50,11 @@ impl RootMutations { ) -> Result { let state = ctx.state(); let id = NodeType::User.extract_ulid(&user_id)?; - let session = ctx.session(); + let requester = ctx.requester(); - let Some(session) = session else { - return Err(async_graphql::Error::new("Unauthorized")); - }; + let user = requester.user().context("Unauthorized")?; - if session.user.id != id { + if user.id != id { return Err(async_graphql::Error::new("Unauthorized")); } @@ -65,16 +63,14 @@ impl RootMutations { // XXX: this logic should be extracted somewhere else, since most of it is // duplicated in mas_handlers // Find an existing email address - let existing_user_email = repo.user_email().find(&session.user, &email).await?; + let existing_user_email = repo.user_email().find(user, &email).await?; let user_email = if let Some(user_email) = existing_user_email { user_email } else { let clock = state.clock(); let mut rng = state.rng(); - repo.user_email() - .add(&mut rng, &clock, &session.user, email) - .await? + repo.user_email().add(&mut rng, &clock, user, email).await? }; // Schedule a job to verify the email address if needed @@ -98,11 +94,8 @@ impl RootMutations { ) -> Result { let state = ctx.state(); let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?; - let session = ctx.session(); - - let Some(session) = session else { - return Err(async_graphql::Error::new("Unauthorized")); - }; + let requester = ctx.requester(); + let user = requester.user().context("Unauthorized")?; let mut repo = state.repository().await?; @@ -112,7 +105,7 @@ impl RootMutations { .await? .context("User email not found")?; - if user_email.user_id != session.user.id { + if user_email.user_id != user.id { return Err(async_graphql::Error::new("Unauthorized")); } @@ -138,11 +131,9 @@ impl RootMutations { ) -> Result { let state = ctx.state(); let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?; - let session = ctx.session(); + let requester = ctx.requester(); - let Some(session) = session else { - return Err(async_graphql::Error::new("Unauthorized")); - }; + let user = requester.user().context("Unauthorized")?; let clock = state.clock(); let mut repo = state.repository().await?; @@ -153,7 +144,7 @@ impl RootMutations { .await? .context("User email not found")?; - if user_email.user_id != session.user.id { + if user_email.user_id != user.id { return Err(async_graphql::Error::new("Unauthorized")); } @@ -173,7 +164,7 @@ impl RootMutations { .await?; // XXX: is this the right place to do this? - if session.user.primary_user_email_id.is_none() { + if user.primary_user_email_id.is_none() { repo.user_email().set_as_primary(&user_email).await?; } @@ -182,9 +173,7 @@ impl RootMutations { .mark_as_verified(&clock, user_email) .await?; - repo.job() - .schedule_job(ProvisionUserJob::new(&session.user)) - .await?; + repo.job().schedule_job(ProvisionUserJob::new(user)).await?; repo.save().await?; diff --git a/crates/graphql/src/query.rs b/crates/graphql/src/query.rs index 5b225920..bc3dcf6e 100644 --- a/crates/graphql/src/query.rs +++ b/crates/graphql/src/query.rs @@ -46,14 +46,17 @@ impl RootQuery { &self, ctx: &Context<'_>, ) -> Result, async_graphql::Error> { - let session = ctx.session().cloned(); - Ok(session.map(BrowserSession::from)) + let requester = ctx.requester(); + Ok(requester + .browser_session() + .cloned() + .map(BrowserSession::from)) } /// Get the current logged in user async fn current_user(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { - let session = ctx.session().cloned(); - Ok(session.map(User::from)) + let requester = ctx.requester(); + Ok(requester.user().cloned().map(User::from)) } /// Fetch an OAuth 2.0 client by its ID. @@ -75,14 +78,12 @@ impl RootQuery { /// Fetch a user by its ID. async fn user(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { let id = NodeType::User.extract_ulid(&id)?; + let requester = ctx.requester(); - let session = ctx.session().cloned(); - - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; + let Some(current_user) = requester.user() else { return Ok(None) }; if current_user.id == id { - Ok(Some(User(current_user))) + Ok(Some(User(current_user.clone()))) } else { Ok(None) } @@ -96,13 +97,11 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let state = ctx.state(); let id = NodeType::BrowserSession.extract_ulid(&id)?; + let requester = ctx.requester(); - let session = ctx.session().cloned(); + let Some(current_user) = requester.user() else { return Ok(None) }; let mut repo = state.repository().await?; - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; - let browser_session = repo.browser_session().lookup(id).await?; repo.cancel().await?; @@ -126,13 +125,11 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let state = ctx.state(); let id = NodeType::UserEmail.extract_ulid(&id)?; + let requester = ctx.requester(); - let session = ctx.session().cloned(); + let Some(current_user) = requester.user() else { return Ok(None) }; let mut repo = state.repository().await?; - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; - let user_email = repo .user_email() .lookup(id) @@ -152,13 +149,11 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let state = ctx.state(); let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; + let requester = ctx.requester(); - let session = ctx.session().cloned(); + let Some(current_user) = requester.user() else { return Ok(None) }; let mut repo = state.repository().await?; - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; - let link = repo.upstream_oauth_link().lookup(id).await?; // Ensure that the link belongs to the current user diff --git a/crates/graphql/src/state.rs b/crates/graphql/src/state.rs index e9775fbb..fb35d736 100644 --- a/crates/graphql/src/state.rs +++ b/crates/graphql/src/state.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use mas_data_model::BrowserSession; use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; +use crate::Requester; + #[async_trait::async_trait] pub trait State { async fn repository(&self) -> Result; @@ -27,15 +28,15 @@ pub type BoxState = Box; pub trait ContextExt { fn state(&self) -> &BoxState; - fn session(&self) -> Option<&BrowserSession>; + fn requester(&self) -> &Requester; } impl ContextExt for async_graphql::Context<'_> { fn state(&self) -> &BoxState { - self.data_unchecked::() + self.data_unchecked() } - fn session(&self) -> Option<&BrowserSession> { - self.data_opt() + fn requester(&self) -> &Requester { + self.data_unchecked() } } diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index 1f76f157..e32a9d12 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -27,7 +27,7 @@ use futures_util::TryStreamExt; use headers::{ContentType, HeaderValue}; use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; -use mas_graphql::Schema; +use mas_graphql::{Requester, Schema}; use mas_keystore::Encrypter; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, RepositoryError, SystemClock}; use mas_storage_pg::PgRepository; @@ -100,23 +100,21 @@ pub async fn post( content_type: Option>, body: BodyStream, ) -> Result { - let content_type = content_type.map(|TypedHeader(h)| h.to_string()); - let (session_info, _cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; + let requester = Requester::from(maybe_session); repo.cancel().await?; - let mut request = async_graphql::http::receive_body( + let content_type = content_type.map(|TypedHeader(h)| h.to_string()); + + let request = async_graphql::http::receive_body( content_type, body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) .into_async_read(), MultipartOptions::default(), ) - .await?; // XXX: this should probably return another error response? - - if let Some(session) = maybe_session { - request = request.data(session); - } + .await? + .data(requester); // XXX: this should probably return another error response? let span = span_for_graphql_request(&request); let response = schema.execute(request).instrument(span).await; @@ -140,13 +138,11 @@ pub async fn get( ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; + let requester = Requester::from(maybe_session); repo.cancel().await?; - let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; - - if let Some(session) = maybe_session { - request = request.data(session); - } + let request = + async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester); let span = span_for_graphql_request(&request); let response = schema.execute(request).instrument(span).await; diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 4d20deb1..1a5ad0bf 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -108,7 +108,7 @@ impl TestState { let policy_factory = Arc::new(policy_factory); - let graphql_schema = graphql_schema(); + let graphql_schema = graphql_schema(&pool); let http_client_factory = HttpClientFactory::new(10);