1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: remaining oauth2 repositories

- authorization grants
 - access tokens
 - refresh tokens
This commit is contained in:
Quentin Gliech
2023-01-12 18:26:04 +01:00
parent 36396c0b45
commit 488a666a8d
17 changed files with 1700 additions and 1366 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,67 +12,61 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, AccessTokenState, Session};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, LookupResultExt};
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[tracing::instrument(
skip_all,
fields(
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
access_token.id,
),
err,
)]
pub async fn add_access_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, sqlx::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
#[async_trait]
pub trait OAuth2AccessTokenRepository: Send + Sync {
type Error;
tracing::Span::current().record("access_token.id", tracing::field::display(id));
/// Lookup an access token by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error>;
sqlx::query!(
r#"
INSERT INTO oauth2_access_tokens
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
&access_token,
created_at,
expires_at,
)
.execute(executor)
.await?;
/// Find an access token by its token
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<AccessToken>, Self::Error>;
Ok(AccessToken {
id,
state: AccessTokenState::default(),
access_token,
session_id: session.id,
created_at,
expires_at,
})
/// Add a new access token to the database
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, Self::Error>;
/// Revoke an access token
async fn revoke(
&mut self,
clock: &Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error>;
/// Cleanup expired access tokens
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error>;
}
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
pub struct PgOAuth2AccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid,
oauth2_session_id: Uuid,
access_token: String,
@ -99,118 +93,164 @@ impl From<OAuth2AccessTokenLookup> for AccessToken {
}
}
#[tracing::instrument(skip_all, err)]
pub async fn find_access_token(
conn: &mut PgConnection,
token: &str,
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
#[async_trait]
impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
type Error = DatabaseError;
FROM oauth2_access_tokens
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
WHERE access_token = $1
"#,
token,
)
.fetch_one(&mut *conn)
.await
.to_option()?;
FROM oauth2_access_tokens
let Some(res) = res else { return Ok(None) };
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(access_token.id = %access_token_id),
err,
)]
pub async fn lookup_access_token(
conn: &mut PgConnection,
access_token_id: Ulid,
) -> Result<Option<AccessToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token_id),
)
.fetch_one(&mut *conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(
%access_token.id,
session.id = %access_token.session_id,
),
err,
)]
pub async fn revoke_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: AccessToken,
) -> Result<AccessToken, DatabaseError> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token.id),
revoked_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
pub async fn cleanup_expired(
executor: impl PgExecutor<'_>,
clock: &Clock,
) -> Result<u64, sqlx::Error> {
// Cleanup token which expired more than 15 minutes ago
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE expires_at < $1
"#,
threshold,
)
.execute(executor)
.await?;
Ok(res.rows_affected())
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<AccessToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT oauth2_access_token_id
, access_token
, created_at
, expires_at
, revoked_at
, oauth2_session_id
FROM oauth2_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_access_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
access_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, Self::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_access_tokens
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
&access_token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(AccessToken {
id,
state: AccessTokenState::default(),
access_token,
session_id: session.id,
created_at,
expires_at,
})
}
async fn revoke(
&mut self,
clock: &Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token.id),
revoked_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
access_token
.revoke(revoked_at)
.map_err(DatabaseError::to_invalid_operation)
}
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error> {
// Cleanup token which expired more than 15 minutes ago
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE expires_at < $1
"#,
threshold,
)
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
}

View File

@ -14,138 +14,97 @@
use std::num::NonZeroU32;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
};
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[tracing::instrument(
skip_all,
fields(
%client.id,
grant.id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
client: Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
_acr_values: Option<String>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, sqlx::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
#[async_trait]
pub trait OAuth2AuthorizationGrantRepository {
type Error;
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("grant.id", tracing::field::display(id));
#[allow(clippy::too_many_arguments)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, Self::Error>;
sqlx::query!(
r#"
INSERT INTO oauth2_authorization_grants (
oauth2_authorization_grant_id,
oauth2_client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_mode,
code_challenge,
code_challenge_method,
response_type_code,
response_type_id_token,
authorization_code,
requires_consent,
created_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#,
Uuid::from(id),
Uuid::from(client.id),
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
max_age_i32,
response_mode.to_string(),
code_challenge,
code_challenge_method,
code.is_some(),
response_type_id_token,
code_str,
requires_consent,
created_at,
)
.execute(executor)
.await?;
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error>;
Ok(AuthorizationGrant {
id,
stage: AuthorizationGrantStage::Pending,
code,
redirect_uri,
client_id: client.id,
scope,
state,
nonce,
max_age,
response_mode,
created_at,
response_type_id_token,
requires_consent,
})
async fn find_by_code(&mut self, code: &str)
-> Result<Option<AuthorizationGrant>, Self::Error>;
async fn fulfill(
&mut self,
clock: &Clock,
session: &Session,
authorization_grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error>;
async fn exchange(
&mut self,
clock: &Clock,
authorization_grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error>;
async fn give_consent(
&mut self,
authorization_grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error>;
}
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[allow(clippy::struct_excessive_bools)]
struct GrantLookup {
oauth2_authorization_grant_id: Uuid,
oauth2_authorization_grant_created_at: DateTime<Utc>,
oauth2_authorization_grant_cancelled_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_fulfilled_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_exchanged_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_scope: String,
oauth2_authorization_grant_state: Option<String>,
oauth2_authorization_grant_nonce: Option<String>,
oauth2_authorization_grant_redirect_uri: String,
oauth2_authorization_grant_response_mode: String,
oauth2_authorization_grant_max_age: Option<i32>,
oauth2_authorization_grant_response_type_code: bool,
oauth2_authorization_grant_response_type_id_token: bool,
oauth2_authorization_grant_code: Option<String>,
oauth2_authorization_grant_code_challenge: Option<String>,
oauth2_authorization_grant_code_challenge_method: Option<String>,
oauth2_authorization_grant_requires_consent: bool,
created_at: DateTime<Utc>,
cancelled_at: Option<DateTime<Utc>>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
scope: String,
state: Option<String>,
nonce: Option<String>,
redirect_uri: String,
response_mode: String,
max_age: Option<i32>,
response_type_code: bool,
response_type_id_token: bool,
authorization_code: Option<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
requires_consent: bool,
oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>,
}
@ -156,20 +115,17 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
#[allow(clippy::too_many_lines)]
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
let id = value.oauth2_authorization_grant_id.into();
let scope: Scope = value
.oauth2_authorization_grant_scope
.parse()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("scope")
.row(id)
.source(e)
})?;
let scope: Scope = value.scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("scope")
.row(id)
.source(e)
})?;
let stage = match (
value.oauth2_authorization_grant_fulfilled_at,
value.oauth2_authorization_grant_exchanged_at,
value.oauth2_authorization_grant_cancelled_at,
value.fulfilled_at,
value.exchanged_at,
value.cancelled_at,
value.oauth2_session_id,
) {
(None, None, None, None) => AuthorizationGrantStage::Pending,
@ -198,10 +154,7 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
}
};
let pkce = match (
value.oauth2_authorization_grant_code_challenge,
value.oauth2_authorization_grant_code_challenge_method,
) {
let pkce = match (value.code_challenge, value.code_challenge_method) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce {
challenge_method: PkceCodeChallengeMethod::Plain,
@ -222,44 +175,35 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
}
};
let code: Option<AuthorizationCode> = match (
value.oauth2_authorization_grant_response_type_code,
value.oauth2_authorization_grant_code,
pkce,
) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("authorization_code")
.row(id),
);
}
};
let code: Option<AuthorizationCode> =
match (value.response_type_code, value.authorization_code, pkce) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("authorization_code")
.row(id),
);
}
};
let redirect_uri = value
.oauth2_authorization_grant_redirect_uri
.parse()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let redirect_uri = value.redirect_uri.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let response_mode = value
.oauth2_authorization_grant_response_mode
.parse()
.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("response_mode")
.row(id)
.source(e)
})?;
let response_mode = value.response_mode.parse().map_err(|e| {
DatabaseInconsistencyError::on("oauth2_authorization_grants")
.column("response_mode")
.row(id)
.source(e)
})?;
let max_age = value
.oauth2_authorization_grant_max_age
.max_age
.map(u32::try_from)
.transpose()
.map_err(|e| {
@ -283,209 +227,330 @@ impl TryFrom<GrantLookup> for AuthorizationGrant {
client_id: value.oauth2_client_id.into(),
code,
scope,
state: value.oauth2_authorization_grant_state,
nonce: value.oauth2_authorization_grant_nonce,
state: value.state,
nonce: value.nonce,
max_age,
response_mode,
redirect_uri,
created_at: value.oauth2_authorization_grant_created_at,
response_type_id_token: value.oauth2_authorization_grant_response_type_id_token,
requires_consent: value.oauth2_authorization_grant_requires_consent,
created_at: value.created_at,
response_type_id_token: value.response_type_id_token,
requires_consent: value.requires_consent,
})
}
}
#[tracing::instrument(
skip_all,
fields(grant.id = %id),
err,
)]
pub async fn get_grant_by_id(
conn: &mut PgConnection,
id: Ulid,
) -> Result<Option<AuthorizationGrant>, DatabaseError> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at AS oauth2_authorization_grant_created_at
, cancelled_at AS oauth2_authorization_grant_cancelled_at
, fulfilled_at AS oauth2_authorization_grant_fulfilled_at
, exchanged_at AS oauth2_authorization_grant_exchanged_at
, scope AS oauth2_authorization_grant_scope
, state AS oauth2_authorization_grant_state
, redirect_uri AS oauth2_authorization_grant_redirect_uri
, response_mode AS oauth2_authorization_grant_response_mode
, nonce AS oauth2_authorization_grant_nonce
, max_age AS oauth2_authorization_grant_max_age
, oauth2_client_id AS oauth2_client_id
, authorization_code AS oauth2_authorization_grant_code
, response_type_code AS oauth2_authorization_grant_response_type_code
, response_type_id_token AS oauth2_authorization_grant_response_type_id_token
, code_challenge AS oauth2_authorization_grant_code_challenge
, code_challenge_method AS oauth2_authorization_grant_code_challenge_method
, requires_consent AS oauth2_authorization_grant_requires_consent
, oauth2_session_id AS "oauth2_session_id?"
FROM
oauth2_authorization_grants
#[async_trait]
impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
type Error = DatabaseError;
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *conn)
.await
.to_option()?;
#[tracing::instrument(
name = "db.oauth2_authorization_grant.add",
skip_all,
fields(
db.statement,
grant.id,
grant.scope = %scope,
%client.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, Self::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let Some(res) = res else { return Ok(None) };
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("grant.id", tracing::field::display(id));
Ok(Some(res.try_into()?))
}
#[tracing::instrument(skip_all, err)]
pub async fn lookup_grant_by_code(
conn: &mut PgConnection,
code: &str,
) -> Result<Option<AuthorizationGrant>, DatabaseError> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at AS oauth2_authorization_grant_created_at
, cancelled_at AS oauth2_authorization_grant_cancelled_at
, fulfilled_at AS oauth2_authorization_grant_fulfilled_at
, exchanged_at AS oauth2_authorization_grant_exchanged_at
, scope AS oauth2_authorization_grant_scope
, state AS oauth2_authorization_grant_state
, redirect_uri AS oauth2_authorization_grant_redirect_uri
, response_mode AS oauth2_authorization_grant_response_mode
, nonce AS oauth2_authorization_grant_nonce
, max_age AS oauth2_authorization_grant_max_age
, oauth2_client_id AS oauth2_client_id
, authorization_code AS oauth2_authorization_grant_code
, response_type_code AS oauth2_authorization_grant_response_type_code
, response_type_id_token AS oauth2_authorization_grant_response_type_id_token
, code_challenge AS oauth2_authorization_grant_code_challenge
, code_challenge_method AS oauth2_authorization_grant_code_challenge_method
, requires_consent AS oauth2_authorization_grant_requires_consent
, oauth2_session_id AS "oauth2_session_id?"
FROM
oauth2_authorization_grants
WHERE authorization_code = $1
"#,
code,
)
.fetch_one(&mut *conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
skip_all,
fields(
%grant.id,
client.id = %grant.client_id,
%session.id,
user_session.id = %session.user_session_id,
),
err,
)]
pub async fn fulfill_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant,
session: Session,
) -> Result<AuthorizationGrant, DatabaseError> {
let fulfilled_at = sqlx::query_scalar!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
oauth2_session_id = os.oauth2_session_id,
fulfilled_at = os.created_at
FROM oauth2_sessions os
WHERE
og.oauth2_authorization_grant_id = $1
AND os.oauth2_session_id = $2
RETURNING fulfilled_at AS "fulfilled_at!: DateTime<Utc>"
"#,
Uuid::from(grant.id),
Uuid::from(session.id),
)
.fetch_one(executor)
.await?;
grant.stage = grant
.stage
.fulfill(fulfilled_at, &session)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
skip_all,
fields(
%grant.id,
client.id = %grant.client_id,
),
err,
)]
pub async fn give_consent_to_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, sqlx::Error> {
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
requires_consent = 'f'
WHERE
og.oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
)
.execute(executor)
.await?;
grant.requires_consent = false;
Ok(grant)
}
#[tracing::instrument(
skip_all,
fields(
%grant.id,
client.id = %grant.client_id,
),
err,
)]
pub async fn exchange_grant(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, DatabaseError> {
let exchanged_at = clock.now();
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
exchanged_at,
)
.execute(executor)
.await?;
grant.stage = grant
.stage
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
sqlx::query!(
r#"
INSERT INTO oauth2_authorization_grants (
oauth2_authorization_grant_id,
oauth2_client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_mode,
code_challenge,
code_challenge_method,
response_type_code,
response_type_id_token,
authorization_code,
requires_consent,
created_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#,
Uuid::from(id),
Uuid::from(client.id),
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
max_age_i32,
response_mode.to_string(),
code_challenge,
code_challenge_method,
code.is_some(),
response_type_id_token,
code_str,
requires_consent,
created_at,
)
.execute(&mut *self.conn)
.await?;
Ok(AuthorizationGrant {
id,
stage: AuthorizationGrantStage::Pending,
code,
redirect_uri,
client_id: client.id,
scope,
state,
nonce,
max_age,
response_mode,
created_at,
response_type_id_token,
requires_consent,
})
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.lookup",
skip_all,
fields(
db.statement,
grant.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.find_by_code",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_code(
&mut self,
code: &str,
) -> Result<Option<AuthorizationGrant>, Self::Error> {
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT oauth2_authorization_grant_id
, created_at
, cancelled_at
, fulfilled_at
, exchanged_at
, scope
, state
, redirect_uri
, response_mode
, nonce
, max_age
, oauth2_client_id
, authorization_code
, response_type_code
, response_type_id_token
, code_challenge
, code_challenge_method
, requires_consent
, oauth2_session_id
FROM
oauth2_authorization_grants
WHERE authorization_code = $1
"#,
code,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.fulfill",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
%session.id,
user_session.id = %session.user_session_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
session: &Session,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let fulfilled_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET fulfilled_at = $2
, oauth2_session_id = $3
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
fulfilled_at,
Uuid::from(session.id),
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
// XXX: check affected rows & new methods
let grant = grant
.fulfill(fulfilled_at, session)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.exchange",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
let exchanged_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
exchanged_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let grant = grant
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
#[tracing::instrument(
name = "db.oauth2_authorization_grant.give_consent",
skip_all,
fields(
db.statement,
%grant.id,
client.id = %grant.client_id,
),
err,
)]
async fn give_consent(
&mut self,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error> {
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
requires_consent = 'f'
WHERE
og.oauth2_authorization_grant_id = $1
"#,
Uuid::from(grant.id),
)
.execute(&mut *self.conn)
.await?;
grant.requires_consent = false;
Ok(grant)
}
}

View File

@ -14,17 +14,21 @@
use std::{
collections::{BTreeMap, BTreeSet},
str::FromStr,
string::ToString,
};
use async_trait::async_trait;
use mas_data_model::{Client, JwksOrJwksUri};
use mas_data_model::{Client, JwksOrJwksUri, User};
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_jose::jwk::PublicJsonWebKeySet;
use oauth2_types::requests::GrantType;
use oauth2_types::{
requests::GrantType,
scope::{Scope, ScopeToken},
};
use rand::{Rng, RngCore};
use sqlx::PgConnection;
use tracing::{info_span, Instrument};
@ -87,6 +91,21 @@ pub trait OAuth2ClientRepository: Send + Sync {
jwks_uri: Option<Url>,
redirect_uris: Vec<Url>,
) -> Result<Client, Self::Error>;
async fn get_consent_for_user(
&mut self,
client: &Client,
user: &User,
) -> Result<Scope, Self::Error>;
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
user: &User,
scope: &Scope,
) -> Result<(), Self::Error>;
}
pub struct PgOAuth2ClientRepository<'c> {
@ -702,4 +721,94 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
initiate_login_uri: None,
})
}
#[tracing::instrument(
name = "db.oauth2_client.get_consent_for_user",
skip_all,
fields(
db.statement,
%user.id,
%client.id,
),
err,
)]
async fn get_consent_for_user(
&mut self,
client: &Client,
user: &User,
) -> Result<Scope, Self::Error> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
Uuid::from(user.id),
Uuid::from(client.id),
)
.fetch_all(&mut *self.conn)
.await?;
let scope: Result<Scope, _> = scope_tokens
.into_iter()
.map(|s| ScopeToken::from_str(&s))
.collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
skip_all,
fields(
db.statement,
%user.id,
%client.id,
%scope,
),
err,
)]
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
client: &Client,
user: &User,
scope: &Scope,
) -> Result<(), Self::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
&ids,
Uuid::from(user.id),
Uuid::from(client.id),
&tokens,
now,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
}

View File

@ -1,110 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::str::FromStr;
use mas_data_model::{Client, User};
use oauth2_types::scope::{Scope, ScopeToken};
use rand::Rng;
use sqlx::PgExecutor;
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
#[tracing::instrument(
skip_all,
fields(
%user.id,
%client.id,
),
err,
)]
pub async fn fetch_client_consent(
executor: impl PgExecutor<'_>,
user: &User,
client: &Client,
) -> Result<Scope, DatabaseError> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
Uuid::from(user.id),
Uuid::from(client.id),
)
.fetch_all(executor)
.await?;
let scope: Result<Scope, _> = scope_tokens
.into_iter()
.map(|s| ScopeToken::from_str(&s))
.collect();
let scope = scope.map_err(|e| {
DatabaseInconsistencyError::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
%client.id,
%scope,
),
err,
)]
pub async fn insert_client_consent(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User,
client: &Client,
scope: &Scope,
) -> Result<(), sqlx::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
&ids,
Uuid::from(user.id),
Uuid::from(client.id),
&tokens,
now,
)
.execute(executor)
.await?;
Ok(())
}

View File

@ -12,14 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod access_token;
mod access_token;
pub mod authorization_grant;
mod client;
pub mod consent;
pub mod refresh_token;
mod refresh_token;
mod session;
pub use self::{
access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository},
authorization_grant::{
OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository,
},
client::{OAuth2ClientRepository, PgOAuth2ClientRepository},
refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository},
session::{OAuth2SessionRepository, PgOAuth2SessionRepository},
};

View File

@ -12,62 +12,55 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError};
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[tracing::instrument(
skip_all,
fields(
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
refresh_token.id,
),
err,
)]
pub async fn add_refresh_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
#[async_trait]
pub trait OAuth2RefreshTokenRepository: Send + Sync {
type Error;
sqlx::query!(
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
refresh_token,
created_at,
)
.execute(executor)
.await?;
/// Lookup a refresh token by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error>;
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
})
/// Find a refresh token by its token
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<RefreshToken>, Self::Error>;
/// Add a new refresh token to the database
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, Self::Error>;
/// Consume a refresh token
async fn consume(
&mut self,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error>;
}
pub struct PgOAuth2RefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct OAuth2RefreshTokenLookup {
@ -79,75 +72,183 @@ struct OAuth2RefreshTokenLookup {
oauth2_session_id: Uuid,
}
#[tracing::instrument(skip_all, err)]
#[allow(clippy::too_many_lines)]
pub async fn lookup_refresh_token(
conn: &mut PgConnection,
token: &str,
) -> Result<Option<RefreshToken>, DatabaseError> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
fn from(value: OAuth2RefreshTokenLookup) -> Self {
let state = match value.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
WHERE refresh_token = $1
"#,
token,
)
.fetch_one(&mut *conn)
.await?;
let state = match res.consumed_at {
None => RefreshTokenState::Valid,
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
};
let refresh_token = RefreshToken {
id: res.oauth2_refresh_token_id.into(),
state,
session_id: res.oauth2_session_id.into(),
refresh_token: res.refresh_token,
created_at: res.created_at,
access_token_id: res.oauth2_access_token_id.map(Ulid::from),
};
Ok(Some(refresh_token))
RefreshToken {
id: value.oauth2_refresh_token_id.into(),
state,
session_id: value.oauth2_session_id.into(),
refresh_token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
}
}
}
#[tracing::instrument(
skip_all,
fields(
%refresh_token.id,
),
err,
)]
pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, DatabaseError> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(executor)
.await?;
#[async_trait]
impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
type Error = DatabaseError;
DatabaseError::ensure_affected_rows(&res, 1)?;
#[tracing::instrument(
name = "db.oauth2_refresh_token.lookup",
skip_all,
fields(
db.statement,
refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<RefreshToken>, Self::Error> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT oauth2_refresh_token_id
, refresh_token
, created_at
, consumed_at
, oauth2_access_token_id
, oauth2_session_id
FROM oauth2_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.add",
skip_all,
fields(
db.statement,
%session.id,
user_session.id = %session.user_session_id,
client.id = %session.client_id,
refresh_token.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
) -> Result<RefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
refresh_token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(RefreshToken {
id,
state: RefreshTokenState::default(),
session_id: session.id,
refresh_token,
access_token_id: Some(access_token.id),
created_at,
})
}
#[tracing::instrument(
name = "db.oauth2_refresh_token.consume",
skip_all,
fields(
db.statement,
%refresh_token.id,
session.id = %refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)
}
}

View File

@ -19,7 +19,10 @@ use crate::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
PgCompatSsoLoginRepository,
},
oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository},
oauth2::{
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
},
upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
PgUpstreamOAuthSessionRepository,
@ -63,10 +66,22 @@ pub trait Repository {
where
Self: 'c;
type OAuth2AuthorizationGrantRepository<'c>
where
Self: 'c;
type OAuth2SessionRepository<'c>
where
Self: 'c;
type OAuth2AccessTokenRepository<'c>
where
Self: 'c;
type OAuth2RefreshTokenRepository<'c>
where
Self: 'c;
type CompatSessionRepository<'c>
where
Self: 'c;
@ -91,7 +106,10 @@ pub trait Repository {
fn user_password(&mut self) -> Self::UserPasswordRepository<'_>;
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_>;
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_>;
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_>;
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>;
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>;
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
@ -107,7 +125,10 @@ impl Repository for PgConnection {
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
@ -145,10 +166,22 @@ impl Repository for PgConnection {
PgOAuth2ClientRepository::new(self)
}
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}
@ -175,7 +208,10 @@ impl<'t> Repository for Transaction<'t, Postgres> {
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
@ -213,10 +249,22 @@ impl<'t> Repository for Transaction<'t, Postgres> {
PgOAuth2ClientRepository::new(self)
}
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
PgOAuth2AuthorizationGrantRepository::new(self)
}
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
PgOAuth2AccessTokenRepository::new(self)
}
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
PgOAuth2RefreshTokenRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}