1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Have a Requester in the GraphQL API, in preparation for accessing it with OAuth credentials

This commit is contained in:
Quentin Gliech
2023-04-21 14:32:32 +02:00
parent be765fe04f
commit c2d8243586
6 changed files with 86 additions and 65 deletions

View File

@@ -27,6 +27,7 @@
)] )]
use async_graphql::EmptySubscription; use async_graphql::EmptySubscription;
use mas_data_model::{BrowserSession, User};
mod model; mod model;
mod mutations; mod mutations;
@@ -49,3 +50,42 @@ pub fn schema_builder() -> SchemaBuilder {
.register_output_type::<Node>() .register_output_type::<Node>()
.register_output_type::<CreationEvent>() .register_output_type::<CreationEvent>()
} }
/// The identity of the requester.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum Requester {
/// The requester presented no authentication information.
#[default]
Anonymous,
/// The requester is a browser session, stored in a cookie.
BrowserSession(BrowserSession),
}
impl Requester {
fn browser_session(&self) -> Option<&BrowserSession> {
match self {
Self::BrowserSession(session) => Some(session),
Self::Anonymous => None,
}
}
fn user(&self) -> Option<&User> {
self.browser_session().map(|session| &session.user)
}
}
impl From<BrowserSession> for Requester {
fn from(session: BrowserSession) -> Self {
Self::BrowserSession(session)
}
}
impl<T> From<Option<T>> for Requester
where
T: Into<Requester>,
{
fn from(session: Option<T>) -> Self {
session.map(Into::into).unwrap_or_default()
}
}

View File

@@ -50,13 +50,11 @@ impl RootMutations {
) -> Result<UserEmail, async_graphql::Error> { ) -> Result<UserEmail, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let id = NodeType::User.extract_ulid(&user_id)?; let id = NodeType::User.extract_ulid(&user_id)?;
let session = ctx.session(); let requester = ctx.requester();
let Some(session) = session else { let user = requester.user().context("Unauthorized")?;
return Err(async_graphql::Error::new("Unauthorized"));
};
if session.user.id != id { if user.id != id {
return Err(async_graphql::Error::new("Unauthorized")); return Err(async_graphql::Error::new("Unauthorized"));
} }
@@ -65,16 +63,14 @@ impl RootMutations {
// XXX: this logic should be extracted somewhere else, since most of it is // XXX: this logic should be extracted somewhere else, since most of it is
// duplicated in mas_handlers // duplicated in mas_handlers
// Find an existing email address // Find an existing email address
let existing_user_email = repo.user_email().find(&session.user, &email).await?; let existing_user_email = repo.user_email().find(user, &email).await?;
let user_email = if let Some(user_email) = existing_user_email { let user_email = if let Some(user_email) = existing_user_email {
user_email user_email
} else { } else {
let clock = state.clock(); let clock = state.clock();
let mut rng = state.rng(); let mut rng = state.rng();
repo.user_email() repo.user_email().add(&mut rng, &clock, user, email).await?
.add(&mut rng, &clock, &session.user, email)
.await?
}; };
// Schedule a job to verify the email address if needed // Schedule a job to verify the email address if needed
@@ -98,11 +94,8 @@ impl RootMutations {
) -> Result<UserEmail, async_graphql::Error> { ) -> Result<UserEmail, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?; let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?;
let session = ctx.session(); let requester = ctx.requester();
let user = requester.user().context("Unauthorized")?;
let Some(session) = session else {
return Err(async_graphql::Error::new("Unauthorized"));
};
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
@@ -112,7 +105,7 @@ impl RootMutations {
.await? .await?
.context("User email not found")?; .context("User email not found")?;
if user_email.user_id != session.user.id { if user_email.user_id != user.id {
return Err(async_graphql::Error::new("Unauthorized")); return Err(async_graphql::Error::new("Unauthorized"));
} }
@@ -138,11 +131,9 @@ impl RootMutations {
) -> Result<UserEmail, async_graphql::Error> { ) -> Result<UserEmail, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?; let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?;
let session = ctx.session(); let requester = ctx.requester();
let Some(session) = session else { let user = requester.user().context("Unauthorized")?;
return Err(async_graphql::Error::new("Unauthorized"));
};
let clock = state.clock(); let clock = state.clock();
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
@@ -153,7 +144,7 @@ impl RootMutations {
.await? .await?
.context("User email not found")?; .context("User email not found")?;
if user_email.user_id != session.user.id { if user_email.user_id != user.id {
return Err(async_graphql::Error::new("Unauthorized")); return Err(async_graphql::Error::new("Unauthorized"));
} }
@@ -173,7 +164,7 @@ impl RootMutations {
.await?; .await?;
// XXX: is this the right place to do this? // XXX: is this the right place to do this?
if session.user.primary_user_email_id.is_none() { if user.primary_user_email_id.is_none() {
repo.user_email().set_as_primary(&user_email).await?; repo.user_email().set_as_primary(&user_email).await?;
} }
@@ -182,9 +173,7 @@ impl RootMutations {
.mark_as_verified(&clock, user_email) .mark_as_verified(&clock, user_email)
.await?; .await?;
repo.job() repo.job().schedule_job(ProvisionUserJob::new(user)).await?;
.schedule_job(ProvisionUserJob::new(&session.user))
.await?;
repo.save().await?; repo.save().await?;

View File

@@ -46,14 +46,17 @@ impl RootQuery {
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
) -> Result<Option<BrowserSession>, async_graphql::Error> { ) -> Result<Option<BrowserSession>, async_graphql::Error> {
let session = ctx.session().cloned(); let requester = ctx.requester();
Ok(session.map(BrowserSession::from)) Ok(requester
.browser_session()
.cloned()
.map(BrowserSession::from))
} }
/// Get the current logged in user /// Get the current logged in user
async fn current_user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> { async fn current_user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let session = ctx.session().cloned(); let requester = ctx.requester();
Ok(session.map(User::from)) Ok(requester.user().cloned().map(User::from))
} }
/// Fetch an OAuth 2.0 client by its ID. /// Fetch an OAuth 2.0 client by its ID.
@@ -75,14 +78,12 @@ impl RootQuery {
/// Fetch a user by its ID. /// Fetch a user by its ID.
async fn user(&self, ctx: &Context<'_>, id: ID) -> Result<Option<User>, async_graphql::Error> { async fn user(&self, ctx: &Context<'_>, id: ID) -> Result<Option<User>, async_graphql::Error> {
let id = NodeType::User.extract_ulid(&id)?; let id = NodeType::User.extract_ulid(&id)?;
let requester = ctx.requester();
let session = ctx.session().cloned(); let Some(current_user) = requester.user() else { return Ok(None) };
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
if current_user.id == id { if current_user.id == id {
Ok(Some(User(current_user))) Ok(Some(User(current_user.clone())))
} else { } else {
Ok(None) Ok(None)
} }
@@ -96,13 +97,11 @@ impl RootQuery {
) -> Result<Option<BrowserSession>, async_graphql::Error> { ) -> Result<Option<BrowserSession>, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let id = NodeType::BrowserSession.extract_ulid(&id)?; let id = NodeType::BrowserSession.extract_ulid(&id)?;
let requester = ctx.requester();
let session = ctx.session().cloned(); let Some(current_user) = requester.user() else { return Ok(None) };
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let browser_session = repo.browser_session().lookup(id).await?; let browser_session = repo.browser_session().lookup(id).await?;
repo.cancel().await?; repo.cancel().await?;
@@ -126,13 +125,11 @@ impl RootQuery {
) -> Result<Option<UserEmail>, async_graphql::Error> { ) -> Result<Option<UserEmail>, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let id = NodeType::UserEmail.extract_ulid(&id)?; let id = NodeType::UserEmail.extract_ulid(&id)?;
let requester = ctx.requester();
let session = ctx.session().cloned(); let Some(current_user) = requester.user() else { return Ok(None) };
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let user_email = repo let user_email = repo
.user_email() .user_email()
.lookup(id) .lookup(id)
@@ -152,13 +149,11 @@ impl RootQuery {
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> { ) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let requester = ctx.requester();
let session = ctx.session().cloned(); let Some(current_user) = requester.user() else { return Ok(None) };
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let link = repo.upstream_oauth_link().lookup(id).await?; let link = repo.upstream_oauth_link().lookup(id).await?;
// Ensure that the link belongs to the current user // Ensure that the link belongs to the current user

View File

@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use mas_data_model::BrowserSession;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
use crate::Requester;
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait State { pub trait State {
async fn repository(&self) -> Result<BoxRepository, RepositoryError>; async fn repository(&self) -> Result<BoxRepository, RepositoryError>;
@@ -27,15 +28,15 @@ pub type BoxState = Box<dyn State + Send + Sync + 'static>;
pub trait ContextExt { pub trait ContextExt {
fn state(&self) -> &BoxState; fn state(&self) -> &BoxState;
fn session(&self) -> Option<&BrowserSession>; fn requester(&self) -> &Requester;
} }
impl ContextExt for async_graphql::Context<'_> { impl ContextExt for async_graphql::Context<'_> {
fn state(&self) -> &BoxState { fn state(&self) -> &BoxState {
self.data_unchecked::<BoxState>() self.data_unchecked()
} }
fn session(&self) -> Option<&BrowserSession> { fn requester(&self) -> &Requester {
self.data_opt() self.data_unchecked()
} }
} }

View File

@@ -27,7 +27,7 @@ use futures_util::TryStreamExt;
use headers::{ContentType, HeaderValue}; use headers::{ContentType, HeaderValue};
use hyper::header::CACHE_CONTROL; use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema; use mas_graphql::{Requester, Schema};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, RepositoryError, SystemClock}; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, RepositoryError, SystemClock};
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
@@ -100,23 +100,21 @@ pub async fn post(
content_type: Option<TypedHeader<ContentType>>, content_type: Option<TypedHeader<ContentType>>,
body: BodyStream, body: BodyStream,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let requester = Requester::from(maybe_session);
repo.cancel().await?; repo.cancel().await?;
let mut request = async_graphql::http::receive_body( let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let request = async_graphql::http::receive_body(
content_type, content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(), .into_async_read(),
MultipartOptions::default(), MultipartOptions::default(),
) )
.await?; // XXX: this should probably return another error response? .await?
.data(requester); // XXX: this should probably return another error response?
if let Some(session) = maybe_session {
request = request.data(session);
}
let span = span_for_graphql_request(&request); let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await; let response = schema.execute(request).instrument(span).await;
@@ -140,13 +138,11 @@ pub async fn get(
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?; let maybe_session = session_info.load_session(&mut repo).await?;
let requester = Requester::from(maybe_session);
repo.cancel().await?; repo.cancel().await?;
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
if let Some(session) = maybe_session {
request = request.data(session);
}
let span = span_for_graphql_request(&request); let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await; let response = schema.execute(request).instrument(span).await;

View File

@@ -108,7 +108,7 @@ impl TestState {
let policy_factory = Arc::new(policy_factory); let policy_factory = Arc::new(policy_factory);
let graphql_schema = graphql_schema(); let graphql_schema = graphql_schema(&pool);
let http_client_factory = HttpClientFactory::new(10); let http_client_factory = HttpClientFactory::new(10);