1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Refactor DB interactions for OAuth code and tokens

This ensures complex types like scopes are properly parsed back from the
database.
This commit is contained in:
Quentin Gliech
2021-10-19 19:18:25 +02:00
parent 617ab83ab2
commit 4307276b0e
16 changed files with 947 additions and 529 deletions

5
Cargo.lock generated
View File

@ -1520,7 +1520,6 @@ dependencies = [
"crc", "crc",
"data-encoding", "data-encoding",
"elliptic-curve", "elliptic-curve",
"figment",
"futures-util", "futures-util",
"headers", "headers",
"hyper", "hyper",
@ -1539,15 +1538,12 @@ dependencies = [
"pkcs8", "pkcs8",
"rand 0.8.4", "rand 0.8.4",
"rsa", "rsa",
"schemars",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"serde_with", "serde_with",
"serde_yaml",
"sha2", "sha2",
"sqlx", "sqlx",
"tera",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
@ -1564,6 +1560,7 @@ dependencies = [
"oauth2-types", "oauth2-types",
"serde", "serde",
"thiserror", "thiserror",
"url",
] ]
[[package]] [[package]]

View File

@ -24,23 +24,15 @@ anyhow = "1.0.44"
warp = "0.3.1" warp = "0.3.1"
hyper = { version = "0.14.13", features = ["full"] } hyper = { version = "0.14.13", features = ["full"] }
# Template engine
tera = "1.13.0"
# Database access # Database access
sqlx = { version = "0.5.9", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] } sqlx = { version = "0.5.9", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
# Various structure (de)serialization # Various structure (de)serialization
serde = { version = "1.0.130", features = ["derive"] } serde = { version = "1.0.130", features = ["derive"] }
serde_yaml = "0.8.21"
serde_with = { version = "1.10.0", features = ["hex", "chrono"] } serde_with = { version = "1.10.0", features = ["hex", "chrono"] }
serde_json = "1.0.68" serde_json = "1.0.68"
serde_urlencoded = "0.7.0" serde_urlencoded = "0.7.0"
# Argument & config parsing
figment = { version = "0.10.6", features = ["env", "yaml", "test"] }
schemars = { version = "0.8.6", features = ["url", "chrono"] }
# Password hashing # Password hashing
argon2 = { version = "0.3.1", features = ["password-hash"] } argon2 = { version = "0.3.1", features = ["password-hash"] }
password-hash = { version = "0.3.2", features = ["std"] } password-hash = { version = "0.3.2", features = ["std"] }

View File

@ -26,8 +26,8 @@
] ]
} }
}, },
"138c3297a66107d8428ca10d04f9a4dd75faf9c1d3f84bcedd3b09f55dd84206": { "0c056fcc1a85d00db88034bcc582376cf220e1933d2932e520c44ed9931f5c9d": {
"query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id, oauth2_session_id, code, code_challenge_method, code_challenge\n ", "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_session_id, oauth2_access_token_id, token)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, created_at\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
@ -37,39 +37,20 @@
}, },
{ {
"ordinal": 1, "ordinal": 1,
"name": "oauth2_session_id", "name": "created_at",
"type_info": "Int8" "type_info": "Timestamptz"
},
{
"ordinal": 2,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "code_challenge_method",
"type_info": "Int2"
},
{
"ordinal": 4,
"name": "code_challenge",
"type_info": "Text"
} }
], ],
"parameters": { "parameters": {
"Left": [ "Left": [
"Int8", "Int8",
"Text", "Int8",
"Int2",
"Text" "Text"
] ]
}, },
"nullable": [ "nullable": [
false, false,
false, false
false,
true,
true
] ]
} }
}, },
@ -93,6 +74,104 @@
] ]
} }
}, },
"282548c5ad51bd95b7d9ad290714bab5860f1e1291021e7d786dc926d12b5dd9": {
"query": "\n SELECT\n oc.id,\n oc.code_challenge,\n oc.code_challenge_method,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce,\n us.id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n LEFT JOIN user_sessions us\n ON us.id = os.user_session_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n LEFT JOIN users u\n ON u.id = us.user_id\n WHERE oc.code = $1\n ORDER BY usa.created_at DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "code_challenge",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "code_challenge_method",
"type_info": "Int2"
},
{
"ordinal": 3,
"name": "oauth2_session_id!",
"type_info": "Int8"
},
{
"ordinal": 4,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "redirect_uri",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "scope!",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "nonce",
"type_info": "Text"
},
{
"ordinal": 8,
"name": "user_session_id?",
"type_info": "Int8"
},
{
"ordinal": 9,
"name": "user_session_created_at?",
"type_info": "Timestamptz"
},
{
"ordinal": 10,
"name": "user_id?",
"type_info": "Int8"
},
{
"ordinal": 11,
"name": "user_username?",
"type_info": "Text"
},
{
"ordinal": 12,
"name": "user_session_last_authentication_id?",
"type_info": "Int8"
},
{
"ordinal": 13,
"name": "user_session_last_authentication_created_at?",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true,
true,
false,
false,
false,
false,
true,
false,
false,
false,
false,
false,
false
]
}
},
"307fd9f71e7a94a0a0d9ce523ee9792e127485d0d12480c43f179dd9b75afbab": { "307fd9f71e7a94a0a0d9ce523ee9792e127485d0d12480c43f179dd9b75afbab": {
"query": "\n INSERT INTO user_sessions (user_id)\n VALUES ($1)\n RETURNING id, created_at\n ", "query": "\n INSERT INTO user_sessions (user_id)\n VALUES ($1)\n RETURNING id, created_at\n ",
"describe": { "describe": {
@ -119,58 +198,31 @@
] ]
} }
}, },
"49888f812910633b87ce65c277f8969377fe264be154d8aa6b33d861d26d2b3b": { "47a7a8d2ef7db8bb1d41230626ded4e4661d488891fbda9b872c0749a9ba58f4": {
"query": "\n SELECT\n u.username AS \"username!\",\n us.active AS \"active!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n at.created_at AS \"created_at!\",\n at.expires_after AS \"expires_after!\"\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n WHERE at.token = $1\n ", "query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
"ordinal": 0, "ordinal": 0,
"name": "username!", "name": "id",
"type_info": "Text" "type_info": "Int8"
},
{
"ordinal": 1,
"name": "active!",
"type_info": "Bool"
},
{
"ordinal": 2,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "scope!",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "expires_after!",
"type_info": "Int4"
} }
], ],
"parameters": { "parameters": {
"Left": [ "Left": [
"Int8",
"Text",
"Int2",
"Text" "Text"
] ]
}, },
"nullable": [ "nullable": [
false,
false,
false,
false,
false,
false false
] ]
} }
}, },
"562b0d4dcf857e99c20e9288e9c8bd46232290715c0d2459b0398a1c746cf65d": { "59e8a5de682642883a9b9fc1b522736fa4397f0a0c97074f2c8908e5956c0166": {
"query": "\n SELECT\n rt.id,\n rt.oauth2_session_id,\n rt.oauth2_access_token_id,\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n WHERE rt.token = $1 AND rt.next_token_id IS NULL\n ", "query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_session_id, token, expires_after)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, created_at\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
@ -180,23 +232,116 @@
}, },
{ {
"ordinal": 1, "ordinal": 1,
"name": "oauth2_session_id", "name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8",
"Text",
"Int4"
]
},
"nullable": [
false,
false
]
}
},
"5d032f4bdb28534da7cf8e9806442a12708d632b7be28f8b952bd3cb63a8b1af": {
"query": "\n SELECT\n rt.id AS refresh_token_id,\n rt.token AS refresh_token,\n rt.created_at AS refresh_token_created_at,\n at.id AS \"access_token_id?\",\n at.token AS \"access_token?\",\n at.expires_after AS \"access_token_expires_after?\",\n at.created_at AS \"access_token_created_at?\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n os.redirect_uri AS \"redirect_uri!\",\n os.nonce AS \"nonce\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_refresh_tokens rt\n LEFT JOIN oauth2_access_tokens at\n ON at.id = rt.oauth2_access_token_id\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE rt.token = $1\n AND rt.next_token_id IS NULL\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "refresh_token_id",
"type_info": "Int8" "type_info": "Int8"
}, },
{
"ordinal": 1,
"name": "refresh_token",
"type_info": "Text"
},
{ {
"ordinal": 2, "ordinal": 2,
"name": "oauth2_access_token_id", "name": "refresh_token_created_at",
"type_info": "Int8" "type_info": "Timestamptz"
}, },
{ {
"ordinal": 3, "ordinal": 3,
"name": "access_token_id?",
"type_info": "Int8"
},
{
"ordinal": 4,
"name": "access_token?",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "access_token_expires_after?",
"type_info": "Int4"
},
{
"ordinal": 6,
"name": "access_token_created_at?",
"type_info": "Timestamptz"
},
{
"ordinal": 7,
"name": "session_id!",
"type_info": "Int8"
},
{
"ordinal": 8,
"name": "client_id!", "name": "client_id!",
"type_info": "Text" "type_info": "Text"
}, },
{ {
"ordinal": 4, "ordinal": 9,
"name": "scope!", "name": "scope!",
"type_info": "Text" "type_info": "Text"
},
{
"ordinal": 10,
"name": "redirect_uri!",
"type_info": "Text"
},
{
"ordinal": 11,
"name": "nonce",
"type_info": "Text"
},
{
"ordinal": 12,
"name": "user_session_id!",
"type_info": "Int8"
},
{
"ordinal": 13,
"name": "user_session_created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 14,
"name": "user_id!",
"type_info": "Int8"
},
{
"ordinal": 15,
"name": "user_username!",
"type_info": "Text"
},
{
"ordinal": 16,
"name": "user_session_last_authentication_id?",
"type_info": "Int8"
},
{
"ordinal": 17,
"name": "user_session_last_authentication_created_at?",
"type_info": "Timestamptz"
} }
], ],
"parameters": { "parameters": {
@ -205,10 +350,23 @@
] ]
}, },
"nullable": [ "nullable": [
false,
false,
false,
false,
false,
false,
false,
false,
false,
false, false,
false, false,
true, true,
false, false,
false,
false,
false,
false,
false false
] ]
} }
@ -223,60 +381,106 @@
"nullable": [] "nullable": []
} }
}, },
"73f2d928f7bf88af79a3685bd6346652b4e4454b0ce75e38343840c9765e3f27": { "686a796a7de689b73a9377083718c95ac5ac51ce396dcf32e614402051d93e16": {
"query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_session_id, oauth2_access_token_id, token)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, oauth2_access_token_id, token, next_token_id, \n created_at, updated_at\n ", "query": "\n SELECT\n at.id AS \"access_token_id\",\n at.token AS \"access_token\",\n at.expires_after AS \"access_token_expires_after\",\n at.created_at AS \"access_token_created_at\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n os.redirect_uri AS \"redirect_uri!\",\n os.nonce AS \"nonce\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE at.token = $1\n AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now()\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
"ordinal": 0, "ordinal": 0,
"name": "id", "name": "access_token_id",
"type_info": "Int8" "type_info": "Int8"
}, },
{ {
"ordinal": 1, "ordinal": 1,
"name": "oauth2_session_id", "name": "access_token",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "oauth2_access_token_id",
"type_info": "Int8"
},
{
"ordinal": 3,
"name": "token",
"type_info": "Text" "type_info": "Text"
}, },
{
"ordinal": 2,
"name": "access_token_expires_after",
"type_info": "Int4"
},
{
"ordinal": 3,
"name": "access_token_created_at",
"type_info": "Timestamptz"
},
{ {
"ordinal": 4, "ordinal": 4,
"name": "next_token_id", "name": "session_id!",
"type_info": "Int8" "type_info": "Int8"
}, },
{ {
"ordinal": 5, "ordinal": 5,
"name": "created_at", "name": "client_id!",
"type_info": "Timestamptz" "type_info": "Text"
}, },
{ {
"ordinal": 6, "ordinal": 6,
"name": "updated_at", "name": "scope!",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "redirect_uri!",
"type_info": "Text"
},
{
"ordinal": 8,
"name": "nonce",
"type_info": "Text"
},
{
"ordinal": 9,
"name": "user_session_id!",
"type_info": "Int8"
},
{
"ordinal": 10,
"name": "user_session_created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"name": "user_id!",
"type_info": "Int8"
},
{
"ordinal": 12,
"name": "user_username!",
"type_info": "Text"
},
{
"ordinal": 13,
"name": "user_session_last_authentication_id?",
"type_info": "Int8"
},
{
"ordinal": 14,
"name": "user_session_last_authentication_created_at?",
"type_info": "Timestamptz" "type_info": "Timestamptz"
} }
], ],
"parameters": { "parameters": {
"Left": [ "Left": [
"Int8",
"Int8",
"Text" "Text"
] ]
}, },
"nullable": [ "nullable": [
false, false,
false, false,
true, false,
false,
false,
false,
false,
false, false,
true, true,
false, false,
false,
false,
false,
false,
false false
] ]
} }
@ -388,52 +592,6 @@
"nullable": [] "nullable": []
} }
}, },
"b766b2b41d8770b5bef9928bb3b96abbaf8466b473e12b21f145c015b7cf2f05": {
"query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_session_id, token, expires_after)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, token, expires_after, created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "token",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "expires_after",
"type_info": "Int4"
},
{
"ordinal": 4,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8",
"Text",
"Int4"
]
},
"nullable": [
false,
false,
false,
false,
false
]
}
},
"c2c402cfe0adcafa615f14a499caba4c96ca71d9ffb163e1feb05e5d85f3462c": { "c2c402cfe0adcafa615f14a499caba4c96ca71d9ffb163e1feb05e5d85f3462c": {
"query": "\n UPDATE oauth2_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n ", "query": "\n UPDATE oauth2_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n ",
"describe": { "describe": {
@ -579,68 +737,6 @@
"nullable": [] "nullable": []
} }
}, },
"eb5f772a7387de0dc2f9f660f470476c075da097134a8ded226eb630545c16eb": {
"query": "\n SELECT\n oc.id,\n oc.code_challenge,\n oc.code_challenge_method,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n WHERE oc.code = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "code_challenge",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "code_challenge_method",
"type_info": "Int2"
},
{
"ordinal": 3,
"name": "oauth2_session_id!",
"type_info": "Int8"
},
{
"ordinal": 4,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "redirect_uri",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "scope!",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "nonce",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true,
true,
false,
false,
false,
false,
true
]
}
},
"f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": { "f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": {
"query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ", "query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ",
"describe": { "describe": {

View File

@ -14,9 +14,9 @@
//! Authenticate an endpoint with an access token as bearer authorization token //! Authenticate an endpoint with an access token as bearer authorization token
use chrono::Utc;
use headers::{authorization::Bearer, Authorization}; use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode; use hyper::StatusCode;
use mas_data_model::{AccessToken, Session};
use sqlx::{pool::PoolConnection, PgPool, Postgres}; use sqlx::{pool::PoolConnection, PgPool, Postgres};
use thiserror::Error; use thiserror::Error;
use warp::{ use warp::{
@ -31,8 +31,9 @@ use super::{
}; };
use crate::{ use crate::{
errors::wrapped_error, errors::wrapped_error,
storage::oauth2::access_token::{ storage::{
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup, oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
PostgresqlBackend,
}, },
tokens::{TokenFormatError, TokenType}, tokens::{TokenFormatError, TokenType},
}; };
@ -82,19 +83,25 @@ impl Reject for AuthenticationError {}
#[must_use] #[must_use]
pub fn authentication( pub fn authentication(
pool: &PgPool, pool: &PgPool,
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static ) -> impl Filter<
{ Extract = (AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>),
Error = Rejection,
> + Clone
+ Send
+ Sync
+ 'static {
connection(pool) connection(pool)
.and(typed_header()) .and(typed_header())
.and_then(authenticate) .and_then(authenticate)
.recover(recover) .recover(recover)
.unify() .unify()
.untuple_one()
} }
async fn authenticate( async fn authenticate(
mut conn: PoolConnection<Postgres>, mut conn: PoolConnection<Postgres>,
auth: Authorization<Bearer>, auth: Authorization<Bearer>,
) -> Result<OAuth2AccessTokenLookup, Rejection> { ) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
let token = auth.0.token(); let token = auth.0.token();
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?; let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
@ -102,29 +109,25 @@ async fn authenticate(
return Err(AuthenticationError::WrongTokenType(token_type).into()); return Err(AuthenticationError::WrongTokenType(token_type).into());
} }
let token = lookup_access_token(&mut conn, token).await.map_err(|e| { let (token, session) = lookup_active_access_token(&mut conn, token)
if e.not_found() { .await
// This error happens if the token was not found and should be recovered .map_err(|e| {
warp::reject::custom(AuthenticationError::TokenNotFound(e)) if e.not_found() {
} else { // This error happens if the token was not found and should be recovered
// This is a generic database error that we want to propagate warp::reject::custom(AuthenticationError::TokenNotFound(e))
warp::reject::custom(wrapped_error(e)) } else {
} // This is a generic database error that we want to propagate
})?; warp::reject::custom(wrapped_error(e))
}
})?;
if !token.active { Ok((token, session))
return Err(AuthenticationError::TokenInactive.into());
}
if token.exp() < Utc::now() {
return Err(AuthenticationError::TokenExpired.into());
}
Ok(token)
} }
/// Transform the rejections from the [`with_typed_header`] filter /// Transform the rejections from the [`with_typed_header`] filter
async fn recover(rejection: Rejection) -> Result<OAuth2AccessTokenLookup, Rejection> { async fn recover(
rejection: Rejection,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), Rejection> {
if rejection.find::<MissingHeader>().is_some() { if rejection.find::<MissingHeader>().is_some() {
return Err(warp::reject::custom( return Err(warp::reject::custom(
AuthenticationError::MissingAuthorizationHeader, AuthenticationError::MissingAuthorizationHeader,

View File

@ -395,7 +395,7 @@ impl StepRequest {
async fn step( async fn step(
oauth2_session_id: i64, oauth2_session_id: i64,
user_session: BrowserSession<PostgresqlBackend>, browser_session: BrowserSession<PostgresqlBackend>,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
) -> Result<ReplyOrBackToClient, Rejection> { ) -> Result<ReplyOrBackToClient, Rejection> {
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id) let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
@ -403,7 +403,7 @@ async fn step(
.wrap_error()?; .wrap_error()?;
let user_session = oauth2_session let user_session = oauth2_session
.match_or_set_session(&mut txn, user_session) .match_or_set_session(&mut txn, browser_session)
.await .await
.wrap_error()?; .wrap_error()?;
@ -427,7 +427,7 @@ async fn step(
// Did they request an access token? // Did they request an access token?
if response_type.contains(&ResponseType::Token) { if response_type.contains(&ResponseType::Token) {
let ttl = Duration::minutes(5); let ttl = Duration::minutes(5);
let (access_token, refresh_token) = { let (access_token_str, refresh_token_str) = {
let mut rng = thread_rng(); let mut rng = thread_rng();
( (
AccessToken.generate(&mut rng), AccessToken.generate(&mut rng),
@ -435,19 +435,24 @@ async fn step(
) )
}; };
let access_token = add_access_token(&mut txn, oauth2_session_id, &access_token, ttl) let access_token =
.await add_access_token(&mut txn, oauth2_session_id, &access_token_str, ttl)
.wrap_error()?;
let refresh_token =
add_refresh_token(&mut txn, oauth2_session_id, access_token.id, &refresh_token)
.await .await
.wrap_error()?; .wrap_error()?;
let _refresh_token = add_refresh_token(
&mut txn,
oauth2_session_id,
access_token,
&refresh_token_str,
)
.await
.wrap_error()?;
params.response = Some( params.response = Some(
AccessTokenResponse::new(access_token.token) AccessTokenResponse::new(access_token_str)
.with_expires_in(ttl) .with_expires_in(ttl)
.with_refresh_token(refresh_token.token), .with_refresh_token(refresh_token_str),
); );
} }

View File

@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use chrono::Utc;
use hyper::Method; use hyper::Method;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint}; use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
use sqlx::{pool::PoolConnection, PgPool, Postgres}; use sqlx::{pool::PoolConnection, PgPool, Postgres};
@ -27,7 +26,9 @@ use crate::{
cors::cors, cors::cors,
database::connection, database::connection,
}, },
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token}, storage::oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
},
tokens::{self, TokenType}, tokens::{self, TokenType},
}; };
@ -84,43 +85,41 @@ async fn introspect(
let reply = match token_type { let reply = match token_type {
tokens::TokenType::AccessToken => { tokens::TokenType::AccessToken => {
let token = lookup_access_token(&mut conn, token).await.wrap_error()?; let (token, session) = lookup_active_access_token(&mut conn, token)
.await
.wrap_error()?;
let exp = token.exp(); let exp = token.exp();
// Check it is active and did not expire
if !token.active || exp < Utc::now() {
info!(?token, "Access token expired");
return Ok(warp::reply::json(&INACTIVE));
}
IntrospectionResponse { IntrospectionResponse {
active: true, active: true,
scope: None, // TODO: parse back scopes scope: Some(session.scope),
client_id: Some(token.client_id.clone()), client_id: Some(session.client.client_id),
username: Some(token.username.clone()), username: session.browser_session.clone().map(|s| s.user.username),
token_type: Some(TokenTypeHint::AccessToken), token_type: Some(TokenTypeHint::AccessToken),
exp: Some(exp), exp: Some(exp),
iat: Some(token.created_at), iat: Some(token.created_at),
nbf: Some(token.created_at), nbf: Some(token.created_at),
sub: None, sub: session.browser_session.map(|s| s.user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: None,
} }
} }
tokens::TokenType::RefreshToken => { tokens::TokenType::RefreshToken => {
let token = lookup_refresh_token(&mut conn, token).await.wrap_error()?; let (token, session) = lookup_active_refresh_token(&mut conn, token)
.await
.wrap_error()?;
IntrospectionResponse { IntrospectionResponse {
active: true, active: true,
scope: None, // TODO: parse back scopes scope: Some(session.scope),
client_id: Some(token.client_id), client_id: Some(session.client.client_id),
username: None, username: session.browser_session.clone().map(|s| s.user.username),
token_type: Some(TokenTypeHint::RefreshToken), token_type: Some(TokenTypeHint::RefreshToken),
exp: None, exp: None,
iat: None, iat: Some(token.created_at),
nbf: None, nbf: Some(token.created_at),
sub: None, sub: session.browser_session.map(|s| s.user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: None,

View File

@ -19,13 +19,11 @@ use headers::{CacheControl, Pragma};
use hyper::{Method, StatusCode}; use hyper::{Method, StatusCode};
use jwt_compact::{Claims, Header, TimeOptions}; use jwt_compact::{Claims, Header, TimeOptions};
use oauth2_types::{ use oauth2_types::{
errors::{ errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, ServerError, UnauthorizedClient,
},
pkce::CodeChallengeMethod,
requests::{ requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant, AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
}, },
scope::OPENID,
}; };
use rand::thread_rng; use rand::thread_rng;
use serde::Serialize; use serde::Serialize;
@ -52,7 +50,7 @@ use crate::{
storage::oauth2::{ storage::oauth2::{
access_token::{add_access_token, revoke_access_token}, access_token::{add_access_token, revoke_access_token},
authorization_code::{consume_code, lookup_code}, authorization_code::{consume_code, lookup_code},
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token}, refresh_token::{add_refresh_token, lookup_active_refresh_token, replace_refresh_token},
}, },
tokens::{AccessToken, RefreshToken}, tokens::{AccessToken, RefreshToken},
}; };
@ -163,46 +161,37 @@ async fn authorization_code_grant(
// TODO: we should invalidate the existing session if a code is used twice after // TODO: we should invalidate the existing session if a code is used twice after
// some period of time. See the `oidcc-codereuse-30seconds` test from the // some period of time. See the `oidcc-codereuse-30seconds` test from the
// conformance suite // conformance suite
let code = match lookup_code(&mut txn, &grant.code).await { let (code, session) = match lookup_code(&mut txn, &grant.code).await {
Err(e) if e.not_found() => return error(InvalidGrant), Err(e) if e.not_found() => return error(InvalidGrant),
x => x, x => x,
}?; }?;
if client.client_id != code.client_id { if client.client_id != session.client.client_id {
return error(UnauthorizedClient); return error(UnauthorizedClient);
} }
match ( match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
code.code_challenge_method.as_ref(), (None, None) => {}
code.code_challenge.as_ref(),
grant.code_verifier.as_ref(),
) {
(None, None, None) => {}
// We have a challenge but no verifier (or vice-versa)? Bad request. // We have a challenge but no verifier (or vice-versa)? Bad request.
(Some(_), Some(_), None) | (None, None, Some(_)) => return error(InvalidRequest), (Some(_), None) | (None, Some(_)) => return error(InvalidRequest),
(Some(0 /* Plain */), Some(code_challenge), Some(code_verifier)) => { // If we have both, we need to check the code validity
if !CodeChallengeMethod::Plain.verify(code_challenge, code_verifier) { (Some(pkce), Some(verifier)) => {
if !pkce.verify(verifier) {
return error(InvalidRequest); return error(InvalidRequest);
} }
} }
(Some(1 /* S256 */), Some(code_challenge), Some(code_verifier)) => {
if !CodeChallengeMethod::S256.verify(code_challenge, code_verifier) {
return error(InvalidRequest);
}
}
// We have something else?
// That's a DB inconcistancy, we should bail out
_ => {
// TODO: are we sure we want to handle errors like that?
tracing::error!("Invalid state from the database");
return error(ServerError); // Somthing bad happened in the database
}
}; };
// TODO: verify PKCE // TODO: this should probably not happen?
let browser_session = session
.browser_session
.ok_or_else(|| {
anyhow::anyhow!("this oauth2 session has no database session attached to it")
})
.wrap_error()?;
let ttl = Duration::minutes(5); let ttl = Duration::minutes(5);
let (access_token, refresh_token) = { let (access_token_str, refresh_token_str) = {
let mut rng = thread_rng(); let mut rng = thread_rng();
( (
AccessToken.generate(&mut rng), AccessToken.generate(&mut rng),
@ -210,45 +199,48 @@ async fn authorization_code_grant(
) )
}; };
let access_token = add_access_token(&mut txn, code.oauth2_session_id, &access_token, ttl) let access_token = add_access_token(&mut txn, session.data, &access_token_str, ttl)
.await .await
.wrap_error()?; .wrap_error()?;
let refresh_token = add_refresh_token( let _refresh_token =
&mut txn, add_refresh_token(&mut txn, session.data, access_token, &refresh_token_str)
code.oauth2_session_id, .await
access_token.id, .wrap_error()?;
&refresh_token,
)
.await
.wrap_error()?;
// TODO: generate id_token only if the "openid" scope was asked let id_token = if session.scope.contains(&OPENID) {
let header = Header::default(); let header = Header::default();
let options = TimeOptions::default(); let options = TimeOptions::default();
let claims = Claims::new(CustomClaims { let claims = Claims::new(CustomClaims {
issuer, issuer,
// TODO: get that from the session subject: browser_session.user.sub,
subject: "random-subject".to_string(), audiences: vec![client.client_id.clone()],
audiences: vec![client.client_id.clone()], nonce: session.nonce,
nonce: code.nonce, at_hash: hash(Sha256::new(), &access_token_str).wrap_error()?,
at_hash: hash(Sha256::new(), &access_token.token).wrap_error()?, c_hash: hash(Sha256::new(), &grant.code).wrap_error()?,
c_hash: hash(Sha256::new(), &grant.code).wrap_error()?, })
}) .set_duration_and_issuance(&options, Duration::minutes(30));
.set_duration_and_issuance(&options, Duration::minutes(30)); let id_token = keys
let id_token = keys .token(crate::config::Algorithm::Rs256, header, claims)
.token(crate::config::Algorithm::Rs256, header, claims) .await
.await .context("could not sign ID token")
.context("could not sign ID token") .wrap_error()?;
.wrap_error()?;
// TODO: have the scopes back here Some(id_token)
let params = AccessTokenResponse::new(access_token.token) } else {
None
};
let mut params = AccessTokenResponse::new(access_token_str)
.with_expires_in(ttl) .with_expires_in(ttl)
.with_refresh_token(refresh_token.token) .with_refresh_token(refresh_token_str)
.with_id_token(id_token); .with_scope(session.scope);
consume_code(&mut txn, code.id).await.wrap_error()?; if let Some(id_token) = id_token {
params = params.with_id_token(id_token);
}
consume_code(&mut txn, code).await.wrap_error()?;
txn.commit().await.wrap_error()?; txn.commit().await.wrap_error()?;
@ -261,18 +253,17 @@ async fn refresh_token_grant(
conn: &mut PoolConnection<Postgres>, conn: &mut PoolConnection<Postgres>,
) -> Result<AccessTokenResponse, Rejection> { ) -> Result<AccessTokenResponse, Rejection> {
let mut txn = conn.begin().await.wrap_error()?; let mut txn = conn.begin().await.wrap_error()?;
// TODO: scope handling let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token)
let refresh_token_lookup = lookup_refresh_token(&mut txn, &grant.refresh_token)
.await .await
.wrap_error()?; .wrap_error()?;
if client.client_id != refresh_token_lookup.client_id { if client.client_id != session.client.client_id {
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
return error(InvalidGrant); return error(InvalidGrant);
} }
let ttl = Duration::minutes(5); let ttl = Duration::minutes(5);
let (access_token, refresh_token) = { let (access_token_str, refresh_token_str) = {
let mut rng = thread_rng(); let mut rng = thread_rng();
( (
AccessToken.generate(&mut rng), AccessToken.generate(&mut rng),
@ -280,37 +271,29 @@ async fn refresh_token_grant(
) )
}; };
let access_token = add_access_token( let new_access_token = add_access_token(&mut txn, session.data, &access_token_str, ttl)
&mut txn,
refresh_token_lookup.oauth2_session_id,
&access_token,
ttl,
)
.await
.wrap_error()?;
let refresh_token = add_refresh_token(
&mut txn,
refresh_token_lookup.oauth2_session_id,
access_token.id,
&refresh_token,
)
.await
.wrap_error()?;
replace_refresh_token(&mut txn, refresh_token_lookup.id, refresh_token.id)
.await .await
.wrap_error()?; .wrap_error()?;
if let Some(access_token_id) = refresh_token_lookup.oauth2_access_token_id { let new_refresh_token =
revoke_access_token(&mut txn, access_token_id) add_refresh_token(&mut txn, session.data, new_access_token, &refresh_token_str)
.await
.wrap_error()?;
replace_refresh_token(&mut txn, &refresh_token, &new_refresh_token)
.await
.wrap_error()?;
if let Some(access_token) = refresh_token.access_token {
revoke_access_token(&mut txn, access_token.data)
.await .await
.wrap_error()?; .wrap_error()?;
} }
let params = AccessTokenResponse::new(access_token.token) let params = AccessTokenResponse::new(access_token_str)
.with_expires_in(ttl) .with_expires_in(ttl)
.with_refresh_token(refresh_token.token); .with_refresh_token(refresh_token_str)
.with_scope(session.scope);
txn.commit().await.wrap_error()?; txn.commit().await.wrap_error()?;

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use hyper::Method; use hyper::Method;
use mas_data_model::{AccessToken, Session};
use serde::Serialize; use serde::Serialize;
use sqlx::PgPool; use sqlx::PgPool;
use warp::{Filter, Rejection, Reply}; use warp::{Filter, Rejection, Reply};
@ -23,12 +24,13 @@ use crate::{
authenticate::{authentication, recover_unauthorized}, authenticate::{authentication, recover_unauthorized},
cors::cors, cors::cors,
}, },
storage::oauth2::access_token::OAuth2AccessTokenLookup, storage::PostgresqlBackend,
}; };
#[derive(Serialize)] #[derive(Serialize)]
struct UserInfo { struct UserInfo {
sub: String, sub: String,
username: String,
} }
pub(super) fn filter( pub(super) fn filter(
@ -46,8 +48,14 @@ pub(super) fn filter(
) )
} }
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> { async fn userinfo(
_token: AccessToken<PostgresqlBackend>,
session: Session<PostgresqlBackend>,
) -> Result<impl Reply, Rejection> {
// TODO: we really should not have an Option here
let user = session.browser_session.unwrap().user;
Ok(warp::reply::json(&UserInfo { Ok(warp::reply::json(&UserInfo {
sub: token.username, sub: user.sub,
username: user.username,
})) }))
} }

View File

@ -16,13 +16,14 @@
#![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros #![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros
use chrono::{DateTime, Utc};
use mas_data_model::{StorageBackend, StorageBackendMarker}; use mas_data_model::{StorageBackend, StorageBackendMarker};
use serde::Serialize; use serde::Serialize;
use sqlx::migrate::Migrator; use sqlx::migrate::Migrator;
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("databse query returned an inconsistent state")] #[error("database query returned an inconsistent state")]
pub struct DatabaseInconsistencyError; pub struct DatabaseInconsistencyError;
#[derive(Serialize, Debug, Clone, PartialEq)] #[derive(Serialize, Debug, Clone, PartialEq)]
@ -34,12 +35,18 @@ impl StorageBackend for PostgresqlBackend {
type AuthorizationCodeData = i64; type AuthorizationCodeData = i64;
type BrowserSessionData = i64; type BrowserSessionData = i64;
type ClientData = (); type ClientData = ();
type RefreshTokenData = i64;
type SessionData = i64; type SessionData = i64;
type UserData = i64; type UserData = i64;
} }
impl StorageBackendMarker for PostgresqlBackend {} impl StorageBackendMarker for PostgresqlBackend {}
struct IdAndCreationTime {
id: i64,
created_at: DateTime<Utc>,
}
pub mod oauth2; pub mod oauth2;
pub mod user; pub mod user;

View File

@ -16,89 +16,108 @@ use std::convert::TryFrom;
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use serde::Serialize; use mas_data_model::{AccessToken, Authentication, BrowserSession, Client, Session, User};
use sqlx::{Executor, FromRow, Postgres}; use sqlx::PgExecutor;
use thiserror::Error; use thiserror::Error;
#[derive(FromRow, Serialize)] use crate::storage::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub struct OAuth2AccessToken {
pub id: i64,
pub oauth2_session_id: i64,
pub token: String,
expires_after: i32,
created_at: DateTime<Utc>,
}
pub async fn add_access_token( pub async fn add_access_token(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
oauth2_session_id: i64, oauth2_session_id: i64,
token: &str, token: &str,
expires_after: Duration, expires_after: Duration,
) -> anyhow::Result<OAuth2AccessToken> { ) -> anyhow::Result<AccessToken<PostgresqlBackend>> {
// Checked convertion of duration to i32, maxing at i32::MAX // Checked convertion of duration to i32, maxing at i32::MAX
let expires_after = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX); let expires_after_seconds = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX);
sqlx::query_as!( let res = sqlx::query_as!(
OAuth2AccessToken, IdAndCreationTime,
r#" r#"
INSERT INTO oauth2_access_tokens INSERT INTO oauth2_access_tokens
(oauth2_session_id, token, expires_after) (oauth2_session_id, token, expires_after)
VALUES VALUES
($1, $2, $3) ($1, $2, $3)
RETURNING RETURNING
id, oauth2_session_id, token, expires_after, created_at id, created_at
"#, "#,
oauth2_session_id, oauth2_session_id,
token, token,
expires_after, expires_after_seconds,
) )
.fetch_one(executor) .fetch_one(executor)
.await .await
.context("could not insert oauth2 access token") .context("could not insert oauth2 access token")?;
Ok(AccessToken {
data: res.id,
expires_after,
token: token.to_string(),
jti: format!("{}", res.id),
created_at: res.created_at,
})
} }
#[derive(Debug)] #[derive(Debug)]
pub struct OAuth2AccessTokenLookup { pub struct OAuth2AccessTokenLookup {
pub active: bool, access_token_id: i64,
pub username: String, access_token: String,
pub client_id: String, access_token_expires_after: i32,
pub scope: String, access_token_created_at: DateTime<Utc>,
pub created_at: DateTime<Utc>, session_id: i64,
expires_after: i32, client_id: String,
} scope: String,
redirect_uri: String,
impl OAuth2AccessTokenLookup { nonce: Option<String>,
#[must_use] user_session_id: i64,
pub fn exp(&self) -> DateTime<Utc> { user_session_created_at: DateTime<Utc>,
self.created_at + Duration::seconds(i64::from(self.expires_after)) user_id: i64,
} user_username: String,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("failed to lookup access token")] #[error("failed to lookup access token")]
pub struct AccessTokenLookupError(#[from] sqlx::Error); pub enum AccessTokenLookupError {
Database(#[from] sqlx::Error),
Inconsistency(#[from] DatabaseInconsistencyError),
}
impl AccessTokenLookupError { impl AccessTokenLookupError {
#[must_use] #[must_use]
pub fn not_found(&self) -> bool { pub fn not_found(&self) -> bool {
matches!(self.0, sqlx::Error::RowNotFound) matches!(
self,
&AccessTokenLookupError::Database(sqlx::Error::RowNotFound)
)
} }
} }
pub async fn lookup_access_token( pub async fn lookup_active_access_token(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
token: &str, token: &str,
) -> Result<OAuth2AccessTokenLookup, AccessTokenLookupError> { ) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), AccessTokenLookupError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2AccessTokenLookup, OAuth2AccessTokenLookup,
r#" r#"
SELECT SELECT
u.username AS "username!", at.id AS "access_token_id",
us.active AS "active!", at.token AS "access_token",
os.client_id AS "client_id!", at.expires_after AS "access_token_expires_after",
os.scope AS "scope!", at.created_at AS "access_token_created_at",
at.created_at AS "created_at!", os.id AS "session_id!",
at.expires_after AS "expires_after!" os.client_id AS "client_id!",
os.scope AS "scope!",
os.redirect_uri AS "redirect_uri!",
os.nonce AS "nonce",
us.id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.id AS "user_id!",
u.username AS "user_username!",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM oauth2_access_tokens at FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os INNER JOIN oauth2_sessions os
ON os.id = at.oauth2_session_id ON os.id = at.oauth2_session_id
@ -106,20 +125,79 @@ pub async fn lookup_access_token(
ON us.id = os.user_session_id ON us.id = os.user_session_id
INNER JOIN users u INNER JOIN users u
ON u.id = us.user_id ON u.id = us.user_id
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
WHERE at.token = $1 WHERE at.token = $1
AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now()
AND us.active
ORDER BY usa.created_at DESC
LIMIT 1
"#, "#,
token, token,
) )
.fetch_one(executor) .fetch_one(executor)
.await?; .await?;
Ok(res) let access_token = AccessToken {
data: res.access_token_id,
jti: format!("{}", res.access_token_id),
token: res.access_token,
created_at: res.access_token_created_at,
expires_after: Duration::seconds(res.access_token_expires_after.into()),
};
let client = Client {
data: (),
client_id: res.client_id,
};
let user = User {
data: res.user_id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
};
let last_authentication = match (
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
created_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let browser_session = Some(BrowserSession {
data: res.user_session_id,
created_at: res.user_session_created_at,
user,
last_authentication,
});
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
let redirect_uri = res
.redirect_uri
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let session = Session {
data: res.session_id,
client,
browser_session,
scope,
redirect_uri,
nonce: res.nonce,
};
Ok((access_token, session))
} }
pub async fn revoke_access_token( pub async fn revoke_access_token(executor: impl PgExecutor<'_>, id: i64) -> anyhow::Result<()> {
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<()> {
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
DELETE FROM oauth2_access_tokens DELETE FROM oauth2_access_tokens
@ -138,9 +216,7 @@ pub async fn revoke_access_token(
} }
} }
pub async fn cleanup_expired( pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result<u64> {
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<u64> {
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
DELETE FROM oauth2_access_tokens DELETE FROM oauth2_access_tokens

View File

@ -13,40 +13,33 @@
// limitations under the License. // limitations under the License.
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Utc};
use mas_data_model::{
Authentication, AuthorizationCode, BrowserSession, Client, Pkce, Session, User,
};
use oauth2_types::pkce; use oauth2_types::pkce;
use serde::Serialize; use sqlx::PgExecutor;
use sqlx::{Executor, FromRow, Postgres};
use thiserror::Error; use thiserror::Error;
use warp::reject::Reject; use warp::reject::Reject;
#[derive(FromRow, Serialize)] use crate::storage::{DatabaseInconsistencyError, PostgresqlBackend};
pub struct OAuth2Code {
id: i64,
oauth2_session_id: i64,
pub code: String,
code_challenge: Option<String>,
code_challenge_method: Option<i16>,
}
pub async fn add_code( pub async fn add_code(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
oauth2_session_id: i64, oauth2_session_id: i64,
code: &str, code: &str,
code_challenge: &Option<pkce::AuthorizationRequest>, pkce: &Option<pkce::AuthorizationRequest>,
) -> anyhow::Result<OAuth2Code> { ) -> anyhow::Result<AuthorizationCode<PostgresqlBackend>> {
let code_challenge_method = code_challenge let code_challenge_method = pkce.as_ref().map(|c| c.code_challenge_method as i16);
.as_ref() let code_challenge = pkce.as_ref().map(|c| &c.code_challenge);
.map(|c| c.code_challenge_method as i16); let id = sqlx::query_scalar!(
let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge);
sqlx::query_as!(
OAuth2Code,
r#" r#"
INSERT INTO oauth2_codes INSERT INTO oauth2_codes
(oauth2_session_id, code, code_challenge_method, code_challenge) (oauth2_session_id, code, code_challenge_method, code_challenge)
VALUES VALUES
($1, $2, $3, $4) ($1, $2, $3, $4)
RETURNING RETURNING
id, oauth2_session_id, code, code_challenge_method, code_challenge id
"#, "#,
oauth2_session_id, oauth2_session_id,
code, code,
@ -55,38 +48,108 @@ pub async fn add_code(
) )
.fetch_one(executor) .fetch_one(executor)
.await .await
.context("could not insert oauth2 authorization code") .context("could not insert oauth2 authorization code")?;
let pkce = pkce
.as_ref()
.map(|c| Pkce::new(c.code_challenge_method, c.code_challenge.clone()));
Ok(AuthorizationCode {
data: id,
code: code.to_string(),
pkce,
})
} }
pub struct OAuth2CodeLookup { struct OAuth2CodeLookup {
pub id: i64, id: i64,
pub oauth2_session_id: i64, oauth2_session_id: i64,
pub client_id: String, client_id: String,
pub redirect_uri: String, redirect_uri: String,
pub scope: String, scope: String,
pub nonce: Option<String>, nonce: Option<String>,
pub code_challenge: Option<String>, code_challenge: Option<String>,
pub code_challenge_method: Option<i16>, code_challenge_method: Option<i16>,
user_session_id: Option<i64>,
user_session_created_at: Option<DateTime<Utc>>,
user_id: Option<i64>,
user_username: Option<String>,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
}
fn browser_session_from_database(
user_session_id: Option<i64>,
user_session_created_at: Option<DateTime<Utc>>,
user_id: Option<i64>,
user_username: Option<String>,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
) -> Result<Option<BrowserSession<PostgresqlBackend>>, DatabaseInconsistencyError> {
match (
user_session_id,
user_session_created_at,
user_id,
user_username,
) {
(None, None, None, None) => Ok(None),
(Some(session_id), Some(session_created_at), Some(user_id), Some(user_username)) => {
let user = User {
data: user_id,
username: user_username,
sub: format!("fake-sub-{}", user_id),
};
let last_authentication = match (
user_session_last_authentication_id,
user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
created_at,
}),
_ => return Err(DatabaseInconsistencyError),
};
Ok(Some(BrowserSession {
data: session_id,
created_at: session_created_at,
user,
last_authentication,
}))
}
_ => Err(DatabaseInconsistencyError),
}
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("failed to lookup oauth2 code")] #[error("failed to lookup oauth2 code")]
pub struct CodeLookupError(#[from] sqlx::Error); pub enum CodeLookupError {
Database(#[from] sqlx::Error),
Inconsistency(#[from] DatabaseInconsistencyError),
}
impl Reject for CodeLookupError {} impl Reject for CodeLookupError {}
impl CodeLookupError { impl CodeLookupError {
#[must_use] #[must_use]
pub fn not_found(&self) -> bool { pub fn not_found(&self) -> bool {
matches!(self.0, sqlx::Error::RowNotFound) matches!(self, &CodeLookupError::Database(sqlx::Error::RowNotFound))
} }
} }
#[allow(clippy::too_many_lines)]
pub async fn lookup_code( pub async fn lookup_code(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
code: &str, code: &str,
) -> Result<OAuth2CodeLookup, CodeLookupError> { ) -> Result<
// TODO: this should return a better type (
AuthorizationCode<PostgresqlBackend>,
Session<PostgresqlBackend>,
),
CodeLookupError,
> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2CodeLookup, OAuth2CodeLookup,
r#" r#"
@ -94,27 +157,88 @@ pub async fn lookup_code(
oc.id, oc.id,
oc.code_challenge, oc.code_challenge,
oc.code_challenge_method, oc.code_challenge_method,
os.id AS "oauth2_session_id!", os.id AS "oauth2_session_id!",
os.client_id AS "client_id!", os.client_id AS "client_id!",
os.redirect_uri, os.redirect_uri,
os.scope AS "scope!", os.scope AS "scope!",
os.nonce os.nonce,
us.id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.id AS "user_id?",
u.username AS "user_username?",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM oauth2_codes oc FROM oauth2_codes oc
INNER JOIN oauth2_sessions os INNER JOIN oauth2_sessions os
ON os.id = oc.oauth2_session_id ON os.id = oc.oauth2_session_id
LEFT JOIN user_sessions us
ON us.id = os.user_session_id
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
LEFT JOIN users u
ON u.id = us.user_id
WHERE oc.code = $1 WHERE oc.code = $1
ORDER BY usa.created_at DESC
LIMIT 1
"#, "#,
code, code,
) )
.fetch_one(executor) .fetch_one(executor)
.await?; .await?;
Ok(res) let pkce = match (res.code_challenge_method, res.code_challenge) {
(None, None) => None,
(Some(0 /* Plain */), Some(challenge)) => {
Some(Pkce::new(pkce::CodeChallengeMethod::Plain, challenge))
}
(Some(1 /* S256 */), Some(challenge)) => {
Some(Pkce::new(pkce::CodeChallengeMethod::S256, challenge))
}
_ => return Err(DatabaseInconsistencyError.into()),
};
let code = AuthorizationCode {
data: res.id,
code: code.to_string(),
pkce,
};
let client = Client {
data: (),
client_id: res.client_id,
};
let browser_session = browser_session_from_database(
res.user_session_id,
res.user_session_created_at,
res.user_id,
res.user_username,
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
)?;
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
let redirect_uri = res
.redirect_uri
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let session = Session {
data: res.oauth2_session_id,
client,
browser_session,
scope,
redirect_uri,
nonce: res.nonce,
};
Ok((code, session))
} }
pub async fn consume_code( pub async fn consume_code(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
code_id: i64, code: AuthorizationCode<PostgresqlBackend>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// TODO: mark the code as invalid instead to allow invalidating the whole // TODO: mark the code as invalid instead to allow invalidating the whole
// session on code reuse // session on code reuse
@ -123,7 +247,7 @@ pub async fn consume_code(
DELETE FROM oauth2_codes DELETE FROM oauth2_codes
WHERE id = $1 WHERE id = $1
"#, "#,
code_id, code.data,
) )
.execute(executor) .execute(executor)
.await .await

View File

@ -13,83 +13,192 @@
// limitations under the License. // limitations under the License.
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Duration, Utc};
use sqlx::{Executor, Postgres}; use mas_data_model::{
AccessToken, Authentication, BrowserSession, Client, RefreshToken, Session, User,
};
use sqlx::PgExecutor;
#[derive(Debug)] use crate::storage::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub struct OAuth2RefreshToken {
pub id: i64,
oauth2_session_id: i64,
oauth2_access_token_id: Option<i64>,
pub token: String,
next_token_id: Option<i64>,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
pub async fn add_refresh_token( pub async fn add_refresh_token(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
oauth2_session_id: i64, oauth2_session_id: i64,
oauth2_access_token_id: i64, access_token: AccessToken<PostgresqlBackend>,
token: &str, token: &str,
) -> anyhow::Result<OAuth2RefreshToken> { ) -> anyhow::Result<RefreshToken<PostgresqlBackend>> {
sqlx::query_as!( let res = sqlx::query_as!(
OAuth2RefreshToken, IdAndCreationTime,
r#" r#"
INSERT INTO oauth2_refresh_tokens INSERT INTO oauth2_refresh_tokens
(oauth2_session_id, oauth2_access_token_id, token) (oauth2_session_id, oauth2_access_token_id, token)
VALUES VALUES
($1, $2, $3) ($1, $2, $3)
RETURNING RETURNING
id, oauth2_session_id, oauth2_access_token_id, token, next_token_id, id, created_at
created_at, updated_at
"#, "#,
oauth2_session_id, oauth2_session_id,
oauth2_access_token_id, access_token.data,
token, token,
) )
.fetch_one(executor) .fetch_one(executor)
.await .await
.context("could not insert oauth2 refresh token") .context("could not insert oauth2 refresh token")?;
Ok(RefreshToken {
data: res.id,
token: token.to_string(),
access_token: Some(access_token),
created_at: res.created_at,
})
} }
pub struct OAuth2RefreshTokenLookup { struct OAuth2RefreshTokenLookup {
pub id: i64, refresh_token_id: i64,
pub oauth2_session_id: i64, refresh_token: String,
pub oauth2_access_token_id: Option<i64>, refresh_token_created_at: DateTime<Utc>,
pub client_id: String, access_token_id: Option<i64>,
pub scope: String, access_token: Option<String>,
access_token_expires_after: Option<i32>,
access_token_created_at: Option<DateTime<Utc>>,
session_id: i64,
client_id: String,
scope: String,
redirect_uri: String,
nonce: Option<String>,
user_session_id: i64,
user_session_created_at: DateTime<Utc>,
user_id: i64,
user_username: String,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
} }
pub async fn lookup_refresh_token( #[allow(clippy::too_many_lines)]
executor: impl Executor<'_, Database = Postgres>, pub async fn lookup_active_refresh_token(
executor: impl PgExecutor<'_>,
token: &str, token: &str,
) -> anyhow::Result<OAuth2RefreshTokenLookup> { ) -> anyhow::Result<(RefreshToken<PostgresqlBackend>, Session<PostgresqlBackend>)> {
sqlx::query_as!( let res = sqlx::query_as!(
OAuth2RefreshTokenLookup, OAuth2RefreshTokenLookup,
r#" r#"
SELECT SELECT
rt.id, rt.id AS refresh_token_id,
rt.oauth2_session_id, rt.token AS refresh_token,
rt.oauth2_access_token_id, rt.created_at AS refresh_token_created_at,
os.client_id AS "client_id!", at.id AS "access_token_id?",
os.scope AS "scope!" at.token AS "access_token?",
at.expires_after AS "access_token_expires_after?",
at.created_at AS "access_token_created_at?",
os.id AS "session_id!",
os.client_id AS "client_id!",
os.scope AS "scope!",
os.redirect_uri AS "redirect_uri!",
os.nonce AS "nonce",
us.id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.id AS "user_id!",
u.username AS "user_username!",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM oauth2_refresh_tokens rt FROM oauth2_refresh_tokens rt
LEFT JOIN oauth2_access_tokens at
ON at.id = rt.oauth2_access_token_id
INNER JOIN oauth2_sessions os INNER JOIN oauth2_sessions os
ON os.id = rt.oauth2_session_id ON os.id = rt.oauth2_session_id
WHERE rt.token = $1 AND rt.next_token_id IS NULL INNER JOIN user_sessions us
ON us.id = os.user_session_id
INNER JOIN users u
ON u.id = us.user_id
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
WHERE rt.token = $1
AND rt.next_token_id IS NULL
AND us.active
ORDER BY usa.created_at DESC
LIMIT 1
"#, "#,
token, token,
) )
.fetch_one(executor) .fetch_one(executor)
.await .await
.context("failed to fetch oauth2 refresh token") .context("failed to fetch oauth2 refresh token")?;
let access_token = match (
res.access_token_id,
res.access_token,
res.access_token_created_at,
res.access_token_expires_after,
) {
(None, None, None, None) => None,
(Some(id), Some(token), Some(created_at), Some(expires_after)) => Some(AccessToken {
data: id,
jti: format!("{}", id),
token,
created_at,
expires_after: Duration::seconds(expires_after.into()),
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let refresh_token = RefreshToken {
data: res.refresh_token_id,
token: res.refresh_token,
created_at: res.refresh_token_created_at,
access_token,
};
let client = Client {
data: (),
client_id: res.client_id,
};
let user = User {
data: res.user_id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
};
let last_authentication = match (
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
created_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let browser_session = Some(BrowserSession {
data: res.user_session_id,
created_at: res.user_session_created_at,
user,
last_authentication,
});
let session = Session {
data: res.session_id,
client,
browser_session,
scope: res.scope.parse().context("invalid scope in database")?,
redirect_uri: res
.redirect_uri
.parse()
.context("invalid redirect_uri in database")?,
nonce: res.nonce,
};
Ok((refresh_token, session))
} }
pub async fn replace_refresh_token( pub async fn replace_refresh_token(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
refresh_token_id: i64, refresh_token: &RefreshToken<PostgresqlBackend>,
next_refresh_token_id: i64, next_refresh_token: &RefreshToken<PostgresqlBackend>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
@ -97,8 +206,8 @@ pub async fn replace_refresh_token(
SET next_token_id = $2 SET next_token_id = $2
WHERE id = $1 WHERE id = $1
"#, "#,
refresh_token_id, refresh_token.data,
next_refresh_token_id next_refresh_token.data
) )
.execute(executor) .execute(executor)
.await .await

View File

@ -17,19 +17,19 @@ use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use itertools::Itertools; use itertools::Itertools;
use mas_data_model::BrowserSession; use mas_data_model::{AuthorizationCode, BrowserSession};
use oauth2_types::{ use oauth2_types::{
pkce, pkce,
requests::{ResponseMode, ResponseType}, requests::{ResponseMode, ResponseType},
}; };
use serde::Serialize; use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres}; use sqlx::PgExecutor;
use url::Url; use url::Url;
use super::authorization_code::{add_code, OAuth2Code}; use super::authorization_code::add_code;
use crate::storage::{lookup_active_session, PostgresqlBackend}; use crate::storage::{lookup_active_session, PostgresqlBackend};
#[derive(FromRow, Serialize)] #[derive(Serialize)]
pub struct OAuth2Session { pub struct OAuth2Session {
pub id: i64, pub id: i64,
user_session_id: Option<i64>, user_session_id: Option<i64>,
@ -49,16 +49,16 @@ pub struct OAuth2Session {
impl OAuth2Session { impl OAuth2Session {
pub async fn add_code<'e>( pub async fn add_code<'e>(
&self, &self,
executor: impl Executor<'e, Database = Postgres>, executor: impl PgExecutor<'e>,
code: &str, code: &str,
code_challenge: &Option<pkce::AuthorizationRequest>, code_challenge: &Option<pkce::AuthorizationRequest>,
) -> anyhow::Result<OAuth2Code> { ) -> anyhow::Result<AuthorizationCode<PostgresqlBackend>> {
add_code(executor, self.id, code, code_challenge).await add_code(executor, self.id, code, code_challenge).await
} }
pub async fn fetch_session( pub async fn fetch_session(
&self, &self,
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
) -> anyhow::Result<Option<BrowserSession<PostgresqlBackend>>> { ) -> anyhow::Result<Option<BrowserSession<PostgresqlBackend>>> {
match self.user_session_id { match self.user_session_id {
Some(id) => { Some(id) => {
@ -70,16 +70,13 @@ impl OAuth2Session {
} }
} }
pub async fn fetch_code( pub async fn fetch_code(&self, executor: impl PgExecutor<'_>) -> anyhow::Result<String> {
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<String> {
get_code_for_session(executor, self.id).await get_code_for_session(executor, self.id).await
} }
pub async fn match_or_set_session( pub async fn match_or_set_session(
&mut self, &mut self,
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
session: BrowserSession<PostgresqlBackend>, session: BrowserSession<PostgresqlBackend>,
) -> anyhow::Result<BrowserSession<PostgresqlBackend>> { ) -> anyhow::Result<BrowserSession<PostgresqlBackend>> {
match self.user_session_id { match self.user_session_id {
@ -130,7 +127,7 @@ impl OAuth2Session {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn start_session( pub async fn start_session(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
optional_session_id: Option<i64>, optional_session_id: Option<i64>,
client_id: &str, client_id: &str,
redirect_uri: &Url, redirect_uri: &Url,
@ -178,7 +175,7 @@ pub async fn start_session(
} }
pub async fn get_session_by_id( pub async fn get_session_by_id(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
oauth2_session_id: i64, oauth2_session_id: i64,
) -> anyhow::Result<OAuth2Session> { ) -> anyhow::Result<OAuth2Session> {
sqlx::query_as!( sqlx::query_as!(
@ -198,7 +195,7 @@ pub async fn get_session_by_id(
} }
pub async fn get_code_for_session( pub async fn get_code_for_session(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
oauth2_session_id: i64, oauth2_session_id: i64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
sqlx::query_scalar!( sqlx::query_scalar!(

View File

@ -20,15 +20,16 @@ use chrono::{DateTime, Utc};
use mas_data_model::{errors::HtmlError, Authentication, BrowserSession, User}; use mas_data_model::{errors::HtmlError, Authentication, BrowserSession, User};
use password_hash::{PasswordHash, PasswordHasher, SaltString}; use password_hash::{PasswordHash, PasswordHasher, SaltString};
use rand::rngs::OsRng; use rand::rngs::OsRng;
use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction}; use sqlx::{Acquire, PgExecutor, Postgres, Transaction};
use thiserror::Error; use thiserror::Error;
use tokio::task; use tokio::task;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
use warp::reject::Reject; use warp::reject::Reject;
use super::{DatabaseInconsistencyError, PostgresqlBackend}; use super::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::storage::IdAndCreationTime;
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone)]
struct UserLookup { struct UserLookup {
pub id: i64, pub id: i64,
pub username: String, pub username: String,
@ -159,7 +160,7 @@ impl TryInto<BrowserSession<PostgresqlBackend>> for SessionLookup {
} }
pub async fn lookup_active_session( pub async fn lookup_active_session(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
id: i64, id: i64,
) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> { ) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
@ -190,18 +191,12 @@ pub async fn lookup_active_session(
Ok(res) Ok(res)
} }
#[derive(FromRow)]
struct SessionStartResult {
id: i64,
created_at: DateTime<Utc>,
}
pub async fn start_session( pub async fn start_session(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
user: User<PostgresqlBackend>, user: User<PostgresqlBackend>,
) -> anyhow::Result<BrowserSession<PostgresqlBackend>> { ) -> anyhow::Result<BrowserSession<PostgresqlBackend>> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
SessionStartResult, IdAndCreationTime,
r#" r#"
INSERT INTO user_sessions (user_id) INSERT INTO user_sessions (user_id)
VALUES ($1) VALUES ($1)
@ -238,12 +233,6 @@ pub enum AuthenticationError {
Internal(#[from] tokio::task::JoinError), Internal(#[from] tokio::task::JoinError),
} }
#[derive(FromRow)]
struct AuthenticationInsertionResult {
id: i64,
created_at: DateTime<Utc>,
}
pub async fn authenticate_session( pub async fn authenticate_session(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
session: &BrowserSession<PostgresqlBackend>, session: &BrowserSession<PostgresqlBackend>,
@ -277,7 +266,7 @@ pub async fn authenticate_session(
// That went well, let's insert the auth info // That went well, let's insert the auth info
let res = sqlx::query_as!( let res = sqlx::query_as!(
AuthenticationInsertionResult, IdAndCreationTime,
r#" r#"
INSERT INTO user_session_authentications (session_id) INSERT INTO user_session_authentications (session_id)
VALUES ($1) VALUES ($1)
@ -296,7 +285,7 @@ pub async fn authenticate_session(
} }
pub async fn register_user( pub async fn register_user(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
phf: impl PasswordHasher, phf: impl PasswordHasher,
username: &str, username: &str,
password: &str, password: &str,
@ -326,7 +315,7 @@ pub async fn register_user(
} }
pub async fn end_session( pub async fn end_session(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
session: &BrowserSession<PostgresqlBackend>, session: &BrowserSession<PostgresqlBackend>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let res = sqlx::query!( let res = sqlx::query!(
@ -346,7 +335,7 @@ pub async fn end_session(
} }
pub async fn lookup_user_by_username( pub async fn lookup_user_by_username(
executor: impl Executor<'_, Database = Postgres>, executor: impl PgExecutor<'_>,
username: &str, username: &str,
) -> Result<User<PostgresqlBackend>, sqlx::Error> { ) -> Result<User<PostgresqlBackend>, sqlx::Error> {
let res = sqlx::query_as!( let res = sqlx::query_as!(

View File

@ -9,5 +9,6 @@ license = "Apache-2.0"
chrono = "0.4.19" chrono = "0.4.19"
thiserror = "1.0.30" thiserror = "1.0.30"
serde = "1.0.130" serde = "1.0.130"
url = { version = "2.2.2", features = ["serde"] }
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }

View File

@ -15,6 +15,7 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use oauth2_types::{pkce::CodeChallengeMethod, scope::Scope}; use oauth2_types::{pkce::CodeChallengeMethod, scope::Scope};
use serde::Serialize; use serde::Serialize;
use url::Url;
pub mod errors; pub mod errors;
@ -28,6 +29,7 @@ pub trait StorageBackend {
type SessionData: Clone + std::fmt::Debug + PartialEq; type SessionData: Clone + std::fmt::Debug + PartialEq;
type AuthorizationCodeData: Clone + std::fmt::Debug + PartialEq; type AuthorizationCodeData: Clone + std::fmt::Debug + PartialEq;
type AccessTokenData: Clone + std::fmt::Debug + PartialEq; type AccessTokenData: Clone + std::fmt::Debug + PartialEq;
type RefreshTokenData: Clone + std::fmt::Debug + PartialEq;
} }
impl StorageBackend for () { impl StorageBackend for () {
@ -36,6 +38,7 @@ impl StorageBackend for () {
type AuthorizationCodeData = (); type AuthorizationCodeData = ();
type BrowserSessionData = (); type BrowserSessionData = ();
type ClientData = (); type ClientData = ();
type RefreshTokenData = ();
type SessionData = (); type SessionData = ();
type UserData = (); type UserData = ();
} }
@ -153,6 +156,8 @@ pub struct Session<T: StorageBackend> {
pub browser_session: Option<BrowserSession<T>>, pub browser_session: Option<BrowserSession<T>>,
pub client: Client<T>, pub client: Client<T>,
pub scope: Scope, pub scope: Scope,
pub redirect_uri: Url,
pub nonce: Option<String>,
} }
impl<S: StorageBackendMarker> From<Session<S>> for Session<()> { impl<S: StorageBackendMarker> From<Session<S>> for Session<()> {
@ -162,6 +167,8 @@ impl<S: StorageBackendMarker> From<Session<S>> for Session<()> {
browser_session: s.browser_session.map(Into::into), browser_session: s.browser_session.map(Into::into),
client: s.client.into(), client: s.client.into(),
scope: s.scope, scope: s.scope,
redirect_uri: s.redirect_uri,
nonce: s.nonce,
} }
} }
} }
@ -191,7 +198,7 @@ pub struct AuthorizationCode<T: StorageBackend> {
#[serde(skip_serializing)] #[serde(skip_serializing)]
pub data: T::AuthorizationCodeData, pub data: T::AuthorizationCodeData,
pub code: String, pub code: String,
pub pkce: Pkce, pub pkce: Option<Pkce>,
} }
impl<S: StorageBackendMarker> From<AuthorizationCode<S>> for AuthorizationCode<()> { impl<S: StorageBackendMarker> From<AuthorizationCode<S>> for AuthorizationCode<()> {
@ -224,3 +231,28 @@ impl<S: StorageBackendMarker> From<AccessToken<S>> for AccessToken<()> {
} }
} }
} }
impl<T: StorageBackend> AccessToken<T> {
pub fn exp(&self) -> DateTime<Utc> {
self.created_at + self.expires_after
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RefreshToken<T: StorageBackend> {
pub data: T::RefreshTokenData,
pub token: String,
pub created_at: DateTime<Utc>,
pub access_token: Option<AccessToken<T>>,
}
impl<S: StorageBackendMarker> From<RefreshToken<S>> for RefreshToken<()> {
fn from(t: RefreshToken<S>) -> Self {
RefreshToken {
data: (),
token: t.token,
created_at: t.created_at,
access_token: t.access_token.map(Into::into),
}
}
}