1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

data-model: don't embed the client in the auth grant

This commit is contained in:
Quentin Gliech
2023-01-09 10:49:51 +01:00
parent fb7c6f4dd1
commit 39cd9a2578
8 changed files with 104 additions and 81 deletions

View File

@@ -23,12 +23,18 @@
clippy::type_repetition_in_bounds clippy::type_repetition_in_bounds
)] )]
use thiserror::Error;
pub(crate) mod compat; pub(crate) mod compat;
pub(crate) mod oauth2; pub(crate) mod oauth2;
pub(crate) mod tokens; pub(crate) mod tokens;
pub(crate) mod upstream_oauth2; pub(crate) mod upstream_oauth2;
pub(crate) mod users; pub(crate) mod users;
#[derive(Debug, Error)]
#[error("invalid state transition")]
pub struct InvalidTransitionError;
pub use self::{ pub use self::{
compat::{ compat::{
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,

View File

@@ -21,11 +21,11 @@ use oauth2_types::{
requests::ResponseMode, requests::ResponseMode,
}; };
use serde::Serialize; use serde::Serialize;
use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use super::{client::Client, session::Session}; use super::session::Session;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Pkce { pub struct Pkce {
@@ -53,10 +53,6 @@ pub struct AuthorizationCode {
pub pkce: Option<Pkce>, pub pkce: Option<Pkce>,
} }
#[derive(Debug, Error)]
#[error("invalid state transition")]
pub struct InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
#[serde(tag = "stage", rename_all = "lowercase")] #[serde(tag = "stage", rename_all = "lowercase")]
pub enum AuthorizationGrantStage { pub enum AuthorizationGrantStage {
@@ -132,7 +128,7 @@ pub struct AuthorizationGrant {
#[serde(flatten)] #[serde(flatten)]
pub stage: AuthorizationGrantStage, pub stage: AuthorizationGrantStage,
pub code: Option<AuthorizationCode>, pub code: Option<AuthorizationCode>,
pub client: Client, pub client_id: Ulid,
pub redirect_uri: Url, pub redirect_uri: Url,
pub scope: oauth2_types::scope::Scope, pub scope: oauth2_types::scope::Scope,
pub state: Option<String>, pub state: Option<String>,

View File

@@ -17,6 +17,8 @@ use oauth2_types::scope::Scope;
use serde::Serialize; use serde::Serialize;
use ulid::Ulid; use ulid::Ulid;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Session { pub struct Session {
pub id: Ulid, pub id: Ulid,
@@ -25,3 +27,14 @@ pub struct Session {
pub scope: Scope, pub scope: Scope,
pub finished_at: Option<DateTime<Utc>>, pub finished_at: Option<DateTime<Utc>>,
} }
impl Session {
pub fn finish(mut self, finished_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
if self.finished_at.is_some() {
return Err(InvalidTransitionError);
}
self.finished_at = Some(finished_at);
Ok(self)
}
}

View File

@@ -29,7 +29,7 @@ use mas_storage::{
oauth2::{ oauth2::{
authorization_grant::{fulfill_grant, get_grant_by_id}, authorization_grant::{fulfill_grant, get_grant_by_id},
consent::fetch_client_consent, consent::fetch_client_consent,
OAuth2SessionRepository, OAuth2ClientRepository, OAuth2SessionRepository,
}, },
Repository, Repository,
}; };
@@ -125,6 +125,7 @@ pub(crate) async fn get(
} }
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending), Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)), Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),
Err(e) => Err(RouteError::Internal(e.into())),
} }
} }
@@ -144,6 +145,9 @@ pub enum GrantCompletionError {
#[error("denied by the policy")] #[error("denied by the policy")]
PolicyViolation, PolicyViolation,
#[error("failed to load client")]
NoSuchClient,
} }
impl_from_error_for_route!(GrantCompletionError: sqlx::Error); impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
@@ -182,8 +186,13 @@ pub(crate) async fn complete(
return Err(GrantCompletionError::PolicyViolation); return Err(GrantCompletionError::PolicyViolation);
} }
let current_consent = let client = txn
fetch_client_consent(&mut txn, &browser_session.user, &grant.client).await?; .oauth2_client()
.lookup(grant.client_id)
.await?
.ok_or(GrantCompletionError::NoSuchClient)?;
let current_consent = fetch_client_consent(&mut txn, &browser_session.user, &client).await?;
let lacks_consent = grant let lacks_consent = grant
.scope .scope

View File

@@ -360,7 +360,10 @@ pub(crate) async fn get(
Err(GrantCompletionError::Internal(e)) => { Err(GrantCompletionError::Internal(e)) => {
return Err(RouteError::Internal(e)) return Err(RouteError::Internal(e))
} }
Err(e @ GrantCompletionError::NotPending) => { Err(
e @ (GrantCompletionError::NotPending
| GrantCompletionError::NoSuchClient),
) => {
// This should never happen // This should never happen
return Err(RouteError::Internal(Box::new(e))); return Err(RouteError::Internal(Box::new(e)));
} }
@@ -390,7 +393,10 @@ pub(crate) async fn get(
Err(GrantCompletionError::Internal(e)) => { Err(GrantCompletionError::Internal(e)) => {
return Err(RouteError::Internal(e)) return Err(RouteError::Internal(e))
} }
Err(e @ GrantCompletionError::NotPending) => { Err(
e @ (GrantCompletionError::NotPending
| GrantCompletionError::NoSuchClient),
) => {
// This should never happen // This should never happen
return Err(RouteError::Internal(Box::new(e))); return Err(RouteError::Internal(Box::new(e)));
} }

View File

@@ -28,9 +28,13 @@ use mas_data_model::AuthorizationGrantStage;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{ use mas_storage::{
oauth2::{
authorization_grant::{get_grant_by_id, give_consent_to_grant}, authorization_grant::{get_grant_by_id, give_consent_to_grant},
consent::insert_client_consent, consent::insert_client_consent,
OAuth2ClientRepository,
},
Repository,
}; };
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;
@@ -55,6 +59,9 @@ pub enum RouteError {
#[error("Policy violation")] #[error("Policy violation")]
PolicyViolation, PolicyViolation,
#[error("Failed to load client")]
NoSuchClient,
} }
impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(sqlx::Error);
@@ -160,6 +167,12 @@ pub(crate) async fn post(
return Err(RouteError::PolicyViolation); return Err(RouteError::PolicyViolation);
} }
let client = txn
.oauth2_client()
.lookup(grant.client_id)
.await?
.ok_or(RouteError::NoSuchClient)?;
// Do not consent for the "urn:matrix:org.matrix.msc2967.client:device:*" scope // Do not consent for the "urn:matrix:org.matrix.msc2967.client:device:*" scope
let scope_without_device = grant let scope_without_device = grant
.scope .scope
@@ -172,7 +185,7 @@ pub(crate) async fn post(
&mut rng, &mut rng,
&clock, &clock,
&session.user, &session.user,
&grant.client, &client,
&scope_without_device, &scope_without_device,
) )
.await?; .await?;

View File

@@ -26,8 +26,7 @@ use ulid::Ulid;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use super::OAuth2ClientRepository; use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@@ -116,7 +115,7 @@ pub async fn new_authorization_grant(
stage: AuthorizationGrantStage::Pending, stage: AuthorizationGrantStage::Pending,
code, code,
redirect_uri, redirect_uri,
client, client_id: client.id,
scope, scope,
state, state,
nonce, nonce,
@@ -151,35 +150,27 @@ struct GrantLookup {
oauth2_session_id: Option<Uuid>, oauth2_session_id: Option<Uuid>,
} }
impl GrantLookup { impl TryFrom<GrantLookup> for AuthorizationGrant {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
async fn into_authorization_grant( fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
self, let id = value.oauth2_authorization_grant_id.into();
conn: &mut PgConnection, let scope: Scope = value
) -> Result<AuthorizationGrant, DatabaseError> { .oauth2_authorization_grant_scope
let id = self.oauth2_authorization_grant_id.into(); .parse()
let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| { .map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants") DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("scope") .column("scope")
.row(id) .row(id)
.source(e) .source(e)
})?; })?;
let client = conn
.oauth2_client()
.lookup(self.oauth2_client_id.into())
.await?
.ok_or_else(|| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("client_id")
.row(id)
})?;
let stage = match ( let stage = match (
self.oauth2_authorization_grant_fulfilled_at, value.oauth2_authorization_grant_fulfilled_at,
self.oauth2_authorization_grant_exchanged_at, value.oauth2_authorization_grant_exchanged_at,
self.oauth2_authorization_grant_cancelled_at, value.oauth2_authorization_grant_cancelled_at,
self.oauth2_session_id, value.oauth2_session_id,
) { ) {
(None, None, None, None) => AuthorizationGrantStage::Pending, (None, None, None, None) => AuthorizationGrantStage::Pending,
(Some(fulfilled_at), None, None, Some(session_id)) => { (Some(fulfilled_at), None, None, Some(session_id)) => {
@@ -202,15 +193,14 @@ impl GrantLookup {
return Err( return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants") DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("stage") .column("stage")
.row(id) .row(id),
.into(),
); );
} }
}; };
let pkce = match ( let pkce = match (
self.oauth2_authorization_grant_code_challenge, value.oauth2_authorization_grant_code_challenge,
self.oauth2_authorization_grant_code_challenge_method, value.oauth2_authorization_grant_code_challenge_method,
) { ) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce { Some(Pkce {
@@ -227,15 +217,14 @@ impl GrantLookup {
return Err( return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants") DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("code_challenge_method") .column("code_challenge_method")
.row(id) .row(id),
.into(),
); );
} }
}; };
let code: Option<AuthorizationCode> = match ( let code: Option<AuthorizationCode> = match (
self.oauth2_authorization_grant_response_type_code, value.oauth2_authorization_grant_response_type_code,
self.oauth2_authorization_grant_code, value.oauth2_authorization_grant_code,
pkce, pkce,
) { ) {
(false, None, None) => None, (false, None, None) => None,
@@ -244,13 +233,12 @@ impl GrantLookup {
return Err( return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants") DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("authorization_code") .column("authorization_code")
.row(id) .row(id),
.into(),
); );
} }
}; };
let redirect_uri = self let redirect_uri = value
.oauth2_authorization_grant_redirect_uri .oauth2_authorization_grant_redirect_uri
.parse() .parse()
.map_err(|e| { .map_err(|e| {
@@ -260,7 +248,7 @@ impl GrantLookup {
.source(e) .source(e)
})?; })?;
let response_mode = self let response_mode = value
.oauth2_authorization_grant_response_mode .oauth2_authorization_grant_response_mode
.parse() .parse()
.map_err(|e| { .map_err(|e| {
@@ -270,7 +258,7 @@ impl GrantLookup {
.source(e) .source(e)
})?; })?;
let max_age = self let max_age = value
.oauth2_authorization_grant_max_age .oauth2_authorization_grant_max_age
.map(u32::try_from) .map(u32::try_from)
.transpose() .transpose()
@@ -292,17 +280,17 @@ impl GrantLookup {
Ok(AuthorizationGrant { Ok(AuthorizationGrant {
id, id,
stage, stage,
client, client_id: value.oauth2_client_id.into(),
code, code,
scope, scope,
state: self.oauth2_authorization_grant_state, state: value.oauth2_authorization_grant_state,
nonce: self.oauth2_authorization_grant_nonce, nonce: value.oauth2_authorization_grant_nonce,
max_age, max_age,
response_mode, response_mode,
redirect_uri, redirect_uri,
created_at: self.oauth2_authorization_grant_created_at, created_at: value.oauth2_authorization_grant_created_at,
response_type_id_token: self.oauth2_authorization_grant_response_type_id_token, response_type_id_token: value.oauth2_authorization_grant_response_type_id_token,
requires_consent: self.oauth2_authorization_grant_requires_consent, requires_consent: value.oauth2_authorization_grant_requires_consent,
}) })
} }
} }
@@ -351,9 +339,7 @@ pub async fn get_grant_by_id(
let Some(res) = res else { return Ok(None) }; let Some(res) = res else { return Ok(None) };
let grant = res.into_authorization_grant(&mut *conn).await?; Ok(Some(res.try_into()?))
Ok(Some(grant))
} }
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
@@ -396,16 +382,14 @@ pub async fn lookup_grant_by_code(
let Some(res) = res else { return Ok(None) }; let Some(res) = res else { return Ok(None) };
let grant = res.into_authorization_grant(&mut *conn).await?; Ok(Some(res.try_into()?))
Ok(Some(grant))
} }
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
%grant.id, %grant.id,
client.id = %grant.client.id, client.id = %grant.client_id,
%session.id, %session.id,
user_session.id = %session.user_session_id, user_session.id = %session.user_session_id,
), ),
@@ -446,7 +430,7 @@ pub async fn fulfill_grant(
skip_all, skip_all,
fields( fields(
%grant.id, %grant.id,
client.id = %grant.client.id, client.id = %grant.client_id,
), ),
err, err,
)] )]
@@ -476,7 +460,7 @@ pub async fn give_consent_to_grant(
skip_all, skip_all,
fields( fields(
%grant.id, %grant.id,
client.id = %grant.client.id, client.id = %grant.client_id,
), ),
err, err,
)] )]

View File

@@ -142,7 +142,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
%user_session.id, %user_session.id,
user.id = %user_session.user.id, user.id = %user_session.user.id,
%grant.id, %grant.id,
client.id = %grant.client.id, client.id = %grant.client_id,
session.id, session.id,
session.scope = %grant.scope, session.scope = %grant.scope,
), ),
@@ -172,7 +172,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(user_session.id), Uuid::from(user_session.id),
Uuid::from(grant.client.id), Uuid::from(grant.client_id),
grant.scope.to_string(), grant.scope.to_string(),
created_at, created_at,
) )
@@ -183,7 +183,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
Ok(Session { Ok(Session {
id, id,
user_session_id: user_session.id, user_session_id: user_session.id,
client_id: grant.client.id, client_id: grant.client_id,
scope: grant.scope.clone(), scope: grant.scope.clone(),
finished_at: None, finished_at: None,
}) })
@@ -201,11 +201,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
), ),
err, err,
)] )]
async fn finish( async fn finish(&mut self, clock: &Clock, session: Session) -> Result<Session, Self::Error> {
&mut self,
clock: &Clock,
mut session: Session,
) -> Result<Session, Self::Error> {
let finished_at = clock.now(); let finished_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
@@ -222,9 +218,9 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
DatabaseError::ensure_affected_rows(&res, 1)?; DatabaseError::ensure_affected_rows(&res, 1)?;
session.finished_at = Some(finished_at); session
.finish(finished_at)
Ok(session) .map_err(DatabaseError::to_invalid_operation)
} }
#[tracing::instrument( #[tracing::instrument(