1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: upstream oauth session repository + unit tests

This commit is contained in:
Quentin Gliech
2022-12-30 15:39:51 +01:00
parent 0faf08fce2
commit 870a37151f
9 changed files with 469 additions and 490 deletions

View File

@ -19,7 +19,111 @@ mod session;
pub use self::{
link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository},
provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository},
session::{
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
},
session::{PgUpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository},
};
#[cfg(test)]
mod tests {
use oauth2_types::scope::{Scope, OPENID};
use rand::SeedableRng;
use sqlx::PgPool;
use super::*;
use crate::{Clock, Repository};
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let clock = Clock::default();
let mut conn = pool.acquire().await?;
// The provider list should be empty at the start
let all_providers = conn.upstream_oauth_provider().all().await?;
assert!(all_providers.is_empty());
// Let's add a provider
let provider = conn
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
"https://example.com/".to_owned(),
Scope::from_iter([OPENID]),
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None,
"client-id".to_owned(),
None,
)
.await?;
// Look it up in the database
let provider = conn
.upstream_oauth_provider()
.lookup(provider.id)
.await?
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.client_id, "client-id");
// Start a session
let session = conn
.upstream_oauth_session()
.add(
&mut rng,
&clock,
&provider,
"some-state".to_owned(),
None,
"some-nonce".to_owned(),
)
.await?;
// Look it up in the database
let session = conn
.upstream_oauth_session()
.lookup(session.id)
.await?
.expect("session to be found in the database");
assert_eq!(session.provider_id, provider.id);
assert_eq!(session.link_id, None);
assert!(!session.completed());
assert!(!session.consumed());
// Create a link
let link = conn
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.await?;
// We can look it up by its ID
conn.upstream_oauth_link()
.lookup(link.id)
.await?
.expect("link to be found in database");
// or by its subject
let link = conn
.upstream_oauth_link()
.find_by_subject(&provider, "a-subject")
.await?
.expect("link to be found in database");
assert_eq!(link.subject, "a-subject");
assert_eq!(link.provider_id, provider.id);
let session = conn
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.await?;
assert!(session.completed());
assert!(!session.consumed());
assert_eq!(session.link_id, Some(link.id));
let session = conn
.upstream_oauth_session()
.consume(&clock, session)
.await?;
assert!(session.consumed());
Ok(())
}
}

View File

@ -12,261 +12,62 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
use rand::Rng;
use sqlx::PgExecutor;
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
use crate::{Clock, DatabaseError, LookupResultExt};
struct SessionAndProviderLookup {
upstream_oauth_authorization_session_id: Uuid,
upstream_oauth_provider_id: Uuid,
upstream_oauth_link_id: Option<Uuid>,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
id_token: Option<String>,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
consumed_at: Option<DateTime<Utc>>,
provider_issuer: String,
provider_scope: String,
provider_client_id: String,
provider_encrypted_client_secret: Option<String>,
provider_token_endpoint_auth_method: String,
provider_token_endpoint_signing_alg: Option<String>,
provider_created_at: DateTime<Utc>,
#[async_trait]
pub trait UpstreamOAuthSessionRepository: Send + Sync {
type Error;
/// Lookup a session by its ID
async fn lookup(
&mut self,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
/// Add a session to the database
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
/// Mark a session as completed and associate the given link
async fn complete_with_link(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
/// Mark a session as consumed
async fn consume(
&mut self,
clock: &Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
}
/// Lookup a session and its provider by its ID
#[tracing::instrument(
skip_all,
fields(upstream_oauth_authorization_session.id = %id),
err,
)]
pub async fn lookup_session(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession)>, DatabaseError> {
let res = sqlx::query_as!(
SessionAndProviderLookup,
r#"
SELECT
ua.upstream_oauth_authorization_session_id,
ua.upstream_oauth_provider_id,
ua.upstream_oauth_link_id,
ua.state,
ua.code_challenge_verifier,
ua.nonce,
ua.id_token,
ua.created_at,
ua.completed_at,
ua.consumed_at,
up.issuer AS "provider_issuer",
up.scope AS "provider_scope",
up.client_id AS "provider_client_id",
up.encrypted_client_secret AS "provider_encrypted_client_secret",
up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method",
up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg",
up.created_at AS "provider_created_at"
FROM upstream_oauth_authorization_sessions ua
INNER JOIN upstream_oauth_providers up
USING (upstream_oauth_provider_id)
WHERE upstream_oauth_authorization_session_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = res.upstream_oauth_provider_id.into();
let provider = UpstreamOAuthProvider {
id,
issuer: res.provider_issuer,
scope: res.provider_scope.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("scope")
.row(id)
.source(e)
})?,
client_id: res.provider_client_id,
encrypted_client_secret: res.provider_encrypted_client_secret,
token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err(
|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
},
)?,
token_endpoint_signing_alg: res
.provider_token_endpoint_signing_alg
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_signing_alg")
.row(id)
.source(e)
})?,
created_at: res.provider_created_at,
};
let session = UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: provider.id,
link_id: res.upstream_oauth_link_id.map(Ulid::from),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,
id_token: res.id_token,
created_at: res.created_at,
completed_at: res.completed_at,
consumed_at: res.consumed_at,
};
Ok(Some((provider, session)))
pub struct PgUpstreamOAuthSessionRepository<'c> {
conn: &'c mut PgConnection,
}
/// Add a session to the database
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),
err,
)]
pub async fn add_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record(
"upstream_oauth_authorization_session.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO upstream_oauth_authorization_sessions (
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state,
code_challenge_verifier.as_deref(),
nonce,
created_at,
)
.execute(executor)
.await?;
Ok(UpstreamOAuthAuthorizationSession {
id,
provider_id: upstream_oauth_provider.id,
link_id: None,
state,
code_challenge_verifier,
nonce,
id_token: None,
created_at,
completed_at: None,
consumed_at: None,
})
}
/// Mark a session as completed and associate the given link
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
%upstream_oauth_link.id,
),
err,
)]
pub async fn complete_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(executor)
.await?;
upstream_oauth_authorization_session.completed_at = Some(completed_at);
upstream_oauth_authorization_session.id_token = id_token;
Ok(upstream_oauth_authorization_session)
}
/// Mark a session as consumed
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
),
err,
)]
pub async fn consume_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET consumed_at = $1
WHERE upstream_oauth_authorization_session_id = $2
"#,
consumed_at,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(executor)
.await?;
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
Ok(upstream_oauth_authorization_session)
impl<'c> PgUpstreamOAuthSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct SessionLookup {
@ -282,57 +83,191 @@ struct SessionLookup {
consumed_at: Option<DateTime<Utc>>,
}
/// Lookup a session, which belongs to a link, by its ID
#[tracing::instrument(
skip_all,
fields(
upstream_oauth_authorization_session.id = %id,
%upstream_oauth_link.id,
),
err,
)]
pub async fn lookup_session_on_link(
executor: impl PgExecutor<'_>,
upstream_oauth_link: &UpstreamOAuthLink,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, sqlx::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
id_token,
created_at,
completed_at,
consumed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
AND upstream_oauth_link_id = $2
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_link.id),
)
.fetch_one(executor)
.await
.to_option()?;
#[async_trait]
impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> {
type Error = DatabaseError;
let Some(res) = res else { return Ok(None) };
#[tracing::instrument(
skip_all,
fields(upstream_oauth_provider.id = %id),
err,
)]
async fn lookup(
&mut self,
id: Ulid,
) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
upstream_oauth_link_id,
state,
code_challenge_verifier,
nonce,
id_token,
created_at,
completed_at,
consumed_at
FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_authorization_session_id = $1
"#,
Uuid::from(id),
)
.fetch_one(&mut *self.conn)
.await
.to_option()?;
Ok(Some(UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: res.upstream_oauth_provider_id.into(),
link_id: res.upstream_oauth_link_id.map(Ulid::from),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,
id_token: res.id_token,
created_at: res.created_at,
completed_at: res.completed_at,
consumed_at: res.consumed_at,
}))
let Some(res) = res else { return Ok(None) };
let session = UpstreamOAuthAuthorizationSession {
id: res.upstream_oauth_authorization_session_id.into(),
provider_id: res.upstream_oauth_provider_id.into(),
link_id: res.upstream_oauth_link_id.map(Ulid::from),
state: res.state,
code_challenge_verifier: res.code_challenge_verifier,
nonce: res.nonce,
id_token: res.id_token,
created_at: res.created_at,
completed_at: res.completed_at,
consumed_at: res.consumed_at,
};
Ok(Some(session))
}
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state: String,
code_challenge_verifier: Option<String>,
nonce: String,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
"upstream_oauth_authorization_session.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO upstream_oauth_authorization_sessions (
upstream_oauth_authorization_session_id,
upstream_oauth_provider_id,
state,
code_challenge_verifier,
nonce,
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&state,
code_challenge_verifier.as_deref(),
nonce,
created_at,
)
.execute(&mut *self.conn)
.await?;
Ok(UpstreamOAuthAuthorizationSession {
id,
provider_id: upstream_oauth_provider.id,
link_id: None,
state,
code_challenge_verifier,
nonce,
id_token: None,
created_at,
completed_at: None,
consumed_at: None,
})
}
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
%upstream_oauth_link.id,
),
err,
)]
async fn complete_with_link(
&mut self,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(&mut *self.conn)
.await?;
upstream_oauth_authorization_session.completed_at = Some(completed_at);
upstream_oauth_authorization_session.id_token = id_token;
upstream_oauth_authorization_session.link_id = Some(upstream_oauth_link.id);
Ok(upstream_oauth_authorization_session)
}
/// Mark a session as consumed
#[tracing::instrument(
skip_all,
fields(
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let consumed_at = clock.now();
sqlx::query!(
r#"
UPDATE upstream_oauth_authorization_sessions
SET consumed_at = $1
WHERE upstream_oauth_authorization_session_id = $2
"#,
consumed_at,
Uuid::from(upstream_oauth_authorization_session.id),
)
.execute(&mut *self.conn)
.await?;
upstream_oauth_authorization_session.consumed_at = Some(consumed_at);
Ok(upstream_oauth_authorization_session)
}
}