1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

Merge the mas_graphql crate into the mas_handlers crate (#2783)

This commit is contained in:
reivilibre
2024-05-17 17:22:34 +01:00
committed by GitHub
parent 37a10aea96
commit 206d45bb31
35 changed files with 199 additions and 278 deletions

View File

@@ -0,0 +1,27 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(
clippy::all,
clippy::str_to_string,
rustdoc::broken_intra_doc_links,
clippy::future_not_send
)]
#![warn(clippy::pedantic)]
fn main() {
let schema = mas_handlers::graphql_schema_builder().finish();
println!("{}", schema.sdl());
}

View File

@@ -12,15 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::module_name_repetitions)]
use std::sync::Arc;
use async_graphql::{
extensions::Tracing,
http::{playground_source, GraphQLPlaygroundConfig, MultipartOptions},
EmptySubscription,
};
use axum::{
async_trait,
extract::{BodyStream, RawQuery, State},
extract::{BodyStream, RawQuery, State as AxumState},
http::StatusCode,
response::{Html, IntoResponse, Response},
Json, TypedHeader,
@@ -31,8 +34,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{
cookies::CookieJar, sentry::SentryEventID, FancyError, SessionInfo, SessionInfoExt,
};
use mas_data_model::{SiteConfig, User};
use mas_graphql::{Requester, Schema};
use mas_data_model::{BrowserSession, Session, SiteConfig, User};
use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_storage::{
@@ -44,7 +46,19 @@ use rand::{thread_rng, SeedableRng};
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
use tracing::{info_span, Instrument};
use ulid::Ulid;
mod model;
mod mutations;
mod query;
mod state;
pub use self::state::{BoxState, State};
use self::{
model::{CreationEvent, Node},
mutations::Mutation,
query::Query,
};
use crate::{impl_from_error_for_route, BoundActivityTracker};
#[cfg(test)]
@@ -58,7 +72,7 @@ struct GraphQLState {
}
#[async_trait]
impl mas_graphql::State for GraphQLState {
impl state::State for GraphQLState {
async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
let repo = PgRepository::from_pool(&self.pool)
.await
@@ -106,12 +120,9 @@ pub fn schema(
homeserver_connection: Arc::new(homeserver_connection),
site_config,
};
let state: mas_graphql::BoxState = Box::new(state);
let state: BoxState = Box::new(state);
mas_graphql::schema_builder()
.extension(Tracing)
.data(state)
.finish()
schema_builder().extension(Tracing).data(state).finish()
}
fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
@@ -261,7 +272,7 @@ async fn get_requester(
}
pub async fn post(
State(schema): State<Schema>,
AxumState(schema): AxumState<Schema>,
clock: BoxClock,
repo: BoxRepository,
activity_tracker: BoundActivityTracker,
@@ -302,7 +313,7 @@ pub async fn post(
}
pub async fn get(
State(schema): State<Schema>,
AxumState(schema): AxumState<Schema>,
clock: BoxClock,
repo: BoxRepository,
activity_tracker: BoundActivityTracker,
@@ -338,3 +349,145 @@ pub async fn playground() -> impl IntoResponse {
GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
))
}
pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
#[must_use]
pub fn schema_builder() -> SchemaBuilder {
async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
.register_output_type::<Node>()
.register_output_type::<CreationEvent>()
}
/// The identity of the requester.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum Requester {
/// The requester presented no authentication information.
#[default]
Anonymous,
/// The requester is a browser session, stored in a cookie.
BrowserSession(Box<BrowserSession>),
/// The requester is a `OAuth2` session, with an access token.
OAuth2Session(Box<(Session, Option<User>)>),
}
trait OwnerId {
fn owner_id(&self) -> Option<Ulid>;
}
impl OwnerId for User {
fn owner_id(&self) -> Option<Ulid> {
Some(self.id)
}
}
impl OwnerId for BrowserSession {
fn owner_id(&self) -> Option<Ulid> {
Some(self.user.id)
}
}
impl OwnerId for mas_data_model::UserEmail {
fn owner_id(&self) -> Option<Ulid> {
Some(self.user_id)
}
}
impl OwnerId for Session {
fn owner_id(&self) -> Option<Ulid> {
self.user_id
}
}
impl OwnerId for mas_data_model::CompatSession {
fn owner_id(&self) -> Option<Ulid> {
Some(self.user_id)
}
}
impl OwnerId for mas_data_model::UpstreamOAuthLink {
fn owner_id(&self) -> Option<Ulid> {
self.user_id
}
}
/// A dumb wrapper around a `Ulid` to implement `OwnerId` for it.
pub struct UserId(Ulid);
impl OwnerId for UserId {
fn owner_id(&self) -> Option<Ulid> {
Some(self.0)
}
}
impl Requester {
fn browser_session(&self) -> Option<&BrowserSession> {
match self {
Self::BrowserSession(session) => Some(session),
Self::OAuth2Session(_) | Self::Anonymous => None,
}
}
fn user(&self) -> Option<&User> {
match self {
Self::BrowserSession(session) => Some(&session.user),
Self::OAuth2Session(tuple) => tuple.1.as_ref(),
Self::Anonymous => None,
}
}
fn oauth2_session(&self) -> Option<&Session> {
match self {
Self::OAuth2Session(tuple) => Some(&tuple.0),
Self::BrowserSession(_) | Self::Anonymous => None,
}
}
/// Returns true if the requester can access the resource.
fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
// If the requester is an admin, they can do anything.
if self.is_admin() {
return true;
}
// Otherwise, they must be the owner of the resource.
let Some(owner_id) = resource.owner_id() else {
return false;
};
let Some(user) = self.user() else {
return false;
};
user.id == owner_id
}
fn is_admin(&self) -> bool {
match self {
Self::OAuth2Session(tuple) => {
// TODO: is this the right scope?
// This has to be in sync with the policy
tuple.0.scope.contains("urn:mas:admin")
}
Self::BrowserSession(_) | Self::Anonymous => false,
}
}
}
impl From<BrowserSession> for Requester {
fn from(session: BrowserSession) -> Self {
Self::BrowserSession(Box::new(session))
}
}
impl<T> From<Option<T>> for Requester
where
T: Into<Requester>,
{
fn from(session: Option<T>) -> Self {
session.map(Into::into).unwrap_or_default()
}
}

View File

@@ -0,0 +1,213 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, Object, ID,
};
use chrono::{DateTime, Utc};
use mas_data_model::Device;
use mas_storage::{
app_session::AppSessionFilter, user::BrowserSessionRepository, Pagination, RepositoryAccess,
};
use super::{
AppSession, CompatSession, Cursor, NodeCursor, NodeType, OAuth2Session, PreloadedTotalCount,
SessionState, User, UserAgent,
};
use crate::graphql::state::ContextExt;
/// A browser session represents a logged in user in a browser.
#[derive(Description)]
pub struct BrowserSession(pub mas_data_model::BrowserSession);
impl From<mas_data_model::BrowserSession> for BrowserSession {
fn from(v: mas_data_model::BrowserSession) -> Self {
Self(v)
}
}
#[Object(use_type_description)]
impl BrowserSession {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::BrowserSession.id(self.0.id)
}
/// The user logged in this session.
async fn user(&self) -> User {
User(self.0.user.clone())
}
/// The most recent authentication of this session.
async fn last_authentication(
&self,
ctx: &Context<'_>,
) -> Result<Option<Authentication>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let last_authentication = repo
.browser_session()
.get_last_authentication(&self.0)
.await?;
repo.cancel().await?;
Ok(last_authentication.map(Authentication))
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.0.created_at
}
/// When the session was finished.
pub async fn finished_at(&self) -> Option<DateTime<Utc>> {
self.0.finished_at
}
/// The state of the session.
pub async fn state(&self) -> SessionState {
if self.0.finished_at.is_some() {
SessionState::Finished
} else {
SessionState::Active
}
}
/// The user-agent with which the session was created.
pub async fn user_agent(&self) -> Option<UserAgent> {
self.0.user_agent.clone().map(UserAgent::from)
}
/// The last IP address used by the session.
pub async fn last_active_ip(&self) -> Option<String> {
self.0.last_active_ip.map(|ip| ip.to_string())
}
/// The last time the session was active.
pub async fn last_active_at(&self) -> Option<DateTime<Utc>> {
self.0.last_active_at
}
/// Get the list of both compat and OAuth 2.0 sessions started by this
/// browser session, chronologically sorted
#[allow(clippy::too_many_arguments)]
async fn app_sessions(
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only sessions in the given state.")]
state_param: Option<SessionState>,
#[graphql(name = "device", desc = "List only sessions for the given device.")]
device_param: Option<String>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, AppSession, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_types(&[NodeType::OAuth2Session, NodeType::CompatSession])
})
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_types(&[NodeType::OAuth2Session, NodeType::CompatSession])
})
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let device_param = device_param.map(Device::try_from).transpose()?;
let filter = AppSessionFilter::new().for_browser_session(&self.0);
let filter = match state_param {
Some(SessionState::Active) => filter.active_only(),
Some(SessionState::Finished) => filter.finished_only(),
None => filter,
};
let filter = match device_param.as_ref() {
Some(device) => filter.for_device(device),
None => filter,
};
let page = repo.app_session().list(filter, pagination).await?;
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.app_session().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection
.edges
.extend(page.edges.into_iter().map(|s| match s {
mas_storage::app_session::AppSession::Compat(session) => Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)),
AppSession::CompatSession(Box::new(CompatSession::new(*session))),
),
mas_storage::app_session::AppSession::OAuth2(session) => Edge::new(
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, session.id)),
AppSession::OAuth2Session(Box::new(OAuth2Session(*session))),
),
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
}
/// An authentication records when a user enter their credential in a browser
/// session.
#[derive(Description)]
pub struct Authentication(pub mas_data_model::Authentication);
#[Object(use_type_description)]
impl Authentication {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::Authentication.id(self.0.id)
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.0.created_at
}
}

View File

@@ -0,0 +1,228 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{compat::CompatSessionRepository, user::UserRepository};
use url::Url;
use super::{BrowserSession, NodeType, SessionState, User, UserAgent};
use crate::graphql::state::ContextExt;
/// Lazy-loaded reverse reference.
///
/// XXX: maybe we want to stick that in a utility module
#[derive(Clone, Debug, Default)]
enum ReverseReference<T> {
Loaded(T),
#[default]
Lazy,
}
/// A compat session represents a client session which used the legacy Matrix
/// login API.
#[derive(Description)]
pub struct CompatSession {
session: mas_data_model::CompatSession,
sso_login: ReverseReference<Option<mas_data_model::CompatSsoLogin>>,
}
impl CompatSession {
pub fn new(session: mas_data_model::CompatSession) -> Self {
Self {
session,
sso_login: ReverseReference::Lazy,
}
}
/// Save an eagerly loaded SSO login.
pub fn with_loaded_sso_login(
mut self,
sso_login: Option<mas_data_model::CompatSsoLogin>,
) -> Self {
self.sso_login = ReverseReference::Loaded(sso_login);
self
}
}
/// The type of a compatibility session.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum CompatSessionType {
/// The session was created by a SSO login.
SsoLogin,
/// The session was created by an unknown method.
Unknown,
}
#[Object(use_type_description)]
impl CompatSession {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::CompatSession.id(self.session.id)
}
/// The user authorized for this session.
async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(self.session.user_id)
.await?
.context("Could not load user")?;
repo.cancel().await?;
Ok(User(user))
}
/// The Matrix Device ID of this session.
async fn device_id(&self) -> &str {
self.session.device.as_str()
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.session.created_at
}
/// When the session ended.
pub async fn finished_at(&self) -> Option<DateTime<Utc>> {
self.session.finished_at()
}
/// The user-agent with which the session was created.
pub async fn user_agent(&self) -> Option<UserAgent> {
self.session.user_agent.clone().map(UserAgent::from)
}
/// The associated SSO login, if any.
pub async fn sso_login(
&self,
ctx: &Context<'_>,
) -> Result<Option<CompatSsoLogin>, async_graphql::Error> {
if let ReverseReference::Loaded(sso_login) = &self.sso_login {
return Ok(sso_login.clone().map(CompatSsoLogin));
}
// We need to load it on the fly
let state = ctx.state();
let mut repo = state.repository().await?;
let sso_login = repo
.compat_sso_login()
.find_for_session(&self.session)
.await
.context("Could not load SSO login")?;
repo.cancel().await?;
Ok(sso_login.map(CompatSsoLogin))
}
/// The browser session which started this session, if any.
pub async fn browser_session(
&self,
ctx: &Context<'_>,
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let Some(user_session_id) = self.session.user_session_id else {
return Ok(None);
};
let state = ctx.state();
let mut repo = state.repository().await?;
let browser_session = repo
.browser_session()
.lookup(user_session_id)
.await?
.context("Could not load browser session")?;
repo.cancel().await?;
Ok(Some(BrowserSession(browser_session)))
}
/// The state of the session.
pub async fn state(&self) -> SessionState {
match &self.session.state {
mas_data_model::CompatSessionState::Valid => SessionState::Active,
mas_data_model::CompatSessionState::Finished { .. } => SessionState::Finished,
}
}
/// The last IP address used by the session.
pub async fn last_active_ip(&self) -> Option<String> {
self.session.last_active_ip.map(|ip| ip.to_string())
}
/// The last time the session was active.
pub async fn last_active_at(&self) -> Option<DateTime<Utc>> {
self.session.last_active_at
}
}
/// A compat SSO login represents a login done through the legacy Matrix login
/// API, via the `m.login.sso` login method.
#[derive(Description)]
pub struct CompatSsoLogin(pub mas_data_model::CompatSsoLogin);
#[Object(use_type_description)]
impl CompatSsoLogin {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::CompatSsoLogin.id(self.0.id)
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.0.created_at
}
/// The redirect URI used during the login.
async fn redirect_uri(&self) -> &Url {
&self.0.redirect_uri
}
/// When the login was fulfilled, and the user was redirected back to the
/// client.
async fn fulfilled_at(&self) -> Option<DateTime<Utc>> {
self.0.fulfilled_at()
}
/// When the client exchanged the login token sent during the redirection.
async fn exchanged_at(&self) -> Option<DateTime<Utc>> {
self.0.exchanged_at()
}
/// The compat session which was started by this login.
async fn session(
&self,
ctx: &Context<'_>,
) -> Result<Option<CompatSession>, async_graphql::Error> {
let Some(session_id) = self.0.session_id() else {
return Ok(None);
};
let state = ctx.state();
let mut repo = state.repository().await?;
let session = repo
.compat_session()
.lookup(session_id)
.await?
.context("Could not load compat session")?;
repo.cancel().await?;
Ok(Some(
CompatSession::new(session).with_loaded_sso_login(Some(self.0.clone())),
))
}
}

View File

@@ -0,0 +1,42 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::connection::OpaqueCursor;
use serde::{Deserialize, Serialize};
use ulid::Ulid;
pub use super::NodeType;
#[derive(Serialize, Deserialize, PartialEq, Eq)]
pub struct NodeCursor(pub NodeType, pub Ulid);
impl NodeCursor {
pub fn extract_for_types(&self, node_types: &[NodeType]) -> Result<Ulid, async_graphql::Error> {
if node_types.contains(&self.0) {
Ok(self.1)
} else {
Err(async_graphql::Error::new("invalid cursor"))
}
}
pub fn extract_for_type(&self, node_type: NodeType) -> Result<Ulid, async_graphql::Error> {
if self.0 == node_type {
Ok(self.1)
} else {
Err(async_graphql::Error::new("invalid cursor"))
}
}
}
pub type Cursor = OpaqueCursor<NodeCursor>;

View File

@@ -0,0 +1,45 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::SimpleObject;
use mas_matrix::HomeserverConnection;
#[derive(SimpleObject)]
pub struct MatrixUser {
/// The Matrix ID of the user.
mxid: String,
/// The display name of the user, if any.
display_name: Option<String>,
/// The avatar URL of the user, if any.
avatar_url: Option<String>,
}
impl MatrixUser {
pub(crate) async fn load<C: HomeserverConnection + ?Sized>(
conn: &C,
user: &str,
) -> Result<MatrixUser, C::Error> {
let mxid = conn.mxid(user);
let info = conn.query_user(&mxid).await?;
Ok(MatrixUser {
mxid,
display_name: info.displayname,
avatar_url: info.avatar_url,
})
}
}

View File

@@ -0,0 +1,143 @@
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Enum, Interface, Object, SimpleObject};
use chrono::{DateTime, Utc};
mod browser_sessions;
mod compat_sessions;
mod cursor;
mod matrix;
mod node;
mod oauth;
mod site_config;
mod upstream_oauth;
mod users;
mod viewer;
pub use self::{
browser_sessions::{Authentication, BrowserSession},
compat_sessions::{CompatSession, CompatSsoLogin},
cursor::{Cursor, NodeCursor},
node::{Node, NodeType},
oauth::{OAuth2Client, OAuth2Session},
site_config::{SiteConfig, SITE_CONFIG_ID},
upstream_oauth::{UpstreamOAuth2Link, UpstreamOAuth2Provider},
users::{AppSession, User, UserEmail},
viewer::{Anonymous, Viewer, ViewerSession},
};
/// An object with a creation date.
#[derive(Interface)]
#[graphql(field(
name = "created_at",
desc = "When the object was created.",
ty = "DateTime<Utc>"
))]
pub enum CreationEvent {
Authentication(Box<Authentication>),
CompatSession(Box<CompatSession>),
BrowserSession(Box<BrowserSession>),
UserEmail(Box<UserEmail>),
UpstreamOAuth2Provider(Box<UpstreamOAuth2Provider>),
UpstreamOAuth2Link(Box<UpstreamOAuth2Link>),
OAuth2Session(Box<OAuth2Session>),
}
pub struct PreloadedTotalCount(pub Option<usize>);
#[Object]
impl PreloadedTotalCount {
/// Identifies the total count of items in the connection.
async fn total_count(&self) -> Result<usize, async_graphql::Error> {
self.0
.ok_or_else(|| async_graphql::Error::new("total count not preloaded"))
}
}
/// The state of a session
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum SessionState {
/// The session is active.
Active,
/// The session is no longer active.
Finished,
}
/// The type of a user agent
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum DeviceType {
/// A personal computer, laptop or desktop
Pc,
/// A mobile phone. Can also sometimes be a tablet.
Mobile,
/// A tablet
Tablet,
/// Unknown device type
Unknown,
}
impl From<mas_data_model::DeviceType> for DeviceType {
fn from(device_type: mas_data_model::DeviceType) -> Self {
match device_type {
mas_data_model::DeviceType::Pc => Self::Pc,
mas_data_model::DeviceType::Mobile => Self::Mobile,
mas_data_model::DeviceType::Tablet => Self::Tablet,
mas_data_model::DeviceType::Unknown => Self::Unknown,
}
}
}
/// A parsed user agent string
#[derive(SimpleObject)]
pub struct UserAgent {
/// The user agent string
pub raw: String,
/// The name of the browser
pub name: Option<String>,
/// The version of the browser
pub version: Option<String>,
/// The operating system name
pub os: Option<String>,
/// The operating system version
pub os_version: Option<String>,
/// The device model
pub model: Option<String>,
/// The device type
pub device_type: DeviceType,
}
impl From<mas_data_model::UserAgent> for UserAgent {
fn from(ua: mas_data_model::UserAgent) -> Self {
Self {
raw: ua.raw,
name: ua.name,
version: ua.version,
os: ua.os,
os_version: ua.os_version,
model: ua.model,
device_type: ua.device_type.into(),
}
}
}

View File

@@ -0,0 +1,131 @@
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Interface, ID};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use ulid::Ulid;
use super::{
Anonymous, Authentication, BrowserSession, CompatSession, CompatSsoLogin, OAuth2Client,
OAuth2Session, SiteConfig, UpstreamOAuth2Link, UpstreamOAuth2Provider, User, UserEmail,
};
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum NodeType {
Authentication,
BrowserSession,
CompatSession,
CompatSsoLogin,
OAuth2Client,
OAuth2Session,
UpstreamOAuth2Provider,
UpstreamOAuth2Link,
User,
UserEmail,
}
#[derive(Debug, Error)]
#[error("invalid id")]
pub enum InvalidID {
InvalidFormat,
InvalidUlid(#[from] ulid::DecodeError),
UnknownPrefix,
TypeMismatch { got: NodeType, expected: NodeType },
}
impl NodeType {
fn to_prefix(self) -> &'static str {
match self {
NodeType::Authentication => "authentication",
NodeType::BrowserSession => "browser_session",
NodeType::CompatSession => "compat_session",
NodeType::CompatSsoLogin => "compat_sso_login",
NodeType::OAuth2Client => "oauth2_client",
NodeType::OAuth2Session => "oauth2_session",
NodeType::UpstreamOAuth2Provider => "upstream_oauth2_provider",
NodeType::UpstreamOAuth2Link => "upstream_oauth2_link",
NodeType::User => "user",
NodeType::UserEmail => "user_email",
}
}
fn from_prefix(prefix: &str) -> Option<Self> {
match prefix {
"authentication" => Some(NodeType::Authentication),
"browser_session" => Some(NodeType::BrowserSession),
"compat_session" => Some(NodeType::CompatSession),
"compat_sso_login" => Some(NodeType::CompatSsoLogin),
"oauth2_client" => Some(NodeType::OAuth2Client),
"oauth2_session" => Some(NodeType::OAuth2Session),
"upstream_oauth2_provider" => Some(NodeType::UpstreamOAuth2Provider),
"upstream_oauth2_link" => Some(NodeType::UpstreamOAuth2Link),
"user" => Some(NodeType::User),
"user_email" => Some(NodeType::UserEmail),
_ => None,
}
}
pub fn serialize(self, id: impl Into<Ulid>) -> String {
let prefix = self.to_prefix();
let id = id.into();
format!("{prefix}:{id}")
}
pub fn id(self, id: impl Into<Ulid>) -> ID {
ID(self.serialize(id))
}
pub fn deserialize(serialized: &str) -> Result<(Self, Ulid), InvalidID> {
let (prefix, id) = serialized.split_once(':').ok_or(InvalidID::InvalidFormat)?;
let prefix = NodeType::from_prefix(prefix).ok_or(InvalidID::UnknownPrefix)?;
let id = id.parse()?;
Ok((prefix, id))
}
pub fn from_id(id: &ID) -> Result<(Self, Ulid), InvalidID> {
Self::deserialize(&id.0)
}
pub fn extract_ulid(self, id: &ID) -> Result<Ulid, InvalidID> {
let (node_type, ulid) = Self::deserialize(&id.0)?;
if node_type == self {
Ok(ulid)
} else {
Err(InvalidID::TypeMismatch {
got: node_type,
expected: self,
})
}
}
}
/// An object with an ID.
#[derive(Interface)]
#[graphql(field(name = "id", desc = "ID of the object.", ty = "ID"))]
pub enum Node {
Anonymous(Box<Anonymous>),
Authentication(Box<Authentication>),
BrowserSession(Box<BrowserSession>),
CompatSession(Box<CompatSession>),
CompatSsoLogin(Box<CompatSsoLogin>),
OAuth2Client(Box<OAuth2Client>),
OAuth2Session(Box<OAuth2Session>),
SiteConfig(Box<SiteConfig>),
UpstreamOAuth2Provider(Box<UpstreamOAuth2Provider>),
UpstreamOAuth2Link(Box<UpstreamOAuth2Link>),
User(Box<User>),
UserEmail(Box<UserEmail>),
}

View File

@@ -0,0 +1,236 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository};
use oauth2_types::{oidc::ApplicationType, scope::Scope};
use ulid::Ulid;
use url::Url;
use super::{BrowserSession, NodeType, SessionState, User, UserAgent};
use crate::graphql::{state::ContextExt, UserId};
/// An OAuth 2.0 session represents a client session which used the OAuth APIs
/// to login.
#[derive(Description)]
pub struct OAuth2Session(pub mas_data_model::Session);
#[Object(use_type_description)]
impl OAuth2Session {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::OAuth2Session.id(self.0.id)
}
/// OAuth 2.0 client used by this session.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let client = repo
.oauth2_client()
.lookup(self.0.client_id)
.await?
.context("Could not load client")?;
repo.cancel().await?;
Ok(OAuth2Client(client))
}
/// Scope granted for this session.
pub async fn scope(&self) -> String {
self.0.scope.to_string()
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.0.created_at
}
/// When the session ended.
pub async fn finished_at(&self) -> Option<DateTime<Utc>> {
match &self.0.state {
mas_data_model::SessionState::Valid => None,
mas_data_model::SessionState::Finished { finished_at } => Some(*finished_at),
}
}
/// The user-agent with which the session was created.
pub async fn user_agent(&self) -> Option<UserAgent> {
self.0.user_agent.clone().map(UserAgent::from)
}
/// The state of the session.
pub async fn state(&self) -> SessionState {
match &self.0.state {
mas_data_model::SessionState::Valid => SessionState::Active,
mas_data_model::SessionState::Finished { .. } => SessionState::Finished,
}
}
/// The browser session which started this OAuth 2.0 session.
pub async fn browser_session(
&self,
ctx: &Context<'_>,
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let Some(user_session_id) = self.0.user_session_id else {
return Ok(None);
};
let state = ctx.state();
let mut repo = state.repository().await?;
let browser_session = repo
.browser_session()
.lookup(user_session_id)
.await?
.context("Could not load browser session")?;
repo.cancel().await?;
Ok(Some(BrowserSession(browser_session)))
}
/// User authorized for this session.
pub async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let Some(user_id) = self.0.user_id else {
return Ok(None);
};
if !ctx.requester().is_owner_or_admin(&UserId(user_id)) {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(user_id)
.await?
.context("Could not load user")?;
repo.cancel().await?;
Ok(Some(User(user)))
}
/// The last IP address used by the session.
pub async fn last_active_ip(&self) -> Option<String> {
self.0.last_active_ip.map(|ip| ip.to_string())
}
/// The last time the session was active.
pub async fn last_active_at(&self) -> Option<DateTime<Utc>> {
self.0.last_active_at
}
}
/// The application type advertised by the client.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum OAuth2ApplicationType {
/// Client is a web application.
Web,
/// Client is a native application.
Native,
}
/// An OAuth 2.0 client
#[derive(Description)]
pub struct OAuth2Client(pub mas_data_model::Client);
#[Object(use_type_description)]
impl OAuth2Client {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::OAuth2Client.id(self.0.id)
}
/// OAuth 2.0 client ID
pub async fn client_id(&self) -> &str {
&self.0.client_id
}
/// Client name advertised by the client.
pub async fn client_name(&self) -> Option<&str> {
self.0.client_name.as_deref()
}
/// Client URI advertised by the client.
pub async fn client_uri(&self) -> Option<&Url> {
self.0.client_uri.as_ref()
}
/// Logo URI advertised by the client.
pub async fn logo_uri(&self) -> Option<&Url> {
self.0.logo_uri.as_ref()
}
/// Terms of services URI advertised by the client.
pub async fn tos_uri(&self) -> Option<&Url> {
self.0.tos_uri.as_ref()
}
/// Privacy policy URI advertised by the client.
pub async fn policy_uri(&self) -> Option<&Url> {
self.0.policy_uri.as_ref()
}
/// List of redirect URIs used for authorization grants by the client.
pub async fn redirect_uris(&self) -> &[Url] {
&self.0.redirect_uris
}
/// List of contacts advertised by the client.
pub async fn contacts(&self) -> &[String] {
&self.0.contacts
}
/// The application type advertised by the client.
pub async fn application_type(&self) -> Option<OAuth2ApplicationType> {
match self.0.application_type.as_ref()? {
ApplicationType::Web => Some(OAuth2ApplicationType::Web),
ApplicationType::Native => Some(OAuth2ApplicationType::Native),
ApplicationType::Unknown(_) => None,
}
}
}
/// An OAuth 2.0 consent represents the scope a user consented to grant to a
/// client.
#[derive(Description)]
pub struct OAuth2Consent {
scope: Scope,
client_id: Ulid,
}
#[Object(use_type_description)]
impl OAuth2Consent {
/// Scope consented by the user for this client.
pub async fn scope(&self) -> String {
self.scope.to_string()
}
/// OAuth 2.0 client for which the user granted access.
pub async fn client(&self, ctx: &Context<'_>) -> Result<OAuth2Client, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let client = repo
.oauth2_client()
.lookup(self.client_id)
.await?
.context("Could not load client")?;
repo.cancel().await?;
Ok(OAuth2Client(client))
}
}

View File

@@ -0,0 +1,65 @@
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::str_to_string)] // ComplexObject macro uses &str.to_string()
use async_graphql::{ComplexObject, SimpleObject, ID};
use url::Url;
pub const SITE_CONFIG_ID: &str = "site_config";
#[derive(SimpleObject)]
#[graphql(complex)]
pub struct SiteConfig {
/// The server name of the homeserver.
server_name: String,
/// The URL to the privacy policy.
policy_uri: Option<Url>,
/// The URL to the terms of service.
tos_uri: Option<Url>,
/// Imprint to show in the footer.
imprint: Option<String>,
/// Whether user can change their email.
email_change_allowed: bool,
/// Whether user can change their display name.
display_name_change_allowed: bool,
}
#[ComplexObject]
impl SiteConfig {
/// The ID of the site configuration.
pub async fn id(&self) -> ID {
SITE_CONFIG_ID.into()
}
}
impl SiteConfig {
/// Create a new [`SiteConfig`] from the data model
/// [`mas_data_model:::SiteConfig`].
pub fn new(data_model: &mas_data_model::SiteConfig) -> Self {
Self {
server_name: data_model.server_name.clone(),
policy_uri: data_model.policy_uri.clone(),
tos_uri: data_model.tos_uri.clone(),
imprint: data_model.imprint.clone(),
email_change_allowed: data_model.email_change_allowed,
display_name_change_allowed: data_model.displayname_change_allowed,
}
}
}

View File

@@ -0,0 +1,143 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository};
use super::{NodeType, User};
use crate::graphql::state::ContextExt;
#[derive(Debug, Clone)]
pub struct UpstreamOAuth2Provider {
provider: mas_data_model::UpstreamOAuthProvider,
}
impl UpstreamOAuth2Provider {
#[must_use]
pub const fn new(provider: mas_data_model::UpstreamOAuthProvider) -> Self {
Self { provider }
}
}
#[Object]
impl UpstreamOAuth2Provider {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::UpstreamOAuth2Provider.id(self.provider.id)
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.provider.created_at
}
/// OpenID Connect issuer URL.
pub async fn issuer(&self) -> &str {
&self.provider.issuer
}
/// Client ID used for this provider.
pub async fn client_id(&self) -> &str {
&self.provider.client_id
}
}
impl UpstreamOAuth2Link {
#[must_use]
pub const fn new(link: mas_data_model::UpstreamOAuthLink) -> Self {
Self {
link,
provider: None,
user: None,
}
}
}
#[derive(Debug, Clone)]
pub struct UpstreamOAuth2Link {
link: mas_data_model::UpstreamOAuthLink,
provider: Option<mas_data_model::UpstreamOAuthProvider>,
user: Option<mas_data_model::User>,
}
#[Object]
impl UpstreamOAuth2Link {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::UpstreamOAuth2Link.id(self.link.id)
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.link.created_at
}
/// Subject used for linking
pub async fn subject(&self) -> &str {
&self.link.subject
}
/// The provider for which this link is.
pub async fn provider(
&self,
ctx: &Context<'_>,
) -> Result<UpstreamOAuth2Provider, async_graphql::Error> {
let state = ctx.state();
let provider = if let Some(provider) = &self.provider {
// Cached
provider.clone()
} else {
// Fetch on-the-fly
let mut repo = state.repository().await?;
let provider = repo
.upstream_oauth_provider()
.lookup(self.link.provider_id)
.await?
.context("Upstream OAuth 2.0 provider not found")?;
repo.cancel().await?;
provider
};
Ok(UpstreamOAuth2Provider::new(provider))
}
/// The user to which this link is associated.
pub async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let user = if let Some(user) = &self.user {
// Cached
user.clone()
} else if let Some(user_id) = &self.link.user_id {
// Fetch on-the-fly
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(*user_id)
.await?
.context("User not found")?;
repo.cancel().await?;
user
} else {
return Ok(None);
};
Ok(Some(User(user)))
}
}

View File

@@ -0,0 +1,711 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, Enum, Object, Union, ID,
};
use chrono::{DateTime, Utc};
use mas_data_model::Device;
use mas_storage::{
app_session::AppSessionFilter,
compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository},
oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository},
Pagination, RepositoryAccess,
};
use super::{
compat_sessions::{CompatSessionType, CompatSsoLogin},
matrix::MatrixUser,
BrowserSession, CompatSession, Cursor, NodeCursor, NodeType, OAuth2Session,
PreloadedTotalCount, SessionState, UpstreamOAuth2Link,
};
use crate::graphql::state::ContextExt;
#[derive(Description)]
/// A user is an individual's account.
pub struct User(pub mas_data_model::User);
impl From<mas_data_model::User> for User {
fn from(v: mas_data_model::User) -> Self {
Self(v)
}
}
impl From<mas_data_model::BrowserSession> for User {
fn from(v: mas_data_model::BrowserSession) -> Self {
Self(v.user)
}
}
#[Object(use_type_description)]
impl User {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::User.id(self.0.id)
}
/// Username chosen by the user.
async fn username(&self) -> &str {
&self.0.username
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.0.created_at
}
/// When the user was locked out.
pub async fn locked_at(&self) -> Option<DateTime<Utc>> {
self.0.locked_at
}
/// Whether the user can request admin privileges.
pub async fn can_request_admin(&self) -> bool {
self.0.can_request_admin
}
/// Access to the user's Matrix account information.
async fn matrix(&self, ctx: &Context<'_>) -> Result<MatrixUser, async_graphql::Error> {
let state = ctx.state();
let conn = state.homeserver_connection();
Ok(MatrixUser::load(conn, &self.0.username).await?)
}
/// Primary email address of the user.
async fn primary_email(
&self,
ctx: &Context<'_>,
) -> Result<Option<UserEmail>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user_email = repo.user_email().get_primary(&self.0).await?.map(UserEmail);
repo.cancel().await?;
Ok(user_email)
}
/// Get the list of compatibility SSO logins, chronologically sorted
async fn compat_sso_logins(
&self,
ctx: &Context<'_>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSsoLogin, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let filter = CompatSsoLoginFilter::new().for_user(&self.0);
let page = repo.compat_sso_login().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.compat_sso_login().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)),
CompatSsoLogin(u),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
/// Get the list of compatibility sessions, chronologically sorted
#[allow(clippy::too_many_arguments)]
async fn compat_sessions(
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only sessions with the given state.")]
state_param: Option<SessionState>,
#[graphql(name = "type", desc = "List only sessions with the given type.")]
type_param: Option<CompatSessionType>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSession, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSession))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSession))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
// Build the query filter
let filter = CompatSessionFilter::new().for_user(&self.0);
let filter = match state_param {
Some(SessionState::Active) => filter.active_only(),
Some(SessionState::Finished) => filter.finished_only(),
None => filter,
};
let filter = match type_param {
Some(CompatSessionType::SsoLogin) => filter.sso_login_only(),
Some(CompatSessionType::Unknown) => filter.unknown_only(),
None => filter,
};
let page = repo.compat_session().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.compat_session().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection
.edges
.extend(page.edges.into_iter().map(|(session, sso_login)| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)),
CompatSession::new(session).with_loaded_sso_login(sso_login),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
/// Get the list of active browser sessions, chronologically sorted
async fn browser_sessions(
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only sessions in the given state.")]
state_param: Option<SessionState>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, BrowserSession, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let filter = BrowserSessionFilter::new().for_user(&self.0);
let filter = match state_param {
Some(SessionState::Active) => filter.active_only(),
Some(SessionState::Finished) => filter.finished_only(),
None => filter,
};
let page = repo.browser_session().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.browser_session().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::BrowserSession, u.id)),
BrowserSession(u),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
/// Get the list of emails, chronologically sorted
async fn emails(
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only emails in the given state.")]
state_param: Option<UserEmailState>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UserEmail, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let filter = UserEmailFilter::new().for_user(&self.0);
let filter = match state_param {
Some(UserEmailState::Pending) => filter.pending_only(),
Some(UserEmailState::Confirmed) => filter.verified_only(),
None => filter,
};
let page = repo.user_email().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.user_email().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UserEmail, u.id)),
UserEmail(u),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
/// Get the list of OAuth 2.0 sessions, chronologically sorted
#[allow(clippy::too_many_arguments)]
async fn oauth2_sessions(
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only sessions in the given state.")]
state_param: Option<SessionState>,
#[graphql(desc = "List only sessions for the given client.")] client: Option<ID>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, OAuth2Session, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let client = if let Some(id) = client {
// Load the client if we're filtering by it
let id = NodeType::OAuth2Client.extract_ulid(&id)?;
let client = repo
.oauth2_client()
.lookup(id)
.await?
.ok_or(async_graphql::Error::new("Unknown client ID"))?;
Some(client)
} else {
None
};
let filter = OAuth2SessionFilter::new().for_user(&self.0);
let filter = match state_param {
Some(SessionState::Active) => filter.active_only(),
Some(SessionState::Finished) => filter.finished_only(),
None => filter,
};
let filter = match client.as_ref() {
Some(client) => filter.for_client(client),
None => filter,
};
let page = repo.oauth2_session().list(filter, pagination).await?;
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.oauth2_session().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|s| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)),
OAuth2Session(s),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
/// Get the list of upstream OAuth 2.0 links
async fn upstream_oauth2_links(
&self,
ctx: &Context<'_>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Link, PreloadedTotalCount>, async_graphql::Error>
{
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Link)
})
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Link)
})
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let filter = UpstreamOAuthLinkFilter::new()
.for_user(&self.0)
.enabled_providers_only();
let page = repo.upstream_oauth_link().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.upstream_oauth_link().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|s| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)),
UpstreamOAuth2Link::new(s),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
/// Get the list of both compat and OAuth 2.0 sessions, chronologically
/// sorted
#[allow(clippy::too_many_arguments)]
async fn app_sessions(
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only sessions in the given state.")]
state_param: Option<SessionState>,
#[graphql(name = "device", desc = "List only sessions for the given device.")]
device_param: Option<String>,
#[graphql(
name = "browserSession",
desc = "List only sessions for the given session."
)]
browser_session_param: Option<ID>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, AppSession, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_types(&[NodeType::OAuth2Session, NodeType::CompatSession])
})
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_types(&[NodeType::OAuth2Session, NodeType::CompatSession])
})
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let device_param = device_param.map(Device::try_from).transpose()?;
let filter = AppSessionFilter::new().for_user(&self.0);
let filter = match state_param {
Some(SessionState::Active) => filter.active_only(),
Some(SessionState::Finished) => filter.finished_only(),
None => filter,
};
let filter = match device_param.as_ref() {
Some(device) => filter.for_device(device),
None => filter,
};
let maybe_session = match browser_session_param {
Some(id) => {
// This might fail, but we're probably alright with it
let id = NodeType::BrowserSession
.extract_ulid(&id)
.context("Invalid browser_session parameter")?;
let Some(session) = repo
.browser_session()
.lookup(id)
.await?
.filter(|u| requester.is_owner_or_admin(u))
else {
// If we couldn't find the session or if the requester can't access it,
// return an empty list
return Ok(Connection::with_additional_fields(
false,
false,
PreloadedTotalCount(Some(0)),
));
};
Some(session)
}
None => None,
};
let filter = match maybe_session {
Some(ref session) => filter.for_browser_session(session),
None => filter,
};
let page = repo.app_session().list(filter, pagination).await?;
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.app_session().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection
.edges
.extend(page.edges.into_iter().map(|s| match s {
mas_storage::app_session::AppSession::Compat(session) => Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)),
AppSession::CompatSession(Box::new(CompatSession::new(*session))),
),
mas_storage::app_session::AppSession::OAuth2(session) => Edge::new(
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, session.id)),
AppSession::OAuth2Session(Box::new(OAuth2Session(*session))),
),
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
}
/// A session in an application, either a compatibility or an OAuth 2.0 one
#[derive(Union)]
pub enum AppSession {
CompatSession(Box<CompatSession>),
OAuth2Session(Box<OAuth2Session>),
}
/// A user email address
#[derive(Description)]
pub struct UserEmail(pub mas_data_model::UserEmail);
#[Object(use_type_description)]
impl UserEmail {
/// ID of the object.
pub async fn id(&self) -> ID {
NodeType::UserEmail.id(self.0.id)
}
/// Email address
async fn email(&self) -> &str {
&self.0.email
}
/// When the object was created.
pub async fn created_at(&self) -> DateTime<Utc> {
self.0.created_at
}
/// When the email address was confirmed. Is `null` if the email was never
/// verified by the user.
async fn confirmed_at(&self) -> Option<DateTime<Utc>> {
self.0.confirmed_at
}
}
/// The state of a compatibility session.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum UserEmailState {
/// The email address is pending confirmation.
Pending,
/// The email address has been confirmed.
Confirmed,
}

View File

@@ -0,0 +1,26 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Object, ID};
/// An anonymous viewer
#[derive(Default, Clone, Copy)]
pub struct Anonymous;
#[Object]
impl Anonymous {
pub async fn id(&self) -> ID {
"anonymous".into()
}
}

View File

@@ -0,0 +1,59 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::Union;
use crate::graphql::model::{BrowserSession, OAuth2Session, User};
mod anonymous;
pub use self::anonymous::Anonymous;
/// Represents the current viewer
#[derive(Union)]
pub enum Viewer {
User(User),
Anonymous(Anonymous),
}
impl Viewer {
pub fn user(user: mas_data_model::User) -> Self {
Self::User(User(user))
}
pub fn anonymous() -> Self {
Self::Anonymous(Anonymous)
}
}
/// Represents the current viewer's session
#[derive(Union)]
pub enum ViewerSession {
BrowserSession(Box<BrowserSession>),
OAuth2Session(Box<OAuth2Session>),
Anonymous(Anonymous),
}
impl ViewerSession {
pub fn browser_session(session: mas_data_model::BrowserSession) -> Self {
Self::BrowserSession(Box::new(BrowserSession(session)))
}
pub fn oauth2_session(session: mas_data_model::Session) -> Self {
Self::OAuth2Session(Box::new(OAuth2Session(session)))
}
pub fn anonymous() -> Self {
Self::Anonymous(Anonymous)
}
}

View File

@@ -0,0 +1,101 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Context, Enum, InputObject, Object, ID};
use mas_storage::RepositoryAccess;
use crate::graphql::{
model::{BrowserSession, NodeType},
state::ContextExt,
};
#[derive(Default)]
pub struct BrowserSessionMutations {
_private: (),
}
/// The input of the `endBrowserSession` mutation.
#[derive(InputObject)]
pub struct EndBrowserSessionInput {
/// The ID of the session to end.
browser_session_id: ID,
}
/// The payload of the `endBrowserSession` mutation.
pub enum EndBrowserSessionPayload {
NotFound,
Ended(Box<mas_data_model::BrowserSession>),
}
/// The status of the `endBrowserSession` mutation.
#[derive(Enum, Copy, Clone, PartialEq, Eq, Debug)]
enum EndBrowserSessionStatus {
/// The session was ended.
Ended,
/// The session was not found.
NotFound,
}
#[Object]
impl EndBrowserSessionPayload {
/// The status of the mutation.
async fn status(&self) -> EndBrowserSessionStatus {
match self {
Self::Ended(_) => EndBrowserSessionStatus::Ended,
Self::NotFound => EndBrowserSessionStatus::NotFound,
}
}
/// Returns the ended session.
async fn browser_session(&self) -> Option<BrowserSession> {
match self {
Self::Ended(session) => Some(BrowserSession(*session.clone())),
Self::NotFound => None,
}
}
}
#[Object]
impl BrowserSessionMutations {
async fn end_browser_session(
&self,
ctx: &Context<'_>,
input: EndBrowserSessionInput,
) -> Result<EndBrowserSessionPayload, async_graphql::Error> {
let state = ctx.state();
let browser_session_id =
NodeType::BrowserSession.extract_ulid(&input.browser_session_id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let clock = state.clock();
let session = repo.browser_session().lookup(browser_session_id).await?;
let Some(session) = session else {
return Ok(EndBrowserSessionPayload::NotFound);
};
if !requester.is_owner_or_admin(&session) {
return Ok(EndBrowserSessionPayload::NotFound);
}
let session = repo.browser_session().finish(&clock, session).await?;
repo.save().await?;
Ok(EndBrowserSessionPayload::Ended(Box::new(session)))
}
}

View File

@@ -0,0 +1,115 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Enum, InputObject, Object, ID};
use mas_storage::{
compat::CompatSessionRepository,
job::{DeleteDeviceJob, JobRepositoryExt},
RepositoryAccess,
};
use crate::graphql::{
model::{CompatSession, NodeType},
state::ContextExt,
};
#[derive(Default)]
pub struct CompatSessionMutations {
_private: (),
}
/// The input of the `endCompatSession` mutation.
#[derive(InputObject)]
pub struct EndCompatSessionInput {
/// The ID of the session to end.
compat_session_id: ID,
}
/// The payload of the `endCompatSession` mutation.
pub enum EndCompatSessionPayload {
NotFound,
Ended(Box<mas_data_model::CompatSession>),
}
/// The status of the `endCompatSession` mutation.
#[derive(Enum, Copy, Clone, PartialEq, Eq, Debug)]
enum EndCompatSessionStatus {
/// The session was ended.
Ended,
/// The session was not found.
NotFound,
}
#[Object]
impl EndCompatSessionPayload {
/// The status of the mutation.
async fn status(&self) -> EndCompatSessionStatus {
match self {
Self::Ended(_) => EndCompatSessionStatus::Ended,
Self::NotFound => EndCompatSessionStatus::NotFound,
}
}
/// Returns the ended session.
async fn compat_session(&self) -> Option<CompatSession> {
match self {
Self::Ended(session) => Some(CompatSession::new(*session.clone())),
Self::NotFound => None,
}
}
}
#[Object]
impl CompatSessionMutations {
async fn end_compat_session(
&self,
ctx: &Context<'_>,
input: EndCompatSessionInput,
) -> Result<EndCompatSessionPayload, async_graphql::Error> {
let state = ctx.state();
let compat_session_id = NodeType::CompatSession.extract_ulid(&input.compat_session_id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let clock = state.clock();
let session = repo.compat_session().lookup(compat_session_id).await?;
let Some(session) = session else {
return Ok(EndCompatSessionPayload::NotFound);
};
if !requester.is_owner_or_admin(&session) {
return Ok(EndCompatSessionPayload::NotFound);
}
let user = repo
.user()
.lookup(session.user_id)
.await?
.context("Could not load user")?;
// Schedule a job to delete the device.
repo.job()
.schedule_job(DeleteDeviceJob::new(&user, &session.device))
.await?;
let session = repo.compat_session().finish(&clock, session).await?;
repo.save().await?;
Ok(EndCompatSessionPayload::Ended(Box::new(session)))
}
}

View File

@@ -0,0 +1,127 @@
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, InputObject, Object, ID};
use crate::graphql::{
model::{NodeType, User},
state::ContextExt,
UserId,
};
#[derive(Default)]
pub struct MatrixMutations {
_private: (),
}
/// The input for the `addEmail` mutation
#[derive(InputObject)]
struct SetDisplayNameInput {
/// The ID of the user to add the email address to
user_id: ID,
/// The display name to set. If `None`, the display name will be removed.
display_name: Option<String>,
}
/// The status of the `setDisplayName` mutation
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum SetDisplayNameStatus {
/// The display name was set
Set,
/// The display name is invalid
Invalid,
}
/// The payload of the `setDisplayName` mutation
#[derive(Description)]
enum SetDisplayNamePayload {
Set(User),
Invalid,
}
#[Object(use_type_description)]
impl SetDisplayNamePayload {
/// Status of the operation
async fn status(&self) -> SetDisplayNameStatus {
match self {
SetDisplayNamePayload::Set(_) => SetDisplayNameStatus::Set,
SetDisplayNamePayload::Invalid => SetDisplayNameStatus::Invalid,
}
}
/// The user that was updated
async fn user(&self) -> Option<&User> {
match self {
SetDisplayNamePayload::Set(user) => Some(user),
SetDisplayNamePayload::Invalid => None,
}
}
}
#[Object]
impl MatrixMutations {
/// Set the display name of a user
async fn set_display_name(
&self,
ctx: &Context<'_>,
input: SetDisplayNameInput,
) -> Result<SetDisplayNamePayload, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
if !requester.is_owner_or_admin(&UserId(id)) {
return Err(async_graphql::Error::new("Unauthorized"));
}
// Allow non-admins to change their display name if the site config allows it
if !requester.is_admin() && !state.site_config().displayname_change_allowed {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(id)
.await?
.context("Failed to lookup user")?;
repo.cancel().await?;
let conn = state.homeserver_connection();
let mxid = conn.mxid(&user.username);
if let Some(display_name) = &input.display_name {
// Let's do some basic validation on the display name
if display_name.len() > 256 {
return Ok(SetDisplayNamePayload::Invalid);
}
if display_name.is_empty() {
return Ok(SetDisplayNamePayload::Invalid);
}
conn.set_displayname(&mxid, display_name)
.await
.context("Failed to set display name")?;
} else {
conn.unset_displayname(&mxid)
.await
.context("Failed to unset display name")?;
}
Ok(SetDisplayNamePayload::Set(User(user.clone())))
}
}

View File

@@ -0,0 +1,40 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod browser_session;
mod compat_session;
mod matrix;
mod oauth2_session;
mod user;
mod user_email;
use async_graphql::MergedObject;
/// The mutations root of the GraphQL interface.
#[derive(Default, MergedObject)]
pub struct Mutation(
user_email::UserEmailMutations,
user::UserMutations,
oauth2_session::OAuth2SessionMutations,
compat_session::CompatSessionMutations,
browser_session::BrowserSessionMutations,
matrix::MatrixMutations,
);
impl Mutation {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}

View File

@@ -0,0 +1,261 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, InputObject, Object, ID};
use chrono::Duration;
use mas_data_model::{Device, TokenType};
use mas_storage::{
job::{DeleteDeviceJob, JobRepositoryExt, ProvisionDeviceJob},
oauth2::{
OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository,
OAuth2SessionRepository,
},
user::UserRepository,
RepositoryAccess,
};
use oauth2_types::scope::Scope;
use crate::graphql::{
model::{NodeType, OAuth2Session},
state::ContextExt,
};
#[derive(Default)]
pub struct OAuth2SessionMutations {
_private: (),
}
/// The input of the `createOauth2Session` mutation.
#[derive(InputObject)]
pub struct CreateOAuth2SessionInput {
/// The scope of the session
scope: String,
/// The ID of the user for which to create the session
user_id: ID,
/// Whether the session should issue a never-expiring access token
permanent: Option<bool>,
}
/// The payload of the `createOauth2Session` mutation.
#[derive(Description)]
pub struct CreateOAuth2SessionPayload {
access_token: String,
refresh_token: Option<String>,
session: mas_data_model::Session,
}
#[Object(use_type_description)]
impl CreateOAuth2SessionPayload {
/// Access token for this session
pub async fn access_token(&self) -> &str {
&self.access_token
}
/// Refresh token for this session, if it is not a permanent session
pub async fn refresh_token(&self) -> Option<&str> {
self.refresh_token.as_deref()
}
/// The OAuth 2.0 session which was just created
pub async fn oauth2_session(&self) -> OAuth2Session {
OAuth2Session(self.session.clone())
}
}
/// The input of the `endOauth2Session` mutation.
#[derive(InputObject)]
pub struct EndOAuth2SessionInput {
/// The ID of the session to end.
oauth2_session_id: ID,
}
/// The payload of the `endOauth2Session` mutation.
pub enum EndOAuth2SessionPayload {
NotFound,
Ended(mas_data_model::Session),
}
/// The status of the `endOauth2Session` mutation.
#[derive(Enum, Copy, Clone, PartialEq, Eq, Debug)]
enum EndOAuth2SessionStatus {
/// The session was ended.
Ended,
/// The session was not found.
NotFound,
}
#[Object]
impl EndOAuth2SessionPayload {
/// The status of the mutation.
async fn status(&self) -> EndOAuth2SessionStatus {
match self {
Self::Ended(_) => EndOAuth2SessionStatus::Ended,
Self::NotFound => EndOAuth2SessionStatus::NotFound,
}
}
/// Returns the ended session.
async fn oauth2_session(&self) -> Option<OAuth2Session> {
match self {
Self::Ended(session) => Some(OAuth2Session(session.clone())),
Self::NotFound => None,
}
}
}
#[Object]
impl OAuth2SessionMutations {
/// Create a new arbitrary OAuth 2.0 Session.
///
/// Only available for administrators.
async fn create_oauth2_session(
&self,
ctx: &Context<'_>,
input: CreateOAuth2SessionInput,
) -> Result<CreateOAuth2SessionPayload, async_graphql::Error> {
let state = ctx.state();
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let scope: Scope = input.scope.parse().context("Invalid scope")?;
let permanent = input.permanent.unwrap_or(false);
let requester = ctx.requester();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let session = requester
.oauth2_session()
.context("Requester should be a OAuth 2.0 client")?;
let mut repo = state.repository().await?;
let clock = state.clock();
let mut rng = state.rng();
let client = repo
.oauth2_client()
.lookup(session.client_id)
.await?
.context("Client not found")?;
let user = repo
.user()
.lookup(user_id)
.await?
.context("User not found")?;
// Generate a new access token
let access_token = TokenType::AccessToken.generate(&mut rng);
// Create the OAuth 2.0 Session
let session = repo
.oauth2_session()
.add(&mut rng, &clock, &client, Some(&user), None, scope)
.await?;
// Look for devices to provision
for scope in &*session.scope {
if let Some(device) = Device::from_scope_token(scope) {
repo.job()
.schedule_job(ProvisionDeviceJob::new(&user, &device))
.await?;
}
}
let ttl = if permanent {
None
} else {
Some(Duration::microseconds(5 * 60 * 1000 * 1000))
};
let access_token = repo
.oauth2_access_token()
.add(&mut rng, &clock, &session, access_token, ttl)
.await?;
let refresh_token = if permanent {
None
} else {
let refresh_token = TokenType::RefreshToken.generate(&mut rng);
let refresh_token = repo
.oauth2_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token)
.await?;
Some(refresh_token)
};
repo.save().await?;
Ok(CreateOAuth2SessionPayload {
session,
access_token: access_token.access_token,
refresh_token: refresh_token.map(|t| t.refresh_token),
})
}
async fn end_oauth2_session(
&self,
ctx: &Context<'_>,
input: EndOAuth2SessionInput,
) -> Result<EndOAuth2SessionPayload, async_graphql::Error> {
let state = ctx.state();
let oauth2_session_id = NodeType::OAuth2Session.extract_ulid(&input.oauth2_session_id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let clock = state.clock();
let session = repo.oauth2_session().lookup(oauth2_session_id).await?;
let Some(session) = session else {
return Ok(EndOAuth2SessionPayload::NotFound);
};
if !requester.is_owner_or_admin(&session) {
return Ok(EndOAuth2SessionPayload::NotFound);
}
if let Some(user_id) = session.user_id {
let user = repo
.user()
.lookup(user_id)
.await?
.context("Could not load user")?;
// Scan the scopes of the session to find if there is any device that should be
// deleted from the Matrix server.
// TODO: this should be moved in a higher level "end oauth session" method.
// XXX: this might not be the right semantic, but it's the best we
// can do for now, since we're not explicitly storing devices for OAuth2
// sessions.
for scope in &*session.scope {
if let Some(device) = Device::from_scope_token(scope) {
// Schedule a job to delete the device.
repo.job()
.schedule_job(DeleteDeviceJob::new(&user, &device))
.await?;
}
}
}
let session = repo.oauth2_session().finish(&clock, session).await?;
repo.save().await?;
Ok(EndOAuth2SessionPayload::Ended(session))
}
}

View File

@@ -0,0 +1,388 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, InputObject, Object, ID};
use mas_storage::{
job::{DeactivateUserJob, JobRepositoryExt, ProvisionUserJob},
user::UserRepository,
};
use tracing::{info, warn};
use crate::graphql::{
model::{NodeType, User},
state::ContextExt,
UserId,
};
#[derive(Default)]
pub struct UserMutations {
_private: (),
}
/// The input for the `addUser` mutation.
#[derive(InputObject)]
struct AddUserInput {
/// The username of the user to add.
username: String,
/// Skip checking with the homeserver whether the username is valid.
///
/// Use this with caution! The main reason to use this, is when a user used
/// by an application service needs to exist in MAS to craft special
/// tokens (like with admin access) for them
skip_homeserver_check: Option<bool>,
}
/// The status of the `addUser` mutation.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum AddUserStatus {
/// The user was added.
Added,
/// The user already exists.
Exists,
/// The username is reserved.
Reserved,
/// The username is invalid.
Invalid,
}
/// The payload for the `addUser` mutation.
#[derive(Description)]
enum AddUserPayload {
Added(mas_data_model::User),
Exists(mas_data_model::User),
Reserved,
Invalid,
}
#[Object(use_type_description)]
impl AddUserPayload {
/// Status of the operation
async fn status(&self) -> AddUserStatus {
match self {
Self::Added(_) => AddUserStatus::Added,
Self::Exists(_) => AddUserStatus::Exists,
Self::Reserved => AddUserStatus::Reserved,
Self::Invalid => AddUserStatus::Invalid,
}
}
/// The user that was added.
async fn user(&self) -> Option<User> {
match self {
Self::Added(user) | Self::Exists(user) => Some(User(user.clone())),
Self::Invalid | Self::Reserved => None,
}
}
}
/// The input for the `lockUser` mutation.
#[derive(InputObject)]
struct LockUserInput {
/// The ID of the user to lock.
user_id: ID,
/// Permanently lock the user.
deactivate: Option<bool>,
}
/// The status of the `lockUser` mutation.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum LockUserStatus {
/// The user was locked.
Locked,
/// The user was not found.
NotFound,
}
/// The payload for the `lockUser` mutation.
#[derive(Description)]
enum LockUserPayload {
/// The user was locked.
Locked(mas_data_model::User),
/// The user was not found.
NotFound,
}
#[Object(use_type_description)]
impl LockUserPayload {
/// Status of the operation
async fn status(&self) -> LockUserStatus {
match self {
Self::Locked(_) => LockUserStatus::Locked,
Self::NotFound => LockUserStatus::NotFound,
}
}
/// The user that was locked.
async fn user(&self) -> Option<User> {
match self {
Self::Locked(user) => Some(User(user.clone())),
Self::NotFound => None,
}
}
}
/// The input for the `setCanRequestAdmin` mutation.
#[derive(InputObject)]
struct SetCanRequestAdminInput {
/// The ID of the user to update.
user_id: ID,
/// Whether the user can request admin.
can_request_admin: bool,
}
/// The payload for the `setCanRequestAdmin` mutation.
#[derive(Description)]
enum SetCanRequestAdminPayload {
/// The user was updated.
Updated(mas_data_model::User),
/// The user was not found.
NotFound,
}
#[Object(use_type_description)]
impl SetCanRequestAdminPayload {
/// The user that was updated.
async fn user(&self) -> Option<User> {
match self {
Self::Updated(user) => Some(User(user.clone())),
Self::NotFound => None,
}
}
}
/// The input for the `allowUserCrossSigningReset` mutation.
#[derive(InputObject)]
struct AllowUserCrossSigningResetInput {
/// The ID of the user to update.
user_id: ID,
}
/// The payload for the `allowUserCrossSigningReset` mutation.
#[derive(Description)]
enum AllowUserCrossSigningResetPayload {
/// The user was updated.
Allowed(mas_data_model::User),
/// The user was not found.
NotFound,
}
#[Object(use_type_description)]
impl AllowUserCrossSigningResetPayload {
/// The user that was updated.
async fn user(&self) -> Option<User> {
match self {
Self::Allowed(user) => Some(User(user.clone())),
Self::NotFound => None,
}
}
}
fn valid_username_character(c: char) -> bool {
c.is_ascii_lowercase()
|| c.is_ascii_digit()
|| c == '='
|| c == '_'
|| c == '-'
|| c == '.'
|| c == '/'
|| c == '+'
}
// XXX: this should probably be moved somewhere else
fn username_valid(username: &str) -> bool {
if username.is_empty() || username.len() > 255 {
return false;
}
// Should not start with an underscore
if username.get(0..1) == Some("_") {
return false;
}
// Should only contain valid characters
if !username.chars().all(valid_username_character) {
return false;
}
true
}
#[Object]
impl UserMutations {
/// Add a user. This is only available to administrators.
async fn add_user(
&self,
ctx: &Context<'_>,
input: AddUserInput,
) -> Result<AddUserPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
let clock = state.clock();
let mut rng = state.rng();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
if let Some(user) = repo.user().find_by_username(&input.username).await? {
return Ok(AddUserPayload::Exists(user));
}
// Do some basic check on the username
if !username_valid(&input.username) {
return Ok(AddUserPayload::Invalid);
}
// Ask the homeserver if the username is available
let homeserver_available = state
.homeserver_connection()
.is_localpart_available(&input.username)
.await?;
if !homeserver_available {
if !input.skip_homeserver_check.unwrap_or(false) {
return Ok(AddUserPayload::Reserved);
}
// If we skipped the check, we still want to shout about it
warn!("Skipped homeserver check for username {}", input.username);
}
let user = repo.user().add(&mut rng, &clock, input.username).await?;
repo.job()
.schedule_job(ProvisionUserJob::new(&user))
.await?;
repo.save().await?;
Ok(AddUserPayload::Added(user))
}
/// Lock a user. This is only available to administrators.
async fn lock_user(
&self,
ctx: &Context<'_>,
input: LockUserInput,
) -> Result<LockUserPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let user = repo.user().lookup(user_id).await?;
let Some(user) = user else {
return Ok(LockUserPayload::NotFound);
};
let deactivate = input.deactivate.unwrap_or(false);
let user = repo.user().lock(&state.clock(), user).await?;
if deactivate {
info!("Scheduling deactivation of user {}", user.id);
repo.job()
.schedule_job(DeactivateUserJob::new(&user, deactivate))
.await?;
}
repo.save().await?;
Ok(LockUserPayload::Locked(user))
}
/// Set whether a user can request admin. This is only available to
/// administrators.
async fn set_can_request_admin(
&self,
ctx: &Context<'_>,
input: SetCanRequestAdminInput,
) -> Result<SetCanRequestAdminPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let user = repo.user().lookup(user_id).await?;
let Some(user) = user else {
return Ok(SetCanRequestAdminPayload::NotFound);
};
let user = repo
.user()
.set_can_request_admin(user, input.can_request_admin)
.await?;
repo.save().await?;
Ok(SetCanRequestAdminPayload::Updated(user))
}
/// Temporarily allow user to reset their cross-signing keys.
async fn allow_user_cross_signing_reset(
&self,
ctx: &Context<'_>,
input: AllowUserCrossSigningResetInput,
) -> Result<AllowUserCrossSigningResetPayload, async_graphql::Error> {
let state = ctx.state();
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
if !requester.is_owner_or_admin(&UserId(user_id)) {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user = repo.user().lookup(user_id).await?;
repo.cancel().await?;
let Some(user) = user else {
return Ok(AllowUserCrossSigningResetPayload::NotFound);
};
let conn = state.homeserver_connection();
let mxid = conn.mxid(&user.username);
conn.allow_cross_signing_reset(&mxid)
.await
.context("Failed to allow cross-signing reset")?;
Ok(AllowUserCrossSigningResetPayload::Allowed(user))
}
}

View File

@@ -0,0 +1,680 @@
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, InputObject, Object, ID};
use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob},
user::{UserEmailRepository, UserRepository},
RepositoryAccess,
};
use crate::graphql::{
model::{NodeType, User, UserEmail},
state::ContextExt,
UserId,
};
#[derive(Default)]
pub struct UserEmailMutations {
_private: (),
}
/// The input for the `addEmail` mutation
#[derive(InputObject)]
struct AddEmailInput {
/// The email address to add
email: String,
/// The ID of the user to add the email address to
user_id: ID,
/// Skip the email address verification. Only allowed for admins.
skip_verification: Option<bool>,
/// Skip the email address policy check. Only allowed for admins.
skip_policy_check: Option<bool>,
}
/// The status of the `addEmail` mutation
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum AddEmailStatus {
/// The email address was added
Added,
/// The email address already exists
Exists,
/// The email address is invalid
Invalid,
/// The email address is not allowed by the policy
Denied,
}
/// The payload of the `addEmail` mutation
#[derive(Description)]
enum AddEmailPayload {
Added(mas_data_model::UserEmail),
Exists(mas_data_model::UserEmail),
Invalid,
Denied {
violations: Vec<mas_policy::Violation>,
},
}
#[Object(use_type_description)]
impl AddEmailPayload {
/// Status of the operation
async fn status(&self) -> AddEmailStatus {
match self {
AddEmailPayload::Added(_) => AddEmailStatus::Added,
AddEmailPayload::Exists(_) => AddEmailStatus::Exists,
AddEmailPayload::Invalid => AddEmailStatus::Invalid,
AddEmailPayload::Denied { .. } => AddEmailStatus::Denied,
}
}
/// The email address that was added
async fn email(&self) -> Option<UserEmail> {
match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => {
Some(UserEmail(email.clone()))
}
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None,
}
}
/// The user to whom the email address was added
async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user_id = match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id,
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None),
};
let user = repo
.user()
.lookup(user_id)
.await?
.context("User not found")?;
Ok(Some(User(user)))
}
/// The list of policy violations if the email address was denied
async fn violations(&self) -> Option<Vec<String>> {
let AddEmailPayload::Denied { violations } = self else {
return None;
};
let messages = violations.iter().map(|v| v.msg.clone()).collect();
Some(messages)
}
}
/// The input for the `sendVerificationEmail` mutation
#[derive(InputObject)]
struct SendVerificationEmailInput {
/// The ID of the email address to verify
user_email_id: ID,
}
/// The status of the `sendVerificationEmail` mutation
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum SendVerificationEmailStatus {
/// The verification email was sent
Sent,
/// The email address is already verified
AlreadyVerified,
}
/// The payload of the `sendVerificationEmail` mutation
#[derive(Description)]
enum SendVerificationEmailPayload {
Sent(mas_data_model::UserEmail),
AlreadyVerified(mas_data_model::UserEmail),
}
#[Object(use_type_description)]
impl SendVerificationEmailPayload {
/// Status of the operation
async fn status(&self) -> SendVerificationEmailStatus {
match self {
SendVerificationEmailPayload::Sent(_) => SendVerificationEmailStatus::Sent,
SendVerificationEmailPayload::AlreadyVerified(_) => {
SendVerificationEmailStatus::AlreadyVerified
}
}
}
/// The email address to which the verification email was sent
async fn email(&self) -> UserEmail {
match self {
SendVerificationEmailPayload::Sent(email)
| SendVerificationEmailPayload::AlreadyVerified(email) => UserEmail(email.clone()),
}
}
/// The user to whom the email address belongs
async fn user(&self, ctx: &Context<'_>) -> Result<User, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user_id = match self {
SendVerificationEmailPayload::Sent(email)
| SendVerificationEmailPayload::AlreadyVerified(email) => email.user_id,
};
let user = repo
.user()
.lookup(user_id)
.await?
.context("User not found")?;
Ok(User(user))
}
}
/// The input for the `verifyEmail` mutation
#[derive(InputObject)]
struct VerifyEmailInput {
/// The ID of the email address to verify
user_email_id: ID,
/// The verification code
code: String,
}
/// The status of the `verifyEmail` mutation
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum VerifyEmailStatus {
/// The email address was just verified
Verified,
/// The email address was already verified before
AlreadyVerified,
/// The verification code is invalid
InvalidCode,
}
/// The payload of the `verifyEmail` mutation
#[derive(Description)]
enum VerifyEmailPayload {
Verified(mas_data_model::UserEmail),
AlreadyVerified(mas_data_model::UserEmail),
InvalidCode,
}
#[Object(use_type_description)]
impl VerifyEmailPayload {
/// Status of the operation
async fn status(&self) -> VerifyEmailStatus {
match self {
VerifyEmailPayload::Verified(_) => VerifyEmailStatus::Verified,
VerifyEmailPayload::AlreadyVerified(_) => VerifyEmailStatus::AlreadyVerified,
VerifyEmailPayload::InvalidCode => VerifyEmailStatus::InvalidCode,
}
}
/// The email address that was verified
async fn email(&self) -> Option<UserEmail> {
match self {
VerifyEmailPayload::Verified(email) | VerifyEmailPayload::AlreadyVerified(email) => {
Some(UserEmail(email.clone()))
}
VerifyEmailPayload::InvalidCode => None,
}
}
/// The user to whom the email address belongs
async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user_id = match self {
VerifyEmailPayload::Verified(email) | VerifyEmailPayload::AlreadyVerified(email) => {
email.user_id
}
VerifyEmailPayload::InvalidCode => return Ok(None),
};
let user = repo
.user()
.lookup(user_id)
.await?
.context("User not found")?;
Ok(Some(User(user)))
}
}
/// The input for the `removeEmail` mutation
#[derive(InputObject)]
struct RemoveEmailInput {
/// The ID of the email address to remove
user_email_id: ID,
}
/// The status of the `removeEmail` mutation
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum RemoveEmailStatus {
/// The email address was removed
Removed,
/// Can't remove the primary email address
Primary,
/// The email address was not found
NotFound,
}
/// The payload of the `removeEmail` mutation
#[derive(Description)]
enum RemoveEmailPayload {
Removed(mas_data_model::UserEmail),
Primary(mas_data_model::UserEmail),
NotFound,
}
#[Object(use_type_description)]
impl RemoveEmailPayload {
/// Status of the operation
async fn status(&self) -> RemoveEmailStatus {
match self {
RemoveEmailPayload::Removed(_) => RemoveEmailStatus::Removed,
RemoveEmailPayload::Primary(_) => RemoveEmailStatus::Primary,
RemoveEmailPayload::NotFound => RemoveEmailStatus::NotFound,
}
}
/// The email address that was removed
async fn email(&self) -> Option<UserEmail> {
match self {
RemoveEmailPayload::Removed(email) | RemoveEmailPayload::Primary(email) => {
Some(UserEmail(email.clone()))
}
RemoveEmailPayload::NotFound => None,
}
}
/// The user to whom the email address belonged
async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user_id = match self {
RemoveEmailPayload::Removed(email) | RemoveEmailPayload::Primary(email) => {
email.user_id
}
RemoveEmailPayload::NotFound => return Ok(None),
};
let user = repo
.user()
.lookup(user_id)
.await?
.context("User not found")?;
Ok(Some(User(user)))
}
}
/// The input for the `setPrimaryEmail` mutation
#[derive(InputObject)]
struct SetPrimaryEmailInput {
/// The ID of the email address to set as primary
user_email_id: ID,
}
/// The status of the `setPrimaryEmail` mutation
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum SetPrimaryEmailStatus {
/// The email address was set as primary
Set,
/// The email address was not found
NotFound,
/// Can't make an unverified email address primary
Unverified,
}
/// The payload of the `setPrimaryEmail` mutation
#[derive(Description)]
enum SetPrimaryEmailPayload {
Set(mas_data_model::User),
NotFound,
Unverified,
}
#[Object(use_type_description)]
impl SetPrimaryEmailPayload {
async fn status(&self) -> SetPrimaryEmailStatus {
match self {
SetPrimaryEmailPayload::Set(_) => SetPrimaryEmailStatus::Set,
SetPrimaryEmailPayload::NotFound => SetPrimaryEmailStatus::NotFound,
SetPrimaryEmailPayload::Unverified => SetPrimaryEmailStatus::Unverified,
}
}
/// The user to whom the email address belongs
async fn user(&self) -> Option<User> {
match self {
SetPrimaryEmailPayload::Set(user) => Some(User(user.clone())),
SetPrimaryEmailPayload::NotFound | SetPrimaryEmailPayload::Unverified => None,
}
}
}
#[Object]
impl UserEmailMutations {
/// Add an email address to the specified user
async fn add_email(
&self,
ctx: &Context<'_>,
input: AddEmailInput,
) -> Result<AddEmailPayload, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
if !requester.is_owner_or_admin(&UserId(id)) {
return Err(async_graphql::Error::new("Unauthorized"));
}
// Allow non-admins to change their email address if the site config allows it
if !requester.is_admin() && !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new("Unauthorized"));
}
// Only admins can skip validation
if (input.skip_verification.is_some() || input.skip_policy_check.is_some())
&& !requester.is_admin()
{
return Err(async_graphql::Error::new("Unauthorized"));
}
let skip_verification = input.skip_verification.unwrap_or(false);
let skip_policy_check = input.skip_policy_check.unwrap_or(false);
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(id)
.await?
.context("Failed to load user")?;
// XXX: this logic should be extracted somewhere else, since most of it is
// duplicated in mas_handlers
// Validate the email address
if input.email.parse::<lettre::Address>().is_err() {
return Ok(AddEmailPayload::Invalid);
}
if !skip_policy_check {
let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
});
}
}
// Find an existing email address
let existing_user_email = repo.user_email().find(&user, &input.email).await?;
let (added, mut user_email) = if let Some(user_email) = existing_user_email {
(false, user_email)
} else {
let clock = state.clock();
let mut rng = state.rng();
let user_email = repo
.user_email()
.add(&mut rng, &clock, &user, input.email)
.await?;
(true, user_email)
};
// Schedule a job to verify the email address if needed
if user_email.confirmed_at.is_none() {
if skip_verification {
user_email = repo
.user_email()
.mark_as_verified(&state.clock(), user_email)
.await?;
} else {
// TODO: figure out the locale
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
}
}
repo.save().await?;
let payload = if added {
AddEmailPayload::Added(user_email)
} else {
AddEmailPayload::Exists(user_email)
};
Ok(payload)
}
/// Send a verification code for an email address
async fn send_verification_email(
&self,
ctx: &Context<'_>,
input: SendVerificationEmailInput,
) -> Result<SendVerificationEmailPayload, async_graphql::Error> {
let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let user_email = repo
.user_email()
.lookup(user_email_id)
.await?
.context("User email not found")?;
if !requester.is_owner_or_admin(&user_email) {
return Err(async_graphql::Error::new("User email not found"));
}
// Schedule a job to verify the email address if needed
let needs_verification = user_email.confirmed_at.is_none();
if needs_verification {
// TODO: figure out the locale
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
}
repo.save().await?;
let payload = if needs_verification {
SendVerificationEmailPayload::Sent(user_email)
} else {
SendVerificationEmailPayload::AlreadyVerified(user_email)
};
Ok(payload)
}
/// Submit a verification code for an email address
async fn verify_email(
&self,
ctx: &Context<'_>,
input: VerifyEmailInput,
) -> Result<VerifyEmailPayload, async_graphql::Error> {
let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?;
let requester = ctx.requester();
let clock = state.clock();
let mut repo = state.repository().await?;
let user_email = repo
.user_email()
.lookup(user_email_id)
.await?
.context("User email not found")?;
if !requester.is_owner_or_admin(&user_email) {
return Err(async_graphql::Error::new("User email not found"));
}
if user_email.confirmed_at.is_some() {
// Just return the email address if it's already verified
// XXX: should we return an error instead?
return Ok(VerifyEmailPayload::AlreadyVerified(user_email));
}
// XXX: this logic should be extracted somewhere else, since most of it is
// duplicated in mas_handlers
// Find the verification code
let verification = repo
.user_email()
.find_verification_code(&clock, &user_email, &input.code)
.await?
.filter(|v| v.is_valid());
let Some(verification) = verification else {
return Ok(VerifyEmailPayload::InvalidCode);
};
repo.user_email()
.consume_verification_code(&clock, verification)
.await?;
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Failed to load user")?;
// XXX: is this the right place to do this?
if user.primary_user_email_id.is_none() {
repo.user_email().set_as_primary(&user_email).await?;
}
let user_email = repo
.user_email()
.mark_as_verified(&clock, user_email)
.await?;
repo.job()
.schedule_job(ProvisionUserJob::new(&user))
.await?;
repo.save().await?;
Ok(VerifyEmailPayload::Verified(user_email))
}
/// Remove an email address
async fn remove_email(
&self,
ctx: &Context<'_>,
input: RemoveEmailInput,
) -> Result<RemoveEmailPayload, async_graphql::Error> {
let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let user_email = repo.user_email().lookup(user_email_id).await?;
let Some(user_email) = user_email else {
return Ok(RemoveEmailPayload::NotFound);
};
if !requester.is_owner_or_admin(&user_email) {
return Ok(RemoveEmailPayload::NotFound);
}
// Allow non-admins to remove their email address if the site config allows it
if !requester.is_admin() && !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new("Unauthorized"));
}
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Failed to load user")?;
if user.primary_user_email_id == Some(user_email.id) {
// Prevent removing the primary email address
return Ok(RemoveEmailPayload::Primary(user_email));
}
repo.user_email().remove(user_email.clone()).await?;
// Schedule a job to update the user
repo.job()
.schedule_job(ProvisionUserJob::new(&user))
.await?;
repo.save().await?;
Ok(RemoveEmailPayload::Removed(user_email))
}
/// Set an email address as primary
async fn set_primary_email(
&self,
ctx: &Context<'_>,
input: SetPrimaryEmailInput,
) -> Result<SetPrimaryEmailPayload, async_graphql::Error> {
let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&input.user_email_id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let user_email = repo.user_email().lookup(user_email_id).await?;
let Some(user_email) = user_email else {
return Ok(SetPrimaryEmailPayload::NotFound);
};
if !requester.is_owner_or_admin(&user_email) {
return Err(async_graphql::Error::new("Unauthorized"));
}
// Allow non-admins to change their primary email address if the site config
// allows it
if !requester.is_admin() && !state.site_config().email_change_allowed {
return Err(async_graphql::Error::new("Unauthorized"));
}
if user_email.confirmed_at.is_none() {
return Ok(SetPrimaryEmailPayload::Unverified);
}
repo.user_email().set_as_primary(&user_email).await?;
// The user primary email should already be up to date
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Failed to load user")?;
repo.save().await?;
Ok(SetPrimaryEmailPayload::Set(user))
}
}

View File

@@ -0,0 +1,294 @@
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Context, MergedObject, Object, ID};
use mas_storage::user::UserRepository;
use crate::graphql::{
model::{
Anonymous, BrowserSession, CompatSession, Node, NodeType, OAuth2Client, OAuth2Session,
SiteConfig, User, UserEmail,
},
state::ContextExt,
UserId,
};
mod session;
mod upstream_oauth;
mod viewer;
use self::{session::SessionQuery, upstream_oauth::UpstreamOAuthQuery, viewer::ViewerQuery};
/// The query root of the GraphQL interface.
#[derive(Default, MergedObject)]
pub struct Query(BaseQuery, UpstreamOAuthQuery, SessionQuery, ViewerQuery);
impl Query {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
#[derive(Default)]
struct BaseQuery;
// TODO: move the rest of the queries in separate modules
#[Object]
impl BaseQuery {
/// Get the current logged in browser session
#[graphql(deprecation = "Use `viewerSession` instead.")]
async fn current_browser_session(
&self,
ctx: &Context<'_>,
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let requester = ctx.requester();
Ok(requester
.browser_session()
.cloned()
.map(BrowserSession::from))
}
/// Get the current logged in user
#[graphql(deprecation = "Use `viewer` instead.")]
async fn current_user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let requester = ctx.requester();
Ok(requester.user().cloned().map(User::from))
}
/// Fetch an OAuth 2.0 client by its ID.
async fn oauth2_client(
&self,
ctx: &Context<'_>,
id: ID,
) -> Result<Option<OAuth2Client>, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::OAuth2Client.extract_ulid(&id)?;
let mut repo = state.repository().await?;
let client = repo.oauth2_client().lookup(id).await?;
repo.cancel().await?;
Ok(client.map(OAuth2Client))
}
/// Fetch a user by its ID.
async fn user(&self, ctx: &Context<'_>, id: ID) -> Result<Option<User>, async_graphql::Error> {
let id = NodeType::User.extract_ulid(&id)?;
let requester = ctx.requester();
if !requester.is_owner_or_admin(&UserId(id)) {
return Ok(None);
}
// We could avoid the database lookup if the requester is the user we're looking
// for but that would make the code more complex and we're not very
// concerned about performance yet
let state = ctx.state();
let mut repo = state.repository().await?;
let user = repo.user().lookup(id).await?;
repo.cancel().await?;
Ok(user.map(User))
}
/// Fetch a user by its username.
async fn user_by_username(
&self,
ctx: &Context<'_>,
username: String,
) -> Result<Option<User>, async_graphql::Error> {
let requester = ctx.requester();
let state = ctx.state();
let mut repo = state.repository().await?;
let user = repo.user().find_by_username(&username).await?;
let Some(user) = user else {
// We don't want to leak the existence of a user
return Ok(None);
};
// Users can only see themselves, except for admins
if !requester.is_owner_or_admin(&user) {
return Ok(None);
}
Ok(Some(User(user)))
}
/// Fetch a browser session by its ID.
async fn browser_session(
&self,
ctx: &Context<'_>,
id: ID,
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::BrowserSession.extract_ulid(&id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let browser_session = repo.browser_session().lookup(id).await?;
repo.cancel().await?;
let Some(browser_session) = browser_session else {
return Ok(None);
};
if !requester.is_owner_or_admin(&browser_session) {
return Ok(None);
}
Ok(Some(BrowserSession(browser_session)))
}
/// Fetch a compatible session by its ID.
async fn compat_session(
&self,
ctx: &Context<'_>,
id: ID,
) -> Result<Option<CompatSession>, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::CompatSession.extract_ulid(&id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let compat_session = repo.compat_session().lookup(id).await?;
repo.cancel().await?;
let Some(compat_session) = compat_session else {
return Ok(None);
};
if !requester.is_owner_or_admin(&compat_session) {
return Ok(None);
}
Ok(Some(CompatSession::new(compat_session)))
}
/// Fetch an OAuth 2.0 session by its ID.
async fn oauth2_session(
&self,
ctx: &Context<'_>,
id: ID,
) -> Result<Option<OAuth2Session>, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::OAuth2Session.extract_ulid(&id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let oauth2_session = repo.oauth2_session().lookup(id).await?;
repo.cancel().await?;
let Some(oauth2_session) = oauth2_session else {
return Ok(None);
};
if !requester.is_owner_or_admin(&oauth2_session) {
return Ok(None);
}
Ok(Some(OAuth2Session(oauth2_session)))
}
/// Fetch a user email by its ID.
async fn user_email(
&self,
ctx: &Context<'_>,
id: ID,
) -> Result<Option<UserEmail>, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::UserEmail.extract_ulid(&id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let user_email = repo.user_email().lookup(id).await?;
repo.cancel().await?;
let Some(user_email) = user_email else {
return Ok(None);
};
if !requester.is_owner_or_admin(&user_email) {
return Ok(None);
}
Ok(Some(UserEmail(user_email)))
}
/// Fetches an object given its ID.
async fn node(&self, ctx: &Context<'_>, id: ID) -> Result<Option<Node>, async_graphql::Error> {
// Special case for the anonymous user
if id.as_str() == "anonymous" {
return Ok(Some(Node::Anonymous(Box::new(Anonymous))));
}
if id.as_str() == crate::graphql::model::SITE_CONFIG_ID {
return Ok(Some(Node::SiteConfig(Box::new(SiteConfig::new(
ctx.state().site_config(),
)))));
}
let (node_type, _id) = NodeType::from_id(&id)?;
let ret = match node_type {
// TODO
NodeType::Authentication | NodeType::CompatSsoLogin => None,
NodeType::UpstreamOAuth2Provider => UpstreamOAuthQuery
.upstream_oauth2_provider(ctx, id)
.await?
.map(|c| Node::UpstreamOAuth2Provider(Box::new(c))),
NodeType::UpstreamOAuth2Link => UpstreamOAuthQuery
.upstream_oauth2_link(ctx, id)
.await?
.map(|c| Node::UpstreamOAuth2Link(Box::new(c))),
NodeType::OAuth2Client => self
.oauth2_client(ctx, id)
.await?
.map(|c| Node::OAuth2Client(Box::new(c))),
NodeType::UserEmail => self
.user_email(ctx, id)
.await?
.map(|e| Node::UserEmail(Box::new(e))),
NodeType::CompatSession => self
.compat_session(ctx, id)
.await?
.map(|s| Node::CompatSession(Box::new(s))),
NodeType::OAuth2Session => self
.oauth2_session(ctx, id)
.await?
.map(|s| Node::OAuth2Session(Box::new(s))),
NodeType::BrowserSession => self
.browser_session(ctx, id)
.await?
.map(|s| Node::BrowserSession(Box::new(s))),
NodeType::User => self.user(ctx, id).await?.map(|u| Node::User(Box::new(u))),
};
Ok(ret)
}
/// Get the current site configuration
async fn site_config(&self, ctx: &Context<'_>) -> SiteConfig {
SiteConfig::new(ctx.state().site_config())
}
}

View File

@@ -0,0 +1,118 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Context, Object, Union, ID};
use mas_data_model::Device;
use mas_storage::{
compat::{CompatSessionFilter, CompatSessionRepository},
oauth2::OAuth2SessionFilter,
Pagination, RepositoryAccess,
};
use oauth2_types::scope::Scope;
use crate::graphql::{
model::{CompatSession, NodeType, OAuth2Session},
state::ContextExt,
UserId,
};
#[derive(Default)]
pub struct SessionQuery;
/// A client session, either compat or OAuth 2.0
#[derive(Union)]
enum Session {
CompatSession(Box<CompatSession>),
OAuth2Session(Box<OAuth2Session>),
}
#[Object]
impl SessionQuery {
/// Lookup a compat or OAuth 2.0 session
async fn session(
&self,
ctx: &Context<'_>,
user_id: ID,
device_id: String,
) -> Result<Option<Session>, async_graphql::Error> {
let user_id = NodeType::User.extract_ulid(&user_id)?;
let requester = ctx.requester();
if !requester.is_owner_or_admin(&UserId(user_id)) {
return Ok(None);
}
let Ok(device) = Device::try_from(device_id) else {
return Ok(None);
};
let state = ctx.state();
let mut repo = state.repository().await?;
// Lookup the user
let Some(user) = repo.user().lookup(user_id).await? else {
return Ok(None);
};
// First, try to find a compat session
let filter = CompatSessionFilter::new()
.for_user(&user)
.active_only()
.for_device(&device);
// We only want most recent session
let pagination = Pagination::last(1);
let compat_sessions = repo.compat_session().list(filter, pagination).await?;
if compat_sessions.has_previous_page {
// XXX: should we bail out?
tracing::warn!(
"Found more than one active session with device {device} for user {user_id}"
);
}
if let Some((compat_session, sso_login)) = compat_sessions.edges.into_iter().next() {
repo.cancel().await?;
return Ok(Some(Session::CompatSession(Box::new(
CompatSession::new(compat_session).with_loaded_sso_login(sso_login),
))));
}
// Then, try to find an OAuth 2.0 session. Because we don't have any dedicated
// device column, we're looking up using the device scope.
let scope = Scope::from_iter([device.to_scope_token()]);
let filter = OAuth2SessionFilter::new()
.for_user(&user)
.active_only()
.with_scope(&scope);
let sessions = repo.oauth2_session().list(filter, pagination).await?;
// It's possible to have multiple active OAuth 2.0 sessions. For now, we just
// log it if it is the case
if sessions.has_previous_page {
// XXX: should we bail out?
tracing::warn!(
"Found more than one active session with device {device} for user {user_id}"
);
}
if let Some(session) = sessions.edges.into_iter().next() {
repo.cancel().await?;
return Ok(Some(Session::OAuth2Session(Box::new(OAuth2Session(
session,
)))));
}
repo.cancel().await?;
Ok(None)
}
}

View File

@@ -0,0 +1,153 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor},
Context, Object, ID,
};
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderFilter, Pagination, RepositoryAccess};
use crate::graphql::{
model::{
Cursor, NodeCursor, NodeType, PreloadedTotalCount, UpstreamOAuth2Link,
UpstreamOAuth2Provider,
},
state::ContextExt,
};
#[derive(Default)]
pub struct UpstreamOAuthQuery;
#[Object]
impl UpstreamOAuthQuery {
/// Fetch an upstream OAuth 2.0 link by its ID.
pub async fn upstream_oauth2_link(
&self,
ctx: &Context<'_>,
id: ID,
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let requester = ctx.requester();
let mut repo = state.repository().await?;
let link = repo.upstream_oauth_link().lookup(id).await?;
repo.cancel().await?;
let Some(link) = link else {
return Ok(None);
};
if !requester.is_owner_or_admin(&link) {
return Ok(None);
}
Ok(Some(UpstreamOAuth2Link::new(link)))
}
/// Fetch an upstream OAuth 2.0 provider by its ID.
pub async fn upstream_oauth2_provider(
&self,
ctx: &Context<'_>,
id: ID,
) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> {
let state = ctx.state();
let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?;
let mut repo = state.repository().await?;
let provider = repo.upstream_oauth_provider().lookup(id).await?;
repo.cancel().await?;
let Some(provider) = provider else {
return Ok(None);
};
// We only allow enabled providers to be fetched
if !provider.enabled() {
return Ok(None);
}
Ok(Some(UpstreamOAuth2Provider::new(provider)))
}
/// Get a list of upstream OAuth 2.0 providers.
async fn upstream_oauth2_providers(
&self,
ctx: &Context<'_>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Provider, PreloadedTotalCount>, async_graphql::Error>
{
let state = ctx.state();
let mut repo = state.repository().await?;
query(
after,
before,
first,
last,
|after, before, first, last| async move {
let after_id = after
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
})
.transpose()?;
let before_id = before
.map(|x: OpaqueCursor<NodeCursor>| {
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
})
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
// We only want enabled providers
// XXX: we may want to let admins see disabled providers
let filter = UpstreamOAuthProviderFilter::new().enabled_only();
let page = repo
.upstream_oauth_provider()
.list(filter, pagination)
.await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.upstream_oauth_provider().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|p| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)),
UpstreamOAuth2Provider::new(p),
)
}));
Ok::<_, async_graphql::Error>(connection)
},
)
.await
}
}

View File

@@ -0,0 +1,52 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Context, Object};
use crate::graphql::{
model::{Viewer, ViewerSession},
state::ContextExt,
Requester,
};
#[derive(Default)]
pub struct ViewerQuery;
#[Object]
impl ViewerQuery {
/// Get the viewer
async fn viewer(&self, ctx: &Context<'_>) -> Viewer {
let requester = ctx.requester();
match requester {
Requester::BrowserSession(session) => Viewer::user(session.user.clone()),
Requester::OAuth2Session(tuple) => match &tuple.1 {
Some(user) => Viewer::user(user.clone()),
None => Viewer::anonymous(),
},
Requester::Anonymous => Viewer::anonymous(),
}
}
/// Get the viewer's session
async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession {
let requester = ctx.requester();
match requester {
Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()),
Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()),
Requester::Anonymous => ViewerSession::anonymous(),
}
}
}

View File

@@ -0,0 +1,48 @@
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use mas_data_model::SiteConfig;
use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
use crate::graphql::Requester;
#[async_trait::async_trait]
pub trait State {
async fn repository(&self) -> Result<BoxRepository, RepositoryError>;
async fn policy(&self) -> Result<Policy, mas_policy::InstantiateError>;
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error>;
fn clock(&self) -> BoxClock;
fn rng(&self) -> BoxRng;
fn site_config(&self) -> &SiteConfig;
}
pub type BoxState = Box<dyn State + Send + Sync + 'static>;
pub trait ContextExt {
fn state(&self) -> &BoxState;
fn requester(&self) -> &Requester;
}
impl ContextExt for async_graphql::Context<'_> {
fn state(&self) -> &BoxState {
self.data_unchecked()
}
fn requester(&self) -> &Requester {
self.data_unchecked()
}
}

View File

@@ -90,7 +90,9 @@ pub use mas_axum_utils::{
pub use self::{
activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
graphql::schema as graphql_schema,
graphql::{
schema as graphql_schema, schema_builder as graphql_schema_builder, Schema as GraphQLSchema,
},
preferred_language::PreferredLanguage,
upstream_oauth2::cache::MetadataCache,
};
@@ -110,7 +112,7 @@ where
<B as HttpBody>::Data: Into<Bytes>,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
mas_graphql::Schema: FromRef<S>,
graphql::Schema: FromRef<S>,
BoundActivityTracker: FromRequestParts<S>,
BoxRepository: FromRequestParts<S>,
BoxClock: FromRequestParts<S>,

View File

@@ -54,6 +54,7 @@ use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
graphql,
passwords::{Hasher, PasswordManager},
upstream_oauth2::cache::MetadataCache,
ActivityTracker, BoundActivityTracker,
@@ -102,7 +103,7 @@ pub(crate) struct TestState {
pub url_builder: UrlBuilder,
pub homeserver_connection: Arc<MockHomeserverConnection>,
pub policy_factory: Arc<PolicyFactory>,
pub graphql_schema: mas_graphql::Schema,
pub graphql_schema: graphql::Schema,
pub http_client_factory: HttpClientFactory,
pub password_manager: PasswordManager,
pub site_config: SiteConfig,
@@ -198,9 +199,9 @@ impl TestState {
rng: Arc::clone(&rng),
clock: Arc::clone(&clock),
};
let state: mas_graphql::BoxState = Box::new(graphql_state);
let state: crate::graphql::BoxState = Box::new(graphql_state);
let graphql_schema = mas_graphql::schema_builder().data(state).finish();
let graphql_schema = graphql::schema_builder().data(state).finish();
let activity_tracker =
ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1));
@@ -316,7 +317,7 @@ struct TestGraphQLState {
}
#[async_trait]
impl mas_graphql::State for TestGraphQLState {
impl graphql::State for TestGraphQLState {
async fn repository(&self) -> Result<BoxRepository, mas_storage::RepositoryError> {
let repo = PgRepository::from_pool(&self.pool)
.await
@@ -356,7 +357,7 @@ impl FromRef<TestState> for PgPool {
}
}
impl FromRef<TestState> for mas_graphql::Schema {
impl FromRef<TestState> for graphql::Schema {
fn from_ref(input: &TestState) -> Self {
input.graphql_schema.clone()
}