diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 49366e17..ae06c4c1 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -26,7 +26,6 @@ clippy::unused_async )] -use anyhow::Context; use async_graphql::EmptySubscription; use mas_data_model::{BrowserSession, Session, User}; use ulid::Ulid; diff --git a/crates/graphql/src/mutations/browser_session.rs b/crates/graphql/src/mutations/browser_session.rs index 39a2f0ff..3e0337a9 100644 --- a/crates/graphql/src/mutations/browser_session.rs +++ b/crates/graphql/src/mutations/browser_session.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context as _; use async_graphql::{Context, Enum, InputObject, Object, ID}; use mas_storage::RepositoryAccess; @@ -80,18 +79,17 @@ impl BrowserSessionMutations { NodeType::BrowserSession.extract_ulid(&input.browser_session_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - let mut repo = state.repository().await?; let clock = state.clock(); let session = repo.browser_session().lookup(browser_session_id).await?; + let Some(session) = session else { return Ok(EndBrowserSessionPayload::NotFound); }; - if session.user.id != user.id { - return Err(async_graphql::Error::new("Unauthorized")); + if !requester.is_owner_or_admin(&session) { + return Ok(EndBrowserSessionPayload::NotFound); } let session = repo.browser_session().finish(&clock, session).await?; diff --git a/crates/graphql/src/mutations/compat_session.rs b/crates/graphql/src/mutations/compat_session.rs index ae065a4a..75dfef5f 100644 --- a/crates/graphql/src/mutations/compat_session.rs +++ b/crates/graphql/src/mutations/compat_session.rs @@ -84,8 +84,6 @@ impl CompatSessionMutations { let compat_session_id = NodeType::CompatSession.extract_ulid(&input.compat_session_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - let mut repo = state.repository().await?; let clock = state.clock(); @@ -94,13 +92,19 @@ impl CompatSessionMutations { return Ok(EndCompatSessionPayload::NotFound); }; - if session.user_id != user.id { - return Err(async_graphql::Error::new("Unauthorized")); + if !requester.is_owner_or_admin(&session) { + return Ok(EndCompatSessionPayload::NotFound); } + let user = repo + .user() + .lookup(session.user_id) + .await? + .context("Could not load user")?; + // Schedule a job to delete the device. repo.job() - .schedule_job(DeleteDeviceJob::new(user, &session.device)) + .schedule_job(DeleteDeviceJob::new(&user, &session.device)) .await?; let session = repo.compat_session().finish(&clock, session).await?; diff --git a/crates/graphql/src/mutations/matrix.rs b/crates/graphql/src/mutations/matrix.rs index 6d7ec2cd..b16220bb 100644 --- a/crates/graphql/src/mutations/matrix.rs +++ b/crates/graphql/src/mutations/matrix.rs @@ -18,6 +18,7 @@ use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use crate::{ model::{NodeType, User}, state::ContextExt, + UserId, }; #[derive(Default)] @@ -82,12 +83,18 @@ impl MatrixMutations { let id = NodeType::User.extract_ulid(&input.user_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - - if user.id != id { + if !requester.is_owner_or_admin(&UserId(id)) { return Err(async_graphql::Error::new("Unauthorized")); } + let mut repo = state.repository().await?; + let user = repo + .user() + .lookup(id) + .await? + .context("Failed to lookup user")?; + repo.cancel().await?; + let conn = state.homeserver_connection(); let mxid = conn.mxid(&user.username); diff --git a/crates/graphql/src/mutations/oauth2_session.rs b/crates/graphql/src/mutations/oauth2_session.rs index 3a7144ae..c3874e5e 100644 --- a/crates/graphql/src/mutations/oauth2_session.rs +++ b/crates/graphql/src/mutations/oauth2_session.rs @@ -84,8 +84,6 @@ impl OAuth2SessionMutations { let oauth2_session_id = NodeType::OAuth2Session.extract_ulid(&input.oauth2_session_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - let mut repo = state.repository().await?; let clock = state.clock(); @@ -94,14 +92,15 @@ impl OAuth2SessionMutations { return Ok(EndOAuth2SessionPayload::NotFound); }; + // XXX: again, the user_id should be directly stored in the session. let user_session = repo .browser_session() .lookup(session.user_session_id) .await? - .context("Browser session not found")?; + .context("Could not load user session")?; - if user_session.user.id != user.id { - return Err(async_graphql::Error::new("Unauthorized")); + if !requester.is_owner_or_admin(&user_session) { + return Ok(EndOAuth2SessionPayload::NotFound); } // Scan the scopes of the session to find if there is any device that should be diff --git a/crates/graphql/src/mutations/user_email.rs b/crates/graphql/src/mutations/user_email.rs index 2aeca2a3..2ea7052c 100644 --- a/crates/graphql/src/mutations/user_email.rs +++ b/crates/graphql/src/mutations/user_email.rs @@ -23,6 +23,7 @@ use mas_storage::{ use crate::{ model::{NodeType, User, UserEmail}, state::ContextExt, + UserId, }; #[derive(Default)] @@ -361,14 +362,18 @@ impl UserEmailMutations { let id = NodeType::User.extract_ulid(&input.user_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - - if user.id != id { + if !requester.is_owner_or_admin(&UserId(id)) { return Err(async_graphql::Error::new("Unauthorized")); } let mut repo = state.repository().await?; + let user = repo + .user() + .lookup(id) + .await? + .context("Failed to load user")?; + // XXX: this logic should be extracted somewhere else, since most of it is // duplicated in mas_handlers @@ -378,7 +383,7 @@ impl UserEmailMutations { } // Find an existing email address - let existing_user_email = repo.user_email().find(user, &input.email).await?; + let existing_user_email = repo.user_email().find(&user, &input.email).await?; let (added, user_email) = if let Some(user_email) = existing_user_email { (false, user_email) } else { @@ -387,7 +392,7 @@ impl UserEmailMutations { let user_email = repo .user_email() - .add(&mut rng, &clock, user, input.email) + .add(&mut rng, &clock, &user, input.email) .await?; (true, user_email) @@ -419,7 +424,6 @@ impl UserEmailMutations { let state = ctx.state(); let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; let mut repo = state.repository().await?; @@ -429,8 +433,8 @@ impl UserEmailMutations { .await? .context("User email not found")?; - if user_email.user_id != user.id { - return Err(async_graphql::Error::new("Unauthorized")); + if !requester.is_owner_or_admin(&user_email) { + return Err(async_graphql::Error::new("User email not found")); } // Schedule a job to verify the email address if needed @@ -461,8 +465,6 @@ impl UserEmailMutations { let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - let clock = state.clock(); let mut repo = state.repository().await?; @@ -472,8 +474,8 @@ impl UserEmailMutations { .await? .context("User email not found")?; - if user_email.user_id != user.id { - return Err(async_graphql::Error::new("Unauthorized")); + if !requester.is_owner_or_admin(&user_email) { + return Err(async_graphql::Error::new("User email not found")); } if user_email.confirmed_at.is_some() { @@ -500,6 +502,12 @@ impl UserEmailMutations { .consume_verification_code(&clock, verification) .await?; + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("Failed to load user")?; + // XXX: is this the right place to do this? if user.primary_user_email_id.is_none() { repo.user_email().set_as_primary(&user_email).await?; @@ -510,7 +518,9 @@ impl UserEmailMutations { .mark_as_verified(&clock, user_email) .await?; - repo.job().schedule_job(ProvisionUserJob::new(user)).await?; + repo.job() + .schedule_job(ProvisionUserJob::new(&user)) + .await?; repo.save().await?; @@ -527,8 +537,6 @@ impl UserEmailMutations { let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - let mut repo = state.repository().await?; let user_email = repo.user_email().lookup(user_email_id).await?; @@ -536,10 +544,16 @@ impl UserEmailMutations { return Ok(RemoveEmailPayload::NotFound); }; - if user_email.user_id != user.id { - return Err(async_graphql::Error::new("Unauthorized")); + if !requester.is_owner_or_admin(&user_email) { + return Ok(RemoveEmailPayload::NotFound); } + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("Failed to load user")?; + if user.primary_user_email_id == Some(user_email.id) { // Prevent removing the primary email address return Ok(RemoveEmailPayload::Primary(user_email)); @@ -547,6 +561,11 @@ impl UserEmailMutations { repo.user_email().remove(user_email.clone()).await?; + // Schedule a job to update the user + repo.job() + .schedule_job(ProvisionUserJob::new(&user)) + .await?; + repo.save().await?; Ok(RemoveEmailPayload::Removed(user_email)) @@ -562,8 +581,6 @@ impl UserEmailMutations { let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?; let requester = ctx.requester(); - let user = requester.user().context("Unauthorized")?; - let mut repo = state.repository().await?; let user_email = repo.user_email().lookup(user_email_id).await?; @@ -571,7 +588,7 @@ impl UserEmailMutations { return Ok(SetPrimaryEmailPayload::NotFound); }; - if user_email.user_id != user.id { + if !requester.is_owner_or_admin(&user_email) { return Err(async_graphql::Error::new("Unauthorized")); } @@ -581,10 +598,15 @@ impl UserEmailMutations { repo.user_email().set_as_primary(&user_email).await?; + // The user primary email should already be up to date + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("Failed to load user")?; + repo.save().await?; - let mut user = user.clone(); - user.primary_user_email_id = Some(user_email.id); Ok(SetPrimaryEmailPayload::Set(user)) } }