1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

feat: support for MSC3824 action param on SSO redirect (#248)

Co-authored-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
Hugh Nimmo-Smith
2022-06-14 12:34:56 +01:00
committed by GitHub
parent 482bfeecc2
commit 5632f6ba99
3 changed files with 79 additions and 17 deletions

View File

@ -16,7 +16,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use axum::{ use axum::{
extract::{Form, Path}, extract::{Form, Path, Query},
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Redirect, Response},
Extension, Extension,
}; };
@ -28,11 +28,11 @@ use mas_axum_utils::{
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::Device; use mas_data_model::Device;
use mas_router::{PostAuthAction, Route}; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use rand::thread_rng; use rand::thread_rng;
use serde::Serialize; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::PgPool;
#[derive(Serialize)] #[derive(Serialize)]
@ -44,11 +44,17 @@ struct AllParams<'s> {
login_token: &'s str, login_token: &'s str,
} }
#[derive(Debug, Deserialize)]
pub struct Params {
action: Option<CompatLoginSsoAction>,
}
pub async fn get( pub async fn get(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<i64>, Path(id): Path<i64>,
Query(params): Query<Params>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
@ -60,9 +66,17 @@ pub async fn get(
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
} else { } else {
// If there is no session, redirect to the login screen // If there is no session, redirect to the login or register screen
let login = mas_router::Login::and_continue_compat_sso_login(id); let url = match params.action {
return Ok((cookie_jar, login.go()).into_response()); Some(CompatLoginSsoAction::Register) => {
mas_router::Register::and_continue_compat_sso_login(id).go()
}
Some(CompatLoginSsoAction::Login) | None => {
mas_router::Login::and_continue_compat_sso_login(id).go()
}
};
return Ok((cookie_jar, url).into_response());
}; };
// TODO: make that more generic // TODO: make that more generic
@ -74,7 +88,7 @@ pub async fn get(
.is_none() .is_none()
{ {
let destination = mas_router::AccountAddEmail::default() let destination = mas_router::AccountAddEmail::default()
.and_then(PostAuthAction::ContinueCompatSsoLogin { data: id }); .and_then(PostAuthAction::continue_compat_sso_login(id));
return Ok((cookie_jar, destination.go()).into_response()); return Ok((cookie_jar, destination.go()).into_response());
} }
@ -105,6 +119,7 @@ pub async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<i64>, Path(id): Path<i64>,
Form(form): Form<ProtectedForm<()>>, Form(form): Form<ProtectedForm<()>>,
Query(params): Query<Params>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
@ -116,9 +131,17 @@ pub async fn post(
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
} else { } else {
// If there is no session, redirect to the login screen // If there is no session, redirect to the login or register screen
let login = mas_router::Login::and_continue_compat_sso_login(id); let url = match params.action {
return Ok((cookie_jar, login.go()).into_response()); Some(CompatLoginSsoAction::Register) => {
mas_router::Register::and_continue_compat_sso_login(id).go()
}
Some(CompatLoginSsoAction::Login) | None => {
mas_router::Login::and_continue_compat_sso_login(id).go()
}
};
return Ok((cookie_jar, url).into_response());
}; };
// TODO: make that more generic // TODO: make that more generic
@ -130,7 +153,7 @@ pub async fn post(
.is_none() .is_none()
{ {
let destination = mas_router::AccountAddEmail::default() let destination = mas_router::AccountAddEmail::default()
.and_then(PostAuthAction::ContinueCompatSsoLogin { data: id }); .and_then(PostAuthAction::continue_compat_sso_login(id));
return Ok((cookie_jar, destination.go()).into_response()); return Ok((cookie_jar, destination.go()).into_response());
} }

View File

@ -15,13 +15,14 @@
use axum::{extract::Query, response::IntoResponse, Extension}; use axum::{extract::Query, response::IntoResponse, Extension};
use hyper::StatusCode; use hyper::StatusCode;
use mas_router::{CompatLoginSsoComplete, UrlBuilder}; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::compat::insert_compat_sso_login; use mas_storage::compat::insert_compat_sso_login;
use rand::{ use rand::{
distributions::{Alphanumeric, DistString}, distributions::{Alphanumeric, DistString},
thread_rng, thread_rng,
}; };
use serde::Deserialize; use serde::Deserialize;
use serde_with::serde;
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use url::Url; use url::Url;
@ -30,6 +31,7 @@ use url::Url;
pub struct Params { pub struct Params {
#[serde(rename = "redirectUrl")] #[serde(rename = "redirectUrl")]
redirect_url: Option<String>, redirect_url: Option<String>,
action: Option<CompatLoginSsoAction>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -83,5 +85,5 @@ pub async fn get(
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?; let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?;
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete(login.data))) Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action)))
} }

View File

@ -45,7 +45,7 @@ impl PostAuthAction {
pub fn go_next(&self) -> axum::response::Redirect { pub fn go_next(&self) -> axum::response::Redirect {
match self { match self {
Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(), Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(),
Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete(*data).go(), Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete::new(*data, None).go(),
Self::ChangePassword => AccountPassword.go(), Self::ChangePassword => AccountPassword.go(),
} }
} }
@ -282,6 +282,13 @@ impl Register {
} }
} }
#[must_use]
pub fn and_continue_compat_sso_login(data: i64) -> Self {
Self {
post_auth_action: Some(PostAuthAction::continue_compat_sso_login(data)),
}
}
/// Get a reference to the reauth's post auth action. /// Get a reference to the reauth's post auth action.
#[must_use] #[must_use]
pub fn post_auth_action(&self) -> Option<&PostAuthAction> { pub fn post_auth_action(&self) -> Option<&PostAuthAction> {
@ -473,16 +480,46 @@ impl SimpleRoute for CompatLoginSsoRedirectIdp {
const PATH: &'static str = "/_matrix/client/:version/login/sso/redirect/:idp"; const PATH: &'static str = "/_matrix/client/:version/login/sso/redirect/:idp";
} }
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum CompatLoginSsoAction {
Login,
Register,
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub struct CompatLoginSsoActionParams {
action: CompatLoginSsoAction,
}
/// `GET|POST /complete-compat-sso/:id` /// `GET|POST /complete-compat-sso/:id`
pub struct CompatLoginSsoComplete(pub i64); pub struct CompatLoginSsoComplete {
id: i64,
query: Option<CompatLoginSsoActionParams>,
}
impl CompatLoginSsoComplete {
#[must_use]
pub fn new(id: i64, action: Option<CompatLoginSsoAction>) -> Self {
Self {
id,
query: action.map(|action| CompatLoginSsoActionParams { action }),
}
}
}
impl Route for CompatLoginSsoComplete { impl Route for CompatLoginSsoComplete {
type Query = (); type Query = CompatLoginSsoActionParams;
fn query(&self) -> Option<&Self::Query> {
self.query.as_ref()
}
fn route() -> &'static str { fn route() -> &'static str {
"/complete-compat-sso/:grant_id" "/complete-compat-sso/:grant_id"
} }
fn path(&self) -> std::borrow::Cow<'static, str> { fn path(&self) -> std::borrow::Cow<'static, str> {
format!("/complete-compat-sso/{}", self.0).into() format!("/complete-compat-sso/{}", self.id).into()
} }
} }