diff --git a/crates/core/src/handlers/oauth2/token.rs b/crates/core/src/handlers/oauth2/token.rs index 9f84463d..873171c2 100644 --- a/crates/core/src/handlers/oauth2/token.rs +++ b/crates/core/src/handlers/oauth2/token.rs @@ -153,8 +153,15 @@ async fn authorization_code_grant( conn: &mut PoolConnection, ) -> Result { let mut txn = conn.begin().await.wrap_error()?; - // TODO: recover from failed code lookup with invalid_grant instead - let code = lookup_code(&mut txn, &grant.code).await.wrap_error()?; + + // TODO: we should invalidate the existing session if a code is used twice after + // some period of time. See the `oidcc-codereuse-30seconds` test from the + // conformance suite + let code = match lookup_code(&mut txn, &grant.code).await { + Err(e) if e.not_found() => return error(InvalidGrant), + x => x, + }?; + if client.client_id != code.client_id { return error(UnauthorizedClient); } diff --git a/crates/core/src/storage/oauth2/authorization_code.rs b/crates/core/src/storage/oauth2/authorization_code.rs index a6437716..35c2e45d 100644 --- a/crates/core/src/storage/oauth2/authorization_code.rs +++ b/crates/core/src/storage/oauth2/authorization_code.rs @@ -16,6 +16,8 @@ use anyhow::Context; use oauth2_types::pkce; use serde::Serialize; use sqlx::{Executor, FromRow, Postgres}; +use thiserror::Error; +use warp::reject::Reject; #[derive(FromRow, Serialize)] pub struct OAuth2Code { @@ -65,11 +67,24 @@ pub struct OAuth2CodeLookup { pub nonce: Option, } +#[derive(Debug, Error)] +#[error("failed to lookup oauth2 code")] +pub struct CodeLookupError(#[from] sqlx::Error); + +impl Reject for CodeLookupError {} + +impl CodeLookupError { + #[must_use] + pub fn not_found(&self) -> bool { + matches!(self.0, sqlx::Error::RowNotFound) + } +} + pub async fn lookup_code( executor: impl Executor<'_, Database = Postgres>, code: &str, -) -> anyhow::Result { - sqlx::query_as!( +) -> Result { + let res = sqlx::query_as!( OAuth2CodeLookup, r#" SELECT @@ -87,8 +102,9 @@ pub async fn lookup_code( code, ) .fetch_one(executor) - .await - .context("could not lookup oauth2 code") + .await?; + + Ok(res) } pub async fn consume_code(