You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
storage: add a method to create an OAuth 2.0 session for a client credentials grant
This commit is contained in:
@ -90,7 +90,7 @@ async fn start_oauth_session(
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &state.clock, client, &browser_session, scope)
|
||||
.add_from_browser_session(&mut rng, &state.clock, client, &browser_session, scope)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -244,7 +244,7 @@ pub(crate) async fn complete(
|
||||
// All good, let's start the session
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(rng, clock, client, browser_session, grant.scope.clone())
|
||||
.add_from_browser_session(rng, clock, client, browser_session, grant.scope.clone())
|
||||
.await?;
|
||||
|
||||
let grant = repo
|
||||
|
@ -443,7 +443,7 @@ mod tests {
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
.add_from_browser_session(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
|
@ -302,7 +302,7 @@ mod tests {
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
.add_from_browser_session(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
@ -369,7 +369,7 @@ mod tests {
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
.add_from_browser_session(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
|
@ -506,7 +506,7 @@ mod tests {
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
.add_from_browser_session(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
@ -606,7 +606,7 @@ mod tests {
|
||||
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
.add_from_browser_session(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
@ -691,7 +691,7 @@ mod tests {
|
||||
// Get a token pair
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
.add_from_browser_session(
|
||||
&mut state.rng(),
|
||||
&state.clock,
|
||||
&client,
|
||||
|
17
crates/storage-pg/.sqlx/query-6554d3620a5f7fb0e85af44e8a21c2f2f3ebe4b805ec67aca4a2278a8ae16693.json
generated
Normal file
17
crates/storage-pg/.sqlx/query-6554d3620a5f7fb0e85af44e8a21c2f2f3ebe4b805ec67aca4a2278a8ae16693.json
generated
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO oauth2_sessions\n ( oauth2_session_id\n , oauth2_client_id\n , scope_list\n , created_at\n )\n VALUES ($1, $2, $3, $4)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Uuid",
|
||||
"TextArray",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "6554d3620a5f7fb0e85af44e8a21c2f2f3ebe4b805ec67aca4a2278a8ae16693"
|
||||
}
|
@ -211,7 +211,7 @@ mod tests {
|
||||
// Create an OAuth session
|
||||
let session = repo
|
||||
.oauth2_session()
|
||||
.add(
|
||||
.add_from_browser_session(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&client,
|
||||
@ -464,28 +464,28 @@ mod tests {
|
||||
// we're getting consistent ordering in lists.
|
||||
let session11 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client1, &user1_session, scope.clone())
|
||||
.add_from_browser_session(&mut rng, &clock, &client1, &user1_session, scope.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
let session12 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client1, &user2_session, scope.clone())
|
||||
.add_from_browser_session(&mut rng, &clock, &client1, &user2_session, scope.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
let session21 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client2, &user1_session, scope2.clone())
|
||||
.add_from_browser_session(&mut rng, &clock, &client2, &user1_session, scope2.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
let session22 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client2, &user2_session, scope2.clone())
|
||||
.add_from_browser_session(&mut rng, &clock, &client2, &user2_session, scope2.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
@ -133,7 +133,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.add",
|
||||
name = "db.oauth2_session.add_from_browser_session",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
@ -145,7 +145,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add(
|
||||
async fn add_from_browser_session(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
@ -193,6 +193,60 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.add_from_client_credentials",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%client.id,
|
||||
session.id,
|
||||
session.scope = %scope,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add_from_client_credentials(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
client: &Client,
|
||||
scope: Scope,
|
||||
) -> Result<Session, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record("session.id", tracing::field::display(id));
|
||||
|
||||
let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
( oauth2_session_id
|
||||
, oauth2_client_id
|
||||
, scope_list
|
||||
, created_at
|
||||
)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(client.id),
|
||||
&scope_list,
|
||||
created_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(Session {
|
||||
id,
|
||||
state: SessionState::Valid,
|
||||
created_at,
|
||||
user_id: None,
|
||||
user_session_id: None,
|
||||
client_id: client.id,
|
||||
scope,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_session.finish",
|
||||
skip_all,
|
||||
|
@ -140,7 +140,7 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
|
||||
|
||||
/// Create a new [`Session`]
|
||||
/// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
|
||||
///
|
||||
/// Returns the newly created [`Session`]
|
||||
///
|
||||
@ -156,7 +156,7 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn add(
|
||||
async fn add_from_browser_session(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
@ -165,6 +165,29 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
||||
scope: Scope,
|
||||
) -> Result<Session, Self::Error>;
|
||||
|
||||
/// Create a new [`Session`] for a [`Client`] using the client credentials
|
||||
/// flow
|
||||
///
|
||||
/// Returns the newly created [`Session`]
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `rng`: The random number generator to use
|
||||
/// * `clock`: The clock used to generate timestamps
|
||||
/// * `client`: The [`Client`] which created the [`Session`]
|
||||
/// * `scope`: The [`Scope`] of the [`Session`]
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn add_from_client_credentials(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
client: &Client,
|
||||
scope: Scope,
|
||||
) -> Result<Session, Self::Error>;
|
||||
|
||||
/// Mark a [`Session`] as finished
|
||||
///
|
||||
/// Returns the updated [`Session`]
|
||||
@ -211,7 +234,7 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
||||
repository_impl!(OAuth2SessionRepository:
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
|
||||
|
||||
async fn add(
|
||||
async fn add_from_browser_session(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
@ -220,6 +243,14 @@ repository_impl!(OAuth2SessionRepository:
|
||||
scope: Scope,
|
||||
) -> Result<Session, Self::Error>;
|
||||
|
||||
async fn add_from_client_credentials(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
client: &Client,
|
||||
scope: Scope,
|
||||
) -> Result<Session, Self::Error>;
|
||||
|
||||
async fn finish(&mut self, clock: &dyn Clock, session: Session)
|
||||
-> Result<Session, Self::Error>;
|
||||
|
||||
|
Reference in New Issue
Block a user