1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

data-model: simplify the authorization grants and sessions

This commit is contained in:
Quentin Gliech
2022-12-07 15:08:04 +01:00
parent 92d6f5b087
commit 12ce2a3d04
18 changed files with 92 additions and 233 deletions

View File

@ -29,7 +29,7 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode
use mas_data_model::Session; use mas_data_model::Session;
use mas_storage::{ use mas_storage::{
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError}, oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
LookupError, PostgresqlBackend, LookupError,
}; };
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use sqlx::PgConnection; use sqlx::PgConnection;
@ -55,10 +55,7 @@ impl AccessToken {
pub async fn fetch( pub async fn fetch(
&self, &self,
conn: &mut PgConnection, conn: &mut PgConnection,
) -> Result< ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> {
(mas_data_model::AccessToken, Session<PostgresqlBackend>),
AuthorizationVerificationError,
> {
let token = match self { let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::Form(t) | AccessToken::Header(t) => t,
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
@ -81,7 +78,7 @@ impl<F: Send> UserAuthorization<F> {
pub async fn protected_form( pub async fn protected_form(
self, self,
conn: &mut PgConnection, conn: &mut PgConnection,
) -> Result<(Session<PostgresqlBackend>, F), AuthorizationVerificationError> { ) -> Result<(Session, F), AuthorizationVerificationError> {
let form = match self.form { let form = match self.form {
Some(f) => f, Some(f) => f,
None => return Err(AuthorizationVerificationError::MissingForm), None => return Err(AuthorizationVerificationError::MissingForm),
@ -96,7 +93,7 @@ impl<F: Send> UserAuthorization<F> {
pub async fn protected( pub async fn protected(
self, self,
conn: &mut PgConnection, conn: &mut PgConnection,
) -> Result<Session<PostgresqlBackend>, AuthorizationVerificationError> { ) -> Result<Session, AuthorizationVerificationError> {
let (_token, session) = self.access_token.fetch(conn).await?; let (_token, session) = self.access_token.fetch(conn).await?;
Ok(session) Ok(session)

View File

@ -26,7 +26,6 @@
pub(crate) mod compat; pub(crate) mod compat;
pub(crate) mod oauth2; pub(crate) mod oauth2;
pub(crate) mod tokens; pub(crate) mod tokens;
pub(crate) mod traits;
pub(crate) mod upstream_oauth2; pub(crate) mod upstream_oauth2;
pub(crate) mod users; pub(crate) mod users;
@ -40,7 +39,6 @@ pub use self::{
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session,
}, },
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
traits::{StorageBackend, StorageBackendMarker},
upstream_oauth2::{ upstream_oauth2::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider,
}, },

View File

@ -22,10 +22,10 @@ use oauth2_types::{
}; };
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid;
use url::Url; use url::Url;
use super::{client::Client, session::Session}; use super::{client::Client, session::Session};
use crate::{traits::StorageBackend, StorageBackendMarker};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Pkce { pub struct Pkce {
@ -57,16 +57,17 @@ pub struct AuthorizationCode {
#[error("invalid state transition")] #[error("invalid state transition")]
pub struct InvalidTransitionError; pub struct InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
#[serde(bound = "T: StorageBackend", tag = "stage", rename_all = "lowercase")] #[serde(tag = "stage", rename_all = "lowercase")]
pub enum AuthorizationGrantStage<T: StorageBackend> { pub enum AuthorizationGrantStage {
#[default]
Pending, Pending,
Fulfilled { Fulfilled {
session: Session<T>, session: Session,
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
}, },
Exchanged { Exchanged {
session: Session<T>, session: Session,
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
exchanged_at: DateTime<Utc>, exchanged_at: DateTime<Utc>,
}, },
@ -75,13 +76,7 @@ pub enum AuthorizationGrantStage<T: StorageBackend> {
}, },
} }
impl<T: StorageBackend> Default for AuthorizationGrantStage<T> { impl AuthorizationGrantStage {
fn default() -> Self {
Self::Pending
}
}
impl<T: StorageBackend> AuthorizationGrantStage<T> {
#[must_use] #[must_use]
pub fn new() -> Self { pub fn new() -> Self {
Self::Pending Self::Pending
@ -90,7 +85,7 @@ impl<T: StorageBackend> AuthorizationGrantStage<T> {
pub fn fulfill( pub fn fulfill(
self, self,
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
session: Session<T>, session: Session,
) -> Result<Self, InvalidTransitionError> { ) -> Result<Self, InvalidTransitionError> {
match self { match self {
Self::Pending => Ok(Self::Fulfilled { Self::Pending => Ok(Self::Fulfilled {
@ -131,39 +126,11 @@ impl<T: StorageBackend> AuthorizationGrantStage<T> {
} }
} }
impl<S: StorageBackendMarker> From<AuthorizationGrantStage<S>> for AuthorizationGrantStage<()> { #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
fn from(s: AuthorizationGrantStage<S>) -> Self { pub struct AuthorizationGrant {
use AuthorizationGrantStage::{Cancelled, Exchanged, Fulfilled, Pending}; pub id: Ulid,
match s {
Pending => Pending,
Fulfilled {
session,
fulfilled_at,
} => Fulfilled {
session: session.into(),
fulfilled_at,
},
Exchanged {
session,
fulfilled_at,
exchanged_at,
} => Exchanged {
session: session.into(),
fulfilled_at,
exchanged_at,
},
Cancelled { cancelled_at } => Cancelled { cancelled_at },
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(bound = "T: StorageBackend")]
pub struct AuthorizationGrant<T: StorageBackend> {
#[serde(skip_serializing)]
pub data: T::AuthorizationGrantData,
#[serde(flatten)] #[serde(flatten)]
pub stage: AuthorizationGrantStage<T>, pub stage: AuthorizationGrantStage,
pub code: Option<AuthorizationCode>, pub code: Option<AuthorizationCode>,
pub client: Client, pub client: Client,
pub redirect_uri: Url, pub redirect_uri: Url,
@ -177,27 +144,8 @@ pub struct AuthorizationGrant<T: StorageBackend> {
pub requires_consent: bool, pub requires_consent: bool,
} }
impl<S: StorageBackendMarker> From<AuthorizationGrant<S>> for AuthorizationGrant<()> { impl AuthorizationGrant {
fn from(g: AuthorizationGrant<S>) -> Self { #[must_use]
AuthorizationGrant {
data: (),
stage: g.stage.into(),
code: g.code,
client: g.client,
redirect_uri: g.redirect_uri,
scope: g.scope,
state: g.state,
nonce: g.nonce,
max_age: g.max_age,
response_mode: g.response_mode,
response_type_id_token: g.response_type_id_token,
created_at: g.created_at,
requires_consent: g.requires_consent,
}
}
}
impl<T: StorageBackend> AuthorizationGrant<T> {
pub fn max_auth_time(&self) -> DateTime<Utc> { pub fn max_auth_time(&self) -> DateTime<Utc> {
let max_age: Option<i64> = self.max_age.map(|x| x.get().into()); let max_age: Option<i64> = self.max_age.map(|x| x.get().into());
self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365)) self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365))

View File

@ -14,30 +14,15 @@
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use serde::Serialize; use serde::Serialize;
use ulid::Ulid;
use super::client::Client; use super::client::Client;
use crate::{ use crate::users::BrowserSession;
traits::{StorageBackend, StorageBackendMarker},
users::BrowserSession,
};
#[derive(Debug, Clone, PartialEq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(bound = "T: StorageBackend")] pub struct Session {
pub struct Session<T: StorageBackend> { pub id: Ulid,
#[serde(skip_serializing)]
pub data: T::SessionData,
pub browser_session: BrowserSession, pub browser_session: BrowserSession,
pub client: Client, pub client: Client,
pub scope: Scope, pub scope: Scope,
} }
impl<S: StorageBackendMarker> From<Session<S>> for Session<()> {
fn from(s: Session<S>) -> Self {
Session {
data: (),
browser_session: s.browser_session,
client: s.client,
scope: s.scope,
}
}
}

View File

@ -1,42 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Debug;
use serde::{de::DeserializeOwned, Serialize};
pub trait StorageBackendMarker: StorageBackend {}
/// Marker trait of traits that should be implemented by primary keys
pub trait Data:
Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default + Sync + Send
{
}
impl<T: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default + Sync + Send> Data
for T
{
}
pub trait StorageBackend {
type ClientData: Data;
type SessionData: Data;
type AuthorizationGrantData: Data;
}
impl StorageBackend for () {
type AuthorizationGrantData = ();
type ClientData = ();
type SessionData = ();
}

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use async_graphql::{Context, Description, Object, ID}; use async_graphql::{Context, Description, Object, ID};
use mas_storage::{oauth2::client::lookup_client, PostgresqlBackend}; use mas_storage::oauth2::client::lookup_client;
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use sqlx::PgPool; use sqlx::PgPool;
use ulid::Ulid; use ulid::Ulid;
@ -24,13 +24,13 @@ use super::{BrowserSession, NodeType, User};
/// An OAuth 2.0 session represents a client session which used the OAuth APIs /// An OAuth 2.0 session represents a client session which used the OAuth APIs
/// to login. /// to login.
#[derive(Description)] #[derive(Description)]
pub struct OAuth2Session(pub mas_data_model::Session<PostgresqlBackend>); pub struct OAuth2Session(pub mas_data_model::Session);
#[Object(use_type_description)] #[Object(use_type_description)]
impl OAuth2Session { impl OAuth2Session {
/// ID of the object. /// ID of the object.
pub async fn id(&self) -> ID { pub async fn id(&self) -> ID {
NodeType::OAuth2Session.id(self.0.data) NodeType::OAuth2Session.id(self.0.id)
} }
/// OAuth 2.0 client used by this session. /// OAuth 2.0 client used by this session.

View File

@ -242,7 +242,7 @@ impl User {
let mut connection = Connection::new(has_previous_page, has_next_page); let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|s| { connection.edges.extend(edges.into_iter().map(|s| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.data)), OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)),
OAuth2Session(s), OAuth2Session(s),
) )
})); }));

View File

@ -17,7 +17,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use axum::response::{Html, IntoResponse, Redirect, Response}; use axum::response::{Html, IntoResponse, Redirect, Response};
use mas_data_model::{AuthorizationGrant, StorageBackend}; use mas_data_model::AuthorizationGrant;
use mas_templates::{FormPostContext, Templates}; use mas_templates::{FormPostContext, Templates};
use oauth2_types::requests::ResponseMode; use oauth2_types::requests::ResponseMode;
use serde::Serialize; use serde::Serialize;
@ -61,10 +61,10 @@ pub enum CallbackDestinationError {
ParamsSerialization(#[from] serde_urlencoded::ser::Error), ParamsSerialization(#[from] serde_urlencoded::ser::Error),
} }
impl<S: StorageBackend> TryFrom<&AuthorizationGrant<S>> for CallbackDestination { impl TryFrom<&AuthorizationGrant> for CallbackDestination {
type Error = IntoCallbackDestinationError; type Error = IntoCallbackDestinationError;
fn try_from(value: &AuthorizationGrant<S>) -> Result<Self, Self::Error> { fn try_from(value: &AuthorizationGrant) -> Result<Self, Self::Error> {
Self::try_new( Self::try_new(
&value.response_mode, &value.response_mode,
value.redirect_uri.clone(), value.redirect_uri.clone(),

View File

@ -32,7 +32,6 @@ use mas_storage::{
consent::fetch_client_consent, consent::fetch_client_consent,
}, },
user::ActiveSessionLookupError, user::ActiveSessionLookupError,
PostgresqlBackend,
}; };
use mas_templates::Templates; use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
@ -185,7 +184,7 @@ impl From<IntoCallbackDestinationError> for GrantCompletionError {
} }
pub(crate) async fn complete( pub(crate) async fn complete(
grant: AuthorizationGrant<PostgresqlBackend>, grant: AuthorizationGrant,
browser_session: BrowserSession, browser_session: BrowserSession,
policy_factory: &PolicyFactory, policy_factory: &PolicyFactory,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,

View File

@ -315,7 +315,7 @@ pub(crate) async fn get(
requires_consent, requires_consent,
) )
.await?; .await?;
let continue_grant = PostAuthAction::continue_grant(grant.data); let continue_grant = PostAuthAction::continue_grant(grant.id);
let res = match maybe_session { let res = match maybe_session {
// Cases where there is no active session, redirect to the relevant page // Cases where there is no active session, redirect to the relevant page
@ -391,7 +391,7 @@ pub(crate) async fn get(
} }
} }
Some(user_session) => { Some(user_session) => {
let grant_id = grant.data; let grant_id = grant.id;
// Else, we show the relevant reauth/consent page if necessary // Else, we show the relevant reauth/consent page if necessary
match self::complete::complete(grant, user_session, &policy_factory, txn).await match self::complete::complete(grant, user_session, &policy_factory, txn).await
{ {

View File

@ -45,7 +45,7 @@ impl OptionalPostAuthAction {
let ctx = match action { let ctx = match action {
PostAuthAction::ContinueAuthorizationGrant { data } => { PostAuthAction::ContinueAuthorizationGrant { data } => {
let grant = get_grant_by_id(conn, data).await?; let grant = get_grant_by_id(conn, data).await?;
let grant = Box::new(grant.into()); let grant = Box::new(grant);
PostAuthContextInner::ContinueAuthorizationGrant { grant } PostAuthContextInner::ContinueAuthorizationGrant { grant }
} }

View File

@ -18,7 +18,7 @@
#![allow(clippy::missing_errors_doc)] #![allow(clippy::missing_errors_doc)]
use anyhow::bail; use anyhow::bail;
use mas_data_model::{AuthorizationGrant, StorageBackend, User}; use mas_data_model::{AuthorizationGrant, User};
use oauth2_types::registration::VerifiedClientMetadata; use oauth2_types::registration::VerifiedClientMetadata;
use opa_wasm::Runtime; use opa_wasm::Runtime;
use serde::Deserialize; use serde::Deserialize;
@ -210,9 +210,9 @@ impl Policy {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub async fn evaluate_authorization_grant<T: StorageBackend + std::fmt::Debug>( pub async fn evaluate_authorization_grant(
&mut self, &mut self,
authorization_grant: &AuthorizationGrant<T>, authorization_grant: &AuthorizationGrant,
user: &User, user: &User,
) -> Result<EvaluationResult, anyhow::Error> { ) -> Result<EvaluationResult, anyhow::Error> {
let authorization_grant = serde_json::to_value(authorization_grant)?; let authorization_grant = serde_json::to_value(authorization_grant)?;

View File

@ -29,11 +29,8 @@
)] )]
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{StorageBackend, StorageBackendMarker};
use serde::Serialize;
use sqlx::migrate::Migrator; use sqlx::migrate::Migrator;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid;
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("failed to lookup {what}")] #[error("failed to lookup {what}")]
@ -101,17 +98,6 @@ impl Clock {
#[error("database query returned an inconsistent state")] #[error("database query returned an inconsistent state")]
pub struct DatabaseInconsistencyError; pub struct DatabaseInconsistencyError;
#[derive(Serialize, Debug, Clone, PartialEq, Eq)]
pub struct PostgresqlBackend;
impl StorageBackend for PostgresqlBackend {
type AuthorizationGrantData = Ulid;
type ClientData = Ulid;
type SessionData = Ulid;
}
impl StorageBackendMarker for PostgresqlBackend {}
pub mod compat; pub mod compat;
pub mod oauth2; pub mod oauth2;
pub(crate) mod pagination; pub(crate) mod pagination;

View File

@ -22,12 +22,12 @@ use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError}; use super::client::{lookup_client, ClientFetchError};
use crate::{Clock, DatabaseInconsistencyError, LookupError, PostgresqlBackend}; use crate::{Clock, DatabaseInconsistencyError, LookupError};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
session.id = %session.data, %session.id,
client.id = %session.client.id, client.id = %session.client.id,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
access_token.id, access_token.id,
@ -38,7 +38,7 @@ pub async fn add_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
session: &Session<PostgresqlBackend>, session: &Session,
access_token: String, access_token: String,
expires_after: Duration, expires_after: Duration,
) -> Result<AccessToken, anyhow::Error> { ) -> Result<AccessToken, anyhow::Error> {
@ -56,7 +56,7 @@ pub async fn add_access_token(
($1, $2, $3, $4, $5) ($1, $2, $3, $4, $5)
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(session.data), Uuid::from(session.id),
&access_token, &access_token,
created_at, created_at,
expires_at, expires_at,
@ -113,7 +113,7 @@ impl LookupError for AccessTokenLookupError {
pub async fn lookup_active_access_token( pub async fn lookup_active_access_token(
conn: &mut PgConnection, conn: &mut PgConnection,
token: &str, token: &str,
) -> Result<(AccessToken, Session<PostgresqlBackend>), AccessTokenLookupError> { ) -> Result<(AccessToken, Session), AccessTokenLookupError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2AccessTokenLookup, OAuth2AccessTokenLookup,
r#" r#"
@ -217,7 +217,7 @@ pub async fn lookup_active_access_token(
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
let session = Session { let session = Session {
data: res.oauth2_session_id.into(), id: res.oauth2_session_id.into(),
client, client,
browser_session, browser_session,
scope, scope,

View File

@ -31,7 +31,7 @@ use url::Url;
use uuid::Uuid; use uuid::Uuid;
use super::client::lookup_client; use super::client::lookup_client;
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend}; use crate::{Clock, DatabaseInconsistencyError};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@ -57,7 +57,7 @@ pub async fn new_authorization_grant(
response_mode: ResponseMode, response_mode: ResponseMode,
response_type_id_token: bool, response_type_id_token: bool,
requires_consent: bool, requires_consent: bool,
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> { ) -> Result<AuthorizationGrant, anyhow::Error> {
let code_challenge = code let code_challenge = code
.as_ref() .as_ref()
.and_then(|c| c.pkce.as_ref()) .and_then(|c| c.pkce.as_ref())
@ -117,7 +117,7 @@ pub async fn new_authorization_grant(
.context("could not insert oauth2 authorization grant")?; .context("could not insert oauth2 authorization grant")?;
Ok(AuthorizationGrant { Ok(AuthorizationGrant {
data: id, id,
stage: AuthorizationGrantStage::Pending, stage: AuthorizationGrantStage::Pending,
code, code,
redirect_uri, redirect_uri,
@ -171,7 +171,7 @@ impl GrantLookup {
async fn into_authorization_grant( async fn into_authorization_grant(
self, self,
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
) -> Result<AuthorizationGrant<PostgresqlBackend>, DatabaseInconsistencyError> { ) -> Result<AuthorizationGrant, DatabaseInconsistencyError> {
let scope: Scope = self let scope: Scope = self
.oauth2_authorization_grant_scope .oauth2_authorization_grant_scope
.parse() .parse()
@ -247,7 +247,7 @@ impl GrantLookup {
let scope = scope.clone(); let scope = scope.clone();
let session = Session { let session = Session {
data: session_id.into(), id: session_id.into(),
client, client,
browser_session, browser_session,
scope, scope,
@ -337,7 +337,7 @@ impl GrantLookup {
.map_err(|_e| DatabaseInconsistencyError)?; .map_err(|_e| DatabaseInconsistencyError)?;
Ok(AuthorizationGrant { Ok(AuthorizationGrant {
data: self.oauth2_authorization_grant_id.into(), id: self.oauth2_authorization_grant_id.into(),
stage, stage,
client, client,
code, code,
@ -362,7 +362,7 @@ impl GrantLookup {
pub async fn get_grant_by_id( pub async fn get_grant_by_id(
conn: &mut PgConnection, conn: &mut PgConnection,
id: Ulid, id: Ulid,
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> { ) -> Result<AuthorizationGrant, anyhow::Error> {
// TODO: handle "not found" cases // TODO: handle "not found" cases
let res = sqlx::query_as!( let res = sqlx::query_as!(
GrantLookup, GrantLookup,
@ -430,7 +430,7 @@ pub async fn get_grant_by_id(
pub async fn lookup_grant_by_code( pub async fn lookup_grant_by_code(
conn: &mut PgConnection, conn: &mut PgConnection,
code: &str, code: &str,
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> { ) -> Result<AuthorizationGrant, anyhow::Error> {
// TODO: handle "not found" cases // TODO: handle "not found" cases
let res = sqlx::query_as!( let res = sqlx::query_as!(
GrantLookup, GrantLookup,
@ -497,7 +497,7 @@ pub async fn lookup_grant_by_code(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, %grant.id,
client.id = %grant.client.id, client.id = %grant.client.id,
session.id, session.id,
user_session.id = %browser_session.id, user_session.id = %browser_session.id,
@ -509,9 +509,9 @@ pub async fn derive_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
grant: &AuthorizationGrant<PostgresqlBackend>, grant: &AuthorizationGrant,
browser_session: BrowserSession, browser_session: BrowserSession,
) -> Result<Session<PostgresqlBackend>, anyhow::Error> { ) -> Result<Session, anyhow::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("session.id", tracing::field::display(id)); tracing::Span::current().record("session.id", tracing::field::display(id));
@ -534,14 +534,14 @@ pub async fn derive_session(
Uuid::from(id), Uuid::from(id),
Uuid::from(browser_session.id), Uuid::from(browser_session.id),
created_at, created_at,
Uuid::from(grant.data), Uuid::from(grant.id),
) )
.execute(executor) .execute(executor)
.await .await
.context("could not insert oauth2 session")?; .context("could not insert oauth2 session")?;
Ok(Session { Ok(Session {
data: id, id,
browser_session, browser_session,
client: grant.client.clone(), client: grant.client.clone(),
scope: grant.scope.clone(), scope: grant.scope.clone(),
@ -551,9 +551,9 @@ pub async fn derive_session(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, %grant.id,
client.id = %grant.client.id, client.id = %grant.client.id,
session.id = %session.data, %session.id,
user_session.id = %session.browser_session.id, user_session.id = %session.browser_session.id,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
), ),
@ -561,9 +561,9 @@ pub async fn derive_session(
)] )]
pub async fn fulfill_grant( pub async fn fulfill_grant(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant<PostgresqlBackend>, mut grant: AuthorizationGrant,
session: Session<PostgresqlBackend>, session: Session,
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> { ) -> Result<AuthorizationGrant, anyhow::Error> {
let fulfilled_at = sqlx::query_scalar!( let fulfilled_at = sqlx::query_scalar!(
r#" r#"
UPDATE oauth2_authorization_grants AS og UPDATE oauth2_authorization_grants AS og
@ -576,8 +576,8 @@ pub async fn fulfill_grant(
AND os.oauth2_session_id = $2 AND os.oauth2_session_id = $2
RETURNING fulfilled_at AS "fulfilled_at!: DateTime<Utc>" RETURNING fulfilled_at AS "fulfilled_at!: DateTime<Utc>"
"#, "#,
Uuid::from(grant.data), Uuid::from(grant.id),
Uuid::from(session.data), Uuid::from(session.id),
) )
.fetch_one(executor) .fetch_one(executor)
.await .await
@ -591,15 +591,15 @@ pub async fn fulfill_grant(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, %grant.id,
client.id = %grant.client.id, client.id = %grant.client.id,
), ),
err(Debug), err(Debug),
)] )]
pub async fn give_consent_to_grant( pub async fn give_consent_to_grant(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant<PostgresqlBackend>, mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant<PostgresqlBackend>, sqlx::Error> { ) -> Result<AuthorizationGrant, sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
UPDATE oauth2_authorization_grants AS og UPDATE oauth2_authorization_grants AS og
@ -608,7 +608,7 @@ pub async fn give_consent_to_grant(
WHERE WHERE
og.oauth2_authorization_grant_id = $1 og.oauth2_authorization_grant_id = $1
"#, "#,
Uuid::from(grant.data), Uuid::from(grant.id),
) )
.execute(executor) .execute(executor)
.await?; .await?;
@ -621,7 +621,7 @@ pub async fn give_consent_to_grant(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
grant.id = %grant.data, %grant.id,
client.id = %grant.client.id, client.id = %grant.client.id,
), ),
err(Debug), err(Debug),
@ -629,8 +629,8 @@ pub async fn give_consent_to_grant(
pub async fn exchange_grant( pub async fn exchange_grant(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
mut grant: AuthorizationGrant<PostgresqlBackend>, mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> { ) -> Result<AuthorizationGrant, anyhow::Error> {
let exchanged_at = clock.now(); let exchanged_at = clock.now();
sqlx::query!( sqlx::query!(
r#" r#"
@ -638,7 +638,7 @@ pub async fn exchange_grant(
SET exchanged_at = $2 SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1 WHERE oauth2_authorization_grant_id = $1
"#, "#,
Uuid::from(grant.data), Uuid::from(grant.id),
exchanged_at, exchanged_at,
) )
.execute(executor) .execute(executor)

View File

@ -25,7 +25,7 @@ use self::client::lookup_clients;
use crate::{ use crate::{
pagination::{process_page, QueryBuilderExt}, pagination::{process_page, QueryBuilderExt},
user::lookup_active_session, user::lookup_active_session,
Clock, PostgresqlBackend, Clock,
}; };
pub mod access_token; pub mod access_token;
@ -37,7 +37,7 @@ pub mod refresh_token;
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
session.id = %session.data, %session.id,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
user_session.id = %session.browser_session.id, user_session.id = %session.browser_session.id,
client.id = %session.client.id, client.id = %session.client.id,
@ -47,7 +47,7 @@ pub mod refresh_token;
pub async fn end_oauth_session( pub async fn end_oauth_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
session: Session<PostgresqlBackend>, session: Session,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let finished_at = clock.now(); let finished_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
@ -56,7 +56,7 @@ pub async fn end_oauth_session(
SET finished_at = $2 SET finished_at = $2
WHERE oauth2_session_id = $1 WHERE oauth2_session_id = $1
"#, "#,
Uuid::from(session.data), Uuid::from(session.id),
finished_at, finished_at,
) )
.execute(executor) .execute(executor)
@ -79,7 +79,7 @@ struct OAuthSessionLookup {
skip_all, skip_all,
fields( fields(
%user.id, %user.id,
user.username = user.username, %user.username,
), ),
err(Display), err(Display),
)] )]
@ -90,7 +90,7 @@ pub async fn get_paginated_user_oauth_sessions(
after: Option<Ulid>, after: Option<Ulid>,
first: Option<usize>, first: Option<usize>,
last: Option<usize>, last: Option<usize>,
) -> Result<(bool, bool, Vec<Session<PostgresqlBackend>>), anyhow::Error> { ) -> Result<(bool, bool, Vec<Session>), anyhow::Error> {
let mut query = QueryBuilder::new( let mut query = QueryBuilder::new(
r#" r#"
SELECT SELECT
@ -157,7 +157,7 @@ pub async fn get_paginated_user_oauth_sessions(
let scope = item.scope.parse()?; let scope = item.scope.parse()?;
anyhow::Ok(Session { anyhow::Ok(Session {
data: Ulid::from(item.oauth2_session_id), id: Ulid::from(item.oauth2_session_id),
client, client,
browser_session, browser_session,
scope, scope,

View File

@ -24,12 +24,12 @@ use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError}; use super::client::{lookup_client, ClientFetchError};
use crate::{Clock, DatabaseInconsistencyError, LookupError, PostgresqlBackend}; use crate::{Clock, DatabaseInconsistencyError, LookupError};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
session.id = %session.data, %session.id,
user.id = %session.browser_session.user.id, user.id = %session.browser_session.user.id,
user_session.id = %session.browser_session.id, user_session.id = %session.browser_session.id,
client.id = %session.client.id, client.id = %session.client.id,
@ -41,7 +41,7 @@ pub async fn add_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
session: &Session<PostgresqlBackend>, session: &Session,
access_token: AccessToken, access_token: AccessToken,
refresh_token: String, refresh_token: String,
) -> anyhow::Result<RefreshToken> { ) -> anyhow::Result<RefreshToken> {
@ -58,7 +58,7 @@ pub async fn add_refresh_token(
($1, $2, $3, $4, $5) ($1, $2, $3, $4, $5)
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(session.data), Uuid::from(session.id),
Uuid::from(access_token.id), Uuid::from(access_token.id),
refresh_token, refresh_token,
created_at, created_at,
@ -117,7 +117,7 @@ impl LookupError for RefreshTokenLookupError {
pub async fn lookup_active_refresh_token( pub async fn lookup_active_refresh_token(
conn: &mut PgConnection, conn: &mut PgConnection,
token: &str, token: &str,
) -> Result<(RefreshToken, Session<PostgresqlBackend>), RefreshTokenLookupError> { ) -> Result<(RefreshToken, Session), RefreshTokenLookupError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2RefreshTokenLookup, OAuth2RefreshTokenLookup,
r#" r#"
@ -248,7 +248,7 @@ pub async fn lookup_active_refresh_token(
.map_err(|_e| DatabaseInconsistencyError)?; .map_err(|_e| DatabaseInconsistencyError)?;
let session = Session { let session = Session {
data: res.oauth2_session_id.into(), id: res.oauth2_session_id.into(),
client, client,
browser_session, browser_session,
scope, scope,

View File

@ -249,7 +249,7 @@ pub enum PostAuthContextInner {
/// Continue an authorization grant /// Continue an authorization grant
ContinueAuthorizationGrant { ContinueAuthorizationGrant {
/// The authorization grant that will be continued after authentication /// The authorization grant that will be continued after authentication
grant: Box<AuthorizationGrant<()>>, grant: Box<AuthorizationGrant>,
}, },
/// Continue legacy login /// Continue legacy login
@ -394,7 +394,7 @@ impl RegisterContext {
/// Context used by the `consent.html` template /// Context used by the `consent.html` template
#[derive(Serialize)] #[derive(Serialize)]
pub struct ConsentContext { pub struct ConsentContext {
grant: AuthorizationGrant<()>, grant: AuthorizationGrant,
action: PostAuthAction, action: PostAuthAction,
} }
@ -411,21 +411,15 @@ impl TemplateContext for ConsentContext {
impl ConsentContext { impl ConsentContext {
/// Constructs a context for the client consent page /// Constructs a context for the client consent page
#[must_use] #[must_use]
pub fn new<T>(grant: T, action: PostAuthAction) -> Self pub fn new(grant: AuthorizationGrant, action: PostAuthAction) -> Self {
where Self { grant, action }
T: Into<AuthorizationGrant<()>>,
{
Self {
grant: grant.into(),
action,
}
} }
} }
/// Context used by the `policy_violation.html` template /// Context used by the `policy_violation.html` template
#[derive(Serialize)] #[derive(Serialize)]
pub struct PolicyViolationContext { pub struct PolicyViolationContext {
grant: AuthorizationGrant<()>, grant: AuthorizationGrant,
action: PostAuthAction, action: PostAuthAction,
} }
@ -442,14 +436,8 @@ impl TemplateContext for PolicyViolationContext {
impl PolicyViolationContext { impl PolicyViolationContext {
/// Constructs a context for the policy violation page /// Constructs a context for the policy violation page
#[must_use] #[must_use]
pub fn new<T>(grant: T, action: PostAuthAction) -> Self pub fn new(grant: AuthorizationGrant, action: PostAuthAction) -> Self {
where Self { grant, action }
T: Into<AuthorizationGrant<()>>,
{
Self {
grant: grant.into(),
action,
}
} }
} }