From 4307276b0e7b227a1dd8f2fcfb47636dbd81b799 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 19 Oct 2021 19:18:25 +0200 Subject: [PATCH] Refactor DB interactions for OAuth code and tokens This ensures complex types like scopes are properly parsed back from the database. --- Cargo.lock | 5 +- crates/core/Cargo.toml | 8 - crates/core/sqlx-data.json | 484 +++++++++++------- crates/core/src/filters/authenticate.rs | 53 +- .../core/src/handlers/oauth2/authorization.rs | 27 +- .../core/src/handlers/oauth2/introspection.rs | 39 +- crates/core/src/handlers/oauth2/token.rs | 163 +++--- crates/core/src/handlers/oauth2/userinfo.rs | 14 +- crates/core/src/storage/mod.rs | 9 +- .../core/src/storage/oauth2/access_token.rs | 176 +++++-- .../src/storage/oauth2/authorization_code.rs | 210 ++++++-- .../core/src/storage/oauth2/refresh_token.rs | 195 +++++-- crates/core/src/storage/oauth2/session.rs | 27 +- crates/core/src/storage/user.rs | 31 +- crates/data-model/Cargo.toml | 1 + crates/data-model/src/lib.rs | 34 +- 16 files changed, 947 insertions(+), 529 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d14461db..6b5e3deb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1520,7 +1520,6 @@ dependencies = [ "crc", "data-encoding", "elliptic-curve", - "figment", "futures-util", "headers", "hyper", @@ -1539,15 +1538,12 @@ dependencies = [ "pkcs8", "rand 0.8.4", "rsa", - "schemars", "serde", "serde_json", "serde_urlencoded", "serde_with", - "serde_yaml", "sha2", "sqlx", - "tera", "thiserror", "tokio", "tokio-stream", @@ -1564,6 +1560,7 @@ dependencies = [ "oauth2-types", "serde", "thiserror", + "url", ] [[package]] diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 6333188a..9696136b 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -24,23 +24,15 @@ anyhow = "1.0.44" warp = "0.3.1" hyper = { version = "0.14.13", features = ["full"] } -# Template engine -tera = "1.13.0" - # Database access sqlx = { version = "0.5.9", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] } # Various structure (de)serialization serde = { version = "1.0.130", features = ["derive"] } -serde_yaml = "0.8.21" serde_with = { version = "1.10.0", features = ["hex", "chrono"] } serde_json = "1.0.68" serde_urlencoded = "0.7.0" -# Argument & config parsing -figment = { version = "0.10.6", features = ["env", "yaml", "test"] } -schemars = { version = "0.8.6", features = ["url", "chrono"] } - # Password hashing argon2 = { version = "0.3.1", features = ["password-hash"] } password-hash = { version = "0.3.2", features = ["std"] } diff --git a/crates/core/sqlx-data.json b/crates/core/sqlx-data.json index fa7fa539..5bd07067 100644 --- a/crates/core/sqlx-data.json +++ b/crates/core/sqlx-data.json @@ -26,8 +26,8 @@ ] } }, - "138c3297a66107d8428ca10d04f9a4dd75faf9c1d3f84bcedd3b09f55dd84206": { - "query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id, oauth2_session_id, code, code_challenge_method, code_challenge\n ", + "0c056fcc1a85d00db88034bcc582376cf220e1933d2932e520c44ed9931f5c9d": { + "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_session_id, oauth2_access_token_id, token)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, created_at\n ", "describe": { "columns": [ { @@ -37,39 +37,20 @@ }, { "ordinal": 1, - "name": "oauth2_session_id", - "type_info": "Int8" - }, - { - "ordinal": 2, - "name": "code", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "code_challenge_method", - "type_info": "Int2" - }, - { - "ordinal": 4, - "name": "code_challenge", - "type_info": "Text" + "name": "created_at", + "type_info": "Timestamptz" } ], "parameters": { "Left": [ "Int8", - "Text", - "Int2", + "Int8", "Text" ] }, "nullable": [ false, - false, - false, - true, - true + false ] } }, @@ -93,6 +74,104 @@ ] } }, + "282548c5ad51bd95b7d9ad290714bab5860f1e1291021e7d786dc926d12b5dd9": { + "query": "\n SELECT\n oc.id,\n oc.code_challenge,\n oc.code_challenge_method,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce,\n us.id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n LEFT JOIN user_sessions us\n ON us.id = os.user_session_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n LEFT JOIN users u\n ON u.id = us.user_id\n WHERE oc.code = $1\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "code_challenge", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "code_challenge_method", + "type_info": "Int2" + }, + { + "ordinal": 3, + "name": "oauth2_session_id!", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "client_id!", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "redirect_uri", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "scope!", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "nonce", + "type_info": "Text" + }, + { + "ordinal": 8, + "name": "user_session_id?", + "type_info": "Int8" + }, + { + "ordinal": 9, + "name": "user_session_created_at?", + "type_info": "Timestamptz" + }, + { + "ordinal": 10, + "name": "user_id?", + "type_info": "Int8" + }, + { + "ordinal": 11, + "name": "user_username?", + "type_info": "Text" + }, + { + "ordinal": 12, + "name": "user_session_last_authentication_id?", + "type_info": "Int8" + }, + { + "ordinal": 13, + "name": "user_session_last_authentication_created_at?", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + true, + true, + false, + false, + false, + false, + true, + false, + false, + false, + false, + false, + false + ] + } + }, "307fd9f71e7a94a0a0d9ce523ee9792e127485d0d12480c43f179dd9b75afbab": { "query": "\n INSERT INTO user_sessions (user_id)\n VALUES ($1)\n RETURNING id, created_at\n ", "describe": { @@ -119,58 +198,31 @@ ] } }, - "49888f812910633b87ce65c277f8969377fe264be154d8aa6b33d861d26d2b3b": { - "query": "\n SELECT\n u.username AS \"username!\",\n us.active AS \"active!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n at.created_at AS \"created_at!\",\n at.expires_after AS \"expires_after!\"\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n WHERE at.token = $1\n ", + "47a7a8d2ef7db8bb1d41230626ded4e4661d488891fbda9b872c0749a9ba58f4": { + "query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id\n ", "describe": { "columns": [ { "ordinal": 0, - "name": "username!", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "active!", - "type_info": "Bool" - }, - { - "ordinal": 2, - "name": "client_id!", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "scope!", - "type_info": "Text" - }, - { - "ordinal": 4, - "name": "created_at!", - "type_info": "Timestamptz" - }, - { - "ordinal": 5, - "name": "expires_after!", - "type_info": "Int4" + "name": "id", + "type_info": "Int8" } ], "parameters": { "Left": [ + "Int8", + "Text", + "Int2", "Text" ] }, "nullable": [ - false, - false, - false, - false, - false, false ] } }, - "562b0d4dcf857e99c20e9288e9c8bd46232290715c0d2459b0398a1c746cf65d": { - "query": "\n SELECT\n rt.id,\n rt.oauth2_session_id,\n rt.oauth2_access_token_id,\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n WHERE rt.token = $1 AND rt.next_token_id IS NULL\n ", + "59e8a5de682642883a9b9fc1b522736fa4397f0a0c97074f2c8908e5956c0166": { + "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_session_id, token, expires_after)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, created_at\n ", "describe": { "columns": [ { @@ -180,23 +232,116 @@ }, { "ordinal": 1, - "name": "oauth2_session_id", + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8", + "Text", + "Int4" + ] + }, + "nullable": [ + false, + false + ] + } + }, + "5d032f4bdb28534da7cf8e9806442a12708d632b7be28f8b952bd3cb63a8b1af": { + "query": "\n SELECT\n rt.id AS refresh_token_id,\n rt.token AS refresh_token,\n rt.created_at AS refresh_token_created_at,\n at.id AS \"access_token_id?\",\n at.token AS \"access_token?\",\n at.expires_after AS \"access_token_expires_after?\",\n at.created_at AS \"access_token_created_at?\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n os.redirect_uri AS \"redirect_uri!\",\n os.nonce AS \"nonce\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_refresh_tokens rt\n LEFT JOIN oauth2_access_tokens at\n ON at.id = rt.oauth2_access_token_id\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE rt.token = $1\n AND rt.next_token_id IS NULL\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "refresh_token_id", "type_info": "Int8" }, + { + "ordinal": 1, + "name": "refresh_token", + "type_info": "Text" + }, { "ordinal": 2, - "name": "oauth2_access_token_id", - "type_info": "Int8" + "name": "refresh_token_created_at", + "type_info": "Timestamptz" }, { "ordinal": 3, + "name": "access_token_id?", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "access_token?", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "access_token_expires_after?", + "type_info": "Int4" + }, + { + "ordinal": 6, + "name": "access_token_created_at?", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "session_id!", + "type_info": "Int8" + }, + { + "ordinal": 8, "name": "client_id!", "type_info": "Text" }, { - "ordinal": 4, + "ordinal": 9, "name": "scope!", "type_info": "Text" + }, + { + "ordinal": 10, + "name": "redirect_uri!", + "type_info": "Text" + }, + { + "ordinal": 11, + "name": "nonce", + "type_info": "Text" + }, + { + "ordinal": 12, + "name": "user_session_id!", + "type_info": "Int8" + }, + { + "ordinal": 13, + "name": "user_session_created_at!", + "type_info": "Timestamptz" + }, + { + "ordinal": 14, + "name": "user_id!", + "type_info": "Int8" + }, + { + "ordinal": 15, + "name": "user_username!", + "type_info": "Text" + }, + { + "ordinal": 16, + "name": "user_session_last_authentication_id?", + "type_info": "Int8" + }, + { + "ordinal": 17, + "name": "user_session_last_authentication_created_at?", + "type_info": "Timestamptz" } ], "parameters": { @@ -205,10 +350,23 @@ ] }, "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, false, false, true, false, + false, + false, + false, + false, false ] } @@ -223,60 +381,106 @@ "nullable": [] } }, - "73f2d928f7bf88af79a3685bd6346652b4e4454b0ce75e38343840c9765e3f27": { - "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_session_id, oauth2_access_token_id, token)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, oauth2_access_token_id, token, next_token_id, \n created_at, updated_at\n ", + "686a796a7de689b73a9377083718c95ac5ac51ce396dcf32e614402051d93e16": { + "query": "\n SELECT\n at.id AS \"access_token_id\",\n at.token AS \"access_token\",\n at.expires_after AS \"access_token_expires_after\",\n at.created_at AS \"access_token_created_at\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n os.redirect_uri AS \"redirect_uri!\",\n os.nonce AS \"nonce\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE at.token = $1\n AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now()\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", "describe": { "columns": [ { "ordinal": 0, - "name": "id", + "name": "access_token_id", "type_info": "Int8" }, { "ordinal": 1, - "name": "oauth2_session_id", - "type_info": "Int8" - }, - { - "ordinal": 2, - "name": "oauth2_access_token_id", - "type_info": "Int8" - }, - { - "ordinal": 3, - "name": "token", + "name": "access_token", "type_info": "Text" }, + { + "ordinal": 2, + "name": "access_token_expires_after", + "type_info": "Int4" + }, + { + "ordinal": 3, + "name": "access_token_created_at", + "type_info": "Timestamptz" + }, { "ordinal": 4, - "name": "next_token_id", + "name": "session_id!", "type_info": "Int8" }, { "ordinal": 5, - "name": "created_at", - "type_info": "Timestamptz" + "name": "client_id!", + "type_info": "Text" }, { "ordinal": 6, - "name": "updated_at", + "name": "scope!", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "redirect_uri!", + "type_info": "Text" + }, + { + "ordinal": 8, + "name": "nonce", + "type_info": "Text" + }, + { + "ordinal": 9, + "name": "user_session_id!", + "type_info": "Int8" + }, + { + "ordinal": 10, + "name": "user_session_created_at!", + "type_info": "Timestamptz" + }, + { + "ordinal": 11, + "name": "user_id!", + "type_info": "Int8" + }, + { + "ordinal": 12, + "name": "user_username!", + "type_info": "Text" + }, + { + "ordinal": 13, + "name": "user_session_last_authentication_id?", + "type_info": "Int8" + }, + { + "ordinal": 14, + "name": "user_session_last_authentication_created_at?", "type_info": "Timestamptz" } ], "parameters": { "Left": [ - "Int8", - "Int8", "Text" ] }, "nullable": [ false, false, - true, + false, + false, + false, + false, + false, false, true, false, + false, + false, + false, + false, false ] } @@ -388,52 +592,6 @@ "nullable": [] } }, - "b766b2b41d8770b5bef9928bb3b96abbaf8466b473e12b21f145c015b7cf2f05": { - "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_session_id, token, expires_after)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, token, expires_after, created_at\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "oauth2_session_id", - "type_info": "Int8" - }, - { - "ordinal": 2, - "name": "token", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "expires_after", - "type_info": "Int4" - }, - { - "ordinal": 4, - "name": "created_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [ - "Int8", - "Text", - "Int4" - ] - }, - "nullable": [ - false, - false, - false, - false, - false - ] - } - }, "c2c402cfe0adcafa615f14a499caba4c96ca71d9ffb163e1feb05e5d85f3462c": { "query": "\n UPDATE oauth2_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n ", "describe": { @@ -579,68 +737,6 @@ "nullable": [] } }, - "eb5f772a7387de0dc2f9f660f470476c075da097134a8ded226eb630545c16eb": { - "query": "\n SELECT\n oc.id,\n oc.code_challenge,\n oc.code_challenge_method,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n WHERE oc.code = $1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "code_challenge", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "code_challenge_method", - "type_info": "Int2" - }, - { - "ordinal": 3, - "name": "oauth2_session_id!", - "type_info": "Int8" - }, - { - "ordinal": 4, - "name": "client_id!", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "redirect_uri", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "scope!", - "type_info": "Text" - }, - { - "ordinal": 7, - "name": "nonce", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Text" - ] - }, - "nullable": [ - false, - true, - true, - false, - false, - false, - false, - true - ] - } - }, "f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": { "query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ", "describe": { diff --git a/crates/core/src/filters/authenticate.rs b/crates/core/src/filters/authenticate.rs index 773c076e..e19482d8 100644 --- a/crates/core/src/filters/authenticate.rs +++ b/crates/core/src/filters/authenticate.rs @@ -14,9 +14,9 @@ //! Authenticate an endpoint with an access token as bearer authorization token -use chrono::Utc; use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; +use mas_data_model::{AccessToken, Session}; use sqlx::{pool::PoolConnection, PgPool, Postgres}; use thiserror::Error; use warp::{ @@ -31,8 +31,9 @@ use super::{ }; use crate::{ errors::wrapped_error, - storage::oauth2::access_token::{ - lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup, + storage::{ + oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError}, + PostgresqlBackend, }, tokens::{TokenFormatError, TokenType}, }; @@ -82,19 +83,25 @@ impl Reject for AuthenticationError {} #[must_use] pub fn authentication( pool: &PgPool, -) -> impl Filter + Clone + Send + Sync + 'static -{ +) -> impl Filter< + Extract = (AccessToken, Session), + Error = Rejection, +> + Clone + + Send + + Sync + + 'static { connection(pool) .and(typed_header()) .and_then(authenticate) .recover(recover) .unify() + .untuple_one() } async fn authenticate( mut conn: PoolConnection, auth: Authorization, -) -> Result { +) -> Result<(AccessToken, Session), Rejection> { let token = auth.0.token(); let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?; @@ -102,29 +109,25 @@ async fn authenticate( return Err(AuthenticationError::WrongTokenType(token_type).into()); } - let token = lookup_access_token(&mut conn, token).await.map_err(|e| { - if e.not_found() { - // This error happens if the token was not found and should be recovered - warp::reject::custom(AuthenticationError::TokenNotFound(e)) - } else { - // This is a generic database error that we want to propagate - warp::reject::custom(wrapped_error(e)) - } - })?; + let (token, session) = lookup_active_access_token(&mut conn, token) + .await + .map_err(|e| { + if e.not_found() { + // This error happens if the token was not found and should be recovered + warp::reject::custom(AuthenticationError::TokenNotFound(e)) + } else { + // This is a generic database error that we want to propagate + warp::reject::custom(wrapped_error(e)) + } + })?; - if !token.active { - return Err(AuthenticationError::TokenInactive.into()); - } - - if token.exp() < Utc::now() { - return Err(AuthenticationError::TokenExpired.into()); - } - - Ok(token) + Ok((token, session)) } /// Transform the rejections from the [`with_typed_header`] filter -async fn recover(rejection: Rejection) -> Result { +async fn recover( + rejection: Rejection, +) -> Result<(AccessToken, Session), Rejection> { if rejection.find::().is_some() { return Err(warp::reject::custom( AuthenticationError::MissingAuthorizationHeader, diff --git a/crates/core/src/handlers/oauth2/authorization.rs b/crates/core/src/handlers/oauth2/authorization.rs index 12373abf..413fb181 100644 --- a/crates/core/src/handlers/oauth2/authorization.rs +++ b/crates/core/src/handlers/oauth2/authorization.rs @@ -395,7 +395,7 @@ impl StepRequest { async fn step( oauth2_session_id: i64, - user_session: BrowserSession, + browser_session: BrowserSession, mut txn: Transaction<'_, Postgres>, ) -> Result { let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id) @@ -403,7 +403,7 @@ async fn step( .wrap_error()?; let user_session = oauth2_session - .match_or_set_session(&mut txn, user_session) + .match_or_set_session(&mut txn, browser_session) .await .wrap_error()?; @@ -427,7 +427,7 @@ async fn step( // Did they request an access token? if response_type.contains(&ResponseType::Token) { let ttl = Duration::minutes(5); - let (access_token, refresh_token) = { + let (access_token_str, refresh_token_str) = { let mut rng = thread_rng(); ( AccessToken.generate(&mut rng), @@ -435,19 +435,24 @@ async fn step( ) }; - let access_token = add_access_token(&mut txn, oauth2_session_id, &access_token, ttl) - .await - .wrap_error()?; - - let refresh_token = - add_refresh_token(&mut txn, oauth2_session_id, access_token.id, &refresh_token) + let access_token = + add_access_token(&mut txn, oauth2_session_id, &access_token_str, ttl) .await .wrap_error()?; + let _refresh_token = add_refresh_token( + &mut txn, + oauth2_session_id, + access_token, + &refresh_token_str, + ) + .await + .wrap_error()?; + params.response = Some( - AccessTokenResponse::new(access_token.token) + AccessTokenResponse::new(access_token_str) .with_expires_in(ttl) - .with_refresh_token(refresh_token.token), + .with_refresh_token(refresh_token_str), ); } diff --git a/crates/core/src/handlers/oauth2/introspection.rs b/crates/core/src/handlers/oauth2/introspection.rs index 8b9320a2..ae989cd2 100644 --- a/crates/core/src/handlers/oauth2/introspection.rs +++ b/crates/core/src/handlers/oauth2/introspection.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::Utc; use hyper::Method; use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint}; use sqlx::{pool::PoolConnection, PgPool, Postgres}; @@ -27,7 +26,9 @@ use crate::{ cors::cors, database::connection, }, - storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token}, + storage::oauth2::{ + access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, + }, tokens::{self, TokenType}, }; @@ -84,43 +85,41 @@ async fn introspect( let reply = match token_type { tokens::TokenType::AccessToken => { - let token = lookup_access_token(&mut conn, token).await.wrap_error()?; + let (token, session) = lookup_active_access_token(&mut conn, token) + .await + .wrap_error()?; let exp = token.exp(); - // Check it is active and did not expire - if !token.active || exp < Utc::now() { - info!(?token, "Access token expired"); - return Ok(warp::reply::json(&INACTIVE)); - } - IntrospectionResponse { active: true, - scope: None, // TODO: parse back scopes - client_id: Some(token.client_id.clone()), - username: Some(token.username.clone()), + scope: Some(session.scope), + client_id: Some(session.client.client_id), + username: session.browser_session.clone().map(|s| s.user.username), token_type: Some(TokenTypeHint::AccessToken), exp: Some(exp), iat: Some(token.created_at), nbf: Some(token.created_at), - sub: None, + sub: session.browser_session.map(|s| s.user.sub), aud: None, iss: None, jti: None, } } tokens::TokenType::RefreshToken => { - let token = lookup_refresh_token(&mut conn, token).await.wrap_error()?; + let (token, session) = lookup_active_refresh_token(&mut conn, token) + .await + .wrap_error()?; IntrospectionResponse { active: true, - scope: None, // TODO: parse back scopes - client_id: Some(token.client_id), - username: None, + scope: Some(session.scope), + client_id: Some(session.client.client_id), + username: session.browser_session.clone().map(|s| s.user.username), token_type: Some(TokenTypeHint::RefreshToken), exp: None, - iat: None, - nbf: None, - sub: None, + iat: Some(token.created_at), + nbf: Some(token.created_at), + sub: session.browser_session.map(|s| s.user.sub), aud: None, iss: None, jti: None, diff --git a/crates/core/src/handlers/oauth2/token.rs b/crates/core/src/handlers/oauth2/token.rs index d51d043c..e4f28800 100644 --- a/crates/core/src/handlers/oauth2/token.rs +++ b/crates/core/src/handlers/oauth2/token.rs @@ -19,13 +19,11 @@ use headers::{CacheControl, Pragma}; use hyper::{Method, StatusCode}; use jwt_compact::{Claims, Header, TimeOptions}; use oauth2_types::{ - errors::{ - InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, ServerError, UnauthorizedClient, - }, - pkce::CodeChallengeMethod, + errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient}, requests::{ AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant, }, + scope::OPENID, }; use rand::thread_rng; use serde::Serialize; @@ -52,7 +50,7 @@ use crate::{ storage::oauth2::{ access_token::{add_access_token, revoke_access_token}, authorization_code::{consume_code, lookup_code}, - refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token}, + refresh_token::{add_refresh_token, lookup_active_refresh_token, replace_refresh_token}, }, tokens::{AccessToken, RefreshToken}, }; @@ -163,46 +161,37 @@ async fn authorization_code_grant( // TODO: we should invalidate the existing session if a code is used twice after // some period of time. See the `oidcc-codereuse-30seconds` test from the // conformance suite - let code = match lookup_code(&mut txn, &grant.code).await { + let (code, session) = match lookup_code(&mut txn, &grant.code).await { Err(e) if e.not_found() => return error(InvalidGrant), x => x, }?; - if client.client_id != code.client_id { + if client.client_id != session.client.client_id { return error(UnauthorizedClient); } - match ( - code.code_challenge_method.as_ref(), - code.code_challenge.as_ref(), - grant.code_verifier.as_ref(), - ) { - (None, None, None) => {} + match (code.pkce.as_ref(), grant.code_verifier.as_ref()) { + (None, None) => {} // We have a challenge but no verifier (or vice-versa)? Bad request. - (Some(_), Some(_), None) | (None, None, Some(_)) => return error(InvalidRequest), - (Some(0 /* Plain */), Some(code_challenge), Some(code_verifier)) => { - if !CodeChallengeMethod::Plain.verify(code_challenge, code_verifier) { + (Some(_), None) | (None, Some(_)) => return error(InvalidRequest), + // If we have both, we need to check the code validity + (Some(pkce), Some(verifier)) => { + if !pkce.verify(verifier) { return error(InvalidRequest); } } - (Some(1 /* S256 */), Some(code_challenge), Some(code_verifier)) => { - if !CodeChallengeMethod::S256.verify(code_challenge, code_verifier) { - return error(InvalidRequest); - } - } - - // We have something else? - // That's a DB inconcistancy, we should bail out - _ => { - // TODO: are we sure we want to handle errors like that? - tracing::error!("Invalid state from the database"); - return error(ServerError); // Somthing bad happened in the database - } }; - // TODO: verify PKCE + // TODO: this should probably not happen? + let browser_session = session + .browser_session + .ok_or_else(|| { + anyhow::anyhow!("this oauth2 session has no database session attached to it") + }) + .wrap_error()?; + let ttl = Duration::minutes(5); - let (access_token, refresh_token) = { + let (access_token_str, refresh_token_str) = { let mut rng = thread_rng(); ( AccessToken.generate(&mut rng), @@ -210,45 +199,48 @@ async fn authorization_code_grant( ) }; - let access_token = add_access_token(&mut txn, code.oauth2_session_id, &access_token, ttl) + let access_token = add_access_token(&mut txn, session.data, &access_token_str, ttl) .await .wrap_error()?; - let refresh_token = add_refresh_token( - &mut txn, - code.oauth2_session_id, - access_token.id, - &refresh_token, - ) - .await - .wrap_error()?; + let _refresh_token = + add_refresh_token(&mut txn, session.data, access_token, &refresh_token_str) + .await + .wrap_error()?; - // TODO: generate id_token only if the "openid" scope was asked - let header = Header::default(); - let options = TimeOptions::default(); - let claims = Claims::new(CustomClaims { - issuer, - // TODO: get that from the session - subject: "random-subject".to_string(), - audiences: vec![client.client_id.clone()], - nonce: code.nonce, - at_hash: hash(Sha256::new(), &access_token.token).wrap_error()?, - c_hash: hash(Sha256::new(), &grant.code).wrap_error()?, - }) - .set_duration_and_issuance(&options, Duration::minutes(30)); - let id_token = keys - .token(crate::config::Algorithm::Rs256, header, claims) - .await - .context("could not sign ID token") - .wrap_error()?; + let id_token = if session.scope.contains(&OPENID) { + let header = Header::default(); + let options = TimeOptions::default(); + let claims = Claims::new(CustomClaims { + issuer, + subject: browser_session.user.sub, + audiences: vec![client.client_id.clone()], + nonce: session.nonce, + at_hash: hash(Sha256::new(), &access_token_str).wrap_error()?, + c_hash: hash(Sha256::new(), &grant.code).wrap_error()?, + }) + .set_duration_and_issuance(&options, Duration::minutes(30)); + let id_token = keys + .token(crate::config::Algorithm::Rs256, header, claims) + .await + .context("could not sign ID token") + .wrap_error()?; - // TODO: have the scopes back here - let params = AccessTokenResponse::new(access_token.token) + Some(id_token) + } else { + None + }; + + let mut params = AccessTokenResponse::new(access_token_str) .with_expires_in(ttl) - .with_refresh_token(refresh_token.token) - .with_id_token(id_token); + .with_refresh_token(refresh_token_str) + .with_scope(session.scope); - consume_code(&mut txn, code.id).await.wrap_error()?; + if let Some(id_token) = id_token { + params = params.with_id_token(id_token); + } + + consume_code(&mut txn, code).await.wrap_error()?; txn.commit().await.wrap_error()?; @@ -261,18 +253,17 @@ async fn refresh_token_grant( conn: &mut PoolConnection, ) -> Result { let mut txn = conn.begin().await.wrap_error()?; - // TODO: scope handling - let refresh_token_lookup = lookup_refresh_token(&mut txn, &grant.refresh_token) + let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token) .await .wrap_error()?; - if client.client_id != refresh_token_lookup.client_id { + if client.client_id != session.client.client_id { // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 return error(InvalidGrant); } let ttl = Duration::minutes(5); - let (access_token, refresh_token) = { + let (access_token_str, refresh_token_str) = { let mut rng = thread_rng(); ( AccessToken.generate(&mut rng), @@ -280,37 +271,29 @@ async fn refresh_token_grant( ) }; - let access_token = add_access_token( - &mut txn, - refresh_token_lookup.oauth2_session_id, - &access_token, - ttl, - ) - .await - .wrap_error()?; - - let refresh_token = add_refresh_token( - &mut txn, - refresh_token_lookup.oauth2_session_id, - access_token.id, - &refresh_token, - ) - .await - .wrap_error()?; - - replace_refresh_token(&mut txn, refresh_token_lookup.id, refresh_token.id) + let new_access_token = add_access_token(&mut txn, session.data, &access_token_str, ttl) .await .wrap_error()?; - if let Some(access_token_id) = refresh_token_lookup.oauth2_access_token_id { - revoke_access_token(&mut txn, access_token_id) + let new_refresh_token = + add_refresh_token(&mut txn, session.data, new_access_token, &refresh_token_str) + .await + .wrap_error()?; + + replace_refresh_token(&mut txn, &refresh_token, &new_refresh_token) + .await + .wrap_error()?; + + if let Some(access_token) = refresh_token.access_token { + revoke_access_token(&mut txn, access_token.data) .await .wrap_error()?; } - let params = AccessTokenResponse::new(access_token.token) + let params = AccessTokenResponse::new(access_token_str) .with_expires_in(ttl) - .with_refresh_token(refresh_token.token); + .with_refresh_token(refresh_token_str) + .with_scope(session.scope); txn.commit().await.wrap_error()?; diff --git a/crates/core/src/handlers/oauth2/userinfo.rs b/crates/core/src/handlers/oauth2/userinfo.rs index 9d0cafb9..26cec4d4 100644 --- a/crates/core/src/handlers/oauth2/userinfo.rs +++ b/crates/core/src/handlers/oauth2/userinfo.rs @@ -13,6 +13,7 @@ // limitations under the License. use hyper::Method; +use mas_data_model::{AccessToken, Session}; use serde::Serialize; use sqlx::PgPool; use warp::{Filter, Rejection, Reply}; @@ -23,12 +24,13 @@ use crate::{ authenticate::{authentication, recover_unauthorized}, cors::cors, }, - storage::oauth2::access_token::OAuth2AccessTokenLookup, + storage::PostgresqlBackend, }; #[derive(Serialize)] struct UserInfo { sub: String, + username: String, } pub(super) fn filter( @@ -46,8 +48,14 @@ pub(super) fn filter( ) } -async fn userinfo(token: OAuth2AccessTokenLookup) -> Result { +async fn userinfo( + _token: AccessToken, + session: Session, +) -> Result { + // TODO: we really should not have an Option here + let user = session.browser_session.unwrap().user; Ok(warp::reply::json(&UserInfo { - sub: token.username, + sub: user.sub, + username: user.username, })) } diff --git a/crates/core/src/storage/mod.rs b/crates/core/src/storage/mod.rs index 522faa17..641e6e00 100644 --- a/crates/core/src/storage/mod.rs +++ b/crates/core/src/storage/mod.rs @@ -16,13 +16,14 @@ #![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros +use chrono::{DateTime, Utc}; use mas_data_model::{StorageBackend, StorageBackendMarker}; use serde::Serialize; use sqlx::migrate::Migrator; use thiserror::Error; #[derive(Debug, Error)] -#[error("databse query returned an inconsistent state")] +#[error("database query returned an inconsistent state")] pub struct DatabaseInconsistencyError; #[derive(Serialize, Debug, Clone, PartialEq)] @@ -34,12 +35,18 @@ impl StorageBackend for PostgresqlBackend { type AuthorizationCodeData = i64; type BrowserSessionData = i64; type ClientData = (); + type RefreshTokenData = i64; type SessionData = i64; type UserData = i64; } impl StorageBackendMarker for PostgresqlBackend {} +struct IdAndCreationTime { + id: i64, + created_at: DateTime, +} + pub mod oauth2; pub mod user; diff --git a/crates/core/src/storage/oauth2/access_token.rs b/crates/core/src/storage/oauth2/access_token.rs index b1458837..c43b25fc 100644 --- a/crates/core/src/storage/oauth2/access_token.rs +++ b/crates/core/src/storage/oauth2/access_token.rs @@ -16,89 +16,108 @@ use std::convert::TryFrom; use anyhow::Context; use chrono::{DateTime, Duration, Utc}; -use serde::Serialize; -use sqlx::{Executor, FromRow, Postgres}; +use mas_data_model::{AccessToken, Authentication, BrowserSession, Client, Session, User}; +use sqlx::PgExecutor; use thiserror::Error; -#[derive(FromRow, Serialize)] -pub struct OAuth2AccessToken { - pub id: i64, - pub oauth2_session_id: i64, - pub token: String, - expires_after: i32, - created_at: DateTime, -} +use crate::storage::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend}; pub async fn add_access_token( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, oauth2_session_id: i64, token: &str, expires_after: Duration, -) -> anyhow::Result { +) -> anyhow::Result> { // Checked convertion of duration to i32, maxing at i32::MAX - let expires_after = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX); + let expires_after_seconds = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX); - sqlx::query_as!( - OAuth2AccessToken, + let res = sqlx::query_as!( + IdAndCreationTime, r#" INSERT INTO oauth2_access_tokens (oauth2_session_id, token, expires_after) VALUES ($1, $2, $3) RETURNING - id, oauth2_session_id, token, expires_after, created_at + id, created_at "#, oauth2_session_id, token, - expires_after, + expires_after_seconds, ) .fetch_one(executor) .await - .context("could not insert oauth2 access token") + .context("could not insert oauth2 access token")?; + + Ok(AccessToken { + data: res.id, + expires_after, + token: token.to_string(), + jti: format!("{}", res.id), + created_at: res.created_at, + }) } #[derive(Debug)] pub struct OAuth2AccessTokenLookup { - pub active: bool, - pub username: String, - pub client_id: String, - pub scope: String, - pub created_at: DateTime, - expires_after: i32, -} - -impl OAuth2AccessTokenLookup { - #[must_use] - pub fn exp(&self) -> DateTime { - self.created_at + Duration::seconds(i64::from(self.expires_after)) - } + access_token_id: i64, + access_token: String, + access_token_expires_after: i32, + access_token_created_at: DateTime, + session_id: i64, + client_id: String, + scope: String, + redirect_uri: String, + nonce: Option, + user_session_id: i64, + user_session_created_at: DateTime, + user_id: i64, + user_username: String, + user_session_last_authentication_id: Option, + user_session_last_authentication_created_at: Option>, } #[derive(Debug, Error)] #[error("failed to lookup access token")] -pub struct AccessTokenLookupError(#[from] sqlx::Error); +pub enum AccessTokenLookupError { + Database(#[from] sqlx::Error), + Inconsistency(#[from] DatabaseInconsistencyError), +} impl AccessTokenLookupError { #[must_use] pub fn not_found(&self) -> bool { - matches!(self.0, sqlx::Error::RowNotFound) + matches!( + self, + &AccessTokenLookupError::Database(sqlx::Error::RowNotFound) + ) } } -pub async fn lookup_access_token( - executor: impl Executor<'_, Database = Postgres>, +pub async fn lookup_active_access_token( + executor: impl PgExecutor<'_>, token: &str, -) -> Result { +) -> Result<(AccessToken, Session), AccessTokenLookupError> { let res = sqlx::query_as!( OAuth2AccessTokenLookup, r#" SELECT - u.username AS "username!", - us.active AS "active!", - os.client_id AS "client_id!", - os.scope AS "scope!", - at.created_at AS "created_at!", - at.expires_after AS "expires_after!" + at.id AS "access_token_id", + at.token AS "access_token", + at.expires_after AS "access_token_expires_after", + at.created_at AS "access_token_created_at", + os.id AS "session_id!", + os.client_id AS "client_id!", + os.scope AS "scope!", + os.redirect_uri AS "redirect_uri!", + os.nonce AS "nonce", + us.id AS "user_session_id!", + us.created_at AS "user_session_created_at!", + u.id AS "user_id!", + u.username AS "user_username!", + usa.id AS "user_session_last_authentication_id?", + usa.created_at AS "user_session_last_authentication_created_at?" + FROM oauth2_access_tokens at INNER JOIN oauth2_sessions os ON os.id = at.oauth2_session_id @@ -106,20 +125,79 @@ pub async fn lookup_access_token( ON us.id = os.user_session_id INNER JOIN users u ON u.id = us.user_id + LEFT JOIN user_session_authentications usa + ON usa.session_id = us.id + WHERE at.token = $1 + AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now() + AND us.active + + ORDER BY usa.created_at DESC + LIMIT 1 "#, token, ) .fetch_one(executor) .await?; - Ok(res) + let access_token = AccessToken { + data: res.access_token_id, + jti: format!("{}", res.access_token_id), + token: res.access_token, + created_at: res.access_token_created_at, + expires_after: Duration::seconds(res.access_token_expires_after.into()), + }; + + let client = Client { + data: (), + client_id: res.client_id, + }; + + let user = User { + data: res.user_id, + username: res.user_username, + sub: format!("fake-sub-{}", res.user_id), + }; + + let last_authentication = match ( + res.user_session_last_authentication_id, + res.user_session_last_authentication_created_at, + ) { + (None, None) => None, + (Some(id), Some(created_at)) => Some(Authentication { + data: id, + created_at, + }), + _ => return Err(DatabaseInconsistencyError.into()), + }; + + let browser_session = Some(BrowserSession { + data: res.user_session_id, + created_at: res.user_session_created_at, + user, + last_authentication, + }); + + let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; + + let redirect_uri = res + .redirect_uri + .parse() + .map_err(|_e| DatabaseInconsistencyError)?; + + let session = Session { + data: res.session_id, + client, + browser_session, + scope, + redirect_uri, + nonce: res.nonce, + }; + + Ok((access_token, session)) } -pub async fn revoke_access_token( - executor: impl Executor<'_, Database = Postgres>, - id: i64, -) -> anyhow::Result<()> { +pub async fn revoke_access_token(executor: impl PgExecutor<'_>, id: i64) -> anyhow::Result<()> { let res = sqlx::query!( r#" DELETE FROM oauth2_access_tokens @@ -138,9 +216,7 @@ pub async fn revoke_access_token( } } -pub async fn cleanup_expired( - executor: impl Executor<'_, Database = Postgres>, -) -> anyhow::Result { +pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result { let res = sqlx::query!( r#" DELETE FROM oauth2_access_tokens diff --git a/crates/core/src/storage/oauth2/authorization_code.rs b/crates/core/src/storage/oauth2/authorization_code.rs index a8ede24b..5b4b9667 100644 --- a/crates/core/src/storage/oauth2/authorization_code.rs +++ b/crates/core/src/storage/oauth2/authorization_code.rs @@ -13,40 +13,33 @@ // limitations under the License. use anyhow::Context; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + Authentication, AuthorizationCode, BrowserSession, Client, Pkce, Session, User, +}; use oauth2_types::pkce; -use serde::Serialize; -use sqlx::{Executor, FromRow, Postgres}; +use sqlx::PgExecutor; use thiserror::Error; use warp::reject::Reject; -#[derive(FromRow, Serialize)] -pub struct OAuth2Code { - id: i64, - oauth2_session_id: i64, - pub code: String, - code_challenge: Option, - code_challenge_method: Option, -} +use crate::storage::{DatabaseInconsistencyError, PostgresqlBackend}; pub async fn add_code( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, oauth2_session_id: i64, code: &str, - code_challenge: &Option, -) -> anyhow::Result { - let code_challenge_method = code_challenge - .as_ref() - .map(|c| c.code_challenge_method as i16); - let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge); - sqlx::query_as!( - OAuth2Code, + pkce: &Option, +) -> anyhow::Result> { + let code_challenge_method = pkce.as_ref().map(|c| c.code_challenge_method as i16); + let code_challenge = pkce.as_ref().map(|c| &c.code_challenge); + let id = sqlx::query_scalar!( r#" INSERT INTO oauth2_codes (oauth2_session_id, code, code_challenge_method, code_challenge) VALUES ($1, $2, $3, $4) RETURNING - id, oauth2_session_id, code, code_challenge_method, code_challenge + id "#, oauth2_session_id, code, @@ -55,38 +48,108 @@ pub async fn add_code( ) .fetch_one(executor) .await - .context("could not insert oauth2 authorization code") + .context("could not insert oauth2 authorization code")?; + + let pkce = pkce + .as_ref() + .map(|c| Pkce::new(c.code_challenge_method, c.code_challenge.clone())); + + Ok(AuthorizationCode { + data: id, + code: code.to_string(), + pkce, + }) } -pub struct OAuth2CodeLookup { - pub id: i64, - pub oauth2_session_id: i64, - pub client_id: String, - pub redirect_uri: String, - pub scope: String, - pub nonce: Option, - pub code_challenge: Option, - pub code_challenge_method: Option, +struct OAuth2CodeLookup { + id: i64, + oauth2_session_id: i64, + client_id: String, + redirect_uri: String, + scope: String, + nonce: Option, + code_challenge: Option, + code_challenge_method: Option, + user_session_id: Option, + user_session_created_at: Option>, + user_id: Option, + user_username: Option, + user_session_last_authentication_id: Option, + user_session_last_authentication_created_at: Option>, +} + +fn browser_session_from_database( + user_session_id: Option, + user_session_created_at: Option>, + user_id: Option, + user_username: Option, + user_session_last_authentication_id: Option, + user_session_last_authentication_created_at: Option>, +) -> Result>, DatabaseInconsistencyError> { + match ( + user_session_id, + user_session_created_at, + user_id, + user_username, + ) { + (None, None, None, None) => Ok(None), + (Some(session_id), Some(session_created_at), Some(user_id), Some(user_username)) => { + let user = User { + data: user_id, + username: user_username, + sub: format!("fake-sub-{}", user_id), + }; + + let last_authentication = match ( + user_session_last_authentication_id, + user_session_last_authentication_created_at, + ) { + (None, None) => None, + (Some(id), Some(created_at)) => Some(Authentication { + data: id, + created_at, + }), + _ => return Err(DatabaseInconsistencyError), + }; + + Ok(Some(BrowserSession { + data: session_id, + created_at: session_created_at, + user, + last_authentication, + })) + } + _ => Err(DatabaseInconsistencyError), + } } #[derive(Debug, Error)] #[error("failed to lookup oauth2 code")] -pub struct CodeLookupError(#[from] sqlx::Error); +pub enum CodeLookupError { + Database(#[from] sqlx::Error), + Inconsistency(#[from] DatabaseInconsistencyError), +} impl Reject for CodeLookupError {} impl CodeLookupError { #[must_use] pub fn not_found(&self) -> bool { - matches!(self.0, sqlx::Error::RowNotFound) + matches!(self, &CodeLookupError::Database(sqlx::Error::RowNotFound)) } } +#[allow(clippy::too_many_lines)] pub async fn lookup_code( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, code: &str, -) -> Result { - // TODO: this should return a better type +) -> Result< + ( + AuthorizationCode, + Session, + ), + CodeLookupError, +> { let res = sqlx::query_as!( OAuth2CodeLookup, r#" @@ -94,27 +157,88 @@ pub async fn lookup_code( oc.id, oc.code_challenge, oc.code_challenge_method, - os.id AS "oauth2_session_id!", - os.client_id AS "client_id!", + os.id AS "oauth2_session_id!", + os.client_id AS "client_id!", os.redirect_uri, - os.scope AS "scope!", - os.nonce + os.scope AS "scope!", + os.nonce, + us.id AS "user_session_id?", + us.created_at AS "user_session_created_at?", + u.id AS "user_id?", + u.username AS "user_username?", + usa.id AS "user_session_last_authentication_id?", + usa.created_at AS "user_session_last_authentication_created_at?" FROM oauth2_codes oc INNER JOIN oauth2_sessions os ON os.id = oc.oauth2_session_id + LEFT JOIN user_sessions us + ON us.id = os.user_session_id + LEFT JOIN user_session_authentications usa + ON usa.session_id = us.id + LEFT JOIN users u + ON u.id = us.user_id WHERE oc.code = $1 + ORDER BY usa.created_at DESC + LIMIT 1 "#, code, ) .fetch_one(executor) .await?; - Ok(res) + let pkce = match (res.code_challenge_method, res.code_challenge) { + (None, None) => None, + (Some(0 /* Plain */), Some(challenge)) => { + Some(Pkce::new(pkce::CodeChallengeMethod::Plain, challenge)) + } + (Some(1 /* S256 */), Some(challenge)) => { + Some(Pkce::new(pkce::CodeChallengeMethod::S256, challenge)) + } + _ => return Err(DatabaseInconsistencyError.into()), + }; + + let code = AuthorizationCode { + data: res.id, + code: code.to_string(), + pkce, + }; + + let client = Client { + data: (), + client_id: res.client_id, + }; + + let browser_session = browser_session_from_database( + res.user_session_id, + res.user_session_created_at, + res.user_id, + res.user_username, + res.user_session_last_authentication_id, + res.user_session_last_authentication_created_at, + )?; + + let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; + + let redirect_uri = res + .redirect_uri + .parse() + .map_err(|_e| DatabaseInconsistencyError)?; + + let session = Session { + data: res.oauth2_session_id, + client, + browser_session, + scope, + redirect_uri, + nonce: res.nonce, + }; + + Ok((code, session)) } pub async fn consume_code( - executor: impl Executor<'_, Database = Postgres>, - code_id: i64, + executor: impl PgExecutor<'_>, + code: AuthorizationCode, ) -> anyhow::Result<()> { // TODO: mark the code as invalid instead to allow invalidating the whole // session on code reuse @@ -123,7 +247,7 @@ pub async fn consume_code( DELETE FROM oauth2_codes WHERE id = $1 "#, - code_id, + code.data, ) .execute(executor) .await diff --git a/crates/core/src/storage/oauth2/refresh_token.rs b/crates/core/src/storage/oauth2/refresh_token.rs index 2ab97477..6fb54c64 100644 --- a/crates/core/src/storage/oauth2/refresh_token.rs +++ b/crates/core/src/storage/oauth2/refresh_token.rs @@ -13,83 +13,192 @@ // limitations under the License. use anyhow::Context; -use chrono::{DateTime, Utc}; -use sqlx::{Executor, Postgres}; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{ + AccessToken, Authentication, BrowserSession, Client, RefreshToken, Session, User, +}; +use sqlx::PgExecutor; -#[derive(Debug)] -pub struct OAuth2RefreshToken { - pub id: i64, - oauth2_session_id: i64, - oauth2_access_token_id: Option, - pub token: String, - next_token_id: Option, - created_at: DateTime, - updated_at: DateTime, -} +use crate::storage::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend}; pub async fn add_refresh_token( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, oauth2_session_id: i64, - oauth2_access_token_id: i64, + access_token: AccessToken, token: &str, -) -> anyhow::Result { - sqlx::query_as!( - OAuth2RefreshToken, +) -> anyhow::Result> { + let res = sqlx::query_as!( + IdAndCreationTime, r#" INSERT INTO oauth2_refresh_tokens (oauth2_session_id, oauth2_access_token_id, token) VALUES ($1, $2, $3) RETURNING - id, oauth2_session_id, oauth2_access_token_id, token, next_token_id, - created_at, updated_at + id, created_at "#, oauth2_session_id, - oauth2_access_token_id, + access_token.data, token, ) .fetch_one(executor) .await - .context("could not insert oauth2 refresh token") + .context("could not insert oauth2 refresh token")?; + + Ok(RefreshToken { + data: res.id, + token: token.to_string(), + access_token: Some(access_token), + created_at: res.created_at, + }) } -pub struct OAuth2RefreshTokenLookup { - pub id: i64, - pub oauth2_session_id: i64, - pub oauth2_access_token_id: Option, - pub client_id: String, - pub scope: String, +struct OAuth2RefreshTokenLookup { + refresh_token_id: i64, + refresh_token: String, + refresh_token_created_at: DateTime, + access_token_id: Option, + access_token: Option, + access_token_expires_after: Option, + access_token_created_at: Option>, + session_id: i64, + client_id: String, + scope: String, + redirect_uri: String, + nonce: Option, + user_session_id: i64, + user_session_created_at: DateTime, + user_id: i64, + user_username: String, + user_session_last_authentication_id: Option, + user_session_last_authentication_created_at: Option>, } -pub async fn lookup_refresh_token( - executor: impl Executor<'_, Database = Postgres>, +#[allow(clippy::too_many_lines)] +pub async fn lookup_active_refresh_token( + executor: impl PgExecutor<'_>, token: &str, -) -> anyhow::Result { - sqlx::query_as!( +) -> anyhow::Result<(RefreshToken, Session)> { + let res = sqlx::query_as!( OAuth2RefreshTokenLookup, r#" SELECT - rt.id, - rt.oauth2_session_id, - rt.oauth2_access_token_id, - os.client_id AS "client_id!", - os.scope AS "scope!" + rt.id AS refresh_token_id, + rt.token AS refresh_token, + rt.created_at AS refresh_token_created_at, + at.id AS "access_token_id?", + at.token AS "access_token?", + at.expires_after AS "access_token_expires_after?", + at.created_at AS "access_token_created_at?", + os.id AS "session_id!", + os.client_id AS "client_id!", + os.scope AS "scope!", + os.redirect_uri AS "redirect_uri!", + os.nonce AS "nonce", + us.id AS "user_session_id!", + us.created_at AS "user_session_created_at!", + u.id AS "user_id!", + u.username AS "user_username!", + usa.id AS "user_session_last_authentication_id?", + usa.created_at AS "user_session_last_authentication_created_at?" FROM oauth2_refresh_tokens rt + LEFT JOIN oauth2_access_tokens at + ON at.id = rt.oauth2_access_token_id INNER JOIN oauth2_sessions os ON os.id = rt.oauth2_session_id - WHERE rt.token = $1 AND rt.next_token_id IS NULL + INNER JOIN user_sessions us + ON us.id = os.user_session_id + INNER JOIN users u + ON u.id = us.user_id + LEFT JOIN user_session_authentications usa + ON usa.session_id = us.id + + WHERE rt.token = $1 + AND rt.next_token_id IS NULL + AND us.active + + ORDER BY usa.created_at DESC + LIMIT 1 "#, token, ) .fetch_one(executor) .await - .context("failed to fetch oauth2 refresh token") + .context("failed to fetch oauth2 refresh token")?; + + let access_token = match ( + res.access_token_id, + res.access_token, + res.access_token_created_at, + res.access_token_expires_after, + ) { + (None, None, None, None) => None, + (Some(id), Some(token), Some(created_at), Some(expires_after)) => Some(AccessToken { + data: id, + jti: format!("{}", id), + token, + created_at, + expires_after: Duration::seconds(expires_after.into()), + }), + _ => return Err(DatabaseInconsistencyError.into()), + }; + + let refresh_token = RefreshToken { + data: res.refresh_token_id, + token: res.refresh_token, + created_at: res.refresh_token_created_at, + access_token, + }; + + let client = Client { + data: (), + client_id: res.client_id, + }; + + let user = User { + data: res.user_id, + username: res.user_username, + sub: format!("fake-sub-{}", res.user_id), + }; + + let last_authentication = match ( + res.user_session_last_authentication_id, + res.user_session_last_authentication_created_at, + ) { + (None, None) => None, + (Some(id), Some(created_at)) => Some(Authentication { + data: id, + created_at, + }), + _ => return Err(DatabaseInconsistencyError.into()), + }; + + let browser_session = Some(BrowserSession { + data: res.user_session_id, + created_at: res.user_session_created_at, + user, + last_authentication, + }); + + let session = Session { + data: res.session_id, + client, + browser_session, + scope: res.scope.parse().context("invalid scope in database")?, + redirect_uri: res + .redirect_uri + .parse() + .context("invalid redirect_uri in database")?, + nonce: res.nonce, + }; + + Ok((refresh_token, session)) } pub async fn replace_refresh_token( - executor: impl Executor<'_, Database = Postgres>, - refresh_token_id: i64, - next_refresh_token_id: i64, + executor: impl PgExecutor<'_>, + refresh_token: &RefreshToken, + next_refresh_token: &RefreshToken, ) -> anyhow::Result<()> { let res = sqlx::query!( r#" @@ -97,8 +206,8 @@ pub async fn replace_refresh_token( SET next_token_id = $2 WHERE id = $1 "#, - refresh_token_id, - next_refresh_token_id + refresh_token.data, + next_refresh_token.data ) .execute(executor) .await diff --git a/crates/core/src/storage/oauth2/session.rs b/crates/core/src/storage/oauth2/session.rs index 7af4c700..9cac2438 100644 --- a/crates/core/src/storage/oauth2/session.rs +++ b/crates/core/src/storage/oauth2/session.rs @@ -17,19 +17,19 @@ use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString use anyhow::Context; use chrono::{DateTime, Duration, Utc}; use itertools::Itertools; -use mas_data_model::BrowserSession; +use mas_data_model::{AuthorizationCode, BrowserSession}; use oauth2_types::{ pkce, requests::{ResponseMode, ResponseType}, }; use serde::Serialize; -use sqlx::{Executor, FromRow, Postgres}; +use sqlx::PgExecutor; use url::Url; -use super::authorization_code::{add_code, OAuth2Code}; +use super::authorization_code::add_code; use crate::storage::{lookup_active_session, PostgresqlBackend}; -#[derive(FromRow, Serialize)] +#[derive(Serialize)] pub struct OAuth2Session { pub id: i64, user_session_id: Option, @@ -49,16 +49,16 @@ pub struct OAuth2Session { impl OAuth2Session { pub async fn add_code<'e>( &self, - executor: impl Executor<'e, Database = Postgres>, + executor: impl PgExecutor<'e>, code: &str, code_challenge: &Option, - ) -> anyhow::Result { + ) -> anyhow::Result> { add_code(executor, self.id, code, code_challenge).await } pub async fn fetch_session( &self, - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, ) -> anyhow::Result>> { match self.user_session_id { Some(id) => { @@ -70,16 +70,13 @@ impl OAuth2Session { } } - pub async fn fetch_code( - &self, - executor: impl Executor<'_, Database = Postgres>, - ) -> anyhow::Result { + pub async fn fetch_code(&self, executor: impl PgExecutor<'_>) -> anyhow::Result { get_code_for_session(executor, self.id).await } pub async fn match_or_set_session( &mut self, - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, session: BrowserSession, ) -> anyhow::Result> { match self.user_session_id { @@ -130,7 +127,7 @@ impl OAuth2Session { #[allow(clippy::too_many_arguments)] pub async fn start_session( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, optional_session_id: Option, client_id: &str, redirect_uri: &Url, @@ -178,7 +175,7 @@ pub async fn start_session( } pub async fn get_session_by_id( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, oauth2_session_id: i64, ) -> anyhow::Result { sqlx::query_as!( @@ -198,7 +195,7 @@ pub async fn get_session_by_id( } pub async fn get_code_for_session( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, oauth2_session_id: i64, ) -> anyhow::Result { sqlx::query_scalar!( diff --git a/crates/core/src/storage/user.rs b/crates/core/src/storage/user.rs index 3f8fe2fc..93f3cf99 100644 --- a/crates/core/src/storage/user.rs +++ b/crates/core/src/storage/user.rs @@ -20,15 +20,16 @@ use chrono::{DateTime, Utc}; use mas_data_model::{errors::HtmlError, Authentication, BrowserSession, User}; use password_hash::{PasswordHash, PasswordHasher, SaltString}; use rand::rngs::OsRng; -use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction}; +use sqlx::{Acquire, PgExecutor, Postgres, Transaction}; use thiserror::Error; use tokio::task; use tracing::{info_span, Instrument}; use warp::reject::Reject; use super::{DatabaseInconsistencyError, PostgresqlBackend}; +use crate::storage::IdAndCreationTime; -#[derive(Debug, Clone, FromRow)] +#[derive(Debug, Clone)] struct UserLookup { pub id: i64, pub username: String, @@ -159,7 +160,7 @@ impl TryInto> for SessionLookup { } pub async fn lookup_active_session( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, id: i64, ) -> Result, ActiveSessionLookupError> { let res = sqlx::query_as!( @@ -190,18 +191,12 @@ pub async fn lookup_active_session( Ok(res) } -#[derive(FromRow)] -struct SessionStartResult { - id: i64, - created_at: DateTime, -} - pub async fn start_session( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, user: User, ) -> anyhow::Result> { let res = sqlx::query_as!( - SessionStartResult, + IdAndCreationTime, r#" INSERT INTO user_sessions (user_id) VALUES ($1) @@ -238,12 +233,6 @@ pub enum AuthenticationError { Internal(#[from] tokio::task::JoinError), } -#[derive(FromRow)] -struct AuthenticationInsertionResult { - id: i64, - created_at: DateTime, -} - pub async fn authenticate_session( txn: &mut Transaction<'_, Postgres>, session: &BrowserSession, @@ -277,7 +266,7 @@ pub async fn authenticate_session( // That went well, let's insert the auth info let res = sqlx::query_as!( - AuthenticationInsertionResult, + IdAndCreationTime, r#" INSERT INTO user_session_authentications (session_id) VALUES ($1) @@ -296,7 +285,7 @@ pub async fn authenticate_session( } pub async fn register_user( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, phf: impl PasswordHasher, username: &str, password: &str, @@ -326,7 +315,7 @@ pub async fn register_user( } pub async fn end_session( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, session: &BrowserSession, ) -> anyhow::Result<()> { let res = sqlx::query!( @@ -346,7 +335,7 @@ pub async fn end_session( } pub async fn lookup_user_by_username( - executor: impl Executor<'_, Database = Postgres>, + executor: impl PgExecutor<'_>, username: &str, ) -> Result, sqlx::Error> { let res = sqlx::query_as!( diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 4bb778f1..fa55b3a5 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -9,5 +9,6 @@ license = "Apache-2.0" chrono = "0.4.19" thiserror = "1.0.30" serde = "1.0.130" +url = { version = "2.2.2", features = ["serde"] } oauth2-types = { path = "../oauth2-types" } diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 827294d9..a59e786c 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -15,6 +15,7 @@ use chrono::{DateTime, Duration, Utc}; use oauth2_types::{pkce::CodeChallengeMethod, scope::Scope}; use serde::Serialize; +use url::Url; pub mod errors; @@ -28,6 +29,7 @@ pub trait StorageBackend { type SessionData: Clone + std::fmt::Debug + PartialEq; type AuthorizationCodeData: Clone + std::fmt::Debug + PartialEq; type AccessTokenData: Clone + std::fmt::Debug + PartialEq; + type RefreshTokenData: Clone + std::fmt::Debug + PartialEq; } impl StorageBackend for () { @@ -36,6 +38,7 @@ impl StorageBackend for () { type AuthorizationCodeData = (); type BrowserSessionData = (); type ClientData = (); + type RefreshTokenData = (); type SessionData = (); type UserData = (); } @@ -153,6 +156,8 @@ pub struct Session { pub browser_session: Option>, pub client: Client, pub scope: Scope, + pub redirect_uri: Url, + pub nonce: Option, } impl From> for Session<()> { @@ -162,6 +167,8 @@ impl From> for Session<()> { browser_session: s.browser_session.map(Into::into), client: s.client.into(), scope: s.scope, + redirect_uri: s.redirect_uri, + nonce: s.nonce, } } } @@ -191,7 +198,7 @@ pub struct AuthorizationCode { #[serde(skip_serializing)] pub data: T::AuthorizationCodeData, pub code: String, - pub pkce: Pkce, + pub pkce: Option, } impl From> for AuthorizationCode<()> { @@ -224,3 +231,28 @@ impl From> for AccessToken<()> { } } } + +impl AccessToken { + pub fn exp(&self) -> DateTime { + self.created_at + self.expires_after + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct RefreshToken { + pub data: T::RefreshTokenData, + pub token: String, + pub created_at: DateTime, + pub access_token: Option>, +} + +impl From> for RefreshToken<()> { + fn from(t: RefreshToken) -> Self { + RefreshToken { + data: (), + token: t.token, + created_at: t.created_at, + access_token: t.access_token.map(Into::into), + } + } +}