1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

data-model: simplify the oauth2 clients

This commit is contained in:
Quentin Gliech
2022-12-07 14:46:08 +01:00
parent 6d82199910
commit 92d6f5b087
12 changed files with 46 additions and 80 deletions

View File

@ -26,15 +26,12 @@ use axum::{
}; };
use headers::{authorization::Basic, Authorization}; use headers::{authorization::Basic, Authorization};
use http::{Request, StatusCode}; use http::{Request, StatusCode};
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend}; use mas_data_model::{Client, JwksOrJwksUri};
use mas_http::HttpServiceExt; use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::oauth2::client::{lookup_client_by_client_id, ClientFetchError};
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
PostgresqlBackend,
};
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use serde_json::Value; use serde_json::Value;
use sqlx::PgExecutor; use sqlx::PgExecutor;
@ -76,10 +73,7 @@ pub enum Credentials {
} }
impl Credentials { impl Credentials {
pub async fn fetch( pub async fn fetch(&self, executor: impl PgExecutor<'_>) -> Result<Client, ClientFetchError> {
&self,
executor: impl PgExecutor<'_>,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let client_id = match self { let client_id = match self {
Credentials::None { client_id } Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. } | Credentials::ClientSecretBasic { client_id, .. }
@ -91,12 +85,12 @@ impl Credentials {
} }
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
pub async fn verify<S: StorageBackend>( pub async fn verify(
&self, &self,
http_client_factory: &HttpClientFactory, http_client_factory: &HttpClientFactory,
encrypter: &Encrypter, encrypter: &Encrypter,
method: &OAuthClientAuthenticationMethod, method: &OAuthClientAuthenticationMethod,
client: &Client<S>, client: &Client,
) -> Result<(), CredentialsVerificationError> { ) -> Result<(), CredentialsVerificationError> {
match (self, method) { match (self, method) {
(Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {} (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}

View File

@ -165,7 +165,7 @@ pub struct AuthorizationGrant<T: StorageBackend> {
#[serde(flatten)] #[serde(flatten)]
pub stage: AuthorizationGrantStage<T>, pub stage: AuthorizationGrantStage<T>,
pub code: Option<AuthorizationCode>, pub code: Option<AuthorizationCode>,
pub client: Client<T>, pub client: Client,
pub redirect_uri: Url, pub redirect_uri: Url,
pub scope: oauth2_types::scope::Scope, pub scope: oauth2_types::scope::Scope,
pub state: Option<String>, pub state: Option<String>,
@ -183,7 +183,7 @@ impl<S: StorageBackendMarker> From<AuthorizationGrant<S>> for AuthorizationGrant
data: (), data: (),
stage: g.stage.into(), stage: g.stage.into(),
code: g.code, code: g.code,
client: g.client.into(), client: g.client,
redirect_uri: g.redirect_uri, redirect_uri: g.redirect_uri,
scope: g.scope, scope: g.scope,
state: g.state, state: g.state,

View File

@ -20,10 +20,9 @@ use mas_jose::jwk::PublicJsonWebKeySet;
use oauth2_types::requests::GrantType; use oauth2_types::requests::GrantType;
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid;
use url::Url; use url::Url;
use crate::traits::{StorageBackend, StorageBackendMarker};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum JwksOrJwksUri { pub enum JwksOrJwksUri {
@ -34,11 +33,9 @@ pub enum JwksOrJwksUri {
JwksUri(Url), JwksUri(Url),
} }
#[derive(Debug, Clone, PartialEq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(bound = "T: StorageBackend")] pub struct Client {
pub struct Client<T: StorageBackend> { pub id: Ulid,
#[serde(skip_serializing)]
pub data: T::ClientData,
/// Client identifier /// Client identifier
pub client_id: String, pub client_id: String,
@ -98,31 +95,6 @@ pub struct Client<T: StorageBackend> {
pub initiate_login_uri: Option<Url>, pub initiate_login_uri: Option<Url>,
} }
impl<S: StorageBackendMarker> From<Client<S>> for Client<()> {
fn from(c: Client<S>) -> Self {
Client {
data: (),
client_id: c.client_id,
encrypted_client_secret: c.encrypted_client_secret,
redirect_uris: c.redirect_uris,
response_types: c.response_types,
grant_types: c.grant_types,
contacts: c.contacts,
client_name: c.client_name,
logo_uri: c.logo_uri,
client_uri: c.client_uri,
policy_uri: c.policy_uri,
tos_uri: c.tos_uri,
jwks: c.jwks,
id_token_signed_response_alg: c.id_token_signed_response_alg,
userinfo_signed_response_alg: c.userinfo_signed_response_alg,
token_endpoint_auth_method: c.token_endpoint_auth_method,
token_endpoint_auth_signing_alg: c.token_endpoint_auth_signing_alg,
initiate_login_uri: c.initiate_login_uri,
}
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum InvalidRedirectUriError { pub enum InvalidRedirectUriError {
#[error("redirect_uri is not allowed for this client")] #[error("redirect_uri is not allowed for this client")]
@ -135,7 +107,7 @@ pub enum InvalidRedirectUriError {
NoneRegistered, NoneRegistered,
} }
impl<S: StorageBackend> Client<S> { impl Client {
pub fn resolve_redirect_uri<'a>( pub fn resolve_redirect_uri<'a>(
&'a self, &'a self,
redirect_uri: &'a Option<Url>, redirect_uri: &'a Option<Url>,

View File

@ -27,7 +27,7 @@ pub struct Session<T: StorageBackend> {
#[serde(skip_serializing)] #[serde(skip_serializing)]
pub data: T::SessionData, pub data: T::SessionData,
pub browser_session: BrowserSession, pub browser_session: BrowserSession,
pub client: Client<T>, pub client: Client,
pub scope: Scope, pub scope: Scope,
} }
@ -36,7 +36,7 @@ impl<S: StorageBackendMarker> From<Session<S>> for Session<()> {
Session { Session {
data: (), data: (),
browser_session: s.browser_session, browser_session: s.browser_session,
client: s.client.into(), client: s.client,
scope: s.scope, scope: s.scope,
} }
} }

View File

@ -56,13 +56,13 @@ impl OAuth2Session {
/// An OAuth 2.0 client /// An OAuth 2.0 client
#[derive(Description)] #[derive(Description)]
pub struct OAuth2Client(pub mas_data_model::Client<PostgresqlBackend>); pub struct OAuth2Client(pub mas_data_model::Client);
#[Object(use_type_description)] #[Object(use_type_description)]
impl OAuth2Client { impl OAuth2Client {
/// ID of the object. /// ID of the object.
pub async fn id(&self) -> ID { pub async fn id(&self) -> ID {
NodeType::OAuth2Client.id(self.0.data) NodeType::OAuth2Client.id(self.0.id)
} }
/// OAuth 2.0 client ID /// OAuth 2.0 client ID

View File

@ -43,7 +43,7 @@ use mas_storage::{
RefreshTokenLookupError, RefreshTokenLookupError,
}, },
}, },
DatabaseInconsistencyError, LookupError, PostgresqlBackend, DatabaseInconsistencyError, LookupError,
}; };
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
@ -239,7 +239,7 @@ pub(crate) async fn post(
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
async fn authorization_code_grant( async fn authorization_code_grant(
grant: &AuthorizationCodeGrant, grant: &AuthorizationCodeGrant,
client: &Client<PostgresqlBackend>, client: &Client,
key_store: &Keystore, key_store: &Keystore,
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
@ -391,7 +391,7 @@ async fn authorization_code_grant(
async fn refresh_token_grant( async fn refresh_token_grant(
grant: &RefreshTokenGrant, grant: &RefreshTokenGrant,
client: &Client<PostgresqlBackend>, client: &Client,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?; let (clock, mut rng) = crate::rng_and_clock()?;

View File

@ -28,7 +28,7 @@ use crate::{Clock, DatabaseInconsistencyError, LookupError, PostgresqlBackend};
skip_all, skip_all,
fields( fields(
session.id = %session.data, session.id = %session.data,
client.id = %session.client.data, client.id = %session.client.id,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
access_token.id, access_token.id,
), ),

View File

@ -36,7 +36,7 @@ use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
client.id = %client.data, %client.id,
grant.id, grant.id,
), ),
err(Debug), err(Debug),
@ -46,7 +46,7 @@ pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
client: Client<PostgresqlBackend>, client: Client,
redirect_uri: Url, redirect_uri: Url,
scope: Scope, scope: Scope,
code: Option<AuthorizationCode>, code: Option<AuthorizationCode>,
@ -97,7 +97,7 @@ pub async fn new_authorization_grant(
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(client.data), Uuid::from(client.id),
redirect_uri.to_string(), redirect_uri.to_string(),
scope.to_string(), scope.to_string(),
state, state,
@ -498,7 +498,7 @@ pub async fn lookup_grant_by_code(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, grant.id = %grant.data,
client.id = %grant.client.data, client.id = %grant.client.id,
session.id, session.id,
user_session.id = %browser_session.id, user_session.id = %browser_session.id,
user.id = %browser_session.user.id, user.id = %browser_session.user.id,
@ -552,7 +552,7 @@ pub async fn derive_session(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, grant.id = %grant.data,
client.id = %grant.client.data, client.id = %grant.client.id,
session.id = %session.data, session.id = %session.data,
user_session.id = %session.browser_session.id, user_session.id = %session.browser_session.id,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
@ -592,7 +592,7 @@ pub async fn fulfill_grant(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, grant.id = %grant.data,
client.id = %grant.client.data, client.id = %grant.client.id,
), ),
err(Debug), err(Debug),
)] )]
@ -622,7 +622,7 @@ pub async fn give_consent_to_grant(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, grant.id = %grant.data,
client.id = %grant.client.data, client.id = %grant.client.id,
), ),
err(Debug), err(Debug),
)] )]

View File

@ -28,7 +28,7 @@ use ulid::Ulid;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, LookupError, PostgresqlBackend}; use crate::{Clock, LookupError};
// XXX: response_types & contacts // XXX: response_types & contacts
#[derive(Debug)] #[derive(Debug)]
@ -90,11 +90,11 @@ impl LookupError for ClientFetchError {
} }
} }
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup { impl TryInto<Client> for OAuth2ClientLookup {
type Error = ClientFetchError; type Error = ClientFetchError;
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
fn try_into(self) -> Result<Client<PostgresqlBackend>, Self::Error> { fn try_into(self) -> Result<Client, Self::Error> {
let redirect_uris: Result<Vec<Url>, _> = let redirect_uris: Result<Vec<Url>, _> =
self.redirect_uris.iter().map(|s| s.parse()).collect(); self.redirect_uris.iter().map(|s| s.parse()).collect();
let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl { let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl {
@ -226,7 +226,7 @@ impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
let id = Ulid::from(self.oauth2_client_id); let id = Ulid::from(self.oauth2_client_id);
Ok(Client { Ok(Client {
data: id, id,
client_id: id.to_string(), client_id: id.to_string(),
encrypted_client_secret: self.encrypted_client_secret, encrypted_client_secret: self.encrypted_client_secret,
redirect_uris, redirect_uris,
@ -253,7 +253,7 @@ impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
pub async fn lookup_clients( pub async fn lookup_clients(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
ids: impl IntoIterator<Item = Ulid> + Send, ids: impl IntoIterator<Item = Ulid> + Send,
) -> Result<HashMap<Ulid, Client<PostgresqlBackend>>, ClientFetchError> { ) -> Result<HashMap<Ulid, Client>, ClientFetchError> {
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect(); let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2ClientLookup, OAuth2ClientLookup,
@ -289,9 +289,9 @@ pub async fn lookup_clients(
.fetch_all(executor) .fetch_all(executor)
.await?; .await?;
let clients: Result<HashMap<Ulid, Client<PostgresqlBackend>>, _> = res let clients: Result<HashMap<Ulid, Client>, _> = res
.into_iter() .into_iter()
.map(|r| r.try_into().map(|c: Client<PostgresqlBackend>| (c.data, c))) .map(|r| r.try_into().map(|c: Client| (c.id, c)))
.collect(); .collect();
clients clients
@ -305,7 +305,7 @@ pub async fn lookup_clients(
pub async fn lookup_client( pub async fn lookup_client(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
id: Ulid, id: Ulid,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> { ) -> Result<Client, ClientFetchError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2ClientLookup, OAuth2ClientLookup,
r#" r#"
@ -353,7 +353,7 @@ pub async fn lookup_client(
pub async fn lookup_client_by_client_id( pub async fn lookup_client_by_client_id(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
client_id: &str, client_id: &str,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> { ) -> Result<Client, ClientFetchError> {
let id: Ulid = client_id.parse()?; let id: Ulid = client_id.parse()?;
lookup_client(executor, id).await lookup_client(executor, id).await
} }

View File

@ -21,20 +21,20 @@ use sqlx::PgExecutor;
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{Clock, PostgresqlBackend}; use crate::Clock;
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
%user.id, %user.id,
client.id = %client.data, %client.id,
), ),
err(Debug), err(Debug),
)] )]
pub async fn fetch_client_consent( pub async fn fetch_client_consent(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user: &User, user: &User,
client: &Client<PostgresqlBackend>, client: &Client,
) -> Result<Scope, anyhow::Error> { ) -> Result<Scope, anyhow::Error> {
let scope_tokens: Vec<String> = sqlx::query_scalar!( let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#" r#"
@ -43,7 +43,7 @@ pub async fn fetch_client_consent(
WHERE user_id = $1 AND oauth2_client_id = $2 WHERE user_id = $1 AND oauth2_client_id = $2
"#, "#,
Uuid::from(user.id), Uuid::from(user.id),
Uuid::from(client.data), Uuid::from(client.id),
) )
.fetch_all(executor) .fetch_all(executor)
.await?; .await?;
@ -60,8 +60,8 @@ pub async fn fetch_client_consent(
skip_all, skip_all,
fields( fields(
%user.id, %user.id,
client.id = %client.data, %client.id,
scope = scope.to_string(), %scope,
), ),
err(Debug), err(Debug),
)] )]
@ -70,7 +70,7 @@ pub async fn insert_client_consent(
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
user: &User, user: &User,
client: &Client<PostgresqlBackend>, client: &Client,
scope: &Scope, scope: &Scope,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let now = clock.now(); let now = clock.now();
@ -93,7 +93,7 @@ pub async fn insert_client_consent(
"#, "#,
&ids, &ids,
Uuid::from(user.id), Uuid::from(user.id),
Uuid::from(client.data), Uuid::from(client.id),
&tokens, &tokens,
now, now,
) )

View File

@ -40,7 +40,7 @@ pub mod refresh_token;
session.id = %session.data, session.id = %session.data,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
user_session.id = %session.browser_session.id, user_session.id = %session.browser_session.id,
client.id = %session.client.data, client.id = %session.client.id,
), ),
err(Debug), err(Debug),
)] )]

View File

@ -32,7 +32,7 @@ use crate::{Clock, DatabaseInconsistencyError, LookupError, PostgresqlBackend};
session.id = %session.data, session.id = %session.data,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
user_session.id = %session.browser_session.id, user_session.id = %session.browser_session.id,
client.id = %session.client.data, client.id = %session.client.id,
refresh_token.id, refresh_token.id,
), ),
err(Debug), err(Debug),