You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
feat: support for MSC3824 action param on SSO redirect (#248)
Co-authored-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
@ -16,7 +16,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Form, Path},
|
extract::{Form, Path, Query},
|
||||||
response::{Html, IntoResponse, Redirect, Response},
|
response::{Html, IntoResponse, Redirect, Response},
|
||||||
Extension,
|
Extension,
|
||||||
};
|
};
|
||||||
@ -28,11 +28,11 @@ use mas_axum_utils::{
|
|||||||
};
|
};
|
||||||
use mas_config::Encrypter;
|
use mas_config::Encrypter;
|
||||||
use mas_data_model::Device;
|
use mas_data_model::Device;
|
||||||
use mas_router::{PostAuthAction, Route};
|
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
|
||||||
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
|
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
|
||||||
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
|
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
|
||||||
use rand::thread_rng;
|
use rand::thread_rng;
|
||||||
use serde::Serialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
@ -44,11 +44,17 @@ struct AllParams<'s> {
|
|||||||
login_token: &'s str,
|
login_token: &'s str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Params {
|
||||||
|
action: Option<CompatLoginSsoAction>,
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn get(
|
pub async fn get(
|
||||||
Extension(pool): Extension<PgPool>,
|
Extension(pool): Extension<PgPool>,
|
||||||
Extension(templates): Extension<Templates>,
|
Extension(templates): Extension<Templates>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(id): Path<i64>,
|
Path(id): Path<i64>,
|
||||||
|
Query(params): Query<Params>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
|
||||||
@ -60,9 +66,17 @@ pub async fn get(
|
|||||||
let session = if let Some(session) = maybe_session {
|
let session = if let Some(session) = maybe_session {
|
||||||
session
|
session
|
||||||
} else {
|
} else {
|
||||||
// If there is no session, redirect to the login screen
|
// If there is no session, redirect to the login or register screen
|
||||||
let login = mas_router::Login::and_continue_compat_sso_login(id);
|
let url = match params.action {
|
||||||
return Ok((cookie_jar, login.go()).into_response());
|
Some(CompatLoginSsoAction::Register) => {
|
||||||
|
mas_router::Register::and_continue_compat_sso_login(id).go()
|
||||||
|
}
|
||||||
|
Some(CompatLoginSsoAction::Login) | None => {
|
||||||
|
mas_router::Login::and_continue_compat_sso_login(id).go()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok((cookie_jar, url).into_response());
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: make that more generic
|
// TODO: make that more generic
|
||||||
@ -74,7 +88,7 @@ pub async fn get(
|
|||||||
.is_none()
|
.is_none()
|
||||||
{
|
{
|
||||||
let destination = mas_router::AccountAddEmail::default()
|
let destination = mas_router::AccountAddEmail::default()
|
||||||
.and_then(PostAuthAction::ContinueCompatSsoLogin { data: id });
|
.and_then(PostAuthAction::continue_compat_sso_login(id));
|
||||||
return Ok((cookie_jar, destination.go()).into_response());
|
return Ok((cookie_jar, destination.go()).into_response());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +119,7 @@ pub async fn post(
|
|||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(id): Path<i64>,
|
Path(id): Path<i64>,
|
||||||
Form(form): Form<ProtectedForm<()>>,
|
Form(form): Form<ProtectedForm<()>>,
|
||||||
|
Query(params): Query<Params>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
@ -116,9 +131,17 @@ pub async fn post(
|
|||||||
let session = if let Some(session) = maybe_session {
|
let session = if let Some(session) = maybe_session {
|
||||||
session
|
session
|
||||||
} else {
|
} else {
|
||||||
// If there is no session, redirect to the login screen
|
// If there is no session, redirect to the login or register screen
|
||||||
let login = mas_router::Login::and_continue_compat_sso_login(id);
|
let url = match params.action {
|
||||||
return Ok((cookie_jar, login.go()).into_response());
|
Some(CompatLoginSsoAction::Register) => {
|
||||||
|
mas_router::Register::and_continue_compat_sso_login(id).go()
|
||||||
|
}
|
||||||
|
Some(CompatLoginSsoAction::Login) | None => {
|
||||||
|
mas_router::Login::and_continue_compat_sso_login(id).go()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok((cookie_jar, url).into_response());
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: make that more generic
|
// TODO: make that more generic
|
||||||
@ -130,7 +153,7 @@ pub async fn post(
|
|||||||
.is_none()
|
.is_none()
|
||||||
{
|
{
|
||||||
let destination = mas_router::AccountAddEmail::default()
|
let destination = mas_router::AccountAddEmail::default()
|
||||||
.and_then(PostAuthAction::ContinueCompatSsoLogin { data: id });
|
.and_then(PostAuthAction::continue_compat_sso_login(id));
|
||||||
return Ok((cookie_jar, destination.go()).into_response());
|
return Ok((cookie_jar, destination.go()).into_response());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,13 +15,14 @@
|
|||||||
|
|
||||||
use axum::{extract::Query, response::IntoResponse, Extension};
|
use axum::{extract::Query, response::IntoResponse, Extension};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_router::{CompatLoginSsoComplete, UrlBuilder};
|
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
|
||||||
use mas_storage::compat::insert_compat_sso_login;
|
use mas_storage::compat::insert_compat_sso_login;
|
||||||
use rand::{
|
use rand::{
|
||||||
distributions::{Alphanumeric, DistString},
|
distributions::{Alphanumeric, DistString},
|
||||||
thread_rng,
|
thread_rng,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use serde_with::serde;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
@ -30,6 +31,7 @@ use url::Url;
|
|||||||
pub struct Params {
|
pub struct Params {
|
||||||
#[serde(rename = "redirectUrl")]
|
#[serde(rename = "redirectUrl")]
|
||||||
redirect_url: Option<String>,
|
redirect_url: Option<String>,
|
||||||
|
action: Option<CompatLoginSsoAction>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
@ -83,5 +85,5 @@ pub async fn get(
|
|||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?;
|
let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?;
|
||||||
|
|
||||||
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete(login.data)))
|
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action)))
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ impl PostAuthAction {
|
|||||||
pub fn go_next(&self) -> axum::response::Redirect {
|
pub fn go_next(&self) -> axum::response::Redirect {
|
||||||
match self {
|
match self {
|
||||||
Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(),
|
Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(),
|
||||||
Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete(*data).go(),
|
Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete::new(*data, None).go(),
|
||||||
Self::ChangePassword => AccountPassword.go(),
|
Self::ChangePassword => AccountPassword.go(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -282,6 +282,13 @@ impl Register {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn and_continue_compat_sso_login(data: i64) -> Self {
|
||||||
|
Self {
|
||||||
|
post_auth_action: Some(PostAuthAction::continue_compat_sso_login(data)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a reference to the reauth's post auth action.
|
/// Get a reference to the reauth's post auth action.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn post_auth_action(&self) -> Option<&PostAuthAction> {
|
pub fn post_auth_action(&self) -> Option<&PostAuthAction> {
|
||||||
@ -473,16 +480,46 @@ impl SimpleRoute for CompatLoginSsoRedirectIdp {
|
|||||||
const PATH: &'static str = "/_matrix/client/:version/login/sso/redirect/:idp";
|
const PATH: &'static str = "/_matrix/client/:version/login/sso/redirect/:idp";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum CompatLoginSsoAction {
|
||||||
|
Login,
|
||||||
|
Register,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
|
||||||
|
pub struct CompatLoginSsoActionParams {
|
||||||
|
action: CompatLoginSsoAction,
|
||||||
|
}
|
||||||
|
|
||||||
/// `GET|POST /complete-compat-sso/:id`
|
/// `GET|POST /complete-compat-sso/:id`
|
||||||
pub struct CompatLoginSsoComplete(pub i64);
|
pub struct CompatLoginSsoComplete {
|
||||||
|
id: i64,
|
||||||
|
query: Option<CompatLoginSsoActionParams>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompatLoginSsoComplete {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(id: i64, action: Option<CompatLoginSsoAction>) -> Self {
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
query: action.map(|action| CompatLoginSsoActionParams { action }),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Route for CompatLoginSsoComplete {
|
impl Route for CompatLoginSsoComplete {
|
||||||
type Query = ();
|
type Query = CompatLoginSsoActionParams;
|
||||||
|
|
||||||
|
fn query(&self) -> Option<&Self::Query> {
|
||||||
|
self.query.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
fn route() -> &'static str {
|
fn route() -> &'static str {
|
||||||
"/complete-compat-sso/:grant_id"
|
"/complete-compat-sso/:grant_id"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn path(&self) -> std::borrow::Cow<'static, str> {
|
fn path(&self) -> std::borrow::Cow<'static, str> {
|
||||||
format!("/complete-compat-sso/{}", self.0).into()
|
format!("/complete-compat-sso/{}", self.id).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user