From cde9187adc53ab5a2bba85d510d163ea5eb3137b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 23 Nov 2022 16:31:47 +0100 Subject: [PATCH] Lookup and save upstream links --- crates/data-model/src/upstream_oauth2/mod.rs | 7 + .../handlers/src/upstream_oauth2/callback.rs | 63 ++++++++- crates/storage/sqlx-data.json | 67 ++++++++++ crates/storage/src/upstream_oauth2/link.rs | 120 ++++++++++++++++++ crates/storage/src/upstream_oauth2/mod.rs | 4 +- crates/storage/src/upstream_oauth2/session.rs | 44 ++++++- 6 files changed, 293 insertions(+), 12 deletions(-) create mode 100644 crates/storage/src/upstream_oauth2/link.rs diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index ace6f100..8f2dc485 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -46,3 +46,10 @@ pub struct UpstreamOAuthAuthorizationSession { pub created_at: DateTime, pub completed_at: Option>, } + +impl UpstreamOAuthAuthorizationSession { + #[must_use] + pub const fn completed(&self) -> bool { + self.completed_at.is_some() + } +} diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 632a150b..2ddf43a8 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -21,13 +21,17 @@ use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_http::ClientInitError; +use mas_jose::claims::ClaimError; use mas_keystore::{Encrypter, Keystore}; use mas_oidc_client::{ error::{DiscoveryError, JwksError, TokenAuthorizationCodeError}, requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData}, }; use mas_router::UrlBuilder; -use mas_storage::{upstream_oauth2::lookup_session, LookupResultExt}; +use mas_storage::{ + upstream_oauth2::{add_link, complete_session, lookup_link_by_subject, lookup_session}, + GenericLookupError, LookupResultExt, +}; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; use sqlx::PgPool; @@ -66,9 +70,21 @@ pub(crate) enum RouteError { #[error("Provider mismatch")] ProviderMismatch, + #[error("Session already completed")] + AlreadyCompleted, + #[error("State parameter mismatch")] StateMismatch, + #[error("Missing ID token")] + MissingIDToken, + + #[error("Invalid ID token")] + InvalidIdToken(#[from] ClaimError), + + #[error("User already linked")] + UserAlreadyLinked, + #[error("Error from the provider: {error}")] ClientError { error: ClientErrorCode, @@ -88,6 +104,12 @@ pub(crate) enum RouteError { Anyhow(#[from] anyhow::Error), } +impl From for RouteError { + fn from(e: GenericLookupError) -> Self { + Self::InternalError(Box::new(e)) + } +} + impl From for RouteError { fn from(e: sqlx::Error) -> Self { Self::InternalError(Box::new(e)) @@ -182,6 +204,11 @@ pub(crate) async fn get( return Err(RouteError::StateMismatch); } + if session.completed() { + // The session was already completed + return Err(RouteError::AlreadyCompleted); + } + // Let's extract the code from the params, and return if there was an error let code = match params.code_or_error { CodeOrError::Error { @@ -224,10 +251,11 @@ pub(crate) async fn get( let redirect_uri = url_builder.upstream_oauth_callback(provider.id); + // TODO: all that should be borrowed let validation_data = AuthorizationValidationData { - state: session.state, - nonce: session.nonce, - code_challenge_verifier: session.code_challenge_verifier, + state: session.state.clone(), + nonce: session.nonce.clone(), + code_challenge_verifier: session.code_challenge_verifier.clone(), redirect_uri, }; @@ -243,7 +271,7 @@ pub(crate) async fn get( .http_service("upstream-exchange-code") .await?; - let (response, _id_token) = + let (response, id_token) = mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( &http_service, client_credentials, @@ -256,5 +284,30 @@ pub(crate) async fn get( ) .await?; + let (_header, mut id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts(); + + // Extract the subject from the id_token + let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; + + // Look for an existing link + let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject) + .await + .to_option()?; + + let link = if let Some((link, maybe_user_id)) = maybe_link { + if let Some(_user_id) = maybe_user_id { + // TODO: Here we should login if the user is linked + return Err(RouteError::UserAlreadyLinked); + } + + link + } else { + add_link(&mut txn, &mut rng, &clock, &provider, subject).await? + }; + + let _session = complete_session(&mut txn, &clock, session, &link).await?; + + txn.commit().await?; + Ok(Json(response)) } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 26ab9744..4d06047d 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -689,6 +689,45 @@ }, "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.email = $2\n " }, + "3a6de39a88ef93a91f3cc0465785bafd58ef7dbd4aae924a8bcfcefaf2f1a0d7": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_link_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "subject", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_link_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " + }, "3df0838b660466f69ee681337fe6753133748defb715e53c8381badcc3e8bca9": { "describe": { "columns": [ @@ -954,6 +993,19 @@ }, "query": "\n INSERT INTO user_passwords (user_password_id, user_id, hashed_password, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "4b6a44d040a0dc849bb4e04abb11a181995b5847917605ef4c160389686a54f5": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2\n " + }, "4f8ec19f3f1bfe0268fe102a24e5a9fa542e77eccbebdce65e6deb1c197adf36": { "describe": { "columns": [ @@ -2417,6 +2469,21 @@ }, "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " }, + "e1dc9dd2bf26a341050a53151bf51f7638448ccc2bd458bbdfe87cc22f086313": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " + }, "e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": { "describe": { "columns": [], diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs new file mode 100644 index 00000000..9b3fb696 --- /dev/null +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -0,0 +1,120 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider}; +use rand::Rng; +use sqlx::PgExecutor; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{Clock, GenericLookupError}; + +struct LinkLookup { + upstream_oauth_link_id: Uuid, + user_id: Option, + subject: String, + created_at: DateTime, +} + +#[tracing::instrument( + skip_all, + fields( + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, +)] +pub async fn lookup_link_by_subject( + executor: impl PgExecutor<'_>, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, +) -> Result<(UpstreamOAuthLink, Option), GenericLookupError> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + AND subject = $2 + "#, + Uuid::from(upstream_oauth_provider.id), + subject, + ) + .fetch_one(executor) + .await + .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; + + Ok(( + UpstreamOAuthLink { + id: Ulid::from(res.upstream_oauth_link_id), + subject: res.subject, + created_at: res.created_at, + }, + res.user_id.map(Ulid::from), + )) +} + +#[tracing::instrument( + skip_all, + fields( + upstream_oauth_link.id, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, +)] +pub async fn add_link( + executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, +) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); + tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_links ( + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + ) VALUES ($1, $2, NULL, $3, $4) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &subject, + created_at, + ) + .execute(executor) + .await?; + + Ok(UpstreamOAuthLink { + id, + subject, + created_at, + }) +} diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 8acb8229..4e5a0bff 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod link; mod provider; mod session; pub use self::{ + link::{add_link, lookup_link_by_subject}, provider::{add_provider, lookup_provider, ProviderLookupError}, - session::{add_session, lookup_session, SessionLookupError}, + session::{add_session, complete_session, lookup_session, SessionLookupError}, }; diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 43be7a99..14a00a2f 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -13,7 +13,7 @@ // limitations under the License. use chrono::{DateTime, Utc}; -use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthProvider}; +use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; use rand::Rng; use sqlx::PgExecutor; use thiserror::Error; @@ -128,9 +128,9 @@ pub async fn lookup_session( #[tracing::instrument( skip_all, fields( - upstream_oauth_provider.id = %provider.id, - upstream_oauth_provider.issuer = %provider.issuer, - upstream_oauth_provider.client_id = %provider.client_id, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, upstream_oauth_authorization_session.id, ), err, @@ -139,7 +139,7 @@ pub async fn add_session( executor: impl PgExecutor<'_>, mut rng: impl Rng + Send, clock: &Clock, - provider: &UpstreamOAuthProvider, + upstream_oauth_provider: &UpstreamOAuthProvider, state: String, code_challenge_verifier: Option, nonce: String, @@ -164,7 +164,7 @@ pub async fn add_session( ) VALUES ($1, $2, $3, $4, $5, $6, NULL) "#, Uuid::from(id), - Uuid::from(provider.id), + Uuid::from(upstream_oauth_provider.id), &state, code_challenge_verifier.as_deref(), nonce, @@ -182,3 +182,35 @@ pub async fn add_session( completed_at: None, }) } + +#[tracing::instrument( + skip_all, + fields( + %upstream_oauth_authorization_session.id, + %upstream_oauth_link.id, + ), + err, +)] +pub async fn complete_session( + executor: impl PgExecutor<'_>, + clock: &Clock, + mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, +) -> Result { + let completed_at = clock.now(); + sqlx::query!( + r#" + UPDATE upstream_oauth_authorization_sessions + SET upstream_oauth_link_id = $1, + completed_at = $2 + "#, + Uuid::from(upstream_oauth_link.id), + completed_at, + ) + .execute(executor) + .await?; + + upstream_oauth_authorization_session.completed_at = Some(completed_at); + + Ok(upstream_oauth_authorization_session) +}