From 22a337cd45fd22bfd6a663489590f7e95770f784 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 23 Nov 2022 17:26:59 +0100 Subject: [PATCH] WIP: handle account linking --- crates/handlers/src/lib.rs | 4 + .../handlers/src/upstream_oauth2/callback.rs | 7 +- crates/handlers/src/upstream_oauth2/link.rs | 256 ++++++++++++++++++ crates/handlers/src/upstream_oauth2/mod.rs | 1 + crates/router/src/endpoints.rs | 23 ++ crates/storage/sqlx-data.json | 229 +++++++++++++--- crates/storage/src/upstream_oauth2/link.rs | 40 +++ crates/storage/src/upstream_oauth2/mod.rs | 6 +- crates/storage/src/upstream_oauth2/session.rs | 65 ++++- crates/storage/src/user.rs | 57 ++++ crates/templates/Cargo.toml | 2 +- crates/templates/src/context.rs | 31 +++ crates/templates/src/lib.rs | 22 +- .../pages/upstream_oauth2/already_linked.html | 25 ++ templates/pages/upstream_oauth2/do_login.html | 32 +++ .../pages/upstream_oauth2/do_register.html | 33 +++ .../pages/upstream_oauth2/link_mismatch.html | 29 ++ .../pages/upstream_oauth2/suggest_link.html | 36 +++ 18 files changed, 848 insertions(+), 50 deletions(-) create mode 100644 crates/handlers/src/upstream_oauth2/link.rs create mode 100644 templates/pages/upstream_oauth2/already_linked.html create mode 100644 templates/pages/upstream_oauth2/do_login.html create mode 100644 templates/pages/upstream_oauth2/do_register.html create mode 100644 templates/pages/upstream_oauth2/link_mismatch.html create mode 100644 templates/pages/upstream_oauth2/suggest_link.html diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 36d993bb..fe9163f5 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -308,6 +308,10 @@ where mas_router::UpstreamOAuth2Callback::route(), get(self::upstream_oauth2::callback::get), ) + .route( + mas_router::UpstreamOAuth2Link::route(), + get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post), + ) .layer(AndThenLayer::new( move |response: axum::response::Response| async move { if response.status().is_server_error() { diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 2ddf43a8..fdbfffba 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -15,7 +15,6 @@ use axum::{ extract::{Path, Query, State}, response::IntoResponse, - Json, }; use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; @@ -27,7 +26,7 @@ use mas_oidc_client::{ error::{DiscoveryError, JwksError, TokenAuthorizationCodeError}, requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData}, }; -use mas_router::UrlBuilder; +use mas_router::{Route, UrlBuilder}; use mas_storage::{ upstream_oauth2::{add_link, complete_session, lookup_link_by_subject, lookup_session}, GenericLookupError, LookupResultExt, @@ -271,7 +270,7 @@ pub(crate) async fn get( .http_service("upstream-exchange-code") .await?; - let (response, id_token) = + let (_response, id_token) = mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( &http_service, client_credentials, @@ -309,5 +308,5 @@ pub(crate) async fn get( txn.commit().await?; - Ok(Json(response)) + Ok(mas_router::UpstreamOAuth2Link::new(link.id).go()) } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs new file mode 100644 index 00000000..94beb059 --- /dev/null +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -0,0 +1,256 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::{ + extract::{Path, State}, + response::{Html, IntoResponse}, + Form, +}; +use axum_extra::extract::PrivateCookieJar; +use hyper::StatusCode; +use mas_axum_utils::{ + csrf::{CsrfError, CsrfExt, ProtectedForm}, + SessionInfoExt, +}; +use mas_keystore::Encrypter; +use mas_storage::{ + upstream_oauth2::{lookup_link, lookup_session_on_link}, + user::{lookup_user, ActiveSessionLookupError, UserLookupError}, + GenericLookupError, LookupResultExt, +}; +use mas_templates::{ + EmptyContext, TemplateContext, TemplateError, Templates, UpstreamExistingLinkContext, +}; +use serde::Deserialize; +use sqlx::PgPool; +use thiserror::Error; +use ulid::Ulid; + +#[derive(Debug, Error)] +pub(crate) enum RouteError { + /// Couldn't find the link specified in the URL + #[error("Link not found")] + LinkNotFound, + + /// Couldn't find the session on the link + #[error("Session not found")] + SessionNotFound, + + #[error("Missing session cookie")] + MissingCookie, + + #[error("Invalid session cookie")] + InvalidCookie(#[source] ulid::DecodeError), + + #[error("Invalid form action")] + InvalidFormAction, + + #[error(transparent)] + InternalError(Box), + + #[error(transparent)] + Anyhow(#[from] anyhow::Error), +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: TemplateError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: ActiveSessionLookupError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: CsrfError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: UserLookupError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: GenericLookupError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + match self { + Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(), + Self::InternalError(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + Self::Anyhow(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response() + } + e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), + } + } +} + +#[derive(Deserialize)] +#[serde(rename_all = "lowercase", tag = "action")] +pub(crate) enum FormData { + Register { username: String }, + Link, + Login, +} + +pub(crate) async fn get( + State(pool): State, + State(templates): State, + cookie_jar: PrivateCookieJar, + Path(link_id): Path, +) -> Result { + let mut txn = pool.begin().await?; + let (clock, mut rng) = crate::rng_and_clock()?; + + let (link, _provider_id, maybe_user_id) = lookup_link(&mut txn, link_id) + .await + .to_option()? + .ok_or(RouteError::LinkNotFound)?; + + // XXX: that cookie should be managed elsewhere + let cookie = cookie_jar + .get("upstream-oauth2-session-id") + .ok_or(RouteError::MissingCookie)?; + + let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?; + + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + let _upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + .await + .to_option()? + .ok_or(RouteError::SessionNotFound)?; + + let (user_session_info, cookie_jar) = cookie_jar.session_info(); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); + let maybe_user_session = user_session_info.load_session(&mut txn).await?; + + let render = match (maybe_user_session, maybe_user_id) { + (Some(user_session), Some(user_id)) if user_session.user.data == user_id => { + // Session already linked, and link matches the currently logged + // user. Do nothing? + let ctx = EmptyContext + .with_session(user_session) + .with_csrf(csrf_token.form_value()); + + templates + .render_upstream_oauth2_already_linked(&ctx) + .await? + } + + (Some(user_session), Some(user_id)) => { + // Session already linked, but link doesn't match the currently + // logged user. Suggest logging out of the current user + // and logging in with the new one + let user = lookup_user(&mut txn, user_id).await?; + + let ctx = UpstreamExistingLinkContext::new(user) + .with_session(user_session) + .with_csrf(csrf_token.form_value()); + + templates.render_upstream_oauth2_link_mismatch(&ctx).await? + } + + (Some(user_session), None) => { + // Session not linked, but user logged in: suggest linking account + let ctx = EmptyContext + .with_session(user_session) + .with_csrf(csrf_token.form_value()); + + templates.render_upstream_oauth2_suggest_link(&ctx).await? + } + + (None, Some(user_id)) => { + // Session linked, but user not logged in: do the login + let user = lookup_user(&mut txn, user_id).await?; + + let ctx = UpstreamExistingLinkContext::new(user).with_csrf(csrf_token.form_value()); + + templates.render_upstream_oauth2_do_login(&ctx).await? + } + + (None, None) => { + // Session not linked and used not logged in: suggest creating an + // account or logging in an existing user + let ctx = EmptyContext.with_csrf(csrf_token.form_value()); + + templates.render_upstream_oauth2_do_register(&ctx).await? + } + }; + + Ok((cookie_jar, Html(render))) +} + +pub(crate) async fn post( + State(pool): State, + cookie_jar: PrivateCookieJar, + Path(link_id): Path, + Form(form): Form>, +) -> Result { + let mut txn = pool.begin().await?; + let (clock, _rng) = crate::rng_and_clock()?; + let form = cookie_jar.verify_form(clock.now(), form)?; + + let (link, _provider_id, maybe_user_id) = lookup_link(&mut txn, link_id) + .await + .to_option()? + .ok_or(RouteError::LinkNotFound)?; + + // XXX: that cookie should be managed elsewhere + let cookie = cookie_jar + .get("upstream-oauth2-session-id") + .ok_or(RouteError::MissingCookie)?; + + let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?; + + // This checks that we're in a browser session which is allowed to consume this + // link: the upstream auth session should have been started in this browser. + let _upstream_session = lookup_session_on_link(&mut txn, &link, session_id) + .await + .to_option()? + .ok_or(RouteError::SessionNotFound)?; + + let (user_session_info, cookie_jar) = cookie_jar.session_info(); + let maybe_user_session = user_session_info.load_session(&mut txn).await?; + + let res = match (maybe_user_session, maybe_user_id, form) { + (Some(_user_session), None, FormData::Link) => "Linked!".to_owned(), + + (None, Some(_user_id), FormData::Login) => "Logged in!".to_owned(), + + (None, None, FormData::Register { username }) => format!("Registered {username}!"), + + _ => return Err(RouteError::InvalidFormAction), + }; + + Ok((cookie_jar, res)) +} diff --git a/crates/handlers/src/upstream_oauth2/mod.rs b/crates/handlers/src/upstream_oauth2/mod.rs index 4cb889bc..154fbeb5 100644 --- a/crates/handlers/src/upstream_oauth2/mod.rs +++ b/crates/handlers/src/upstream_oauth2/mod.rs @@ -22,6 +22,7 @@ use url::Url; pub(crate) mod authorize; pub(crate) mod callback; +pub(crate) mod link; #[derive(Debug, Error)] enum ProviderCredentialsError { diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 08e558ff..271baa07 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -570,6 +570,29 @@ impl Route for UpstreamOAuth2Callback { } } +/// `GET /upstream/link/:id` +pub struct UpstreamOAuth2Link { + id: Ulid, +} + +impl UpstreamOAuth2Link { + #[must_use] + pub const fn new(id: Ulid) -> Self { + Self { id } + } +} + +impl Route for UpstreamOAuth2Link { + type Query = (); + fn route() -> &'static str { + "/upstream/link/:link_id" + } + + fn path(&self) -> std::borrow::Cow<'static, str> { + format!("/upstream/link/{}", self.id).into() + } +} + /// `GET /assets` pub struct StaticAsset { path: String, diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 4d06047d..208c8185 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -689,45 +689,6 @@ }, "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.email = $2\n " }, - "3a6de39a88ef93a91f3cc0465785bafd58ef7dbd4aae924a8bcfcefaf2f1a0d7": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_link_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "subject", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_link_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " - }, "3df0838b660466f69ee681337fe6753133748defb715e53c8381badcc3e8bca9": { "describe": { "columns": [ @@ -978,6 +939,50 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, + "47d4048365144c7bfc14790dfb8fa7f862d2952075a68cd5e90ac76d9e6d1388": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_link_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "subject", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " + }, "47fff42fd9871f73baf3e3ebb9e296fa65f7bc99f94639891f29d56d204b659a": { "describe": { "columns": [], @@ -1181,6 +1186,56 @@ }, "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n FROM compat_access_tokens ca\n WHERE ca.access_token = $1\n AND ca.compat_session_id = cs.compat_session_id\n AND cs.finished_at IS NULL\n RETURNING cs.compat_session_id\n " }, + "59439585536bb4e547a6cf58a8bc6ac735f29c225bcbeac7d371f09166789a73": { + "describe": { + "columns": [ + { + "name": "user_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_username", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "user_email_id?", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "user_email?", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "user_email_created_at?", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "user_email_confirmed_at?", + "ordinal": 5, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n u.user_id,\n u.username AS user_username,\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n FROM users u\n\n LEFT JOIN user_emails ue\n USING (user_id)\n\n WHERE u.user_id = $1\n " + }, "5b5d5c82da37c6f2d8affacfb02119965c04d1f2a9cc53dbf5bd4c12584969a0": { "describe": { "columns": [], @@ -1193,6 +1248,57 @@ }, "query": "\n DELETE FROM oauth2_access_tokens\n WHERE expires_at < $1\n " }, + "5cb91740580a37044dd37c90a2fadaab9abcd387c7883f47c73c18a8fa260683": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_authorization_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "state", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "code_challenge_verifier", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "completed_at", + "ordinal": 5, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n " + }, "5ccde09ee3fe43e7b492d73fa67708b5dcb2b7496c4d05bcfcf0ea63c7576d48": { "describe": { "columns": [ @@ -2497,5 +2603,50 @@ } }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " + }, + "f71cb5761bfc15d8bc3ba7ee49b63fb3c3ea9691745688eb5fd91f4f6e1ec018": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_link_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "subject", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " } } \ No newline at end of file diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 9b3fb696..e0be2596 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -23,11 +23,50 @@ use crate::{Clock, GenericLookupError}; struct LinkLookup { upstream_oauth_link_id: Uuid, + upstream_oauth_provider_id: Uuid, user_id: Option, subject: String, created_at: DateTime, } +#[tracing::instrument( + skip_all, + fields(upstream_oauth_link.id = %id), + err, +)] +pub async fn lookup_link( + executor: impl PgExecutor<'_>, + id: Ulid, +) -> Result<(UpstreamOAuthLink, Ulid, Option), GenericLookupError> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_link_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(executor) + .await + .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; + + Ok(( + UpstreamOAuthLink { + id: Ulid::from(res.upstream_oauth_link_id), + subject: res.subject, + created_at: res.created_at, + }, + Ulid::from(res.upstream_oauth_provider_id), + res.user_id.map(Ulid::from), + )) +} + #[tracing::instrument( skip_all, fields( @@ -48,6 +87,7 @@ pub async fn lookup_link_by_subject( r#" SELECT upstream_oauth_link_id, + upstream_oauth_provider_id, user_id, subject, created_at diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 4e5a0bff..58dccf0a 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -17,7 +17,9 @@ mod provider; mod session; pub use self::{ - link::{add_link, lookup_link_by_subject}, + link::{add_link, lookup_link, lookup_link_by_subject}, provider::{add_provider, lookup_provider, ProviderLookupError}, - session::{add_session, complete_session, lookup_session, SessionLookupError}, + session::{ + add_session, complete_session, lookup_session, lookup_session_on_link, SessionLookupError, + }, }; diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 14a00a2f..43020b94 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -20,7 +20,7 @@ use thiserror::Error; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseInconsistencyError, LookupError}; +use crate::{Clock, DatabaseInconsistencyError, GenericLookupError, LookupError}; #[derive(Debug, Error)] #[error("Failed to lookup upstream OAuth 2.0 authorization session")] @@ -35,7 +35,7 @@ impl LookupError for SessionLookupError { } } -struct SessionLookup { +struct SessionAndProviderLookup { upstream_oauth_authorization_session_id: Uuid, upstream_oauth_provider_id: Uuid, state: String, @@ -52,6 +52,7 @@ struct SessionLookup { provider_created_at: DateTime, } +/// Lookup a session and its provider by its ID #[tracing::instrument( skip_all, fields(upstream_oauth_authorization_session.id = %id), @@ -62,7 +63,7 @@ pub async fn lookup_session( id: Ulid, ) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> { let res = sqlx::query_as!( - SessionLookup, + SessionAndProviderLookup, r#" SELECT ua.upstream_oauth_authorization_session_id, @@ -125,6 +126,7 @@ pub async fn lookup_session( Ok((provider, session)) } +/// Add a session to the database #[tracing::instrument( skip_all, fields( @@ -183,6 +185,7 @@ pub async fn add_session( }) } +/// Mark a session as completed and associate the given link #[tracing::instrument( skip_all, fields( @@ -214,3 +217,59 @@ pub async fn complete_session( Ok(upstream_oauth_authorization_session) } + +struct SessionLookup { + upstream_oauth_authorization_session_id: Uuid, + state: String, + code_challenge_verifier: Option, + nonce: String, + created_at: DateTime, + completed_at: Option>, +} + +/// Lookup a session, which belongs to a link, by its ID +#[tracing::instrument( + skip_all, + fields( + upstream_oauth_authorization_session.id = %id, + %upstream_oauth_link.id, + ), + err, +)] +pub async fn lookup_session_on_link( + executor: impl PgExecutor<'_>, + upstream_oauth_link: &UpstreamOAuthLink, + id: Ulid, +) -> Result { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT + upstream_oauth_authorization_session_id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at + FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_authorization_session_id = $1 + AND upstream_oauth_link_id = $2 + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_link.id), + ) + .fetch_one(executor) + .await + .map_err(GenericLookupError::what( + "Upstream OAuth 2.0 session on link", + ))?; + + Ok(UpstreamOAuthAuthorizationSession { + id: res.upstream_oauth_authorization_session_id.into(), + state: res.state, + code_challenge_verifier: res.code_challenge_verifier, + nonce: res.nonce, + created_at: res.created_at, + completed_at: res.completed_at, + }) +} diff --git a/crates/storage/src/user.rs b/crates/storage/src/user.rs index d796ef3d..c4c9daf1 100644 --- a/crates/storage/src/user.rs +++ b/crates/storage/src/user.rs @@ -628,6 +628,63 @@ pub async fn lookup_user_by_username( }) } +#[tracing::instrument( + skip_all, + fields(user.id = %id), + err, +)] +pub async fn lookup_user( + executor: impl PgExecutor<'_>, + id: Ulid, +) -> Result, UserLookupError> { + let res = sqlx::query_as!( + UserLookup, + r#" + SELECT + u.user_id, + u.username AS user_username, + ue.user_email_id AS "user_email_id?", + ue.email AS "user_email?", + ue.created_at AS "user_email_created_at?", + ue.confirmed_at AS "user_email_confirmed_at?" + FROM users u + + LEFT JOIN user_emails ue + USING (user_id) + + WHERE u.user_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(executor) + .instrument(info_span!("Fetch user")) + .await?; + + let primary_email = match ( + res.user_email_id, + res.user_email, + res.user_email_created_at, + res.user_email_confirmed_at, + ) { + (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { + data: id.into(), + email, + created_at, + confirmed_at, + }), + (None, None, None, None) => None, + _ => return Err(DatabaseInconsistencyError.into()), + }; + + let id = Ulid::from(res.user_id); + Ok(User { + data: id, + username: res.user_username, + sub: id.to_string(), + primary_email, + }) +} + #[tracing::instrument( skip_all, fields(user.username = username), diff --git a/crates/templates/Cargo.toml b/crates/templates/Cargo.toml index d49168ea..673a2627 100644 --- a/crates/templates/Cargo.toml +++ b/crates/templates/Cargo.toml @@ -7,7 +7,7 @@ license = "Apache-2.0" [dependencies] tracing = "0.1.37" -tokio = { version = "1.22.0", features = ["macros"] } +tokio = { version = "1.22.0", features = ["macros", "rt"] } anyhow = "1.0.66" thiserror = "1.0.37" diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index e9aba32a..c882dd35 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -748,6 +748,37 @@ impl TemplateContext for EmailAddContext { } } +/// Context used by the `pages/upstream_oauth2/{link_mismatch,do_login}.html` +/// templates +#[derive(Serialize)] +pub struct UpstreamExistingLinkContext { + linked_user: User<()>, +} + +impl UpstreamExistingLinkContext { + /// Constructs a new context with an existing linked user + pub fn new(linked_user: T) -> Self + where + T: Into>, + { + Self { + linked_user: linked_user.into(), + } + } +} + +impl TemplateContext for UpstreamExistingLinkContext { + fn sample(now: chrono::DateTime) -> Vec + where + Self: Sized, + { + User::samples(now) + .into_iter() + .map(|linked_user| Self { linked_user }) + .collect() + } +} + /// Context used by the `form_post.html` template #[derive(Serialize)] pub struct FormPostContext { diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index f2fc0bc6..f2919802 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -49,7 +49,7 @@ pub use self::{ EmailVerificationContext, EmailVerificationPageContext, EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, PolicyViolationContext, PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, - TemplateContext, WithCsrf, WithOptionalSession, WithSession, + TemplateContext, UpstreamExistingLinkContext, WithCsrf, WithOptionalSession, WithSession, }, forms::{FieldError, FormError, FormField, FormState, ToFormState}, }; @@ -225,6 +225,21 @@ register_templates! { /// Render the email verification subject pub fn render_email_verification_subject(EmailVerificationContext) { "emails/verification.subject" } + + /// Render the upstream already linked message + pub fn render_upstream_oauth2_already_linked(WithCsrf>) { "pages/upstream_oauth2/already_linked.html" } + + /// Render the upstream link mismatch message + pub fn render_upstream_oauth2_link_mismatch(WithCsrf>) { "pages/upstream_oauth2/link_mismatch.html" } + + /// Render the upstream suggest link message + pub fn render_upstream_oauth2_suggest_link(WithCsrf>) { "pages/upstream_oauth2/suggest_link.html" } + + /// Render the upstream login screen + pub fn render_upstream_oauth2_do_login(WithCsrf) { "pages/upstream_oauth2/do_login.html" } + + /// Render the upstream register screen + pub fn render_upstream_oauth2_do_register(WithCsrf) { "pages/upstream_oauth2/do_register.html" } } impl Templates { @@ -248,6 +263,11 @@ impl Templates { check::render_email_verification_txt(self, now).await?; check::render_email_verification_html(self, now).await?; check::render_email_verification_subject(self, now).await?; + check::render_upstream_oauth2_already_linked(self, now).await?; + check::render_upstream_oauth2_link_mismatch(self, now).await?; + check::render_upstream_oauth2_suggest_link(self, now).await?; + check::render_upstream_oauth2_do_login(self, now).await?; + check::render_upstream_oauth2_do_register(self, now).await?; Ok(()) } } diff --git a/templates/pages/upstream_oauth2/already_linked.html b/templates/pages/upstream_oauth2/already_linked.html new file mode 100644 index 00000000..32b3875f --- /dev/null +++ b/templates/pages/upstream_oauth2/already_linked.html @@ -0,0 +1,25 @@ +{# +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% extends "base.html" %} + +{% block content %} +
+

+ Your upstream account is already linked. +

+
+{% endblock content %} diff --git a/templates/pages/upstream_oauth2/do_login.html b/templates/pages/upstream_oauth2/do_login.html new file mode 100644 index 00000000..9fbbaeac --- /dev/null +++ b/templates/pages/upstream_oauth2/do_login.html @@ -0,0 +1,32 @@ +{# +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% extends "base.html" %} + +{% block content %} +
+
+

+ Continue login +

+ + + + + {{ button::button(text="Continue") }} +
+
+{% endblock content %} diff --git a/templates/pages/upstream_oauth2/do_register.html b/templates/pages/upstream_oauth2/do_register.html new file mode 100644 index 00000000..5aaedec7 --- /dev/null +++ b/templates/pages/upstream_oauth2/do_register.html @@ -0,0 +1,33 @@ +{# +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% extends "base.html" %} + +{% block content %} +
+
+

+ Choose your username +

+ + + + {{ field::input(label="Username", name="username", autocomplete="username", autocorrect="off", autocapitalize="none") }} + + {{ button::button(text="Continue") }} +
+
+{% endblock content %} diff --git a/templates/pages/upstream_oauth2/link_mismatch.html b/templates/pages/upstream_oauth2/link_mismatch.html new file mode 100644 index 00000000..856d8385 --- /dev/null +++ b/templates/pages/upstream_oauth2/link_mismatch.html @@ -0,0 +1,29 @@ +{# +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% extends "base.html" %} + +{% block content %} +
+
+

+ This upstream account is already linked to another account. +

+ + {{ logout::button(text="Logout", class=button::plain_class(), csrf_token=csrf_token) }} +
+
+{% endblock content %} diff --git a/templates/pages/upstream_oauth2/suggest_link.html b/templates/pages/upstream_oauth2/suggest_link.html new file mode 100644 index 00000000..4244a3d2 --- /dev/null +++ b/templates/pages/upstream_oauth2/suggest_link.html @@ -0,0 +1,36 @@ +{# +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% extends "base.html" %} + +{% block content %} +
+
+

+ Link to your existing account +

+ +
+ + + + {{ button::button(text="Link", class="flex-1") }} +
+ +
Or {{ logout::button(text="Logout", class=button::outline_class(), csrf_token=csrf_token) }}
+
+
+{% endblock content %}