diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 2ff1df03..6dd7479c 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -16,7 +16,7 @@ use std::collections::HashMap; use axum::{ - extract::{Form, Path}, + extract::{Form, Path, Query}, response::{Html, IntoResponse, Redirect, Response}, Extension, }; @@ -28,11 +28,11 @@ use mas_axum_utils::{ }; use mas_config::Encrypter; use mas_data_model::Device; -use mas_router::{PostAuthAction, Route}; +use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use rand::thread_rng; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use sqlx::PgPool; #[derive(Serialize)] @@ -44,11 +44,17 @@ struct AllParams<'s> { login_token: &'s str, } +#[derive(Debug, Deserialize)] +pub struct Params { + action: Option, +} + pub async fn get( Extension(pool): Extension, Extension(templates): Extension, cookie_jar: PrivateCookieJar, Path(id): Path, + Query(params): Query, ) -> Result { let mut conn = pool.acquire().await?; @@ -60,9 +66,17 @@ pub async fn get( let session = if let Some(session) = maybe_session { session } else { - // If there is no session, redirect to the login screen - let login = mas_router::Login::and_continue_compat_sso_login(id); - return Ok((cookie_jar, login.go()).into_response()); + // If there is no session, redirect to the login or register screen + let url = match params.action { + Some(CompatLoginSsoAction::Register) => { + mas_router::Register::and_continue_compat_sso_login(id).go() + } + Some(CompatLoginSsoAction::Login) | None => { + mas_router::Login::and_continue_compat_sso_login(id).go() + } + }; + + return Ok((cookie_jar, url).into_response()); }; // TODO: make that more generic @@ -74,7 +88,7 @@ pub async fn get( .is_none() { let destination = mas_router::AccountAddEmail::default() - .and_then(PostAuthAction::ContinueCompatSsoLogin { data: id }); + .and_then(PostAuthAction::continue_compat_sso_login(id)); return Ok((cookie_jar, destination.go()).into_response()); } @@ -105,6 +119,7 @@ pub async fn post( cookie_jar: PrivateCookieJar, Path(id): Path, Form(form): Form>, + Query(params): Query, ) -> Result { let mut txn = pool.begin().await?; @@ -116,9 +131,17 @@ pub async fn post( let session = if let Some(session) = maybe_session { session } else { - // If there is no session, redirect to the login screen - let login = mas_router::Login::and_continue_compat_sso_login(id); - return Ok((cookie_jar, login.go()).into_response()); + // If there is no session, redirect to the login or register screen + let url = match params.action { + Some(CompatLoginSsoAction::Register) => { + mas_router::Register::and_continue_compat_sso_login(id).go() + } + Some(CompatLoginSsoAction::Login) | None => { + mas_router::Login::and_continue_compat_sso_login(id).go() + } + }; + + return Ok((cookie_jar, url).into_response()); }; // TODO: make that more generic @@ -130,7 +153,7 @@ pub async fn post( .is_none() { let destination = mas_router::AccountAddEmail::default() - .and_then(PostAuthAction::ContinueCompatSsoLogin { data: id }); + .and_then(PostAuthAction::continue_compat_sso_login(id)); return Ok((cookie_jar, destination.go()).into_response()); } diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index 99e9b3aa..3ce8b8a0 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -15,13 +15,14 @@ use axum::{extract::Query, response::IntoResponse, Extension}; use hyper::StatusCode; -use mas_router::{CompatLoginSsoComplete, UrlBuilder}; +use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; use mas_storage::compat::insert_compat_sso_login; use rand::{ distributions::{Alphanumeric, DistString}, thread_rng, }; use serde::Deserialize; +use serde_with::serde; use sqlx::PgPool; use thiserror::Error; use url::Url; @@ -30,6 +31,7 @@ use url::Url; pub struct Params { #[serde(rename = "redirectUrl")] redirect_url: Option, + action: Option, } #[derive(Debug, Error)] @@ -83,5 +85,5 @@ pub async fn get( let mut conn = pool.acquire().await?; let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?; - Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete(login.data))) + Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action))) } diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 2c6d758a..15616a5a 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -45,7 +45,7 @@ impl PostAuthAction { pub fn go_next(&self) -> axum::response::Redirect { match self { Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(), - Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete(*data).go(), + Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete::new(*data, None).go(), Self::ChangePassword => AccountPassword.go(), } } @@ -282,6 +282,13 @@ impl Register { } } + #[must_use] + pub fn and_continue_compat_sso_login(data: i64) -> Self { + Self { + post_auth_action: Some(PostAuthAction::continue_compat_sso_login(data)), + } + } + /// Get a reference to the reauth's post auth action. #[must_use] pub fn post_auth_action(&self) -> Option<&PostAuthAction> { @@ -473,16 +480,46 @@ impl SimpleRoute for CompatLoginSsoRedirectIdp { const PATH: &'static str = "/_matrix/client/:version/login/sso/redirect/:idp"; } +#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum CompatLoginSsoAction { + Login, + Register, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +pub struct CompatLoginSsoActionParams { + action: CompatLoginSsoAction, +} + /// `GET|POST /complete-compat-sso/:id` -pub struct CompatLoginSsoComplete(pub i64); +pub struct CompatLoginSsoComplete { + id: i64, + query: Option, +} + +impl CompatLoginSsoComplete { + #[must_use] + pub fn new(id: i64, action: Option) -> Self { + Self { + id, + query: action.map(|action| CompatLoginSsoActionParams { action }), + } + } +} impl Route for CompatLoginSsoComplete { - type Query = (); + type Query = CompatLoginSsoActionParams; + + fn query(&self) -> Option<&Self::Query> { + self.query.as_ref() + } + fn route() -> &'static str { "/complete-compat-sso/:grant_id" } fn path(&self) -> std::borrow::Cow<'static, str> { - format!("/complete-compat-sso/{}", self.0).into() + format!("/complete-compat-sso/{}", self.id).into() } }