diff --git a/crates/core/migrations/20211021201500_oauth2_sessions.down.sql b/crates/core/migrations/20211021201500_oauth2_sessions.down.sql new file mode 100644 index 00000000..e69de29b diff --git a/crates/core/migrations/20211021201500_oauth2_sessions.up.sql b/crates/core/migrations/20211021201500_oauth2_sessions.up.sql new file mode 100644 index 00000000..69ba4c62 --- /dev/null +++ b/crates/core/migrations/20211021201500_oauth2_sessions.up.sql @@ -0,0 +1,103 @@ +-- Copyright 2021 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + + +-- Replace the old "sessions" table +ALTER TABLE oauth2_sessions RENAME TO oauth2_sessions_old; + +-- TODO: how do we handle temporary session upgrades (aka. sudo mode)? +CREATE TABLE oauth2_sessions ( + "id" BIGSERIAL PRIMARY KEY, + "user_session_id" BIGINT NOT NULL REFERENCES user_sessions (id) ON DELETE CASCADE, + "client_id" TEXT NOT NULL, -- The "authorization party" would be more accurate in that case + "scope" TEXT NOT NULL, + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now() +); + +TRUNCATE oauth2_access_tokens, oauth2_refresh_tokens; +ALTER TABLE oauth2_access_tokens + DROP CONSTRAINT oauth2_access_tokens_oauth2_session_id_fkey, + ADD CONSTRAINT oauth2_access_tokens_oauth2_session_id_fkey + FOREIGN KEY (oauth2_session_id) REFERENCES oauth2_sessions (id); +ALTER TABLE oauth2_refresh_tokens + DROP CONSTRAINT oauth2_refresh_tokens_oauth2_session_id_fkey, + ADD CONSTRAINT oauth2_refresh_tokens_oauth2_session_id_fkey + FOREIGN KEY (oauth2_session_id) REFERENCES oauth2_sessions (id); +DROP TABLE oauth2_codes, oauth2_sessions_old; + +CREATE TABLE oauth2_authorization_grants ( + "id" BIGSERIAL PRIMARY KEY, -- Saved as encrypted cookie + + -- All this comes from the authorization request + "client_id" TEXT NOT NULL, -- This should be verified before insertion + "redirect_uri" TEXT NOT NULL, -- This should be verified before insertion + "scope" TEXT NOT NULL, -- This should be verified before insertion + "state" TEXT, + "nonce" TEXT, + "max_age" INT CHECK ("max_age" IS NULL OR "max_age" > 0), + "acr_values" TEXT, -- This should be verified before insertion + "response_mode" TEXT NOT NULL, + "code_challenge_method" TEXT, + "code_challenge" TEXT, + + -- The "response_type" parameter broken down + "response_type_code" BOOLEAN NOT NULL, + "response_type_token" BOOLEAN NOT NULL, + "response_type_id_token" BOOLEAN NOT NULL, + + -- This one is created eagerly on grant creation if the response_type + -- includes "code" + -- When looking up codes, it should do "where fulfilled_at is not null" and + -- "inner join on oauth2_sessions". When doing that, it should check the + -- "exchanged_at" field: if it is not null and was exchanged more than 30s + -- ago, the session shold be considered as hijacked and fully invalidated + "code" TEXT UNIQUE, + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), + "fulfilled_at" TIMESTAMP WITH TIME ZONE, -- When we got back to the client + "cancelled_at" TIMESTAMP WITH TIME ZONE, -- When that grant was cancelled + "exchanged_at" TIMESTAMP WITH TIME ZONE, -- When the code was exchanged by the client + + "oauth2_session_id" BIGINT REFERENCES oauth2_sessions (id) ON DELETE CASCADE, + + -- Check a few invariants to keep a coherent state. + -- Even though the service should never violate those, it helps ensuring we're not doing anything wrong + + -- Code exchange can only happen after the grant was fulfilled + CONSTRAINT "oauth2_authorization_grants_exchanged_after_fullfill" + CHECK (("exchanged_at" IS NULL) + OR ("exchanged_at" IS NOT NULL AND + "fulfilled_at" IS NOT NULL AND + "exchanged_at" >= "fulfilled_at")), + + -- A grant can be either fulfilled or cancelled, but not both + CONSTRAINT "oauth2_authorization_grants_fulfilled_xor_cancelled" + CHECK ("fulfilled_at" IS NULL OR "cancelled_at" IS NULL), + + -- If it was fulfilled there is an oauth2_session_id attached to it + CONSTRAINT "oauth2_authorization_grants_fulfilled_and_session" + CHECK (("fulfilled_at" IS NULL AND "oauth2_session_id" IS NULL) + OR ("fulfilled_at" IS NOT NULL AND "oauth2_session_id" IS NOT NULL)), + + -- We should have a code if and only if the "code" response_type was asked + CONSTRAINT "oauth2_authorization_grants_code" + CHECK (("response_type_code" IS TRUE AND "code" IS NOT NULL) + OR ("response_type_code" IS FALSE AND "code" IS NULL)), + + -- If we have a challenge, we also have a challenge method and a code + CONSTRAINT "oauth2_authorization_grants_code_challenge" + CHECK (("code_challenge" IS NULL AND "code_challenge_method" IS NULL) + OR ("code_challenge" IS NOT NULL AND "code_challenge_method" IS NOT NULL AND "response_type_code" IS TRUE)) +); diff --git a/crates/core/sqlx-data.json b/crates/core/sqlx-data.json index 5bd07067..f8250a46 100644 --- a/crates/core/sqlx-data.json +++ b/crates/core/sqlx-data.json @@ -54,14 +54,139 @@ ] } }, - "17729fd0354a84e04bfcd525db6575ed2ba75dd730bea3f2be964f4b347dd484": { - "query": "\n SELECT code\n FROM oauth2_codes\n WHERE oauth2_session_id = $1\n ", + "0cc63e00143cf94f63695be24acdcdffd8e8a3da50ea1ddf973a39bc34f861d4": { + "query": "\n SELECT\n og.id AS grant_id,\n og.created_at AS grant_created_at,\n og.cancelled_at AS grant_cancelled_at,\n og.fulfilled_at AS grant_fulfilled_at,\n og.exchanged_at AS grant_exchanged_at,\n og.scope AS grant_scope,\n og.state AS grant_state,\n og.redirect_uri AS grant_redirect_uri,\n og.response_mode AS grant_response_mode,\n og.nonce AS grant_nonce,\n og.max_age AS grant_max_age,\n og.acr_values AS grant_acr_values,\n og.client_id AS client_id,\n og.code AS grant_code,\n og.response_type_code AS grant_response_type_code,\n og.response_type_token AS grant_response_type_token,\n og.response_type_id_token AS grant_response_type_id_token,\n og.code_challenge AS grant_code_challenge,\n og.code_challenge_method AS grant_code_challenge_method,\n os.id AS \"session_id?\",\n us.id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n ON os.id = og.oauth2_session_id\n LEFT JOIN user_sessions us\n ON us.id = os.user_session_id\n LEFT JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n WHERE\n og.id = $1\n ", "describe": { "columns": [ { "ordinal": 0, - "name": "code", + "name": "grant_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "grant_created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "grant_cancelled_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 3, + "name": "grant_fulfilled_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 4, + "name": "grant_exchanged_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "grant_scope", "type_info": "Text" + }, + { + "ordinal": 6, + "name": "grant_state", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "grant_redirect_uri", + "type_info": "Text" + }, + { + "ordinal": 8, + "name": "grant_response_mode", + "type_info": "Text" + }, + { + "ordinal": 9, + "name": "grant_nonce", + "type_info": "Text" + }, + { + "ordinal": 10, + "name": "grant_max_age", + "type_info": "Int4" + }, + { + "ordinal": 11, + "name": "grant_acr_values", + "type_info": "Text" + }, + { + "ordinal": 12, + "name": "client_id", + "type_info": "Text" + }, + { + "ordinal": 13, + "name": "grant_code", + "type_info": "Text" + }, + { + "ordinal": 14, + "name": "grant_response_type_code", + "type_info": "Bool" + }, + { + "ordinal": 15, + "name": "grant_response_type_token", + "type_info": "Bool" + }, + { + "ordinal": 16, + "name": "grant_response_type_id_token", + "type_info": "Bool" + }, + { + "ordinal": 17, + "name": "grant_code_challenge", + "type_info": "Text" + }, + { + "ordinal": 18, + "name": "grant_code_challenge_method", + "type_info": "Text" + }, + { + "ordinal": 19, + "name": "session_id?", + "type_info": "Int8" + }, + { + "ordinal": 20, + "name": "user_session_id?", + "type_info": "Int8" + }, + { + "ordinal": 21, + "name": "user_session_created_at?", + "type_info": "Timestamptz" + }, + { + "ordinal": 22, + "name": "user_id?", + "type_info": "Int8" + }, + { + "ordinal": 23, + "name": "user_username?", + "type_info": "Text" + }, + { + "ordinal": 24, + "name": "user_session_last_authentication_id?", + "type_info": "Int8" + }, + { + "ordinal": 25, + "name": "user_session_last_authentication_created_at?", + "type_info": "Timestamptz" } ], "parameters": { @@ -70,42 +195,67 @@ ] }, "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + true, + false, + true, + false, + false, + false, + true, + true, + false, + false, + false, + false, + false, + false, false ] } }, - "282548c5ad51bd95b7d9ad290714bab5860f1e1291021e7d786dc926d12b5dd9": { - "query": "\n SELECT\n oc.id,\n oc.code_challenge,\n oc.code_challenge_method,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce,\n us.id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n LEFT JOIN user_sessions us\n ON us.id = os.user_session_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n LEFT JOIN users u\n ON u.id = us.user_id\n WHERE oc.code = $1\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", + "2dbccaf2fb557dd36598bf4d00941280535cc523ac3a481903ed825088901bce": { + "query": "\n SELECT\n at.id AS \"access_token_id\",\n at.token AS \"access_token\",\n at.expires_after AS \"access_token_expires_after\",\n at.created_at AS \"access_token_created_at\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE at.token = $1\n AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now()\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", "describe": { "columns": [ { "ordinal": 0, - "name": "id", + "name": "access_token_id", "type_info": "Int8" }, { "ordinal": 1, - "name": "code_challenge", + "name": "access_token", "type_info": "Text" }, { "ordinal": 2, - "name": "code_challenge_method", - "type_info": "Int2" + "name": "access_token_expires_after", + "type_info": "Int4" }, { "ordinal": 3, - "name": "oauth2_session_id!", - "type_info": "Int8" + "name": "access_token_created_at", + "type_info": "Timestamptz" }, { "ordinal": 4, - "name": "client_id!", - "type_info": "Text" + "name": "session_id!", + "type_info": "Int8" }, { "ordinal": 5, - "name": "redirect_uri", + "name": "client_id!", "type_info": "Text" }, { @@ -115,36 +265,31 @@ }, { "ordinal": 7, - "name": "nonce", - "type_info": "Text" + "name": "user_session_id!", + "type_info": "Int8" }, { "ordinal": 8, - "name": "user_session_id?", - "type_info": "Int8" - }, - { - "ordinal": 9, - "name": "user_session_created_at?", + "name": "user_session_created_at!", "type_info": "Timestamptz" }, { - "ordinal": 10, - "name": "user_id?", + "ordinal": 9, + "name": "user_id!", "type_info": "Int8" }, { - "ordinal": 11, - "name": "user_username?", + "ordinal": 10, + "name": "user_username!", "type_info": "Text" }, { - "ordinal": 12, + "ordinal": 11, "name": "user_session_last_authentication_id?", "type_info": "Int8" }, { - "ordinal": 13, + "ordinal": 12, "name": "user_session_last_authentication_created_at?", "type_info": "Timestamptz" } @@ -156,13 +301,12 @@ }, "nullable": [ false, - true, - true, false, false, false, false, - true, + false, + false, false, false, false, @@ -198,25 +342,41 @@ ] } }, - "47a7a8d2ef7db8bb1d41230626ded4e4661d488891fbda9b872c0749a9ba58f4": { - "query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id\n ", + "38641231a3bff71252e8bc0ead3a033c9148762ea64d707642551c01a4c89b84": { + "query": "\n INSERT INTO oauth2_authorization_grants\n (client_id, redirect_uri, scope, state, nonce, max_age,\n acr_values, response_mode, code_challenge, code_challenge_method,\n response_type_code, response_type_token, response_type_id_token,\n code)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)\n RETURNING id, created_at\n ", "describe": { "columns": [ { "ordinal": 0, "name": "id", "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "created_at", + "type_info": "Timestamptz" } ], "parameters": { "Left": [ - "Int8", "Text", - "Int2", + "Text", + "Text", + "Text", + "Text", + "Int4", + "Text", + "Text", + "Text", + "Text", + "Bool", + "Bool", + "Bool", "Text" ] }, "nullable": [ + false, false ] } @@ -249,8 +409,18 @@ ] } }, - "5d032f4bdb28534da7cf8e9806442a12708d632b7be28f8b952bd3cb63a8b1af": { - "query": "\n SELECT\n rt.id AS refresh_token_id,\n rt.token AS refresh_token,\n rt.created_at AS refresh_token_created_at,\n at.id AS \"access_token_id?\",\n at.token AS \"access_token?\",\n at.expires_after AS \"access_token_expires_after?\",\n at.created_at AS \"access_token_created_at?\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n os.redirect_uri AS \"redirect_uri!\",\n os.nonce AS \"nonce\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_refresh_tokens rt\n LEFT JOIN oauth2_access_tokens at\n ON at.id = rt.oauth2_access_token_id\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE rt.token = $1\n AND rt.next_token_id IS NULL\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", + "5d1a17b2ad6153217551ae31549ad9d62cc39d2f9a4e62a7ccb60fd91e0ac685": { + "query": "\n DELETE FROM oauth2_access_tokens\n WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + } + }, + "6765e725d31a1490ddee3f28e32dea41abdd9acefb1edd9a7b4e6790ec131173": { + "query": "\n SELECT\n rt.id AS refresh_token_id,\n rt.token AS refresh_token,\n rt.created_at AS refresh_token_created_at,\n at.id AS \"access_token_id?\",\n at.token AS \"access_token?\",\n at.expires_after AS \"access_token_expires_after?\",\n at.created_at AS \"access_token_created_at?\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM oauth2_refresh_tokens rt\n LEFT JOIN oauth2_access_tokens at\n ON at.id = rt.oauth2_access_token_id\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE rt.token = $1\n AND rt.next_token_id IS NULL\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", "describe": { "columns": [ { @@ -305,41 +475,31 @@ }, { "ordinal": 10, - "name": "redirect_uri!", - "type_info": "Text" - }, - { - "ordinal": 11, - "name": "nonce", - "type_info": "Text" - }, - { - "ordinal": 12, "name": "user_session_id!", "type_info": "Int8" }, { - "ordinal": 13, + "ordinal": 11, "name": "user_session_created_at!", "type_info": "Timestamptz" }, { - "ordinal": 14, + "ordinal": 12, "name": "user_id!", "type_info": "Int8" }, { - "ordinal": 15, + "ordinal": 13, "name": "user_username!", "type_info": "Text" }, { - "ordinal": 16, + "ordinal": 14, "name": "user_session_last_authentication_id?", "type_info": "Int8" }, { - "ordinal": 17, + "ordinal": 15, "name": "user_session_last_authentication_created_at?", "type_info": "Timestamptz" } @@ -361,8 +521,6 @@ false, false, false, - true, - false, false, false, false, @@ -371,117 +529,24 @@ ] } }, - "5d1a17b2ad6153217551ae31549ad9d62cc39d2f9a4e62a7ccb60fd91e0ac685": { - "query": "\n DELETE FROM oauth2_access_tokens\n WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [] - }, - "nullable": [] - } - }, - "686a796a7de689b73a9377083718c95ac5ac51ce396dcf32e614402051d93e16": { - "query": "\n SELECT\n at.id AS \"access_token_id\",\n at.token AS \"access_token\",\n at.expires_after AS \"access_token_expires_after\",\n at.created_at AS \"access_token_created_at\",\n os.id AS \"session_id!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n os.redirect_uri AS \"redirect_uri!\",\n os.nonce AS \"nonce\",\n us.id AS \"user_session_id!\",\n us.created_at AS \"user_session_created_at!\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n\n WHERE at.token = $1\n AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now()\n AND us.active\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n ", + "703850ba4e001d53776d77a64cbc1ee6feb61485ce41aff1103251f9b3778128": { + "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.id = $1 AND os.id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n ", "describe": { "columns": [ { "ordinal": 0, - "name": "access_token_id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "access_token", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "access_token_expires_after", - "type_info": "Int4" - }, - { - "ordinal": 3, - "name": "access_token_created_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 4, - "name": "session_id!", - "type_info": "Int8" - }, - { - "ordinal": 5, - "name": "client_id!", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "scope!", - "type_info": "Text" - }, - { - "ordinal": 7, - "name": "redirect_uri!", - "type_info": "Text" - }, - { - "ordinal": 8, - "name": "nonce", - "type_info": "Text" - }, - { - "ordinal": 9, - "name": "user_session_id!", - "type_info": "Int8" - }, - { - "ordinal": 10, - "name": "user_session_created_at!", - "type_info": "Timestamptz" - }, - { - "ordinal": 11, - "name": "user_id!", - "type_info": "Int8" - }, - { - "ordinal": 12, - "name": "user_username!", - "type_info": "Text" - }, - { - "ordinal": 13, - "name": "user_session_last_authentication_id?", - "type_info": "Int8" - }, - { - "ordinal": 14, - "name": "user_session_last_authentication_created_at?", + "name": "fulfilled_at!: DateTime", "type_info": "Timestamptz" } ], "parameters": { "Left": [ - "Text" + "Int8", + "Int8" ] }, "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - true, - false, - false, - false, - false, - false, - false + true ] } }, @@ -547,6 +612,176 @@ "nullable": [] } }, + "8dde452a37c8faad20df68eb2b665202e0fb6b4ce805138e5f19d4e7eb0ce802": { + "query": "\n SELECT\n og.id AS grant_id,\n og.created_at AS grant_created_at,\n og.cancelled_at AS grant_cancelled_at,\n og.fulfilled_at AS grant_fulfilled_at,\n og.exchanged_at AS grant_exchanged_at,\n og.scope AS grant_scope,\n og.state AS grant_state,\n og.redirect_uri AS grant_redirect_uri,\n og.response_mode AS grant_response_mode,\n og.nonce AS grant_nonce,\n og.max_age AS grant_max_age,\n og.acr_values AS grant_acr_values,\n og.client_id AS client_id,\n og.code AS grant_code,\n og.response_type_code AS grant_response_type_code,\n og.response_type_token AS grant_response_type_token,\n og.response_type_id_token AS grant_response_type_id_token,\n og.code_challenge AS grant_code_challenge,\n og.code_challenge_method AS grant_code_challenge_method,\n os.id AS \"session_id?\",\n us.id AS \"user_session_id?\",\n us.created_at AS \"user_session_created_at?\",\n u.id AS \"user_id?\",\n u.username AS \"user_username?\",\n usa.id AS \"user_session_last_authentication_id?\",\n usa.created_at AS \"user_session_last_authentication_created_at?\"\n FROM\n oauth2_authorization_grants og\n LEFT JOIN oauth2_sessions os\n ON os.id = og.oauth2_session_id\n LEFT JOIN user_sessions us\n ON us.id = os.user_session_id\n LEFT JOIN users u\n ON u.id = us.user_id\n LEFT JOIN user_session_authentications usa\n ON usa.session_id = us.id\n WHERE\n og.code = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "grant_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "grant_created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "grant_cancelled_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 3, + "name": "grant_fulfilled_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 4, + "name": "grant_exchanged_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "grant_scope", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "grant_state", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "grant_redirect_uri", + "type_info": "Text" + }, + { + "ordinal": 8, + "name": "grant_response_mode", + "type_info": "Text" + }, + { + "ordinal": 9, + "name": "grant_nonce", + "type_info": "Text" + }, + { + "ordinal": 10, + "name": "grant_max_age", + "type_info": "Int4" + }, + { + "ordinal": 11, + "name": "grant_acr_values", + "type_info": "Text" + }, + { + "ordinal": 12, + "name": "client_id", + "type_info": "Text" + }, + { + "ordinal": 13, + "name": "grant_code", + "type_info": "Text" + }, + { + "ordinal": 14, + "name": "grant_response_type_code", + "type_info": "Bool" + }, + { + "ordinal": 15, + "name": "grant_response_type_token", + "type_info": "Bool" + }, + { + "ordinal": 16, + "name": "grant_response_type_id_token", + "type_info": "Bool" + }, + { + "ordinal": 17, + "name": "grant_code_challenge", + "type_info": "Text" + }, + { + "ordinal": 18, + "name": "grant_code_challenge_method", + "type_info": "Text" + }, + { + "ordinal": 19, + "name": "session_id?", + "type_info": "Int8" + }, + { + "ordinal": 20, + "name": "user_session_id?", + "type_info": "Int8" + }, + { + "ordinal": 21, + "name": "user_session_created_at?", + "type_info": "Timestamptz" + }, + { + "ordinal": 22, + "name": "user_id?", + "type_info": "Int8" + }, + { + "ordinal": 23, + "name": "user_username?", + "type_info": "Text" + }, + { + "ordinal": 24, + "name": "user_session_last_authentication_id?", + "type_info": "Int8" + }, + { + "ordinal": 25, + "name": "user_session_last_authentication_created_at?", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + true, + true, + true, + false, + true, + false, + false, + true, + true, + true, + false, + true, + false, + false, + false, + true, + true, + false, + false, + false, + false, + false, + false, + false + ] + } + }, "a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": { "query": "UPDATE user_sessions SET active = FALSE WHERE id = $1", "describe": { @@ -579,17 +814,31 @@ ] } }, - "a6eb935107d060dd01bf9824ceff87b9ff5492b58cefef002a49f444d3a3daa1": { - "query": "UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2", + "c29e741474aacc91c0aacc028a9e7452a5327d5ce6d4b791bf20a2636069087e": { + "query": "\n INSERT INTO oauth2_sessions\n (user_session_id, client_id, scope)\n SELECT\n $1,\n og.client_id,\n og.scope\n FROM\n oauth2_authorization_grants og\n WHERE\n og.id = $2\n RETURNING id, created_at\n ", "describe": { - "columns": [], + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "created_at", + "type_info": "Timestamptz" + } + ], "parameters": { "Left": [ "Int8", "Int8" ] }, - "nullable": [] + "nullable": [ + false, + false + ] } }, "c2c402cfe0adcafa615f14a499caba4c96ca71d9ffb163e1feb05e5d85f3462c": { @@ -605,97 +854,23 @@ "nullable": [] } }, - "cacec823f5d4ed886854fbd62b5f5bb2def792582df58c8a047c769d34d9b190": { - "query": "\n INSERT INTO oauth2_sessions\n (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n RETURNING\n id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode, created_at, updated_at\n ", + "d604e13bdfb2ff3d354d995f0b68f04091847755db98bafea7c45bd7b5c4ab68": { + "query": "\n UPDATE oauth2_authorization_grants\n SET\n exchanged_at = NOW()\n WHERE\n id = $1\n RETURNING exchanged_at AS \"exchanged_at!: DateTime\"\n ", "describe": { "columns": [ { "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "user_session_id", - "type_info": "Int8" - }, - { - "ordinal": 2, - "name": "client_id", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "redirect_uri", - "type_info": "Text" - }, - { - "ordinal": 4, - "name": "scope", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "state", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "nonce", - "type_info": "Text" - }, - { - "ordinal": 7, - "name": "max_age", - "type_info": "Int4" - }, - { - "ordinal": 8, - "name": "response_type", - "type_info": "Text" - }, - { - "ordinal": 9, - "name": "response_mode", - "type_info": "Text" - }, - { - "ordinal": 10, - "name": "created_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 11, - "name": "updated_at", + "name": "exchanged_at!: DateTime", "type_info": "Timestamptz" } ], "parameters": { "Left": [ - "Int8", - "Text", - "Text", - "Text", - "Text", - "Text", - "Int4", - "Text", - "Text" + "Int8" ] }, "nullable": [ - false, - true, - false, - false, - false, - true, - true, - true, - false, - false, - false, - false + true ] } }, @@ -725,18 +900,6 @@ ] } }, - "eaddc1e33715ad31b4195fda72dbe870f179dd8da53a88d0543b72a278ed1d3d": { - "query": "\n DELETE FROM oauth2_codes\n WHERE id = $1\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [] - } - }, "f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": { "query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ", "describe": { @@ -757,91 +920,5 @@ false ] } - }, - "ff515ebb80ba4af1948472f5c7120a03e25b1ebe42151b8a2036bfbb042f17f6": { - "query": "\n SELECT\n id, user_session_id, client_id, redirect_uri, scope, state, nonce,\n max_age, response_type, response_mode, created_at, updated_at\n FROM oauth2_sessions\n WHERE id = $1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "user_session_id", - "type_info": "Int8" - }, - { - "ordinal": 2, - "name": "client_id", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "redirect_uri", - "type_info": "Text" - }, - { - "ordinal": 4, - "name": "scope", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "state", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "nonce", - "type_info": "Text" - }, - { - "ordinal": 7, - "name": "max_age", - "type_info": "Int4" - }, - { - "ordinal": 8, - "name": "response_type", - "type_info": "Text" - }, - { - "ordinal": 9, - "name": "response_mode", - "type_info": "Text" - }, - { - "ordinal": 10, - "name": "created_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 11, - "name": "updated_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [ - false, - true, - false, - false, - false, - true, - true, - true, - false, - false, - false, - false - ] - } } } \ No newline at end of file diff --git a/crates/core/src/handlers/oauth2/authorization.rs b/crates/core/src/handlers/oauth2/authorization.rs index 413fb181..67a6402d 100644 --- a/crates/core/src/handlers/oauth2/authorization.rs +++ b/crates/core/src/handlers/oauth2/authorization.rs @@ -23,11 +23,12 @@ use hyper::{ http::uri::{Parts, PathAndQuery, Uri}, StatusCode, }; -use itertools::Itertools; -use mas_data_model::BrowserSession; +use mas_data_model::{ + Authentication, AuthorizationCode, AuthorizationGrantStage, BrowserSession, Pkce, +}; use mas_templates::{FormPostContext, Templates}; use oauth2_types::{ - errors::{ErrorResponse, InvalidRequest, OAuth2Error}, + errors::{ErrorResponse, InvalidGrant, InvalidRequest, OAuth2Error}, pkce, requests::{ AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, ResponseMode, @@ -58,8 +59,10 @@ use crate::{ storage::{ oauth2::{ access_token::add_access_token, + authorization_grant::{ + derive_session, fulfill_grant, get_grant_by_id, new_authorization_grant, + }, refresh_token::add_refresh_token, - session::{get_session_by_id, start_session}, }, PostgresqlBackend, }, @@ -308,13 +311,6 @@ async fn get( .ok_or_else(|| anyhow::anyhow!("could not find client")) .wrap_error()?; - let maybe_session_id = maybe_session.as_ref().map(|s| s.data); - - let scope: String = { - let it = params.auth.scope.iter().map(ToString::to_string); - Itertools::intersperse(it, " ".to_string()).collect() - }; - let redirect_uri = client .resolve_redirect_uri(¶ms.auth.redirect_uri) .wrap_error()?; @@ -322,23 +318,7 @@ async fn get( let response_mode = resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?; - let oauth2_session = start_session( - &mut txn, - maybe_session_id, - &client.client_id, - redirect_uri, - &scope, - params.auth.state.as_deref(), - params.auth.nonce.as_deref(), - params.auth.max_age, - response_type, - response_mode, - ) - .await - .wrap_error()?; - - // Generate the code at this stage, since we have the PKCE params ready - if response_type.contains(&ResponseType::Code) { + let code: Option = if response_type.contains(&ResponseType::Code) { // 32 random alphanumeric characters, about 190bit of entropy let code: String = thread_rng() .sample_iter(&Alphanumeric) @@ -346,22 +326,47 @@ async fn get( .map(char::from) .collect(); - oauth2_session - .add_code(&mut txn, &code, ¶ms.pkce) - .await - .wrap_error()?; - } + let pkce = params.pkce.map(|p| Pkce { + challenge: p.code_challenge, + challenge_method: p.code_challenge_method, + }); - // Do we already have a user session for this oauth2 session? - let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?; + Some(AuthorizationCode { code, pkce }) + } else { + // If the request had PKCE params but no code asked, it should get back with an + // error + if params.pkce.is_some() { + return Ok(ReplyOrBackToClient::Error(Box::new(InvalidGrant))); + } - if let Some(user_session) = user_session { - step(oauth2_session.id, user_session, txn).await + None + }; + + let grant = new_authorization_grant( + &mut txn, + client.client_id.clone(), + redirect_uri.clone(), + params.auth.scope, + code, + params.auth.state, + params.auth.nonce, + // TODO: support max_age and acr_values + None, + None, + response_mode, + response_type.contains(&ResponseType::Token), + response_type.contains(&ResponseType::IdToken), + ) + .await + .wrap_error()?; + + if let Some(user_session) = maybe_session { + step(grant.data, user_session, txn).await } else { // If not, redirect the user to the login page txn.commit().await.wrap_error()?; - let next = StepRequest::new(oauth2_session.id) + let next = StepRequest::new(grant.data) .build_uri() .wrap_error()? .to_string(); @@ -393,85 +398,84 @@ impl StepRequest { } } +fn reauth() -> ReplyOrBackToClient { + // Ask for a reauth + // TODO: have the OAuth2 session ID in there + ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth")))) +} + async fn step( - oauth2_session_id: i64, + grant_id: i64, browser_session: BrowserSession, mut txn: Transaction<'_, Postgres>, ) -> Result { - let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id) - .await - .wrap_error()?; + // TODO: we should check if the grant here was started by the browser doing that + // request using a signed cookie + let grant = get_grant_by_id(&mut txn, grant_id).await.wrap_error()?; - let user_session = oauth2_session - .match_or_set_session(&mut txn, browser_session) - .await - .wrap_error()?; + if !matches!(grant.stage, AuthorizationGrantStage::Pending) { + return Err(anyhow::anyhow!("authorization grant not pending")).wrap_error(); + } - let response_mode = oauth2_session.response_mode().wrap_error()?; - let response_type = oauth2_session.response_type().wrap_error()?; - let redirect_uri = oauth2_session.redirect_uri().wrap_error()?; + let reply = match browser_session.last_authentication { + Some(Authentication { created_at, .. }) if created_at < grant.max_auth_time() => { + let session = derive_session(&mut txn, &grant, browser_session) + .await + .wrap_error()?; - // Check if the active session is valid - // TODO: this is ugly & should check if the session is active - let reply = if user_session.last_authentication.map(|x| x.created_at) - >= oauth2_session.max_auth_time() - { - // Yep! Let's complete the auth now - let mut params = AuthorizationResponse::default(); + let grant = fulfill_grant(&mut txn, grant, session.clone()) + .await + .wrap_error()?; - // Did they request an auth code? - if response_type.contains(&ResponseType::Code) { - params.code = Some(oauth2_session.fetch_code(&mut txn).await.wrap_error()?); - } + // Yep! Let's complete the auth now + let mut params = AuthorizationResponse::default(); - // Did they request an access token? - if response_type.contains(&ResponseType::Token) { - let ttl = Duration::minutes(5); - let (access_token_str, refresh_token_str) = { - let mut rng = thread_rng(); - ( - AccessToken.generate(&mut rng), - RefreshToken.generate(&mut rng), - ) - }; + // Did they request an auth code? + if let Some(code) = grant.code { + params.code = Some(code.code); + } - let access_token = - add_access_token(&mut txn, oauth2_session_id, &access_token_str, ttl) + // Did they request an access token? + if grant.response_type_token { + let ttl = Duration::minutes(5); + let (access_token_str, refresh_token_str) = { + let mut rng = thread_rng(); + ( + AccessToken.generate(&mut rng), + RefreshToken.generate(&mut rng), + ) + }; + + let access_token = add_access_token(&mut txn, &session, &access_token_str, ttl) .await .wrap_error()?; - let _refresh_token = add_refresh_token( - &mut txn, - oauth2_session_id, - access_token, - &refresh_token_str, - ) - .await - .wrap_error()?; + let _refresh_token = + add_refresh_token(&mut txn, &session, access_token, &refresh_token_str) + .await + .wrap_error()?; - params.response = Some( - AccessTokenResponse::new(access_token_str) - .with_expires_in(ttl) - .with_refresh_token(refresh_token_str), - ); - } + params.response = Some( + AccessTokenResponse::new(access_token_str) + .with_expires_in(ttl) + .with_refresh_token(refresh_token_str), + ); + } - // Did they request an ID token? - if response_type.contains(&ResponseType::IdToken) { - todo!("id tokens are not implemented yet"); - } + // Did they request an ID token? + if grant.response_type_id_token { + todo!("id tokens are not implemented yet"); + } - let params = serde_json::to_value(¶ms).unwrap(); - ReplyOrBackToClient::BackToClient { - redirect_uri, - response_mode, - state: oauth2_session.state.clone(), - params, + let params = serde_json::to_value(¶ms).unwrap(); + ReplyOrBackToClient::BackToClient { + redirect_uri: grant.redirect_uri, + response_mode: grant.response_mode, + state: grant.state, + params, + } } - } else { - // Ask for a reauth - // TODO: have the OAuth2 session ID in there - ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth")))) + _ => reauth(), }; txn.commit().await.wrap_error()?; diff --git a/crates/core/src/handlers/oauth2/introspection.rs b/crates/core/src/handlers/oauth2/introspection.rs index ae989cd2..5757960b 100644 --- a/crates/core/src/handlers/oauth2/introspection.rs +++ b/crates/core/src/handlers/oauth2/introspection.rs @@ -94,12 +94,12 @@ async fn introspect( active: true, scope: Some(session.scope), client_id: Some(session.client.client_id), - username: session.browser_session.clone().map(|s| s.user.username), + username: Some(session.browser_session.user.username), token_type: Some(TokenTypeHint::AccessToken), exp: Some(exp), iat: Some(token.created_at), nbf: Some(token.created_at), - sub: session.browser_session.map(|s| s.user.sub), + sub: Some(session.browser_session.user.sub), aud: None, iss: None, jti: None, @@ -114,12 +114,12 @@ async fn introspect( active: true, scope: Some(session.scope), client_id: Some(session.client.client_id), - username: session.browser_session.clone().map(|s| s.user.username), + username: Some(session.browser_session.user.username), token_type: Some(TokenTypeHint::RefreshToken), exp: None, iat: Some(token.created_at), nbf: Some(token.created_at), - sub: session.browser_session.map(|s| s.user.sub), + sub: Some(session.browser_session.user.sub), aud: None, iss: None, jti: None, diff --git a/crates/core/src/handlers/oauth2/token.rs b/crates/core/src/handlers/oauth2/token.rs index e4f28800..d33c896c 100644 --- a/crates/core/src/handlers/oauth2/token.rs +++ b/crates/core/src/handlers/oauth2/token.rs @@ -18,6 +18,7 @@ use data_encoding::BASE64URL_NOPAD; use headers::{CacheControl, Pragma}; use hyper::{Method, StatusCode}; use jwt_compact::{Claims, Header, TimeOptions}; +use mas_data_model::AuthorizationGrantStage; use oauth2_types::{ errors::{InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient}, requests::{ @@ -30,6 +31,7 @@ use serde::Serialize; use serde_with::skip_serializing_none; use sha2::{Digest, Sha256}; use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres}; +use tracing::debug; use url::Url; use warp::{ reject::Reject, @@ -47,10 +49,15 @@ use crate::{ with_keys, }, reply::with_typed_header, - storage::oauth2::{ - access_token::{add_access_token, revoke_access_token}, - authorization_code::{consume_code, lookup_code}, - refresh_token::{add_refresh_token, lookup_active_refresh_token, replace_refresh_token}, + storage::{ + oauth2::{ + access_token::{add_access_token, revoke_access_token}, + authorization_grant::{exchange_grant, lookup_grant_by_code}, + refresh_token::{ + add_refresh_token, lookup_active_refresh_token, replace_refresh_token, + }, + }, + DatabaseInconsistencyError, }, tokens::{AccessToken, RefreshToken}, }; @@ -156,15 +163,50 @@ async fn authorization_code_grant( issuer: Url, conn: &mut PoolConnection, ) -> Result { + // TODO: there is a bunch of unnecessary cloning here let mut txn = conn.begin().await.wrap_error()?; - // TODO: we should invalidate the existing session if a code is used twice after - // some period of time. See the `oidcc-codereuse-30seconds` test from the - // conformance suite - let (code, session) = match lookup_code(&mut txn, &grant.code).await { - Err(e) if e.not_found() => return error(InvalidGrant), - x => x, - }?; + // TODO: handle "not found" cases + let authz_grant = lookup_grant_by_code(&mut txn, &grant.code) + .await + .wrap_error()?; + + let session = match authz_grant.stage { + AuthorizationGrantStage::Cancelled { cancelled_at } => { + debug!(%cancelled_at, "Authorization grant was cancelled"); + return error(InvalidGrant); + } + AuthorizationGrantStage::Exchanged { + exchanged_at, + fulfilled_at, + session: _, + } => { + // TODO: we should invalidate the existing session if a code is used twice after + // some period of time. See the `oidcc-codereuse-30seconds` test from the + // conformance suite + debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged"); + return error(InvalidGrant); + } + AuthorizationGrantStage::Pending => { + debug!("Authorization grant has not been fulfilled yet"); + return error(InvalidGrant); + } + AuthorizationGrantStage::Fulfilled { + ref session, + fulfilled_at: _, + } => { + // TODO: we should check that the session was not fullfilled too long ago + // (30s to 1min?). The main problem is getting a timestamp from the database + session + } + }; + + // This should never happen, since we looked up in the database using the code + let code = authz_grant + .code + .as_ref() + .ok_or(DatabaseInconsistencyError) + .wrap_error()?; if client.client_id != session.client.client_id { return error(UnauthorizedClient); @@ -182,13 +224,7 @@ async fn authorization_code_grant( } }; - // TODO: this should probably not happen? - let browser_session = session - .browser_session - .ok_or_else(|| { - anyhow::anyhow!("this oauth2 session has no database session attached to it") - }) - .wrap_error()?; + let browser_session = &session.browser_session; let ttl = Duration::minutes(5); let (access_token_str, refresh_token_str) = { @@ -199,23 +235,22 @@ async fn authorization_code_grant( ) }; - let access_token = add_access_token(&mut txn, session.data, &access_token_str, ttl) + let access_token = add_access_token(&mut txn, session, &access_token_str, ttl) .await .wrap_error()?; - let _refresh_token = - add_refresh_token(&mut txn, session.data, access_token, &refresh_token_str) - .await - .wrap_error()?; + let _refresh_token = add_refresh_token(&mut txn, session, access_token, &refresh_token_str) + .await + .wrap_error()?; let id_token = if session.scope.contains(&OPENID) { let header = Header::default(); let options = TimeOptions::default(); let claims = Claims::new(CustomClaims { issuer, - subject: browser_session.user.sub, + subject: browser_session.user.sub.clone(), audiences: vec![client.client_id.clone()], - nonce: session.nonce, + nonce: authz_grant.nonce.clone(), at_hash: hash(Sha256::new(), &access_token_str).wrap_error()?, c_hash: hash(Sha256::new(), &grant.code).wrap_error()?, }) @@ -234,13 +269,13 @@ async fn authorization_code_grant( let mut params = AccessTokenResponse::new(access_token_str) .with_expires_in(ttl) .with_refresh_token(refresh_token_str) - .with_scope(session.scope); + .with_scope(session.scope.clone()); if let Some(id_token) = id_token { params = params.with_id_token(id_token); } - consume_code(&mut txn, code).await.wrap_error()?; + exchange_grant(&mut txn, authz_grant).await.wrap_error()?; txn.commit().await.wrap_error()?; @@ -271,12 +306,12 @@ async fn refresh_token_grant( ) }; - let new_access_token = add_access_token(&mut txn, session.data, &access_token_str, ttl) + let new_access_token = add_access_token(&mut txn, &session, &access_token_str, ttl) .await .wrap_error()?; let new_refresh_token = - add_refresh_token(&mut txn, session.data, new_access_token, &refresh_token_str) + add_refresh_token(&mut txn, &session, new_access_token, &refresh_token_str) .await .wrap_error()?; @@ -285,7 +320,7 @@ async fn refresh_token_grant( .wrap_error()?; if let Some(access_token) = refresh_token.access_token { - revoke_access_token(&mut txn, access_token.data) + revoke_access_token(&mut txn, &access_token) .await .wrap_error()?; } diff --git a/crates/core/src/handlers/oauth2/userinfo.rs b/crates/core/src/handlers/oauth2/userinfo.rs index 26cec4d4..bf3a2a1c 100644 --- a/crates/core/src/handlers/oauth2/userinfo.rs +++ b/crates/core/src/handlers/oauth2/userinfo.rs @@ -52,8 +52,7 @@ async fn userinfo( _token: AccessToken, session: Session, ) -> Result { - // TODO: we really should not have an Option here - let user = session.browser_session.unwrap().user; + let user = session.browser_session.user; Ok(warp::reply::json(&UserInfo { sub: user.sub, username: user.username, diff --git a/crates/core/src/storage/mod.rs b/crates/core/src/storage/mod.rs index 641e6e00..633e7d24 100644 --- a/crates/core/src/storage/mod.rs +++ b/crates/core/src/storage/mod.rs @@ -32,7 +32,7 @@ pub struct PostgresqlBackend; impl StorageBackend for PostgresqlBackend { type AccessTokenData = i64; type AuthenticationData = i64; - type AuthorizationCodeData = i64; + type AuthorizationGrantData = i64; type BrowserSessionData = i64; type ClientData = (); type RefreshTokenData = i64; diff --git a/crates/core/src/storage/oauth2/access_token.rs b/crates/core/src/storage/oauth2/access_token.rs index c43b25fc..80fe575e 100644 --- a/crates/core/src/storage/oauth2/access_token.rs +++ b/crates/core/src/storage/oauth2/access_token.rs @@ -24,7 +24,7 @@ use crate::storage::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBa pub async fn add_access_token( executor: impl PgExecutor<'_>, - oauth2_session_id: i64, + session: &Session, token: &str, expires_after: Duration, ) -> anyhow::Result> { @@ -41,7 +41,7 @@ pub async fn add_access_token( RETURNING id, created_at "#, - oauth2_session_id, + session.data, token, expires_after_seconds, ) @@ -67,8 +67,6 @@ pub struct OAuth2AccessTokenLookup { session_id: i64, client_id: String, scope: String, - redirect_uri: String, - nonce: Option, user_session_id: i64, user_session_created_at: DateTime, user_id: i64, @@ -109,8 +107,6 @@ pub async fn lookup_active_access_token( os.id AS "session_id!", os.client_id AS "client_id!", os.scope AS "scope!", - os.redirect_uri AS "redirect_uri!", - os.nonce AS "nonce", us.id AS "user_session_id!", us.created_at AS "user_session_created_at!", u.id AS "user_id!", @@ -171,39 +167,35 @@ pub async fn lookup_active_access_token( _ => return Err(DatabaseInconsistencyError.into()), }; - let browser_session = Some(BrowserSession { + let browser_session = BrowserSession { data: res.user_session_id, created_at: res.user_session_created_at, user, last_authentication, - }); + }; let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; - let redirect_uri = res - .redirect_uri - .parse() - .map_err(|_e| DatabaseInconsistencyError)?; - let session = Session { data: res.session_id, client, browser_session, scope, - redirect_uri, - nonce: res.nonce, }; Ok((access_token, session)) } -pub async fn revoke_access_token(executor: impl PgExecutor<'_>, id: i64) -> anyhow::Result<()> { +pub async fn revoke_access_token( + executor: impl PgExecutor<'_>, + access_token: &AccessToken, +) -> anyhow::Result<()> { let res = sqlx::query!( r#" DELETE FROM oauth2_access_tokens WHERE id = $1 "#, - id, + access_token.data, ) .execute(executor) .await diff --git a/crates/core/src/storage/oauth2/authorization_code.rs b/crates/core/src/storage/oauth2/authorization_code.rs deleted file mode 100644 index 5b4b9667..00000000 --- a/crates/core/src/storage/oauth2/authorization_code.rs +++ /dev/null @@ -1,263 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use anyhow::Context; -use chrono::{DateTime, Utc}; -use mas_data_model::{ - Authentication, AuthorizationCode, BrowserSession, Client, Pkce, Session, User, -}; -use oauth2_types::pkce; -use sqlx::PgExecutor; -use thiserror::Error; -use warp::reject::Reject; - -use crate::storage::{DatabaseInconsistencyError, PostgresqlBackend}; - -pub async fn add_code( - executor: impl PgExecutor<'_>, - oauth2_session_id: i64, - code: &str, - pkce: &Option, -) -> anyhow::Result> { - let code_challenge_method = pkce.as_ref().map(|c| c.code_challenge_method as i16); - let code_challenge = pkce.as_ref().map(|c| &c.code_challenge); - let id = sqlx::query_scalar!( - r#" - INSERT INTO oauth2_codes - (oauth2_session_id, code, code_challenge_method, code_challenge) - VALUES - ($1, $2, $3, $4) - RETURNING - id - "#, - oauth2_session_id, - code, - code_challenge_method, - code_challenge, - ) - .fetch_one(executor) - .await - .context("could not insert oauth2 authorization code")?; - - let pkce = pkce - .as_ref() - .map(|c| Pkce::new(c.code_challenge_method, c.code_challenge.clone())); - - Ok(AuthorizationCode { - data: id, - code: code.to_string(), - pkce, - }) -} - -struct OAuth2CodeLookup { - id: i64, - oauth2_session_id: i64, - client_id: String, - redirect_uri: String, - scope: String, - nonce: Option, - code_challenge: Option, - code_challenge_method: Option, - user_session_id: Option, - user_session_created_at: Option>, - user_id: Option, - user_username: Option, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, -} - -fn browser_session_from_database( - user_session_id: Option, - user_session_created_at: Option>, - user_id: Option, - user_username: Option, - user_session_last_authentication_id: Option, - user_session_last_authentication_created_at: Option>, -) -> Result>, DatabaseInconsistencyError> { - match ( - user_session_id, - user_session_created_at, - user_id, - user_username, - ) { - (None, None, None, None) => Ok(None), - (Some(session_id), Some(session_created_at), Some(user_id), Some(user_username)) => { - let user = User { - data: user_id, - username: user_username, - sub: format!("fake-sub-{}", user_id), - }; - - let last_authentication = match ( - user_session_last_authentication_id, - user_session_last_authentication_created_at, - ) { - (None, None) => None, - (Some(id), Some(created_at)) => Some(Authentication { - data: id, - created_at, - }), - _ => return Err(DatabaseInconsistencyError), - }; - - Ok(Some(BrowserSession { - data: session_id, - created_at: session_created_at, - user, - last_authentication, - })) - } - _ => Err(DatabaseInconsistencyError), - } -} - -#[derive(Debug, Error)] -#[error("failed to lookup oauth2 code")] -pub enum CodeLookupError { - Database(#[from] sqlx::Error), - Inconsistency(#[from] DatabaseInconsistencyError), -} - -impl Reject for CodeLookupError {} - -impl CodeLookupError { - #[must_use] - pub fn not_found(&self) -> bool { - matches!(self, &CodeLookupError::Database(sqlx::Error::RowNotFound)) - } -} - -#[allow(clippy::too_many_lines)] -pub async fn lookup_code( - executor: impl PgExecutor<'_>, - code: &str, -) -> Result< - ( - AuthorizationCode, - Session, - ), - CodeLookupError, -> { - let res = sqlx::query_as!( - OAuth2CodeLookup, - r#" - SELECT - oc.id, - oc.code_challenge, - oc.code_challenge_method, - os.id AS "oauth2_session_id!", - os.client_id AS "client_id!", - os.redirect_uri, - os.scope AS "scope!", - os.nonce, - us.id AS "user_session_id?", - us.created_at AS "user_session_created_at?", - u.id AS "user_id?", - u.username AS "user_username?", - usa.id AS "user_session_last_authentication_id?", - usa.created_at AS "user_session_last_authentication_created_at?" - FROM oauth2_codes oc - INNER JOIN oauth2_sessions os - ON os.id = oc.oauth2_session_id - LEFT JOIN user_sessions us - ON us.id = os.user_session_id - LEFT JOIN user_session_authentications usa - ON usa.session_id = us.id - LEFT JOIN users u - ON u.id = us.user_id - WHERE oc.code = $1 - ORDER BY usa.created_at DESC - LIMIT 1 - "#, - code, - ) - .fetch_one(executor) - .await?; - - let pkce = match (res.code_challenge_method, res.code_challenge) { - (None, None) => None, - (Some(0 /* Plain */), Some(challenge)) => { - Some(Pkce::new(pkce::CodeChallengeMethod::Plain, challenge)) - } - (Some(1 /* S256 */), Some(challenge)) => { - Some(Pkce::new(pkce::CodeChallengeMethod::S256, challenge)) - } - _ => return Err(DatabaseInconsistencyError.into()), - }; - - let code = AuthorizationCode { - data: res.id, - code: code.to_string(), - pkce, - }; - - let client = Client { - data: (), - client_id: res.client_id, - }; - - let browser_session = browser_session_from_database( - res.user_session_id, - res.user_session_created_at, - res.user_id, - res.user_username, - res.user_session_last_authentication_id, - res.user_session_last_authentication_created_at, - )?; - - let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; - - let redirect_uri = res - .redirect_uri - .parse() - .map_err(|_e| DatabaseInconsistencyError)?; - - let session = Session { - data: res.oauth2_session_id, - client, - browser_session, - scope, - redirect_uri, - nonce: res.nonce, - }; - - Ok((code, session)) -} - -pub async fn consume_code( - executor: impl PgExecutor<'_>, - code: AuthorizationCode, -) -> anyhow::Result<()> { - // TODO: mark the code as invalid instead to allow invalidating the whole - // session on code reuse - let res = sqlx::query!( - r#" - DELETE FROM oauth2_codes - WHERE id = $1 - "#, - code.data, - ) - .execute(executor) - .await - .context("could not consume authorization code")?; - - if res.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow::anyhow!( - "no row were affected when consuming authorization code" - )) - } -} diff --git a/crates/core/src/storage/oauth2/authorization_grant.rs b/crates/core/src/storage/oauth2/authorization_grant.rs new file mode 100644 index 00000000..56309957 --- /dev/null +++ b/crates/core/src/storage/oauth2/authorization_grant.rs @@ -0,0 +1,499 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(clippy::unused_async)] + +use std::{ + convert::{TryFrom, TryInto}, + num::NonZeroU32, +}; + +use anyhow::Context; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, + Client, Pkce, Session, User, +}; +use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode, scope::Scope}; +use sqlx::PgExecutor; +use url::Url; + +use crate::storage::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend}; + +#[allow(clippy::too_many_arguments)] +pub async fn new_authorization_grant( + executor: impl PgExecutor<'_>, + client_id: String, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + acr_values: Option, + response_mode: ResponseMode, + response_type_token: bool, + response_type_id_token: bool, +) -> anyhow::Result> { + let code_challenge = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| &p.challenge); + let code_challenge_method = code + .as_ref() + .and_then(|c| c.pkce.as_ref()) + .map(|p| p.challenge_method.to_string()); + let code_str = code.as_ref().map(|c| &c.code); + let res = sqlx::query_as!( + IdAndCreationTime, + r#" + INSERT INTO oauth2_authorization_grants + (client_id, redirect_uri, scope, state, nonce, max_age, + acr_values, response_mode, code_challenge, code_challenge_method, + response_type_code, response_type_token, response_type_id_token, + code) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id, created_at + "#, + &client_id, + redirect_uri.to_string(), + scope.to_string(), + state, + nonce, + // TODO: this conversion is a bit ugly + max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)), + acr_values, + response_mode.to_string(), + code_challenge, + code_challenge_method, + code.is_some(), + response_type_token, + response_type_id_token, + code_str, + ) + .fetch_one(executor) + .await + .context("could not insert oauth2 authorization grant")?; + + let client = Client { + data: (), + client_id, + }; + + Ok(AuthorizationGrant { + data: res.id, + stage: AuthorizationGrantStage::Pending, + code, + redirect_uri, + client, + scope, + state, + nonce, + max_age, + acr_values, + response_mode, + created_at: res.created_at, + response_type_token, + response_type_id_token, + }) +} + +struct GrantLookup { + grant_id: i64, + grant_created_at: DateTime, + grant_cancelled_at: Option>, + grant_fulfilled_at: Option>, + grant_exchanged_at: Option>, + grant_scope: String, + grant_state: Option, + grant_redirect_uri: String, + grant_response_mode: String, + grant_nonce: Option, + #[allow(dead_code)] + grant_max_age: Option, + grant_acr_values: Option, + grant_response_type_code: bool, + grant_response_type_token: bool, + grant_response_type_id_token: bool, + grant_code: Option, + grant_code_challenge: Option, + grant_code_challenge_method: Option, + client_id: String, + session_id: Option, + user_session_id: Option, + user_session_created_at: Option>, + user_id: Option, + user_username: Option, + user_session_last_authentication_id: Option, + user_session_last_authentication_created_at: Option>, +} + +impl TryInto> for GrantLookup { + type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] + fn try_into(self) -> Result, Self::Error> { + let scope: Scope = self + .grant_scope + .parse() + .map_err(|_e| DatabaseInconsistencyError)?; + + let client = Client { + data: (), + client_id: self.client_id, + }; + + let last_authentication = match ( + self.user_session_last_authentication_id, + self.user_session_last_authentication_created_at, + ) { + (Some(id), Some(created_at)) => Some(Authentication { + data: id, + created_at, + }), + (None, None) => None, + _ => return Err(DatabaseInconsistencyError), + }; + + let session = match ( + self.session_id, + self.user_session_id, + self.user_session_created_at, + self.user_id, + self.user_username, + last_authentication, + ) { + ( + Some(session_id), + Some(user_session_id), + Some(user_session_created_at), + Some(user_id), + Some(user_username), + last_authentication, + ) => { + let user = User { + data: user_id, + username: user_username, + sub: format!("fake-sub-{}", user_id), + }; + + let browser_session = BrowserSession { + data: user_session_id, + user, + created_at: user_session_created_at, + last_authentication, + }; + + let client = client.clone(); + let scope = scope.clone(); + + let session = Session { + data: session_id, + client, + browser_session, + scope, + }; + + Some(session) + } + (None, None, None, None, None, None) => None, + _ => return Err(DatabaseInconsistencyError), + }; + + let stage = match ( + self.grant_fulfilled_at, + self.grant_exchanged_at, + self.grant_cancelled_at, + session, + ) { + (None, None, None, None) => AuthorizationGrantStage::Pending, + (Some(fulfilled_at), None, None, Some(session)) => AuthorizationGrantStage::Fulfilled { + session, + fulfilled_at, + }, + (Some(fulfilled_at), Some(exchanged_at), None, Some(session)) => { + AuthorizationGrantStage::Exchanged { + session, + fulfilled_at, + exchanged_at, + } + } + (None, None, Some(cancelled_at), None) => { + AuthorizationGrantStage::Cancelled { cancelled_at } + } + _ => { + return Err(DatabaseInconsistencyError); + } + }; + + let pkce = match (self.grant_code_challenge, self.grant_code_challenge_method) { + (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => { + Some(Pkce { + challenge_method: CodeChallengeMethod::Plain, + challenge, + }) + } + (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce { + challenge_method: CodeChallengeMethod::S256, + challenge, + }), + (None, None) => None, + _ => { + return Err(DatabaseInconsistencyError); + } + }; + + let code: Option = + match (self.grant_response_type_code, self.grant_code, pkce) { + (false, None, None) => None, + (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }), + _ => { + return Err(DatabaseInconsistencyError); + } + }; + + let redirect_uri = self + .grant_redirect_uri + .parse() + .map_err(|_e| DatabaseInconsistencyError)?; + + let response_mode = self + .grant_response_mode + .parse() + .map_err(|_e| DatabaseInconsistencyError)?; + + Ok(AuthorizationGrant { + data: self.grant_id, + stage, + client, + code, + acr_values: self.grant_acr_values, + scope, + state: self.grant_state, + nonce: self.grant_nonce, + max_age: None, // TODO + response_mode, + redirect_uri, + created_at: self.grant_created_at, + response_type_token: self.grant_response_type_token, + response_type_id_token: self.grant_response_type_id_token, + }) + } +} + +pub async fn get_grant_by_id( + executor: impl PgExecutor<'_>, + id: i64, +) -> anyhow::Result> { + // TODO: handle "not found" cases + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT + og.id AS grant_id, + og.created_at AS grant_created_at, + og.cancelled_at AS grant_cancelled_at, + og.fulfilled_at AS grant_fulfilled_at, + og.exchanged_at AS grant_exchanged_at, + og.scope AS grant_scope, + og.state AS grant_state, + og.redirect_uri AS grant_redirect_uri, + og.response_mode AS grant_response_mode, + og.nonce AS grant_nonce, + og.max_age AS grant_max_age, + og.acr_values AS grant_acr_values, + og.client_id AS client_id, + og.code AS grant_code, + og.response_type_code AS grant_response_type_code, + og.response_type_token AS grant_response_type_token, + og.response_type_id_token AS grant_response_type_id_token, + og.code_challenge AS grant_code_challenge, + og.code_challenge_method AS grant_code_challenge_method, + os.id AS "session_id?", + us.id AS "user_session_id?", + us.created_at AS "user_session_created_at?", + u.id AS "user_id?", + u.username AS "user_username?", + usa.id AS "user_session_last_authentication_id?", + usa.created_at AS "user_session_last_authentication_created_at?" + FROM + oauth2_authorization_grants og + LEFT JOIN oauth2_sessions os + ON os.id = og.oauth2_session_id + LEFT JOIN user_sessions us + ON us.id = os.user_session_id + LEFT JOIN users u + ON u.id = us.user_id + LEFT JOIN user_session_authentications usa + ON usa.session_id = us.id + WHERE + og.id = $1 + "#, + id, + ) + .fetch_one(executor) + .await + .context("failed to get grant by id")?; + + let grant = res.try_into()?; + + Ok(grant) +} + +pub async fn lookup_grant_by_code( + executor: impl PgExecutor<'_>, + code: &str, +) -> anyhow::Result> { + // TODO: handle "not found" cases + let res = sqlx::query_as!( + GrantLookup, + r#" + SELECT + og.id AS grant_id, + og.created_at AS grant_created_at, + og.cancelled_at AS grant_cancelled_at, + og.fulfilled_at AS grant_fulfilled_at, + og.exchanged_at AS grant_exchanged_at, + og.scope AS grant_scope, + og.state AS grant_state, + og.redirect_uri AS grant_redirect_uri, + og.response_mode AS grant_response_mode, + og.nonce AS grant_nonce, + og.max_age AS grant_max_age, + og.acr_values AS grant_acr_values, + og.client_id AS client_id, + og.code AS grant_code, + og.response_type_code AS grant_response_type_code, + og.response_type_token AS grant_response_type_token, + og.response_type_id_token AS grant_response_type_id_token, + og.code_challenge AS grant_code_challenge, + og.code_challenge_method AS grant_code_challenge_method, + os.id AS "session_id?", + us.id AS "user_session_id?", + us.created_at AS "user_session_created_at?", + u.id AS "user_id?", + u.username AS "user_username?", + usa.id AS "user_session_last_authentication_id?", + usa.created_at AS "user_session_last_authentication_created_at?" + FROM + oauth2_authorization_grants og + LEFT JOIN oauth2_sessions os + ON os.id = og.oauth2_session_id + LEFT JOIN user_sessions us + ON us.id = os.user_session_id + LEFT JOIN users u + ON u.id = us.user_id + LEFT JOIN user_session_authentications usa + ON usa.session_id = us.id + WHERE + og.code = $1 + "#, + code, + ) + .fetch_one(executor) + .await + .context("failed to lookup grant by code")?; + + let grant = res.try_into()?; + + Ok(grant) +} + +pub async fn derive_session( + executor: impl PgExecutor<'_>, + grant: &AuthorizationGrant, + browser_session: BrowserSession, +) -> anyhow::Result> { + let res = sqlx::query_as!( + IdAndCreationTime, + r#" + INSERT INTO oauth2_sessions + (user_session_id, client_id, scope) + SELECT + $1, + og.client_id, + og.scope + FROM + oauth2_authorization_grants og + WHERE + og.id = $2 + RETURNING id, created_at + "#, + browser_session.data, + grant.data, + ) + .fetch_one(executor) + .await + .context("could not insert oauth2 session")?; + + Ok(Session { + data: res.id, + browser_session, + client: grant.client.clone(), + scope: grant.scope.clone(), + }) +} + +pub async fn fulfill_grant( + executor: impl PgExecutor<'_>, + mut grant: AuthorizationGrant, + session: Session, +) -> anyhow::Result> { + let fulfilled_at = sqlx::query_scalar!( + r#" + UPDATE oauth2_authorization_grants AS og + SET + oauth2_session_id = os.id, + fulfilled_at = os.created_at + FROM oauth2_sessions os + WHERE + og.id = $1 AND os.id = $2 + RETURNING fulfilled_at AS "fulfilled_at!: DateTime" + "#, + grant.data, + session.data, + ) + .fetch_one(executor) + .await + .context("could not makr grant as fulfilled")?; + + grant.stage = grant.stage.fulfill(fulfilled_at, session)?; + + Ok(grant) +} + +pub async fn exchange_grant( + executor: impl PgExecutor<'_>, + mut grant: AuthorizationGrant, +) -> anyhow::Result> { + let exchanged_at = sqlx::query_scalar!( + r#" + UPDATE oauth2_authorization_grants + SET + exchanged_at = NOW() + WHERE + id = $1 + RETURNING exchanged_at AS "exchanged_at!: DateTime" + "#, + grant.data, + ) + .fetch_one(executor) + .await + .context("could not mark grant as exchanged")?; + + grant.stage = grant.stage.exchange(exchanged_at)?; + + Ok(grant) +} diff --git a/crates/core/src/storage/oauth2/mod.rs b/crates/core/src/storage/oauth2/mod.rs index e9f83bc8..46d9d516 100644 --- a/crates/core/src/storage/oauth2/mod.rs +++ b/crates/core/src/storage/oauth2/mod.rs @@ -13,6 +13,5 @@ // limitations under the License. pub mod access_token; -pub mod authorization_code; +pub mod authorization_grant; pub mod refresh_token; -pub mod session; diff --git a/crates/core/src/storage/oauth2/refresh_token.rs b/crates/core/src/storage/oauth2/refresh_token.rs index 6fb54c64..69c12426 100644 --- a/crates/core/src/storage/oauth2/refresh_token.rs +++ b/crates/core/src/storage/oauth2/refresh_token.rs @@ -23,7 +23,7 @@ use crate::storage::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBa pub async fn add_refresh_token( executor: impl PgExecutor<'_>, - oauth2_session_id: i64, + session: &Session, access_token: AccessToken, token: &str, ) -> anyhow::Result> { @@ -37,7 +37,7 @@ pub async fn add_refresh_token( RETURNING id, created_at "#, - oauth2_session_id, + session.data, access_token.data, token, ) @@ -64,8 +64,6 @@ struct OAuth2RefreshTokenLookup { session_id: i64, client_id: String, scope: String, - redirect_uri: String, - nonce: Option, user_session_id: i64, user_session_created_at: DateTime, user_id: i64, @@ -93,8 +91,6 @@ pub async fn lookup_active_refresh_token( os.id AS "session_id!", os.client_id AS "client_id!", os.scope AS "scope!", - os.redirect_uri AS "redirect_uri!", - os.nonce AS "nonce", us.id AS "user_session_id!", us.created_at AS "user_session_created_at!", u.id AS "user_id!", @@ -173,23 +169,18 @@ pub async fn lookup_active_refresh_token( _ => return Err(DatabaseInconsistencyError.into()), }; - let browser_session = Some(BrowserSession { + let browser_session = BrowserSession { data: res.user_session_id, created_at: res.user_session_created_at, user, last_authentication, - }); + }; let session = Session { data: res.session_id, client, browser_session, scope: res.scope.parse().context("invalid scope in database")?, - redirect_uri: res - .redirect_uri - .parse() - .context("invalid redirect_uri in database")?, - nonce: res.nonce, }; Ok((refresh_token, session)) diff --git a/crates/core/src/storage/oauth2/session.rs b/crates/core/src/storage/oauth2/session.rs deleted file mode 100644 index 9cac2438..00000000 --- a/crates/core/src/storage/oauth2/session.rs +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString}; - -use anyhow::Context; -use chrono::{DateTime, Duration, Utc}; -use itertools::Itertools; -use mas_data_model::{AuthorizationCode, BrowserSession}; -use oauth2_types::{ - pkce, - requests::{ResponseMode, ResponseType}, -}; -use serde::Serialize; -use sqlx::PgExecutor; -use url::Url; - -use super::authorization_code::add_code; -use crate::storage::{lookup_active_session, PostgresqlBackend}; - -#[derive(Serialize)] -pub struct OAuth2Session { - pub id: i64, - user_session_id: Option, - pub client_id: String, - redirect_uri: String, - scope: String, - pub state: Option, - nonce: Option, - max_age: Option, - response_type: String, - response_mode: String, - - created_at: DateTime, - updated_at: DateTime, -} - -impl OAuth2Session { - pub async fn add_code<'e>( - &self, - executor: impl PgExecutor<'e>, - code: &str, - code_challenge: &Option, - ) -> anyhow::Result> { - add_code(executor, self.id, code, code_challenge).await - } - - pub async fn fetch_session( - &self, - executor: impl PgExecutor<'_>, - ) -> anyhow::Result>> { - match self.user_session_id { - Some(id) => { - // TODO: and if the session is inactive? - let info = lookup_active_session(executor, id).await?; - Ok(Some(info)) - } - None => Ok(None), - } - } - - pub async fn fetch_code(&self, executor: impl PgExecutor<'_>) -> anyhow::Result { - get_code_for_session(executor, self.id).await - } - - pub async fn match_or_set_session( - &mut self, - executor: impl PgExecutor<'_>, - session: BrowserSession, - ) -> anyhow::Result> { - match self.user_session_id { - Some(id) if id == session.data => Ok(session), - Some(id) => Err(anyhow::anyhow!( - "session mismatch, expected {}, got {}", - id, - session.data - )), - None => { - sqlx::query!( - "UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2", - session.data, - self.id, - ) - .execute(executor) - .await - .context("could not update oauth2 session")?; - Ok(session) - } - } - } - - #[must_use] - pub fn max_auth_time(&self) -> Option> { - self.max_age - .map(|d| Duration::seconds(i64::from(d))) - .map(|d| self.created_at - d) - } - - pub fn response_type(&self) -> anyhow::Result> { - self.response_type - .split(' ') - .map(|s| { - ResponseType::from_str(s).with_context(|| format!("invalid response type {}", s)) - }) - .collect() - } - - pub fn response_mode(&self) -> anyhow::Result { - self.response_mode.parse().context("invalid response mode") - } - - pub fn redirect_uri(&self) -> anyhow::Result { - self.redirect_uri.parse().context("invalid redirect uri") - } -} - -#[allow(clippy::too_many_arguments)] -pub async fn start_session( - executor: impl PgExecutor<'_>, - optional_session_id: Option, - client_id: &str, - redirect_uri: &Url, - scope: &str, - state: Option<&str>, - nonce: Option<&str>, - max_age: Option, - response_type: &HashSet, - response_mode: ResponseMode, -) -> anyhow::Result { - // Checked convertion of duration to i32, maxing at i32::MAX - let max_age = max_age.map(|d| i32::try_from(d.num_seconds()).unwrap_or(i32::MAX)); - let response_mode = response_mode.to_string(); - let redirect_uri = redirect_uri.to_string(); - let response_type: String = { - let it = response_type.iter().map(ToString::to_string); - Itertools::intersperse(it, " ".to_string()).collect() - }; - - sqlx::query_as!( - OAuth2Session, - r#" - INSERT INTO oauth2_sessions - (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age, - response_type, response_mode) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9) - RETURNING - id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age, - response_type, response_mode, created_at, updated_at - "#, - optional_session_id, - client_id, - redirect_uri, - scope, - state, - nonce, - max_age, - response_type, - response_mode, - ) - .fetch_one(executor) - .await - .context("could not insert oauth2 session") -} - -pub async fn get_session_by_id( - executor: impl PgExecutor<'_>, - oauth2_session_id: i64, -) -> anyhow::Result { - sqlx::query_as!( - OAuth2Session, - r#" - SELECT - id, user_session_id, client_id, redirect_uri, scope, state, nonce, - max_age, response_type, response_mode, created_at, updated_at - FROM oauth2_sessions - WHERE id = $1 - "#, - oauth2_session_id - ) - .fetch_one(executor) - .await - .context("could not fetch oauth2 session") -} - -pub async fn get_code_for_session( - executor: impl PgExecutor<'_>, - oauth2_session_id: i64, -) -> anyhow::Result { - sqlx::query_scalar!( - r#" - SELECT code - FROM oauth2_codes - WHERE oauth2_session_id = $1 - "#, - oauth2_session_id - ) - .fetch_one(executor) - .await - .context("could not fetch oauth2 code") -} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index a59e786c..e3e8fbf5 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::num::NonZeroU32; + use chrono::{DateTime, Duration, Utc}; -use oauth2_types::{pkce::CodeChallengeMethod, scope::Scope}; +use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode, scope::Scope}; use serde::Serialize; +use thiserror::Error; use url::Url; pub mod errors; @@ -27,7 +30,7 @@ pub trait StorageBackend { type BrowserSessionData: Clone + std::fmt::Debug + PartialEq; type ClientData: Clone + std::fmt::Debug + PartialEq; type SessionData: Clone + std::fmt::Debug + PartialEq; - type AuthorizationCodeData: Clone + std::fmt::Debug + PartialEq; + type AuthorizationGrantData: Clone + std::fmt::Debug + PartialEq; type AccessTokenData: Clone + std::fmt::Debug + PartialEq; type RefreshTokenData: Clone + std::fmt::Debug + PartialEq; } @@ -35,7 +38,7 @@ pub trait StorageBackend { impl StorageBackend for () { type AccessTokenData = (); type AuthenticationData = (); - type AuthorizationCodeData = (); + type AuthorizationGrantData = (); type BrowserSessionData = (); type ClientData = (); type RefreshTokenData = (); @@ -153,60 +156,18 @@ impl From> for Client<()> { pub struct Session { #[serde(skip_serializing)] pub data: T::SessionData, - pub browser_session: Option>, + pub browser_session: BrowserSession, pub client: Client, pub scope: Scope, - pub redirect_uri: Url, - pub nonce: Option, } impl From> for Session<()> { fn from(s: Session) -> Self { Session { data: (), - browser_session: s.browser_session.map(Into::into), + browser_session: s.browser_session.into(), client: s.client.into(), scope: s.scope, - redirect_uri: s.redirect_uri, - nonce: s.nonce, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct Pkce { - challenge_method: CodeChallengeMethod, - challenge: String, -} - -impl Pkce { - pub fn new(challenge_method: CodeChallengeMethod, challenge: String) -> Self { - Pkce { - challenge_method, - challenge, - } - } - - pub fn verify(&self, verifier: &str) -> bool { - self.challenge_method.verify(&self.challenge, verifier) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -#[serde(bound = "T: StorageBackend")] -pub struct AuthorizationCode { - #[serde(skip_serializing)] - pub data: T::AuthorizationCodeData, - pub code: String, - pub pkce: Option, -} - -impl From> for AuthorizationCode<()> { - fn from(c: AuthorizationCode) -> Self { - AuthorizationCode { - data: (), - code: c.code, - pkce: c.pkce, } } } @@ -256,3 +217,125 @@ impl From> for RefreshToken<()> { } } } + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct Pkce { + pub challenge_method: CodeChallengeMethod, + pub challenge: String, +} + +impl Pkce { + pub fn new(challenge_method: CodeChallengeMethod, challenge: String) -> Self { + Pkce { + challenge_method, + challenge, + } + } + + pub fn verify(&self, verifier: &str) -> bool { + self.challenge_method.verify(&self.challenge, verifier) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct AuthorizationCode { + pub code: String, + pub pkce: Option, +} + +#[derive(Debug, Error)] +#[error("invalid state transition")] +pub struct InvalidTransitionError; + +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(bound = "T: StorageBackend")] +pub enum AuthorizationGrantStage { + Pending, + Fulfilled { + session: Session, + fulfilled_at: DateTime, + }, + Exchanged { + session: Session, + fulfilled_at: DateTime, + exchanged_at: DateTime, + }, + Cancelled { + cancelled_at: DateTime, + }, +} + +impl Default for AuthorizationGrantStage { + fn default() -> Self { + Self::Pending + } +} + +impl AuthorizationGrantStage { + pub fn new() -> Self { + Self::Pending + } + + pub fn fulfill( + self, + fulfilled_at: DateTime, + session: Session, + ) -> Result { + match self { + Self::Pending => Ok(Self::Fulfilled { + fulfilled_at, + session, + }), + _ => Err(InvalidTransitionError), + } + } + + pub fn exchange(self, exchanged_at: DateTime) -> Result { + match self { + Self::Fulfilled { + fulfilled_at, + session, + } => Ok(Self::Exchanged { + fulfilled_at, + exchanged_at, + session, + }), + _ => Err(InvalidTransitionError), + } + } + + pub fn cancel(self, cancelled_at: DateTime) -> Result { + match self { + Self::Pending => Ok(Self::Cancelled { cancelled_at }), + _ => Err(InvalidTransitionError), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(bound = "T: StorageBackend")] +pub struct AuthorizationGrant { + #[serde(skip_serializing)] + pub data: T::AuthorizationGrantData, + #[serde(flatten)] + pub stage: AuthorizationGrantStage, + pub code: Option, + pub client: Client, + pub redirect_uri: Url, + pub scope: Scope, + pub state: Option, + pub nonce: Option, + pub max_age: Option, + pub acr_values: Option, + pub response_mode: ResponseMode, + pub response_type_token: bool, + pub response_type_id_token: bool, + pub created_at: DateTime, +} + +impl AuthorizationGrant { + pub fn max_auth_time(&self) -> DateTime { + let max_age: Option = self.max_age.map(|x| x.get().into()); + self.created_at + Duration::seconds(max_age.unwrap_or(3600 * 24 * 365)) + } +} diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index 5ab20a32..0a499d3e 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -148,8 +148,7 @@ pub struct AuthorizationRequest { pub redirect_uri: Option, - #[serde_as(as = "StringWithSeparator::")] - pub scope: HashSet, + pub scope: Scope, pub state: Option,