You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
storage: remaining oauth2 repositories
- authorization grants - access tokens - refresh tokens
This commit is contained in:
@ -29,7 +29,7 @@ use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, Header
|
|||||||
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
||||||
use mas_data_model::Session;
|
use mas_data_model::Session;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
oauth2::{access_token::find_access_token, OAuth2SessionRepository},
|
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
|
||||||
DatabaseError, Repository,
|
DatabaseError, Repository,
|
||||||
};
|
};
|
||||||
use serde::{de::DeserializeOwned, Deserialize};
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
@ -62,7 +62,9 @@ impl AccessToken {
|
|||||||
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
|
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
|
||||||
};
|
};
|
||||||
|
|
||||||
let token = find_access_token(conn, token.as_str())
|
let token = conn
|
||||||
|
.oauth2_access_token()
|
||||||
|
.find_by_token(token.as_str())
|
||||||
.await?
|
.await?
|
||||||
.ok_or(AuthorizationVerificationError::InvalidToken)?;
|
.ok_or(AuthorizationVerificationError::InvalidToken)?;
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ impl AuthorizationGrantStage {
|
|||||||
Self::Pending
|
Self::Pending
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fulfill(
|
fn fulfill(
|
||||||
self,
|
self,
|
||||||
fulfilled_at: DateTime<Utc>,
|
fulfilled_at: DateTime<Utc>,
|
||||||
session: &Session,
|
session: &Session,
|
||||||
@ -92,7 +92,7 @@ impl AuthorizationGrantStage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||||
match self {
|
match self {
|
||||||
Self::Fulfilled {
|
Self::Fulfilled {
|
||||||
fulfilled_at,
|
fulfilled_at,
|
||||||
@ -106,7 +106,7 @@ impl AuthorizationGrantStage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||||
match self {
|
match self {
|
||||||
Self::Pending => Ok(Self::Cancelled { cancelled_at }),
|
Self::Pending => Ok(Self::Cancelled { cancelled_at }),
|
||||||
_ => Err(InvalidTransitionError),
|
_ => Err(InvalidTransitionError),
|
||||||
@ -146,4 +146,24 @@ impl AuthorizationGrant {
|
|||||||
let max_age: Option<i64> = self.max_age.map(|x| x.get().into());
|
let max_age: Option<i64> = self.max_age.map(|x| x.get().into());
|
||||||
self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365))
|
self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||||
|
self.stage = self.stage.exchange(exchanged_at)?;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fulfill(
|
||||||
|
mut self,
|
||||||
|
fulfilled_at: DateTime<Utc>,
|
||||||
|
session: &Session,
|
||||||
|
) -> Result<Self, InvalidTransitionError> {
|
||||||
|
self.stage = self.stage.fulfill(fulfilled_at, session)?;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: this is not used?
|
||||||
|
pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
|
||||||
|
self.stage = self.stage.cancel(canceld_at)?;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,11 +26,7 @@ use mas_keystore::Encrypter;
|
|||||||
use mas_policy::PolicyFactory;
|
use mas_policy::PolicyFactory;
|
||||||
use mas_router::{PostAuthAction, Route};
|
use mas_router::{PostAuthAction, Route};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
oauth2::{
|
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
|
||||||
authorization_grant::{fulfill_grant, get_grant_by_id},
|
|
||||||
consent::fetch_client_consent,
|
|
||||||
OAuth2ClientRepository, OAuth2SessionRepository,
|
|
||||||
},
|
|
||||||
Repository,
|
Repository,
|
||||||
};
|
};
|
||||||
use mas_templates::Templates;
|
use mas_templates::Templates;
|
||||||
@ -94,7 +90,9 @@ pub(crate) async fn get(
|
|||||||
|
|
||||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||||
|
|
||||||
let grant = get_grant_by_id(&mut txn, grant_id)
|
let grant = txn
|
||||||
|
.oauth2_authorization_grant()
|
||||||
|
.lookup(grant_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::NotFound)?;
|
.ok_or(RouteError::NotFound)?;
|
||||||
|
|
||||||
@ -192,7 +190,10 @@ pub(crate) async fn complete(
|
|||||||
.await?
|
.await?
|
||||||
.ok_or(GrantCompletionError::NoSuchClient)?;
|
.ok_or(GrantCompletionError::NoSuchClient)?;
|
||||||
|
|
||||||
let current_consent = fetch_client_consent(&mut txn, &browser_session.user, &client).await?;
|
let current_consent = txn
|
||||||
|
.oauth2_client()
|
||||||
|
.get_consent_for_user(&client, &browser_session.user)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let lacks_consent = grant
|
let lacks_consent = grant
|
||||||
.scope
|
.scope
|
||||||
@ -211,7 +212,10 @@ pub(crate) async fn complete(
|
|||||||
.create_from_grant(&mut rng, &clock, &grant, &browser_session)
|
.create_from_grant(&mut rng, &clock, &grant, &browser_session)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let grant = fulfill_grant(&mut txn, grant, session.clone()).await?;
|
let grant = txn
|
||||||
|
.oauth2_authorization_grant()
|
||||||
|
.fulfill(&clock, &session, grant)
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Yep! Let's complete the auth now
|
// Yep! Let's complete the auth now
|
||||||
let mut params = AuthorizationResponse::default();
|
let mut params = AuthorizationResponse::default();
|
||||||
|
@ -26,7 +26,7 @@ use mas_keystore::Encrypter;
|
|||||||
use mas_policy::PolicyFactory;
|
use mas_policy::PolicyFactory;
|
||||||
use mas_router::{PostAuthAction, Route};
|
use mas_router::{PostAuthAction, Route};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
oauth2::{authorization_grant::new_authorization_grant, OAuth2ClientRepository},
|
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
|
||||||
Repository,
|
Repository,
|
||||||
};
|
};
|
||||||
use mas_templates::Templates;
|
use mas_templates::Templates;
|
||||||
@ -275,18 +275,18 @@ pub(crate) async fn get(
|
|||||||
|
|
||||||
let requires_consent = prompt.contains(&Prompt::Consent);
|
let requires_consent = prompt.contains(&Prompt::Consent);
|
||||||
|
|
||||||
let grant = new_authorization_grant(
|
let grant = txn
|
||||||
&mut txn,
|
.oauth2_authorization_grant()
|
||||||
|
.add(
|
||||||
&mut rng,
|
&mut rng,
|
||||||
&clock,
|
&clock,
|
||||||
client,
|
&client,
|
||||||
redirect_uri.clone(),
|
redirect_uri.clone(),
|
||||||
params.auth.scope,
|
params.auth.scope,
|
||||||
code,
|
code,
|
||||||
params.auth.state.clone(),
|
params.auth.state.clone(),
|
||||||
params.auth.nonce,
|
params.auth.nonce,
|
||||||
params.auth.max_age,
|
params.auth.max_age,
|
||||||
None,
|
|
||||||
response_mode,
|
response_mode,
|
||||||
response_type.has_id_token(),
|
response_type.has_id_token(),
|
||||||
requires_consent,
|
requires_consent,
|
||||||
|
@ -29,11 +29,7 @@ use mas_keystore::Encrypter;
|
|||||||
use mas_policy::PolicyFactory;
|
use mas_policy::PolicyFactory;
|
||||||
use mas_router::{PostAuthAction, Route};
|
use mas_router::{PostAuthAction, Route};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
oauth2::{
|
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
|
||||||
authorization_grant::{get_grant_by_id, give_consent_to_grant},
|
|
||||||
consent::insert_client_consent,
|
|
||||||
OAuth2ClientRepository,
|
|
||||||
},
|
|
||||||
Repository,
|
Repository,
|
||||||
};
|
};
|
||||||
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
|
use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
|
||||||
@ -91,7 +87,9 @@ pub(crate) async fn get(
|
|||||||
|
|
||||||
let maybe_session = session_info.load_session(&mut conn).await?;
|
let maybe_session = session_info.load_session(&mut conn).await?;
|
||||||
|
|
||||||
let grant = get_grant_by_id(&mut conn, grant_id)
|
let grant = conn
|
||||||
|
.oauth2_authorization_grant()
|
||||||
|
.lookup(grant_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::GrantNotFound)?;
|
.ok_or(RouteError::GrantNotFound)?;
|
||||||
|
|
||||||
@ -146,7 +144,9 @@ pub(crate) async fn post(
|
|||||||
|
|
||||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||||
|
|
||||||
let grant = get_grant_by_id(&mut txn, grant_id)
|
let grant = txn
|
||||||
|
.oauth2_authorization_grant()
|
||||||
|
.lookup(grant_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::GrantNotFound)?;
|
.ok_or(RouteError::GrantNotFound)?;
|
||||||
let next = PostAuthAction::continue_grant(grant_id);
|
let next = PostAuthAction::continue_grant(grant_id);
|
||||||
@ -180,17 +180,17 @@ pub(crate) async fn post(
|
|||||||
.filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:"))
|
.filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:"))
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
insert_client_consent(
|
txn.oauth2_client()
|
||||||
&mut txn,
|
.give_consent_for_user(
|
||||||
&mut rng,
|
&mut rng,
|
||||||
&clock,
|
&clock,
|
||||||
&session.user,
|
|
||||||
&client,
|
&client,
|
||||||
|
&session.user,
|
||||||
&scope_without_device,
|
&scope_without_device,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let _grant = give_consent_to_grant(&mut txn, grant).await?;
|
txn.oauth2_authorization_grant().give_consent(grant).await?;
|
||||||
|
|
||||||
txn.commit().await?;
|
txn.commit().await?;
|
||||||
|
|
||||||
|
@ -23,10 +23,7 @@ use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
|||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
|
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
|
||||||
oauth2::{
|
oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository},
|
||||||
access_token::find_access_token, refresh_token::lookup_refresh_token,
|
|
||||||
OAuth2SessionRepository,
|
|
||||||
},
|
|
||||||
user::{BrowserSessionRepository, UserRepository},
|
user::{BrowserSessionRepository, UserRepository},
|
||||||
Clock, Repository,
|
Clock, Repository,
|
||||||
};
|
};
|
||||||
@ -169,7 +166,9 @@ pub(crate) async fn post(
|
|||||||
|
|
||||||
let reply = match token_type {
|
let reply = match token_type {
|
||||||
TokenType::AccessToken => {
|
TokenType::AccessToken => {
|
||||||
let token = find_access_token(&mut conn, token)
|
let token = conn
|
||||||
|
.oauth2_access_token()
|
||||||
|
.find_by_token(token)
|
||||||
.await?
|
.await?
|
||||||
.filter(|t| t.is_valid(clock.now()))
|
.filter(|t| t.is_valid(clock.now()))
|
||||||
.ok_or(RouteError::UnknownToken)?;
|
.ok_or(RouteError::UnknownToken)?;
|
||||||
@ -206,7 +205,9 @@ pub(crate) async fn post(
|
|||||||
}
|
}
|
||||||
|
|
||||||
TokenType::RefreshToken => {
|
TokenType::RefreshToken => {
|
||||||
let token = lookup_refresh_token(&mut conn, token)
|
let token = conn
|
||||||
|
.oauth2_refresh_token()
|
||||||
|
.find_by_token(token)
|
||||||
.await?
|
.await?
|
||||||
.filter(|t| t.is_valid())
|
.filter(|t| t.is_valid())
|
||||||
.ok_or(RouteError::UnknownToken)?;
|
.ok_or(RouteError::UnknownToken)?;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@ -33,10 +33,8 @@ use mas_keystore::{Encrypter, Keystore};
|
|||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
oauth2::{
|
oauth2::{
|
||||||
access_token::{add_access_token, lookup_access_token, revoke_access_token},
|
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
|
||||||
authorization_grant::{exchange_grant, lookup_grant_by_code},
|
OAuth2RefreshTokenRepository, OAuth2SessionRepository,
|
||||||
refresh_token::{add_refresh_token, consume_refresh_token, lookup_refresh_token},
|
|
||||||
OAuth2SessionRepository,
|
|
||||||
},
|
},
|
||||||
user::BrowserSessionRepository,
|
user::BrowserSessionRepository,
|
||||||
Repository,
|
Repository,
|
||||||
@ -217,9 +215,9 @@ async fn authorization_code_grant(
|
|||||||
) -> Result<AccessTokenResponse, RouteError> {
|
) -> Result<AccessTokenResponse, RouteError> {
|
||||||
let (clock, mut rng) = crate::clock_and_rng();
|
let (clock, mut rng) = crate::clock_and_rng();
|
||||||
|
|
||||||
// TODO: there is a bunch of unnecessary cloning here
|
let authz_grant = txn
|
||||||
// TODO: handle "not found" cases
|
.oauth2_authorization_grant()
|
||||||
let authz_grant = lookup_grant_by_code(&mut txn, &grant.code)
|
.find_by_code(&grant.code)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::GrantNotFound)?;
|
.ok_or(RouteError::GrantNotFound)?;
|
||||||
|
|
||||||
@ -301,17 +299,14 @@ async fn authorization_code_grant(
|
|||||||
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
||||||
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
||||||
|
|
||||||
let access_token =
|
let access_token = txn
|
||||||
add_access_token(&mut txn, &mut rng, &clock, &session, access_token_str, ttl).await?;
|
.oauth2_access_token()
|
||||||
|
.add(&mut rng, &clock, &session, access_token_str, ttl)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let refresh_token = add_refresh_token(
|
let refresh_token = txn
|
||||||
&mut txn,
|
.oauth2_refresh_token()
|
||||||
&mut rng,
|
.add(&mut rng, &clock, &session, &access_token, refresh_token_str)
|
||||||
&clock,
|
|
||||||
&session,
|
|
||||||
&access_token,
|
|
||||||
refresh_token_str,
|
|
||||||
)
|
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let id_token = if session.scope.contains(&scope::OPENID) {
|
let id_token = if session.scope.contains(&scope::OPENID) {
|
||||||
@ -360,7 +355,9 @@ async fn authorization_code_grant(
|
|||||||
params = params.with_id_token(id_token);
|
params = params.with_id_token(id_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
exchange_grant(&mut txn, &clock, authz_grant).await?;
|
txn.oauth2_authorization_grant()
|
||||||
|
.exchange(&clock, authz_grant)
|
||||||
|
.await?;
|
||||||
|
|
||||||
txn.commit().await?;
|
txn.commit().await?;
|
||||||
|
|
||||||
@ -374,7 +371,9 @@ async fn refresh_token_grant(
|
|||||||
) -> Result<AccessTokenResponse, RouteError> {
|
) -> Result<AccessTokenResponse, RouteError> {
|
||||||
let (clock, mut rng) = crate::clock_and_rng();
|
let (clock, mut rng) = crate::clock_and_rng();
|
||||||
|
|
||||||
let refresh_token = lookup_refresh_token(&mut txn, &grant.refresh_token)
|
let refresh_token = txn
|
||||||
|
.oauth2_refresh_token()
|
||||||
|
.find_by_token(&grant.refresh_token)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(RouteError::InvalidGrant)?;
|
.ok_or(RouteError::InvalidGrant)?;
|
||||||
|
|
||||||
@ -397,18 +396,14 @@ async fn refresh_token_grant(
|
|||||||
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
||||||
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
||||||
|
|
||||||
let new_access_token = add_access_token(
|
let new_access_token = txn
|
||||||
&mut txn,
|
.oauth2_access_token()
|
||||||
&mut rng,
|
.add(&mut rng, &clock, &session, access_token_str.clone(), ttl)
|
||||||
&clock,
|
|
||||||
&session,
|
|
||||||
access_token_str.clone(),
|
|
||||||
ttl,
|
|
||||||
)
|
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let new_refresh_token = add_refresh_token(
|
let new_refresh_token = txn
|
||||||
&mut txn,
|
.oauth2_refresh_token()
|
||||||
|
.add(
|
||||||
&mut rng,
|
&mut rng,
|
||||||
&clock,
|
&clock,
|
||||||
&session,
|
&session,
|
||||||
@ -417,11 +412,16 @@ async fn refresh_token_grant(
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let refresh_token = consume_refresh_token(&mut txn, &clock, refresh_token).await?;
|
let refresh_token = txn
|
||||||
|
.oauth2_refresh_token()
|
||||||
|
.consume(&clock, refresh_token)
|
||||||
|
.await?;
|
||||||
|
|
||||||
if let Some(access_token_id) = refresh_token.access_token_id {
|
if let Some(access_token_id) = refresh_token.access_token_id {
|
||||||
if let Some(access_token) = lookup_access_token(&mut txn, access_token_id).await? {
|
if let Some(access_token) = txn.oauth2_access_token().lookup(access_token_id).await? {
|
||||||
revoke_access_token(&mut txn, &clock, access_token).await?;
|
txn.oauth2_access_token()
|
||||||
|
.revoke(&clock, access_token)
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use mas_router::{PostAuthAction, Route};
|
use mas_router::{PostAuthAction, Route};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::CompatSsoLoginRepository, oauth2::authorization_grant::get_grant_by_id,
|
compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository,
|
||||||
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository,
|
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository,
|
||||||
};
|
};
|
||||||
use mas_templates::{PostAuthContext, PostAuthContextInner};
|
use mas_templates::{PostAuthContext, PostAuthContextInner};
|
||||||
@ -46,7 +46,9 @@ impl OptionalPostAuthAction {
|
|||||||
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
|
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
|
||||||
let ctx = match action {
|
let ctx = match action {
|
||||||
PostAuthAction::ContinueAuthorizationGrant { id } => {
|
PostAuthAction::ContinueAuthorizationGrant { id } => {
|
||||||
let grant = get_grant_by_id(conn, id)
|
let grant = conn
|
||||||
|
.oauth2_authorization_grant()
|
||||||
|
.lookup(id)
|
||||||
.await?
|
.await?
|
||||||
.context("Failed to load authorization grant")?;
|
.context("Failed to load authorization grant")?;
|
||||||
let grant = Box::new(grant);
|
let grant = Box::new(grant);
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
|||||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@ -12,67 +12,61 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use mas_data_model::{AccessToken, AccessTokenState, Session};
|
use mas_data_model::{AccessToken, AccessTokenState, Session};
|
||||||
use rand::Rng;
|
use rand::RngCore;
|
||||||
use sqlx::{PgConnection, PgExecutor};
|
use sqlx::PgConnection;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{Clock, DatabaseError, LookupResultExt};
|
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[async_trait]
|
||||||
skip_all,
|
pub trait OAuth2AccessTokenRepository: Send + Sync {
|
||||||
fields(
|
type Error;
|
||||||
%session.id,
|
|
||||||
user_session.id = %session.user_session_id,
|
/// Lookup an access token by its ID
|
||||||
client.id = %session.client_id,
|
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error>;
|
||||||
access_token.id,
|
|
||||||
),
|
/// Find an access token by its token
|
||||||
err,
|
async fn find_by_token(
|
||||||
)]
|
&mut self,
|
||||||
pub async fn add_access_token(
|
access_token: &str,
|
||||||
executor: impl PgExecutor<'_>,
|
) -> Result<Option<AccessToken>, Self::Error>;
|
||||||
mut rng: impl Rng + Send,
|
|
||||||
|
/// Add a new access token to the database
|
||||||
|
async fn add(
|
||||||
|
&mut self,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
session: &Session,
|
session: &Session,
|
||||||
access_token: String,
|
access_token: String,
|
||||||
expires_after: Duration,
|
expires_after: Duration,
|
||||||
) -> Result<AccessToken, sqlx::Error> {
|
) -> Result<AccessToken, Self::Error>;
|
||||||
let created_at = clock.now();
|
|
||||||
let expires_at = created_at + expires_after;
|
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
|
||||||
|
|
||||||
tracing::Span::current().record("access_token.id", tracing::field::display(id));
|
/// Revoke an access token
|
||||||
|
async fn revoke(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
access_token: AccessToken,
|
||||||
|
) -> Result<AccessToken, Self::Error>;
|
||||||
|
|
||||||
sqlx::query!(
|
/// Cleanup expired access tokens
|
||||||
r#"
|
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error>;
|
||||||
INSERT INTO oauth2_access_tokens
|
|
||||||
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
|
|
||||||
VALUES
|
|
||||||
($1, $2, $3, $4, $5)
|
|
||||||
"#,
|
|
||||||
Uuid::from(id),
|
|
||||||
Uuid::from(session.id),
|
|
||||||
&access_token,
|
|
||||||
created_at,
|
|
||||||
expires_at,
|
|
||||||
)
|
|
||||||
.execute(executor)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(AccessToken {
|
|
||||||
id,
|
|
||||||
state: AccessTokenState::default(),
|
|
||||||
access_token,
|
|
||||||
session_id: session.id,
|
|
||||||
created_at,
|
|
||||||
expires_at,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
pub struct PgOAuth2AccessTokenRepository<'c> {
|
||||||
pub struct OAuth2AccessTokenLookup {
|
conn: &'c mut PgConnection,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'c> PgOAuth2AccessTokenRepository<'c> {
|
||||||
|
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||||
|
Self { conn }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct OAuth2AccessTokenLookup {
|
||||||
oauth2_access_token_id: Uuid,
|
oauth2_access_token_id: Uuid,
|
||||||
oauth2_session_id: Uuid,
|
oauth2_session_id: Uuid,
|
||||||
access_token: String,
|
access_token: String,
|
||||||
@ -99,11 +93,48 @@ impl From<OAuth2AccessTokenLookup> for AccessToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[async_trait]
|
||||||
pub async fn find_access_token(
|
impl<'c> OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'c> {
|
||||||
conn: &mut PgConnection,
|
type Error = DatabaseError;
|
||||||
token: &str,
|
|
||||||
) -> Result<Option<AccessToken>, DatabaseError> {
|
async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
|
||||||
|
let res = sqlx::query_as!(
|
||||||
|
OAuth2AccessTokenLookup,
|
||||||
|
r#"
|
||||||
|
SELECT oauth2_access_token_id
|
||||||
|
, access_token
|
||||||
|
, created_at
|
||||||
|
, expires_at
|
||||||
|
, revoked_at
|
||||||
|
, oauth2_session_id
|
||||||
|
|
||||||
|
FROM oauth2_access_tokens
|
||||||
|
|
||||||
|
WHERE oauth2_access_token_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
|
Ok(Some(res.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_access_token.find_by_token",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn find_by_token(
|
||||||
|
&mut self,
|
||||||
|
access_token: &str,
|
||||||
|
) -> Result<Option<AccessToken>, Self::Error> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
OAuth2AccessTokenLookup,
|
OAuth2AccessTokenLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -118,64 +149,75 @@ pub async fn find_access_token(
|
|||||||
|
|
||||||
WHERE access_token = $1
|
WHERE access_token = $1
|
||||||
"#,
|
"#,
|
||||||
token,
|
access_token,
|
||||||
)
|
)
|
||||||
.fetch_one(&mut *conn)
|
.fetch_one(&mut *self.conn)
|
||||||
.await
|
.await
|
||||||
.to_option()?;
|
.to_option()?;
|
||||||
|
|
||||||
let Some(res) = res else { return Ok(None) };
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
Ok(Some(res.into()))
|
Ok(Some(res.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
name = "db.oauth2_access_token.add",
|
||||||
fields(access_token.id = %access_token_id),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn lookup_access_token(
|
|
||||||
conn: &mut PgConnection,
|
|
||||||
access_token_id: Ulid,
|
|
||||||
) -> Result<Option<AccessToken>, DatabaseError> {
|
|
||||||
let res = sqlx::query_as!(
|
|
||||||
OAuth2AccessTokenLookup,
|
|
||||||
r#"
|
|
||||||
SELECT oauth2_access_token_id
|
|
||||||
, access_token
|
|
||||||
, created_at
|
|
||||||
, expires_at
|
|
||||||
, revoked_at
|
|
||||||
, oauth2_session_id
|
|
||||||
|
|
||||||
FROM oauth2_access_tokens
|
|
||||||
|
|
||||||
WHERE oauth2_access_token_id = $1
|
|
||||||
"#,
|
|
||||||
Uuid::from(access_token_id),
|
|
||||||
)
|
|
||||||
.fetch_one(&mut *conn)
|
|
||||||
.await
|
|
||||||
.to_option()?;
|
|
||||||
|
|
||||||
let Some(res) = res else { return Ok(None) };
|
|
||||||
|
|
||||||
Ok(Some(res.into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
%access_token.id,
|
db.statement,
|
||||||
session.id = %access_token.session_id,
|
%session.id,
|
||||||
|
user_session.id = %session.user_session_id,
|
||||||
|
client.id = %session.client_id,
|
||||||
|
access_token.id,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn revoke_access_token(
|
async fn add(
|
||||||
executor: impl PgExecutor<'_>,
|
&mut self,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
|
clock: &Clock,
|
||||||
|
session: &Session,
|
||||||
|
access_token: String,
|
||||||
|
expires_after: Duration,
|
||||||
|
) -> Result<AccessToken, Self::Error> {
|
||||||
|
let created_at = clock.now();
|
||||||
|
let expires_at = created_at + expires_after;
|
||||||
|
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||||
|
|
||||||
|
tracing::Span::current().record("access_token.id", tracing::field::display(id));
|
||||||
|
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
INSERT INTO oauth2_access_tokens
|
||||||
|
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, $4, $5)
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
Uuid::from(session.id),
|
||||||
|
&access_token,
|
||||||
|
created_at,
|
||||||
|
expires_at,
|
||||||
|
)
|
||||||
|
.traced()
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(AccessToken {
|
||||||
|
id,
|
||||||
|
state: AccessTokenState::default(),
|
||||||
|
access_token,
|
||||||
|
session_id: session.id,
|
||||||
|
created_at,
|
||||||
|
expires_at,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn revoke(
|
||||||
|
&mut self,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
access_token: AccessToken,
|
access_token: AccessToken,
|
||||||
) -> Result<AccessToken, DatabaseError> {
|
) -> Result<AccessToken, Self::Error> {
|
||||||
let revoked_at = clock.now();
|
let revoked_at = clock.now();
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
@ -186,7 +228,7 @@ pub async fn revoke_access_token(
|
|||||||
Uuid::from(access_token.id),
|
Uuid::from(access_token.id),
|
||||||
revoked_at,
|
revoked_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||||
@ -194,12 +236,9 @@ pub async fn revoke_access_token(
|
|||||||
access_token
|
access_token
|
||||||
.revoke(revoked_at)
|
.revoke(revoked_at)
|
||||||
.map_err(DatabaseError::to_invalid_operation)
|
.map_err(DatabaseError::to_invalid_operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn cleanup_expired(
|
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error> {
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
clock: &Clock,
|
|
||||||
) -> Result<u64, sqlx::Error> {
|
|
||||||
// Cleanup token which expired more than 15 minutes ago
|
// Cleanup token which expired more than 15 minutes ago
|
||||||
let threshold = clock.now() - Duration::minutes(15);
|
let threshold = clock.now() - Duration::minutes(15);
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
@ -209,8 +248,9 @@ pub async fn cleanup_expired(
|
|||||||
"#,
|
"#,
|
||||||
threshold,
|
threshold,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(res.rows_affected())
|
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,45 +14,261 @@
|
|||||||
|
|
||||||
use std::num::NonZeroU32;
|
use std::num::NonZeroU32;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use mas_data_model::{
|
use mas_data_model::{
|
||||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
|
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
|
||||||
};
|
};
|
||||||
use mas_iana::oauth::PkceCodeChallengeMethod;
|
use mas_iana::oauth::PkceCodeChallengeMethod;
|
||||||
use oauth2_types::{requests::ResponseMode, scope::Scope};
|
use oauth2_types::{requests::ResponseMode, scope::Scope};
|
||||||
use rand::Rng;
|
use rand::RngCore;
|
||||||
use sqlx::{PgConnection, PgExecutor};
|
use sqlx::PgConnection;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
|
use crate::{
|
||||||
|
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||||
|
};
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[async_trait]
|
||||||
skip_all,
|
pub trait OAuth2AuthorizationGrantRepository {
|
||||||
fields(
|
type Error;
|
||||||
%client.id,
|
|
||||||
grant.id,
|
#[allow(clippy::too_many_arguments)]
|
||||||
),
|
async fn add(
|
||||||
err,
|
&mut self,
|
||||||
)]
|
rng: &mut (dyn RngCore + Send),
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub async fn new_authorization_grant(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
mut rng: impl Rng + Send,
|
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
client: Client,
|
client: &Client,
|
||||||
redirect_uri: Url,
|
redirect_uri: Url,
|
||||||
scope: Scope,
|
scope: Scope,
|
||||||
code: Option<AuthorizationCode>,
|
code: Option<AuthorizationCode>,
|
||||||
state: Option<String>,
|
state: Option<String>,
|
||||||
nonce: Option<String>,
|
nonce: Option<String>,
|
||||||
max_age: Option<NonZeroU32>,
|
max_age: Option<NonZeroU32>,
|
||||||
_acr_values: Option<String>,
|
|
||||||
response_mode: ResponseMode,
|
response_mode: ResponseMode,
|
||||||
response_type_id_token: bool,
|
response_type_id_token: bool,
|
||||||
requires_consent: bool,
|
requires_consent: bool,
|
||||||
) -> Result<AuthorizationGrant, sqlx::Error> {
|
) -> Result<AuthorizationGrant, Self::Error>;
|
||||||
|
|
||||||
|
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error>;
|
||||||
|
|
||||||
|
async fn find_by_code(&mut self, code: &str)
|
||||||
|
-> Result<Option<AuthorizationGrant>, Self::Error>;
|
||||||
|
|
||||||
|
async fn fulfill(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
session: &Session,
|
||||||
|
authorization_grant: AuthorizationGrant,
|
||||||
|
) -> Result<AuthorizationGrant, Self::Error>;
|
||||||
|
|
||||||
|
async fn exchange(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
authorization_grant: AuthorizationGrant,
|
||||||
|
) -> Result<AuthorizationGrant, Self::Error>;
|
||||||
|
|
||||||
|
async fn give_consent(
|
||||||
|
&mut self,
|
||||||
|
authorization_grant: AuthorizationGrant,
|
||||||
|
) -> Result<AuthorizationGrant, Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PgOAuth2AuthorizationGrantRepository<'c> {
|
||||||
|
conn: &'c mut PgConnection,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
|
||||||
|
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||||
|
Self { conn }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
|
struct GrantLookup {
|
||||||
|
oauth2_authorization_grant_id: Uuid,
|
||||||
|
created_at: DateTime<Utc>,
|
||||||
|
cancelled_at: Option<DateTime<Utc>>,
|
||||||
|
fulfilled_at: Option<DateTime<Utc>>,
|
||||||
|
exchanged_at: Option<DateTime<Utc>>,
|
||||||
|
scope: String,
|
||||||
|
state: Option<String>,
|
||||||
|
nonce: Option<String>,
|
||||||
|
redirect_uri: String,
|
||||||
|
response_mode: String,
|
||||||
|
max_age: Option<i32>,
|
||||||
|
response_type_code: bool,
|
||||||
|
response_type_id_token: bool,
|
||||||
|
authorization_code: Option<String>,
|
||||||
|
code_challenge: Option<String>,
|
||||||
|
code_challenge_method: Option<String>,
|
||||||
|
requires_consent: bool,
|
||||||
|
oauth2_client_id: Uuid,
|
||||||
|
oauth2_session_id: Option<Uuid>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<GrantLookup> for AuthorizationGrant {
|
||||||
|
type Error = DatabaseInconsistencyError;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
|
||||||
|
let id = value.oauth2_authorization_grant_id.into();
|
||||||
|
let scope: Scope = value.scope.parse().map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("scope")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let stage = match (
|
||||||
|
value.fulfilled_at,
|
||||||
|
value.exchanged_at,
|
||||||
|
value.cancelled_at,
|
||||||
|
value.oauth2_session_id,
|
||||||
|
) {
|
||||||
|
(None, None, None, None) => AuthorizationGrantStage::Pending,
|
||||||
|
(Some(fulfilled_at), None, None, Some(session_id)) => {
|
||||||
|
AuthorizationGrantStage::Fulfilled {
|
||||||
|
session_id: session_id.into(),
|
||||||
|
fulfilled_at,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
|
||||||
|
AuthorizationGrantStage::Exchanged {
|
||||||
|
session_id: session_id.into(),
|
||||||
|
fulfilled_at,
|
||||||
|
exchanged_at,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, None, Some(cancelled_at), None) => {
|
||||||
|
AuthorizationGrantStage::Cancelled { cancelled_at }
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("stage")
|
||||||
|
.row(id),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let pkce = match (value.code_challenge, value.code_challenge_method) {
|
||||||
|
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
|
||||||
|
Some(Pkce {
|
||||||
|
challenge_method: PkceCodeChallengeMethod::Plain,
|
||||||
|
challenge,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
|
||||||
|
challenge_method: PkceCodeChallengeMethod::S256,
|
||||||
|
challenge,
|
||||||
|
}),
|
||||||
|
(None, None) => None,
|
||||||
|
_ => {
|
||||||
|
return Err(
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("code_challenge_method")
|
||||||
|
.row(id),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let code: Option<AuthorizationCode> =
|
||||||
|
match (value.response_type_code, value.authorization_code, pkce) {
|
||||||
|
(false, None, None) => None,
|
||||||
|
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
|
||||||
|
_ => {
|
||||||
|
return Err(
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("authorization_code")
|
||||||
|
.row(id),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let redirect_uri = value.redirect_uri.parse().map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("redirect_uri")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let response_mode = value.response_mode.parse().map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("response_mode")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let max_age = value
|
||||||
|
.max_age
|
||||||
|
.map(u32::try_from)
|
||||||
|
.transpose()
|
||||||
|
.map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("max_age")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?
|
||||||
|
.map(NonZeroU32::try_from)
|
||||||
|
.transpose()
|
||||||
|
.map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||||
|
.column("max_age")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(AuthorizationGrant {
|
||||||
|
id,
|
||||||
|
stage,
|
||||||
|
client_id: value.oauth2_client_id.into(),
|
||||||
|
code,
|
||||||
|
scope,
|
||||||
|
state: value.state,
|
||||||
|
nonce: value.nonce,
|
||||||
|
max_age,
|
||||||
|
response_mode,
|
||||||
|
redirect_uri,
|
||||||
|
created_at: value.created_at,
|
||||||
|
response_type_id_token: value.response_type_id_token,
|
||||||
|
requires_consent: value.requires_consent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'c> {
|
||||||
|
type Error = DatabaseError;
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_authorization_grant.add",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
grant.id,
|
||||||
|
grant.scope = %scope,
|
||||||
|
%client.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn add(
|
||||||
|
&mut self,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
|
clock: &Clock,
|
||||||
|
client: &Client,
|
||||||
|
redirect_uri: Url,
|
||||||
|
scope: Scope,
|
||||||
|
code: Option<AuthorizationCode>,
|
||||||
|
state: Option<String>,
|
||||||
|
nonce: Option<String>,
|
||||||
|
max_age: Option<NonZeroU32>,
|
||||||
|
response_mode: ResponseMode,
|
||||||
|
response_type_id_token: bool,
|
||||||
|
requires_consent: bool,
|
||||||
|
) -> Result<AuthorizationGrant, Self::Error> {
|
||||||
let code_challenge = code
|
let code_challenge = code
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|c| c.pkce.as_ref())
|
.and_then(|c| c.pkce.as_ref())
|
||||||
@ -66,7 +282,7 @@ pub async fn new_authorization_grant(
|
|||||||
let code_str = code.as_ref().map(|c| &c.code);
|
let code_str = code.as_ref().map(|c| &c.code);
|
||||||
|
|
||||||
let created_at = clock.now();
|
let created_at = clock.now();
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||||
tracing::Span::current().record("grant.id", tracing::field::display(id));
|
tracing::Span::current().record("grant.id", tracing::field::display(id));
|
||||||
|
|
||||||
sqlx::query!(
|
sqlx::query!(
|
||||||
@ -107,7 +323,7 @@ pub async fn new_authorization_grant(
|
|||||||
requires_consent,
|
requires_consent,
|
||||||
created_at,
|
created_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(AuthorizationGrant {
|
Ok(AuthorizationGrant {
|
||||||
@ -125,207 +341,40 @@ pub async fn new_authorization_grant(
|
|||||||
response_type_id_token,
|
response_type_id_token,
|
||||||
requires_consent,
|
requires_consent,
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::struct_excessive_bools)]
|
|
||||||
struct GrantLookup {
|
|
||||||
oauth2_authorization_grant_id: Uuid,
|
|
||||||
oauth2_authorization_grant_created_at: DateTime<Utc>,
|
|
||||||
oauth2_authorization_grant_cancelled_at: Option<DateTime<Utc>>,
|
|
||||||
oauth2_authorization_grant_fulfilled_at: Option<DateTime<Utc>>,
|
|
||||||
oauth2_authorization_grant_exchanged_at: Option<DateTime<Utc>>,
|
|
||||||
oauth2_authorization_grant_scope: String,
|
|
||||||
oauth2_authorization_grant_state: Option<String>,
|
|
||||||
oauth2_authorization_grant_nonce: Option<String>,
|
|
||||||
oauth2_authorization_grant_redirect_uri: String,
|
|
||||||
oauth2_authorization_grant_response_mode: String,
|
|
||||||
oauth2_authorization_grant_max_age: Option<i32>,
|
|
||||||
oauth2_authorization_grant_response_type_code: bool,
|
|
||||||
oauth2_authorization_grant_response_type_id_token: bool,
|
|
||||||
oauth2_authorization_grant_code: Option<String>,
|
|
||||||
oauth2_authorization_grant_code_challenge: Option<String>,
|
|
||||||
oauth2_authorization_grant_code_challenge_method: Option<String>,
|
|
||||||
oauth2_authorization_grant_requires_consent: bool,
|
|
||||||
oauth2_client_id: Uuid,
|
|
||||||
oauth2_session_id: Option<Uuid>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<GrantLookup> for AuthorizationGrant {
|
|
||||||
type Error = DatabaseInconsistencyError;
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
|
||||||
fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
|
|
||||||
let id = value.oauth2_authorization_grant_id.into();
|
|
||||||
let scope: Scope = value
|
|
||||||
.oauth2_authorization_grant_scope
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("scope")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let stage = match (
|
|
||||||
value.oauth2_authorization_grant_fulfilled_at,
|
|
||||||
value.oauth2_authorization_grant_exchanged_at,
|
|
||||||
value.oauth2_authorization_grant_cancelled_at,
|
|
||||||
value.oauth2_session_id,
|
|
||||||
) {
|
|
||||||
(None, None, None, None) => AuthorizationGrantStage::Pending,
|
|
||||||
(Some(fulfilled_at), None, None, Some(session_id)) => {
|
|
||||||
AuthorizationGrantStage::Fulfilled {
|
|
||||||
session_id: session_id.into(),
|
|
||||||
fulfilled_at,
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
(Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
|
|
||||||
AuthorizationGrantStage::Exchanged {
|
|
||||||
session_id: session_id.into(),
|
|
||||||
fulfilled_at,
|
|
||||||
exchanged_at,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(None, None, Some(cancelled_at), None) => {
|
|
||||||
AuthorizationGrantStage::Cancelled { cancelled_at }
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
return Err(
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("stage")
|
|
||||||
.row(id),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let pkce = match (
|
#[tracing::instrument(
|
||||||
value.oauth2_authorization_grant_code_challenge,
|
name = "db.oauth2_authorization_grant.lookup",
|
||||||
value.oauth2_authorization_grant_code_challenge_method,
|
|
||||||
) {
|
|
||||||
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
|
|
||||||
Some(Pkce {
|
|
||||||
challenge_method: PkceCodeChallengeMethod::Plain,
|
|
||||||
challenge,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
|
|
||||||
challenge_method: PkceCodeChallengeMethod::S256,
|
|
||||||
challenge,
|
|
||||||
}),
|
|
||||||
(None, None) => None,
|
|
||||||
_ => {
|
|
||||||
return Err(
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("code_challenge_method")
|
|
||||||
.row(id),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let code: Option<AuthorizationCode> = match (
|
|
||||||
value.oauth2_authorization_grant_response_type_code,
|
|
||||||
value.oauth2_authorization_grant_code,
|
|
||||||
pkce,
|
|
||||||
) {
|
|
||||||
(false, None, None) => None,
|
|
||||||
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
|
|
||||||
_ => {
|
|
||||||
return Err(
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("authorization_code")
|
|
||||||
.row(id),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let redirect_uri = value
|
|
||||||
.oauth2_authorization_grant_redirect_uri
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("redirect_uri")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let response_mode = value
|
|
||||||
.oauth2_authorization_grant_response_mode
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("response_mode")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let max_age = value
|
|
||||||
.oauth2_authorization_grant_max_age
|
|
||||||
.map(u32::try_from)
|
|
||||||
.transpose()
|
|
||||||
.map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("max_age")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
})?
|
|
||||||
.map(NonZeroU32::try_from)
|
|
||||||
.transpose()
|
|
||||||
.map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
|
||||||
.column("max_age")
|
|
||||||
.row(id)
|
|
||||||
.source(e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(AuthorizationGrant {
|
|
||||||
id,
|
|
||||||
stage,
|
|
||||||
client_id: value.oauth2_client_id.into(),
|
|
||||||
code,
|
|
||||||
scope,
|
|
||||||
state: value.oauth2_authorization_grant_state,
|
|
||||||
nonce: value.oauth2_authorization_grant_nonce,
|
|
||||||
max_age,
|
|
||||||
response_mode,
|
|
||||||
redirect_uri,
|
|
||||||
created_at: value.oauth2_authorization_grant_created_at,
|
|
||||||
response_type_id_token: value.oauth2_authorization_grant_response_type_id_token,
|
|
||||||
requires_consent: value.oauth2_authorization_grant_requires_consent,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(grant.id = %id),
|
fields(
|
||||||
|
db.statement,
|
||||||
|
grant.id = %id,
|
||||||
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn get_grant_by_id(
|
async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
|
||||||
conn: &mut PgConnection,
|
|
||||||
id: Ulid,
|
|
||||||
) -> Result<Option<AuthorizationGrant>, DatabaseError> {
|
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
GrantLookup,
|
GrantLookup,
|
||||||
r#"
|
r#"
|
||||||
SELECT oauth2_authorization_grant_id
|
SELECT oauth2_authorization_grant_id
|
||||||
, created_at AS oauth2_authorization_grant_created_at
|
, created_at
|
||||||
, cancelled_at AS oauth2_authorization_grant_cancelled_at
|
, cancelled_at
|
||||||
, fulfilled_at AS oauth2_authorization_grant_fulfilled_at
|
, fulfilled_at
|
||||||
, exchanged_at AS oauth2_authorization_grant_exchanged_at
|
, exchanged_at
|
||||||
, scope AS oauth2_authorization_grant_scope
|
, scope
|
||||||
, state AS oauth2_authorization_grant_state
|
, state
|
||||||
, redirect_uri AS oauth2_authorization_grant_redirect_uri
|
, redirect_uri
|
||||||
, response_mode AS oauth2_authorization_grant_response_mode
|
, response_mode
|
||||||
, nonce AS oauth2_authorization_grant_nonce
|
, nonce
|
||||||
, max_age AS oauth2_authorization_grant_max_age
|
, max_age
|
||||||
, oauth2_client_id AS oauth2_client_id
|
, oauth2_client_id
|
||||||
, authorization_code AS oauth2_authorization_grant_code
|
, authorization_code
|
||||||
, response_type_code AS oauth2_authorization_grant_response_type_code
|
, response_type_code
|
||||||
, response_type_id_token AS oauth2_authorization_grant_response_type_id_token
|
, response_type_id_token
|
||||||
, code_challenge AS oauth2_authorization_grant_code_challenge
|
, code_challenge
|
||||||
, code_challenge_method AS oauth2_authorization_grant_code_challenge_method
|
, code_challenge_method
|
||||||
, requires_consent AS oauth2_authorization_grant_requires_consent
|
, requires_consent
|
||||||
, oauth2_session_id AS "oauth2_session_id?"
|
, oauth2_session_id
|
||||||
FROM
|
FROM
|
||||||
oauth2_authorization_grants
|
oauth2_authorization_grants
|
||||||
|
|
||||||
@ -333,42 +382,49 @@ pub async fn get_grant_by_id(
|
|||||||
"#,
|
"#,
|
||||||
Uuid::from(id),
|
Uuid::from(id),
|
||||||
)
|
)
|
||||||
.fetch_one(&mut *conn)
|
.fetch_one(&mut *self.conn)
|
||||||
.await
|
.await
|
||||||
.to_option()?;
|
.to_option()?;
|
||||||
|
|
||||||
let Some(res) = res else { return Ok(None) };
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
Ok(Some(res.try_into()?))
|
Ok(Some(res.try_into()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(
|
||||||
pub async fn lookup_grant_by_code(
|
name = "db.oauth2_authorization_grant.find_by_code",
|
||||||
conn: &mut PgConnection,
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn find_by_code(
|
||||||
|
&mut self,
|
||||||
code: &str,
|
code: &str,
|
||||||
) -> Result<Option<AuthorizationGrant>, DatabaseError> {
|
) -> Result<Option<AuthorizationGrant>, Self::Error> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
GrantLookup,
|
GrantLookup,
|
||||||
r#"
|
r#"
|
||||||
SELECT oauth2_authorization_grant_id
|
SELECT oauth2_authorization_grant_id
|
||||||
, created_at AS oauth2_authorization_grant_created_at
|
, created_at
|
||||||
, cancelled_at AS oauth2_authorization_grant_cancelled_at
|
, cancelled_at
|
||||||
, fulfilled_at AS oauth2_authorization_grant_fulfilled_at
|
, fulfilled_at
|
||||||
, exchanged_at AS oauth2_authorization_grant_exchanged_at
|
, exchanged_at
|
||||||
, scope AS oauth2_authorization_grant_scope
|
, scope
|
||||||
, state AS oauth2_authorization_grant_state
|
, state
|
||||||
, redirect_uri AS oauth2_authorization_grant_redirect_uri
|
, redirect_uri
|
||||||
, response_mode AS oauth2_authorization_grant_response_mode
|
, response_mode
|
||||||
, nonce AS oauth2_authorization_grant_nonce
|
, nonce
|
||||||
, max_age AS oauth2_authorization_grant_max_age
|
, max_age
|
||||||
, oauth2_client_id AS oauth2_client_id
|
, oauth2_client_id
|
||||||
, authorization_code AS oauth2_authorization_grant_code
|
, authorization_code
|
||||||
, response_type_code AS oauth2_authorization_grant_response_type_code
|
, response_type_code
|
||||||
, response_type_id_token AS oauth2_authorization_grant_response_type_id_token
|
, response_type_id_token
|
||||||
, code_challenge AS oauth2_authorization_grant_code_challenge
|
, code_challenge
|
||||||
, code_challenge_method AS oauth2_authorization_grant_code_challenge_method
|
, code_challenge_method
|
||||||
, requires_consent AS oauth2_authorization_grant_requires_consent
|
, requires_consent
|
||||||
, oauth2_session_id AS "oauth2_session_id?"
|
, oauth2_session_id
|
||||||
FROM
|
FROM
|
||||||
oauth2_authorization_grants
|
oauth2_authorization_grants
|
||||||
|
|
||||||
@ -376,68 +432,110 @@ pub async fn lookup_grant_by_code(
|
|||||||
"#,
|
"#,
|
||||||
code,
|
code,
|
||||||
)
|
)
|
||||||
.fetch_one(&mut *conn)
|
.traced()
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
.await
|
.await
|
||||||
.to_option()?;
|
.to_option()?;
|
||||||
|
|
||||||
let Some(res) = res else { return Ok(None) };
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
Ok(Some(res.try_into()?))
|
Ok(Some(res.try_into()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_authorization_grant.fulfill",
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
|
db.statement,
|
||||||
%grant.id,
|
%grant.id,
|
||||||
client.id = %grant.client_id,
|
client.id = %grant.client_id,
|
||||||
%session.id,
|
%session.id,
|
||||||
user_session.id = %session.user_session_id,
|
user_session.id = %session.user_session_id,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn fulfill_grant(
|
async fn fulfill(
|
||||||
executor: impl PgExecutor<'_>,
|
&mut self,
|
||||||
mut grant: AuthorizationGrant,
|
clock: &Clock,
|
||||||
session: Session,
|
session: &Session,
|
||||||
) -> Result<AuthorizationGrant, DatabaseError> {
|
grant: AuthorizationGrant,
|
||||||
let fulfilled_at = sqlx::query_scalar!(
|
) -> Result<AuthorizationGrant, Self::Error> {
|
||||||
|
let fulfilled_at = clock.now();
|
||||||
|
let res = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
UPDATE oauth2_authorization_grants AS og
|
UPDATE oauth2_authorization_grants
|
||||||
SET
|
SET fulfilled_at = $2
|
||||||
oauth2_session_id = os.oauth2_session_id,
|
, oauth2_session_id = $3
|
||||||
fulfilled_at = os.created_at
|
WHERE oauth2_authorization_grant_id = $1
|
||||||
FROM oauth2_sessions os
|
|
||||||
WHERE
|
|
||||||
og.oauth2_authorization_grant_id = $1
|
|
||||||
AND os.oauth2_session_id = $2
|
|
||||||
RETURNING fulfilled_at AS "fulfilled_at!: DateTime<Utc>"
|
|
||||||
"#,
|
"#,
|
||||||
Uuid::from(grant.id),
|
Uuid::from(grant.id),
|
||||||
|
fulfilled_at,
|
||||||
Uuid::from(session.id),
|
Uuid::from(session.id),
|
||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.execute(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
grant.stage = grant
|
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||||
.stage
|
|
||||||
.fulfill(fulfilled_at, &session)
|
// XXX: check affected rows & new methods
|
||||||
|
let grant = grant
|
||||||
|
.fulfill(fulfilled_at, session)
|
||||||
.map_err(DatabaseError::to_invalid_operation)?;
|
.map_err(DatabaseError::to_invalid_operation)?;
|
||||||
|
|
||||||
Ok(grant)
|
Ok(grant)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_authorization_grant.exchange",
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
|
db.statement,
|
||||||
%grant.id,
|
%grant.id,
|
||||||
client.id = %grant.client_id,
|
client.id = %grant.client_id,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn give_consent_to_grant(
|
async fn exchange(
|
||||||
executor: impl PgExecutor<'_>,
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
grant: AuthorizationGrant,
|
||||||
|
) -> Result<AuthorizationGrant, Self::Error> {
|
||||||
|
let exchanged_at = clock.now();
|
||||||
|
let res = sqlx::query!(
|
||||||
|
r#"
|
||||||
|
UPDATE oauth2_authorization_grants
|
||||||
|
SET exchanged_at = $2
|
||||||
|
WHERE oauth2_authorization_grant_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(grant.id),
|
||||||
|
exchanged_at,
|
||||||
|
)
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||||
|
|
||||||
|
let grant = grant
|
||||||
|
.exchange(exchanged_at)
|
||||||
|
.map_err(DatabaseError::to_invalid_operation)?;
|
||||||
|
|
||||||
|
Ok(grant)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_authorization_grant.give_consent",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
%grant.id,
|
||||||
|
client.id = %grant.client_id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn give_consent(
|
||||||
|
&mut self,
|
||||||
mut grant: AuthorizationGrant,
|
mut grant: AuthorizationGrant,
|
||||||
) -> Result<AuthorizationGrant, sqlx::Error> {
|
) -> Result<AuthorizationGrant, Self::Error> {
|
||||||
sqlx::query!(
|
sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
UPDATE oauth2_authorization_grants AS og
|
UPDATE oauth2_authorization_grants AS og
|
||||||
@ -448,44 +546,11 @@ pub async fn give_consent_to_grant(
|
|||||||
"#,
|
"#,
|
||||||
Uuid::from(grant.id),
|
Uuid::from(grant.id),
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
grant.requires_consent = false;
|
grant.requires_consent = false;
|
||||||
|
|
||||||
Ok(grant)
|
Ok(grant)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
%grant.id,
|
|
||||||
client.id = %grant.client_id,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn exchange_grant(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
clock: &Clock,
|
|
||||||
mut grant: AuthorizationGrant,
|
|
||||||
) -> Result<AuthorizationGrant, DatabaseError> {
|
|
||||||
let exchanged_at = clock.now();
|
|
||||||
sqlx::query!(
|
|
||||||
r#"
|
|
||||||
UPDATE oauth2_authorization_grants
|
|
||||||
SET exchanged_at = $2
|
|
||||||
WHERE oauth2_authorization_grant_id = $1
|
|
||||||
"#,
|
|
||||||
Uuid::from(grant.id),
|
|
||||||
exchanged_at,
|
|
||||||
)
|
|
||||||
.execute(executor)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
grant.stage = grant
|
|
||||||
.stage
|
|
||||||
.exchange(exchanged_at)
|
|
||||||
.map_err(DatabaseError::to_invalid_operation)?;
|
|
||||||
|
|
||||||
Ok(grant)
|
|
||||||
}
|
}
|
||||||
|
@ -14,17 +14,21 @@
|
|||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, BTreeSet},
|
collections::{BTreeMap, BTreeSet},
|
||||||
|
str::FromStr,
|
||||||
string::ToString,
|
string::ToString,
|
||||||
};
|
};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use mas_data_model::{Client, JwksOrJwksUri};
|
use mas_data_model::{Client, JwksOrJwksUri, User};
|
||||||
use mas_iana::{
|
use mas_iana::{
|
||||||
jose::JsonWebSignatureAlg,
|
jose::JsonWebSignatureAlg,
|
||||||
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
|
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
|
||||||
};
|
};
|
||||||
use mas_jose::jwk::PublicJsonWebKeySet;
|
use mas_jose::jwk::PublicJsonWebKeySet;
|
||||||
use oauth2_types::requests::GrantType;
|
use oauth2_types::{
|
||||||
|
requests::GrantType,
|
||||||
|
scope::{Scope, ScopeToken},
|
||||||
|
};
|
||||||
use rand::{Rng, RngCore};
|
use rand::{Rng, RngCore};
|
||||||
use sqlx::PgConnection;
|
use sqlx::PgConnection;
|
||||||
use tracing::{info_span, Instrument};
|
use tracing::{info_span, Instrument};
|
||||||
@ -87,6 +91,21 @@ pub trait OAuth2ClientRepository: Send + Sync {
|
|||||||
jwks_uri: Option<Url>,
|
jwks_uri: Option<Url>,
|
||||||
redirect_uris: Vec<Url>,
|
redirect_uris: Vec<Url>,
|
||||||
) -> Result<Client, Self::Error>;
|
) -> Result<Client, Self::Error>;
|
||||||
|
|
||||||
|
async fn get_consent_for_user(
|
||||||
|
&mut self,
|
||||||
|
client: &Client,
|
||||||
|
user: &User,
|
||||||
|
) -> Result<Scope, Self::Error>;
|
||||||
|
|
||||||
|
async fn give_consent_for_user(
|
||||||
|
&mut self,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
|
clock: &Clock,
|
||||||
|
client: &Client,
|
||||||
|
user: &User,
|
||||||
|
scope: &Scope,
|
||||||
|
) -> Result<(), Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PgOAuth2ClientRepository<'c> {
|
pub struct PgOAuth2ClientRepository<'c> {
|
||||||
@ -702,4 +721,94 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
|
|||||||
initiate_login_uri: None,
|
initiate_login_uri: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_client.get_consent_for_user",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
%user.id,
|
||||||
|
%client.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn get_consent_for_user(
|
||||||
|
&mut self,
|
||||||
|
client: &Client,
|
||||||
|
user: &User,
|
||||||
|
) -> Result<Scope, Self::Error> {
|
||||||
|
let scope_tokens: Vec<String> = sqlx::query_scalar!(
|
||||||
|
r#"
|
||||||
|
SELECT scope_token
|
||||||
|
FROM oauth2_consents
|
||||||
|
WHERE user_id = $1 AND oauth2_client_id = $2
|
||||||
|
"#,
|
||||||
|
Uuid::from(user.id),
|
||||||
|
Uuid::from(client.id),
|
||||||
|
)
|
||||||
|
.fetch_all(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let scope: Result<Scope, _> = scope_tokens
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| ScopeToken::from_str(&s))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let scope = scope.map_err(|e| {
|
||||||
|
DatabaseInconsistencyError::on("oauth2_consents")
|
||||||
|
.column("scope_token")
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
%user.id,
|
||||||
|
%client.id,
|
||||||
|
%scope,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn give_consent_for_user(
|
||||||
|
&mut self,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
|
clock: &Clock,
|
||||||
|
client: &Client,
|
||||||
|
user: &User,
|
||||||
|
scope: &Scope,
|
||||||
|
) -> Result<(), Self::Error> {
|
||||||
|
let now = clock.now();
|
||||||
|
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
|
||||||
|
.iter()
|
||||||
|
.map(|token| {
|
||||||
|
(
|
||||||
|
token.to_string(),
|
||||||
|
Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
INSERT INTO oauth2_consents
|
||||||
|
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
|
||||||
|
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
|
||||||
|
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
|
||||||
|
"#,
|
||||||
|
&ids,
|
||||||
|
Uuid::from(user.id),
|
||||||
|
Uuid::from(client.id),
|
||||||
|
&tokens,
|
||||||
|
now,
|
||||||
|
)
|
||||||
|
.traced()
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,110 +0,0 @@
|
|||||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
use mas_data_model::{Client, User};
|
|
||||||
use oauth2_types::scope::{Scope, ScopeToken};
|
|
||||||
use rand::Rng;
|
|
||||||
use sqlx::PgExecutor;
|
|
||||||
use ulid::Ulid;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
|
|
||||||
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
%user.id,
|
|
||||||
%client.id,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn fetch_client_consent(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
user: &User,
|
|
||||||
client: &Client,
|
|
||||||
) -> Result<Scope, DatabaseError> {
|
|
||||||
let scope_tokens: Vec<String> = sqlx::query_scalar!(
|
|
||||||
r#"
|
|
||||||
SELECT scope_token
|
|
||||||
FROM oauth2_consents
|
|
||||||
WHERE user_id = $1 AND oauth2_client_id = $2
|
|
||||||
"#,
|
|
||||||
Uuid::from(user.id),
|
|
||||||
Uuid::from(client.id),
|
|
||||||
)
|
|
||||||
.fetch_all(executor)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let scope: Result<Scope, _> = scope_tokens
|
|
||||||
.into_iter()
|
|
||||||
.map(|s| ScopeToken::from_str(&s))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let scope = scope.map_err(|e| {
|
|
||||||
DatabaseInconsistencyError::on("oauth2_consents")
|
|
||||||
.column("scope_token")
|
|
||||||
.source(e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(scope)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
%user.id,
|
|
||||||
%client.id,
|
|
||||||
%scope,
|
|
||||||
),
|
|
||||||
err,
|
|
||||||
)]
|
|
||||||
pub async fn insert_client_consent(
|
|
||||||
executor: impl PgExecutor<'_>,
|
|
||||||
mut rng: impl Rng + Send,
|
|
||||||
clock: &Clock,
|
|
||||||
user: &User,
|
|
||||||
client: &Client,
|
|
||||||
scope: &Scope,
|
|
||||||
) -> Result<(), sqlx::Error> {
|
|
||||||
let now = clock.now();
|
|
||||||
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
|
|
||||||
.iter()
|
|
||||||
.map(|token| {
|
|
||||||
(
|
|
||||||
token.to_string(),
|
|
||||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.unzip();
|
|
||||||
|
|
||||||
sqlx::query!(
|
|
||||||
r#"
|
|
||||||
INSERT INTO oauth2_consents
|
|
||||||
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
|
|
||||||
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
|
|
||||||
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
|
|
||||||
"#,
|
|
||||||
&ids,
|
|
||||||
Uuid::from(user.id),
|
|
||||||
Uuid::from(client.id),
|
|
||||||
&tokens,
|
|
||||||
now,
|
|
||||||
)
|
|
||||||
.execute(executor)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -12,14 +12,18 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
pub mod access_token;
|
mod access_token;
|
||||||
pub mod authorization_grant;
|
pub mod authorization_grant;
|
||||||
mod client;
|
mod client;
|
||||||
pub mod consent;
|
mod refresh_token;
|
||||||
pub mod refresh_token;
|
|
||||||
mod session;
|
mod session;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
|
access_token::{OAuth2AccessTokenRepository, PgOAuth2AccessTokenRepository},
|
||||||
|
authorization_grant::{
|
||||||
|
OAuth2AuthorizationGrantRepository, PgOAuth2AuthorizationGrantRepository,
|
||||||
|
},
|
||||||
client::{OAuth2ClientRepository, PgOAuth2ClientRepository},
|
client::{OAuth2ClientRepository, PgOAuth2ClientRepository},
|
||||||
|
refresh_token::{OAuth2RefreshTokenRepository, PgOAuth2RefreshTokenRepository},
|
||||||
session::{OAuth2SessionRepository, PgOAuth2SessionRepository},
|
session::{OAuth2SessionRepository, PgOAuth2SessionRepository},
|
||||||
};
|
};
|
||||||
|
@ -12,35 +12,181 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
|
use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
|
||||||
use rand::Rng;
|
use rand::RngCore;
|
||||||
use sqlx::{PgConnection, PgExecutor};
|
use sqlx::PgConnection;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{Clock, DatabaseError};
|
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[async_trait]
|
||||||
|
pub trait OAuth2RefreshTokenRepository: Send + Sync {
|
||||||
|
type Error;
|
||||||
|
|
||||||
|
/// Lookup a refresh token by its ID
|
||||||
|
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error>;
|
||||||
|
|
||||||
|
/// Find a refresh token by its token
|
||||||
|
async fn find_by_token(
|
||||||
|
&mut self,
|
||||||
|
refresh_token: &str,
|
||||||
|
) -> Result<Option<RefreshToken>, Self::Error>;
|
||||||
|
|
||||||
|
/// Add a new refresh token to the database
|
||||||
|
async fn add(
|
||||||
|
&mut self,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
|
clock: &Clock,
|
||||||
|
session: &Session,
|
||||||
|
access_token: &AccessToken,
|
||||||
|
refresh_token: String,
|
||||||
|
) -> Result<RefreshToken, Self::Error>;
|
||||||
|
|
||||||
|
/// Consume a refresh token
|
||||||
|
async fn consume(
|
||||||
|
&mut self,
|
||||||
|
clock: &Clock,
|
||||||
|
refresh_token: RefreshToken,
|
||||||
|
) -> Result<RefreshToken, Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PgOAuth2RefreshTokenRepository<'c> {
|
||||||
|
conn: &'c mut PgConnection,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'c> PgOAuth2RefreshTokenRepository<'c> {
|
||||||
|
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||||
|
Self { conn }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct OAuth2RefreshTokenLookup {
|
||||||
|
oauth2_refresh_token_id: Uuid,
|
||||||
|
refresh_token: String,
|
||||||
|
created_at: DateTime<Utc>,
|
||||||
|
consumed_at: Option<DateTime<Utc>>,
|
||||||
|
oauth2_access_token_id: Option<Uuid>,
|
||||||
|
oauth2_session_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OAuth2RefreshTokenLookup> for RefreshToken {
|
||||||
|
fn from(value: OAuth2RefreshTokenLookup) -> Self {
|
||||||
|
let state = match value.consumed_at {
|
||||||
|
None => RefreshTokenState::Valid,
|
||||||
|
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
|
||||||
|
};
|
||||||
|
|
||||||
|
RefreshToken {
|
||||||
|
id: value.oauth2_refresh_token_id.into(),
|
||||||
|
state,
|
||||||
|
session_id: value.oauth2_session_id.into(),
|
||||||
|
refresh_token: value.refresh_token,
|
||||||
|
created_at: value.created_at,
|
||||||
|
access_token_id: value.oauth2_access_token_id.map(Ulid::from),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<'c> OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'c> {
|
||||||
|
type Error = DatabaseError;
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_refresh_token.lookup",
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
|
db.statement,
|
||||||
|
refresh_token.id = %id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
|
||||||
|
let res = sqlx::query_as!(
|
||||||
|
OAuth2RefreshTokenLookup,
|
||||||
|
r#"
|
||||||
|
SELECT oauth2_refresh_token_id
|
||||||
|
, refresh_token
|
||||||
|
, created_at
|
||||||
|
, consumed_at
|
||||||
|
, oauth2_access_token_id
|
||||||
|
, oauth2_session_id
|
||||||
|
FROM oauth2_refresh_tokens
|
||||||
|
|
||||||
|
WHERE oauth2_refresh_token_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
|
Ok(Some(res.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_refresh_token.find_by_token",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn find_by_token(
|
||||||
|
&mut self,
|
||||||
|
refresh_token: &str,
|
||||||
|
) -> Result<Option<RefreshToken>, Self::Error> {
|
||||||
|
let res = sqlx::query_as!(
|
||||||
|
OAuth2RefreshTokenLookup,
|
||||||
|
r#"
|
||||||
|
SELECT oauth2_refresh_token_id
|
||||||
|
, refresh_token
|
||||||
|
, created_at
|
||||||
|
, consumed_at
|
||||||
|
, oauth2_access_token_id
|
||||||
|
, oauth2_session_id
|
||||||
|
FROM oauth2_refresh_tokens
|
||||||
|
|
||||||
|
WHERE refresh_token = $1
|
||||||
|
"#,
|
||||||
|
refresh_token,
|
||||||
|
)
|
||||||
|
.traced()
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
|
let Some(res) = res else { return Ok(None) };
|
||||||
|
|
||||||
|
Ok(Some(res.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_refresh_token.add",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
%session.id,
|
%session.id,
|
||||||
user_session.id = %session.user_session_id,
|
user_session.id = %session.user_session_id,
|
||||||
client.id = %session.client_id,
|
client.id = %session.client_id,
|
||||||
refresh_token.id,
|
refresh_token.id,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn add_refresh_token(
|
async fn add(
|
||||||
executor: impl PgExecutor<'_>,
|
&mut self,
|
||||||
mut rng: impl Rng + Send,
|
rng: &mut (dyn RngCore + Send),
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
session: &Session,
|
session: &Session,
|
||||||
access_token: &AccessToken,
|
access_token: &AccessToken,
|
||||||
refresh_token: String,
|
refresh_token: String,
|
||||||
) -> Result<RefreshToken, sqlx::Error> {
|
) -> Result<RefreshToken, Self::Error> {
|
||||||
let created_at = clock.now();
|
let created_at = clock.now();
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||||
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
|
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
|
||||||
|
|
||||||
sqlx::query!(
|
sqlx::query!(
|
||||||
@ -57,7 +203,8 @@ pub async fn add_refresh_token(
|
|||||||
refresh_token,
|
refresh_token,
|
||||||
created_at,
|
created_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.traced()
|
||||||
|
.execute(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(RefreshToken {
|
Ok(RefreshToken {
|
||||||
@ -68,70 +215,23 @@ pub async fn add_refresh_token(
|
|||||||
access_token_id: Some(access_token.id),
|
access_token_id: Some(access_token.id),
|
||||||
created_at,
|
created_at,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
struct OAuth2RefreshTokenLookup {
|
#[tracing::instrument(
|
||||||
oauth2_refresh_token_id: Uuid,
|
name = "db.oauth2_refresh_token.consume",
|
||||||
refresh_token: String,
|
|
||||||
created_at: DateTime<Utc>,
|
|
||||||
consumed_at: Option<DateTime<Utc>>,
|
|
||||||
oauth2_access_token_id: Option<Uuid>,
|
|
||||||
oauth2_session_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
|
||||||
#[allow(clippy::too_many_lines)]
|
|
||||||
pub async fn lookup_refresh_token(
|
|
||||||
conn: &mut PgConnection,
|
|
||||||
token: &str,
|
|
||||||
) -> Result<Option<RefreshToken>, DatabaseError> {
|
|
||||||
let res = sqlx::query_as!(
|
|
||||||
OAuth2RefreshTokenLookup,
|
|
||||||
r#"
|
|
||||||
SELECT oauth2_refresh_token_id
|
|
||||||
, refresh_token
|
|
||||||
, created_at
|
|
||||||
, consumed_at
|
|
||||||
, oauth2_access_token_id
|
|
||||||
, oauth2_session_id
|
|
||||||
FROM oauth2_refresh_tokens
|
|
||||||
|
|
||||||
WHERE refresh_token = $1
|
|
||||||
"#,
|
|
||||||
token,
|
|
||||||
)
|
|
||||||
.fetch_one(&mut *conn)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let state = match res.consumed_at {
|
|
||||||
None => RefreshTokenState::Valid,
|
|
||||||
Some(consumed_at) => RefreshTokenState::Consumed { consumed_at },
|
|
||||||
};
|
|
||||||
|
|
||||||
let refresh_token = RefreshToken {
|
|
||||||
id: res.oauth2_refresh_token_id.into(),
|
|
||||||
state,
|
|
||||||
session_id: res.oauth2_session_id.into(),
|
|
||||||
refresh_token: res.refresh_token,
|
|
||||||
created_at: res.created_at,
|
|
||||||
access_token_id: res.oauth2_access_token_id.map(Ulid::from),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Some(refresh_token))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(
|
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
|
db.statement,
|
||||||
%refresh_token.id,
|
%refresh_token.id,
|
||||||
|
session.id = %refresh_token.session_id,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
pub async fn consume_refresh_token(
|
async fn consume(
|
||||||
executor: impl PgExecutor<'_>,
|
&mut self,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
refresh_token: RefreshToken,
|
refresh_token: RefreshToken,
|
||||||
) -> Result<RefreshToken, DatabaseError> {
|
) -> Result<RefreshToken, Self::Error> {
|
||||||
let consumed_at = clock.now();
|
let consumed_at = clock.now();
|
||||||
let res = sqlx::query!(
|
let res = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
@ -142,7 +242,7 @@ pub async fn consume_refresh_token(
|
|||||||
Uuid::from(refresh_token.id),
|
Uuid::from(refresh_token.id),
|
||||||
consumed_at,
|
consumed_at,
|
||||||
)
|
)
|
||||||
.execute(executor)
|
.execute(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||||
@ -150,4 +250,5 @@ pub async fn consume_refresh_token(
|
|||||||
refresh_token
|
refresh_token
|
||||||
.consume(consumed_at)
|
.consume(consumed_at)
|
||||||
.map_err(DatabaseError::to_invalid_operation)
|
.map_err(DatabaseError::to_invalid_operation)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,10 @@ use crate::{
|
|||||||
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
|
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
|
||||||
PgCompatSsoLoginRepository,
|
PgCompatSsoLoginRepository,
|
||||||
},
|
},
|
||||||
oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository},
|
oauth2::{
|
||||||
|
PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
|
||||||
|
PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
|
||||||
|
},
|
||||||
upstream_oauth2::{
|
upstream_oauth2::{
|
||||||
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
|
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
|
||||||
PgUpstreamOAuthSessionRepository,
|
PgUpstreamOAuthSessionRepository,
|
||||||
@ -63,10 +66,22 @@ pub trait Repository {
|
|||||||
where
|
where
|
||||||
Self: 'c;
|
Self: 'c;
|
||||||
|
|
||||||
|
type OAuth2AuthorizationGrantRepository<'c>
|
||||||
|
where
|
||||||
|
Self: 'c;
|
||||||
|
|
||||||
type OAuth2SessionRepository<'c>
|
type OAuth2SessionRepository<'c>
|
||||||
where
|
where
|
||||||
Self: 'c;
|
Self: 'c;
|
||||||
|
|
||||||
|
type OAuth2AccessTokenRepository<'c>
|
||||||
|
where
|
||||||
|
Self: 'c;
|
||||||
|
|
||||||
|
type OAuth2RefreshTokenRepository<'c>
|
||||||
|
where
|
||||||
|
Self: 'c;
|
||||||
|
|
||||||
type CompatSessionRepository<'c>
|
type CompatSessionRepository<'c>
|
||||||
where
|
where
|
||||||
Self: 'c;
|
Self: 'c;
|
||||||
@ -91,7 +106,10 @@ pub trait Repository {
|
|||||||
fn user_password(&mut self) -> Self::UserPasswordRepository<'_>;
|
fn user_password(&mut self) -> Self::UserPasswordRepository<'_>;
|
||||||
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
|
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
|
||||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
|
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
|
||||||
|
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_>;
|
||||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
|
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
|
||||||
|
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_>;
|
||||||
|
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_>;
|
||||||
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>;
|
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>;
|
||||||
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>;
|
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>;
|
||||||
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
|
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
|
||||||
@ -107,7 +125,10 @@ impl Repository for PgConnection {
|
|||||||
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
||||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||||
|
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
|
||||||
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
||||||
|
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
|
||||||
|
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
|
||||||
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
|
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
|
||||||
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
|
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
|
||||||
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
|
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
|
||||||
@ -145,10 +166,22 @@ impl Repository for PgConnection {
|
|||||||
PgOAuth2ClientRepository::new(self)
|
PgOAuth2ClientRepository::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
|
||||||
|
PgOAuth2AuthorizationGrantRepository::new(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
||||||
PgOAuth2SessionRepository::new(self)
|
PgOAuth2SessionRepository::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
|
||||||
|
PgOAuth2AccessTokenRepository::new(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
|
||||||
|
PgOAuth2RefreshTokenRepository::new(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
|
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
|
||||||
PgCompatSessionRepository::new(self)
|
PgCompatSessionRepository::new(self)
|
||||||
}
|
}
|
||||||
@ -175,7 +208,10 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
|||||||
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
||||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||||
|
type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c;
|
||||||
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
||||||
|
type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c;
|
||||||
|
type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c;
|
||||||
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
|
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
|
||||||
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
|
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
|
||||||
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
|
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
|
||||||
@ -213,10 +249,22 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
|||||||
PgOAuth2ClientRepository::new(self)
|
PgOAuth2ClientRepository::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> {
|
||||||
|
PgOAuth2AuthorizationGrantRepository::new(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
||||||
PgOAuth2SessionRepository::new(self)
|
PgOAuth2SessionRepository::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> {
|
||||||
|
PgOAuth2AccessTokenRepository::new(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> {
|
||||||
|
PgOAuth2RefreshTokenRepository::new(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
|
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
|
||||||
PgCompatSessionRepository::new(self)
|
PgCompatSessionRepository::new(self)
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
//! Database-related tasks
|
//! Database-related tasks
|
||||||
|
|
||||||
use mas_storage::Clock;
|
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository};
|
||||||
use sqlx::{Pool, Postgres};
|
use sqlx::{Pool, Postgres};
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
@ -32,7 +32,12 @@ impl std::fmt::Debug for CleanupExpired {
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl Task for CleanupExpired {
|
impl Task for CleanupExpired {
|
||||||
async fn run(&self) {
|
async fn run(&self) {
|
||||||
let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await;
|
let res = async move {
|
||||||
|
let mut conn = self.0.acquire().await?;
|
||||||
|
conn.oauth2_access_token().cleanup_expired(&self.1).await
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
match res {
|
match res {
|
||||||
Ok(0) => {
|
Ok(0) => {
|
||||||
debug!("no token to clean up");
|
debug!("no token to clean up");
|
||||||
|
Reference in New Issue
Block a user