You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
storage: oauth2 session repository
This commit is contained in:
@@ -134,6 +134,7 @@ pub async fn lookup_active_access_token(
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
finished_at: None,
|
||||
};
|
||||
|
||||
Ok(Some((access_token, session)))
|
||||
|
@@ -16,8 +16,7 @@ use std::num::NonZeroU32;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Client, Pkce,
|
||||
Session,
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
|
||||
};
|
||||
use mas_iana::oauth::PkceCodeChallengeMethod;
|
||||
use oauth2_types::{requests::ResponseMode, scope::Scope};
|
||||
@@ -27,7 +26,7 @@ use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::OAuth2ClientRepository;
|
||||
use super::OAuth2ClientRepository;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository};
|
||||
|
||||
#[tracing::instrument(
|
||||
@@ -186,6 +185,7 @@ impl GrantLookup {
|
||||
client_id: client.id,
|
||||
user_session_id: user_session_id.into(),
|
||||
scope,
|
||||
finished_at: None,
|
||||
};
|
||||
|
||||
Some(session)
|
||||
@@ -431,59 +431,6 @@ pub async fn lookup_grant_by_code(
|
||||
Ok(Some(grant))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%grant.id,
|
||||
client.id = %grant.client.id,
|
||||
session.id,
|
||||
user_session.id = %browser_session.id,
|
||||
user.id = %browser_session.user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn derive_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
grant: &AuthorizationGrant,
|
||||
browser_session: BrowserSession,
|
||||
) -> Result<Session, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
(oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at)
|
||||
SELECT
|
||||
$1,
|
||||
$2,
|
||||
og.oauth2_client_id,
|
||||
og.scope,
|
||||
$3
|
||||
FROM
|
||||
oauth2_authorization_grants og
|
||||
WHERE
|
||||
og.oauth2_authorization_grant_id = $4
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(browser_session.id),
|
||||
created_at,
|
||||
Uuid::from(grant.id),
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
user_session_id: browser_session.id,
|
||||
client_id: grant.client.id,
|
||||
scope: grant.scope.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
|
@@ -12,129 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use mas_data_model::{Session, User};
|
||||
use sqlx::{PgConnection, PgExecutor, QueryBuilder};
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{process_page, QueryBuilderExt},
|
||||
Clock, DatabaseError, DatabaseInconsistencyError,
|
||||
};
|
||||
|
||||
pub mod access_token;
|
||||
pub mod authorization_grant;
|
||||
pub mod client;
|
||||
mod client;
|
||||
pub mod consent;
|
||||
pub mod refresh_token;
|
||||
pub mod session;
|
||||
mod session;
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%session.id,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn end_oauth_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
session: Session,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let finished_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_sessions
|
||||
SET finished_at = $2
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
Uuid::from(session.id),
|
||||
finished_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct OAuthSessionLookup {
|
||||
oauth2_session_id: Uuid,
|
||||
user_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
scope: String,
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_paginated_user_oauth_sessions(
|
||||
conn: &mut PgConnection,
|
||||
user: &User,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(bool, bool, Vec<Session>), DatabaseError> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
os.oauth2_session_id,
|
||||
os.user_session_id,
|
||||
os.oauth2_client_id,
|
||||
os.scope,
|
||||
os.created_at,
|
||||
os.finished_at
|
||||
FROM oauth2_sessions os
|
||||
LEFT JOIN user_sessions us
|
||||
USING (user_session_id)
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE us.user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("oauth2_session_id", before, after, first, last)?;
|
||||
|
||||
let span = info_span!(
|
||||
"Fetch paginated user oauth sessions",
|
||||
db.statement = query.sql()
|
||||
);
|
||||
let page: Vec<OAuthSessionLookup> = query
|
||||
.build_query_as()
|
||||
.fetch_all(&mut *conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
|
||||
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
|
||||
|
||||
let page: Result<Vec<_>, DatabaseInconsistencyError> = page
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
let id = Ulid::from(item.oauth2_session_id);
|
||||
let scope = item.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(Session {
|
||||
id: Ulid::from(item.oauth2_session_id),
|
||||
client_id: item.oauth2_client_id.into(),
|
||||
user_session_id: item.user_session_id.into(),
|
||||
scope,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok((has_previous_page, has_next_page, page?))
|
||||
}
|
||||
pub use self::{
|
||||
client::{OAuth2ClientRepository, PgOAuth2ClientRepository},
|
||||
session::{OAuth2SessionRepository, PgOAuth2SessionRepository},
|
||||
};
|
||||
|
@@ -158,6 +158,7 @@ pub async fn lookup_active_refresh_token(
|
||||
client_id: res.oauth2_client_id.into(),
|
||||
user_session_id: res.user_session_id.into(),
|
||||
scope,
|
||||
finished_at: None,
|
||||
};
|
||||
|
||||
Ok(Some((refresh_token, session)))
|
||||
|
@@ -13,8 +13,231 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User};
|
||||
use rand::RngCore;
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{process_page, Page, QueryBuilderExt},
|
||||
tracing::ExecuteExt,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuth2SessionRepository {
|
||||
type Error;
|
||||
|
||||
async fn create_from_grant(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
grant: &AuthorizationGrant,
|
||||
user_session: &BrowserSession,
|
||||
) -> Result<Session, Self::Error>;
|
||||
|
||||
async fn finish(&mut self, clock: &Clock, session: Session) -> Result<Session, Self::Error>;
|
||||
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<Page<Session>, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgOAuth2SessionRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2SessionRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct OAuthSessionLookup {
|
||||
oauth2_session_id: Uuid,
|
||||
user_session_id: Uuid,
|
||||
oauth2_client_id: Uuid,
|
||||
scope: String,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryFrom<OAuthSessionLookup> for Session {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
|
||||
let id = Ulid::from(value.oauth2_session_id);
|
||||
let scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
client_id: value.oauth2_client_id.into(),
|
||||
user_session_id: value.user_session_id.into(),
|
||||
scope,
|
||||
finished_at: value.finished_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.create_from_grant",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user_session.id,
|
||||
user.id = %user_session.user.id,
|
||||
%grant.id,
|
||||
client.id = %grant.client.id,
|
||||
session.id,
|
||||
session.scope = %grant.scope,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn create_from_grant(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
grant: &AuthorizationGrant,
|
||||
user_session: &BrowserSession,
|
||||
) -> Result<Session, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
( oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, created_at
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_session.id),
|
||||
Uuid::from(grant.client.id),
|
||||
grant.scope.to_string(),
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
user_session_id: user_session.id,
|
||||
client_id: grant.client.id,
|
||||
scope: grant.scope.clone(),
|
||||
finished_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.finish",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%session.id,
|
||||
%session.scope,
|
||||
user_session.id = %session.user_session_id,
|
||||
client.id = %session.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn finish(
|
||||
&mut self,
|
||||
clock: &Clock,
|
||||
mut session: Session,
|
||||
) -> Result<Session, Self::Error> {
|
||||
let finished_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_sessions
|
||||
SET finished_at = $2
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
Uuid::from(session.id),
|
||||
finished_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
session.finished_at = Some(finished_at);
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.list_paginated",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_paginated(
|
||||
&mut self,
|
||||
user: &User,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<Page<Session>, Self::Error> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT oauth2_session_id
|
||||
, user_session_id
|
||||
, oauth2_client_id
|
||||
, scope
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM oauth2_sessions os
|
||||
"#,
|
||||
);
|
||||
|
||||
query
|
||||
.push(" WHERE us.user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("oauth2_session_id", before, after, first, last)?;
|
||||
|
||||
let edges: Vec<OAuthSessionLookup> = query
|
||||
.build_query_as()
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let (has_previous_page, has_next_page, edges) = process_page(edges, first, last)?;
|
||||
|
||||
let edges: Result<Vec<_>, DatabaseInconsistencyError> =
|
||||
edges.into_iter().map(Session::try_from).collect();
|
||||
|
||||
Ok(Page {
|
||||
has_next_page,
|
||||
has_previous_page,
|
||||
edges: edges?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -15,7 +15,7 @@
|
||||
use sqlx::{PgConnection, Postgres, Transaction};
|
||||
|
||||
use crate::{
|
||||
oauth2::client::PgOAuth2ClientRepository,
|
||||
oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository},
|
||||
upstream_oauth2::{
|
||||
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
|
||||
PgUpstreamOAuthSessionRepository,
|
||||
@@ -59,6 +59,10 @@ pub trait Repository {
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
type OAuth2SessionRepository<'c>
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
|
||||
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
|
||||
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>;
|
||||
@@ -67,6 +71,7 @@ pub trait Repository {
|
||||
fn user_password(&mut self) -> Self::UserPasswordRepository<'_>;
|
||||
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
|
||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
|
||||
}
|
||||
|
||||
impl Repository for PgConnection {
|
||||
@@ -78,6 +83,7 @@ impl Repository for PgConnection {
|
||||
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||
PgUpstreamOAuthLinkRepository::new(self)
|
||||
@@ -110,6 +116,10 @@ impl Repository for PgConnection {
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
|
||||
PgOAuth2ClientRepository::new(self)
|
||||
}
|
||||
|
||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
||||
PgOAuth2SessionRepository::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
@@ -121,6 +131,7 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||
PgUpstreamOAuthLinkRepository::new(self)
|
||||
@@ -153,4 +164,8 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
|
||||
PgOAuth2ClientRepository::new(self)
|
||||
}
|
||||
|
||||
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
|
||||
PgOAuth2SessionRepository::new(self)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user