1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Save the post auth action during upstream OAuth login

This commit is contained in:
Quentin Gliech
2022-12-05 18:27:56 +01:00
parent 4d93f4d4f0
commit 23fd833d45
18 changed files with 142 additions and 100 deletions

2
Cargo.lock generated
View File

@ -2638,7 +2638,6 @@ dependencies = [
"async-trait", "async-trait",
"axum 0.6.1", "axum 0.6.1",
"axum-extra", "axum-extra",
"bincode",
"chrono", "chrono",
"data-encoding", "data-encoding",
"futures-util", "futures-util",
@ -3125,6 +3124,7 @@ dependencies = [
"anyhow", "anyhow",
"camino", "camino",
"chrono", "chrono",
"http",
"mas-data-model", "mas-data-model",
"mas-router", "mas-router",
"oauth2-types", "oauth2-types",

View File

@ -9,7 +9,6 @@ license = "Apache-2.0"
async-trait = "0.1.59" async-trait = "0.1.59"
axum = { version = "0.6.1", features = ["headers"] } axum = { version = "0.6.1", features = ["headers"] }
axum-extra = { version = "0.4.2", features = ["cookie-private"] } axum-extra = { version = "0.4.2", features = ["cookie-private"] }
bincode = "1.3.3"
chrono = "0.4.23" chrono = "0.4.23"
data-encoding = "2.3.2" data-encoding = "2.3.2"
futures-util = "0.3.25" futures-util = "0.3.25"

View File

@ -14,15 +14,13 @@
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar //! Private (encrypted) cookie jar, based on axum-extra's cookie jar
use data_encoding::BASE64URL_NOPAD;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("could not decode cookie")] #[error("could not decode cookie")]
pub enum CookieDecodeError { pub enum CookieDecodeError {
Deserialize(#[from] bincode::Error), Deserialize(#[from] serde_json::Error),
Decode(#[from] data_encoding::DecodeError),
} }
pub trait CookieExt { pub trait CookieExt {
@ -41,10 +39,7 @@ impl<'a> CookieExt for axum_extra::extract::cookie::Cookie<'a> {
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
let bytes = BASE64URL_NOPAD.decode(self.value().as_bytes())?; let decoded = serde_json::from_str(self.value())?;
let decoded = bincode::deserialize(&bytes)?;
Ok(decoded) Ok(decoded)
} }
@ -52,8 +47,7 @@ impl<'a> CookieExt for axum_extra::extract::cookie::Cookie<'a> {
where where
T: Serialize, T: Serialize,
{ {
let bytes = bincode::serialize(t).unwrap(); let encoded = serde_json::to_string(t).unwrap();
let encoded = BASE64URL_NOPAD.encode(&bytes);
self.set_value(encoded); self.set_value(encoded);
self self
} }

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, Query, State},
response::{IntoResponse, Redirect}, response::{IntoResponse, Redirect},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
@ -28,7 +28,7 @@ use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use super::UpstreamSessionsCookie; use super::UpstreamSessionsCookie;
use crate::impl_from_error_for_route; use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub(crate) enum RouteError { pub(crate) enum RouteError {
@ -68,6 +68,7 @@ pub(crate) async fn get(
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(provider_id): Path<Ulid>, Path(provider_id): Path<Ulid>,
Query(query): Query<OptionalPostAuthAction>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?; let (clock, mut rng) = crate::rng_and_clock()?;
@ -115,7 +116,7 @@ pub(crate) async fn get(
.await?; .await?;
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar) let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
.add(session.id, provider.id, data.state) .add(session.id, provider.id, data.state, query.post_auth_action)
.save(cookie_jar, clock.now()); .save(cookie_jar, clock.now());
txn.commit().await?; txn.commit().await?;

View File

@ -137,7 +137,7 @@ pub(crate) async fn get(
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let session_id = sessions_cookie let (session_id, _post_auth_action) = sessions_cookie
.find_session(provider_id, &params.state) .find_session(provider_id, &params.state)
.map_err(|_| RouteError::MissingCookie)?; .map_err(|_| RouteError::MissingCookie)?;

View File

@ -17,6 +17,7 @@
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
use chrono::{DateTime, Duration, NaiveDateTime, Utc}; use chrono::{DateTime, Duration, NaiveDateTime, Utc};
use mas_axum_utils::CookieExt; use mas_axum_utils::CookieExt;
use mas_router::PostAuthAction;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use time::OffsetDateTime; use time::OffsetDateTime;
@ -28,12 +29,13 @@ static COOKIE_NAME: &str = "upstream-oauth2-sessions";
/// Sessions expire after 10 minutes /// Sessions expire after 10 minutes
static SESSION_MAX_TIME_SECS: i64 = 60 * 10; static SESSION_MAX_TIME_SECS: i64 = 60 * 10;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Debug)]
pub struct Payload { pub struct Payload {
session: Ulid, session: Ulid,
provider: Ulid, provider: Ulid,
state: String, state: String,
link: Option<Ulid>, link: Option<Ulid>,
post_auth_action: Option<PostAuthAction>,
} }
impl Payload { impl Payload {
@ -46,7 +48,7 @@ impl Payload {
} }
} }
#[derive(Serialize, Deserialize, Default)] #[derive(Serialize, Deserialize, Default, Debug)]
pub struct UpstreamSessions(Vec<Payload>); pub struct UpstreamSessions(Vec<Payload>);
#[derive(Debug, Error, PartialEq, Eq)] #[derive(Debug, Error, PartialEq, Eq)]
@ -87,12 +89,19 @@ impl UpstreamSessions {
} }
/// Add a new session, for a provider and a random state /// Add a new session, for a provider and a random state
pub fn add(mut self, session: Ulid, provider: Ulid, state: String) -> Self { pub fn add(
mut self,
session: Ulid,
provider: Ulid,
state: String,
post_auth_action: Option<PostAuthAction>,
) -> Self {
self.0.push(Payload { self.0.push(Payload {
session, session,
provider, provider,
state, state,
link: None, link: None,
post_auth_action,
}); });
self self
} }
@ -102,11 +111,11 @@ impl UpstreamSessions {
&self, &self,
provider: Ulid, provider: Ulid,
state: &str, state: &str,
) -> Result<Ulid, UpstreamSessionNotFound> { ) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
self.0 self.0
.iter() .iter()
.find(|p| p.provider == provider && p.state == state && p.link.is_none()) .find(|p| p.provider == provider && p.state == state && p.link.is_none())
.map(|p| p.session) .map(|p| (p.session, p.post_auth_action.as_ref()))
.ok_or(UpstreamSessionNotFound) .ok_or(UpstreamSessionNotFound)
} }
@ -127,11 +136,14 @@ impl UpstreamSessions {
} }
/// Find a session from its link /// Find a session from its link
pub fn lookup_link(&self, link_id: Ulid) -> Result<Ulid, UpstreamSessionNotFound> { pub fn lookup_link(
&self,
link_id: Ulid,
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
self.0 self.0
.iter() .iter()
.find(|p| p.link == Some(link_id)) .find(|p| p.link == Some(link_id))
.map(|p| p.session) .map(|p| (p.session, p.post_auth_action.as_ref()))
.ok_or(UpstreamSessionNotFound) .ok_or(UpstreamSessionNotFound)
} }
@ -171,22 +183,22 @@ mod tests {
let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng); let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
let first_state = "first-state"; let first_state = "first-state";
let sessions = sessions.add(first_session, provider_a, first_state.into()); let sessions = sessions.add(first_session, provider_a, first_state.into(), None);
let now = now + Duration::minutes(5); let now = now + Duration::minutes(5);
let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng); let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
let second_state = "second-state"; let second_state = "second-state";
let sessions = sessions.add(second_session, provider_b, second_state.into()); let sessions = sessions.add(second_session, provider_b, second_state.into(), None);
let sessions = sessions.expire(now); let sessions = sessions.expire(now);
assert_eq!( assert_eq!(
sessions.find_session(provider_a, first_state), sessions.find_session(provider_a, first_state).unwrap().0,
Ok(first_session) first_session,
); );
assert_eq!( assert_eq!(
sessions.find_session(provider_b, second_state), sessions.find_session(provider_b, second_state).unwrap().0,
Ok(second_session) second_session
); );
assert!(sessions.find_session(provider_b, first_state).is_err()); assert!(sessions.find_session(provider_b, first_state).is_err());
assert!(sessions.find_session(provider_a, second_state).is_err()); assert!(sessions.find_session(provider_a, second_state).is_err());
@ -196,8 +208,8 @@ mod tests {
let sessions = sessions.expire(now); let sessions = sessions.expire(now);
assert!(sessions.find_session(provider_a, first_state).is_err()); assert!(sessions.find_session(provider_a, first_state).is_err());
assert_eq!( assert_eq!(
sessions.find_session(provider_b, second_state), sessions.find_session(provider_b, second_state).unwrap().0,
Ok(second_session) second_session
); );
// Associate a link with the second // Associate a link with the second
@ -210,7 +222,7 @@ mod tests {
assert!(sessions.find_session(provider_b, second_state).is_err()); assert!(sessions.find_session(provider_b, second_state).is_err());
// But it can be looked up by its link // But it can be looked up by its link
assert_eq!(sessions.lookup_link(second_link), Ok(second_session)); assert_eq!(sessions.lookup_link(second_link).unwrap().0, second_session);
// And it can be consumed // And it can be consumed
let sessions = sessions.consume_link(second_link).unwrap(); let sessions = sessions.consume_link(second_link).unwrap();
// But only once // But only once

View File

@ -24,7 +24,6 @@ use mas_axum_utils::{
SessionInfoExt, SessionInfoExt,
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{ use mas_storage::{
upstream_oauth2::{ upstream_oauth2::{
associate_link_to_user, consume_session, lookup_link, lookup_session_on_link, associate_link_to_user, consume_session, lookup_link, lookup_session_on_link,
@ -44,7 +43,7 @@ use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use super::UpstreamSessionsCookie; use super::UpstreamSessionsCookie;
use crate::impl_from_error_for_route; use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub(crate) enum RouteError { pub(crate) enum RouteError {
@ -114,7 +113,7 @@ pub(crate) async fn get(
let (clock, mut rng) = crate::rng_and_clock()?; let (clock, mut rng) = crate::rng_and_clock()?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let session_id = sessions_cookie let (session_id, _post_auth_action) = sessions_cookie
.lookup_link(link_id) .lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?; .map_err(|_| RouteError::MissingCookie)?;
@ -213,10 +212,14 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(clock.now(), form)?; let form = cookie_jar.verify_form(clock.now(), form)?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let session_id = sessions_cookie let (session_id, post_auth_action) = sessions_cookie
.lookup_link(link_id) .lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?; .map_err(|_| RouteError::MissingCookie)?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
let link = lookup_link(&mut txn, link_id) let link = lookup_link(&mut txn, link_id)
.await .await
.to_option()? .to_option()?
@ -267,5 +270,5 @@ pub(crate) async fn post(
txn.commit().await?; txn.commit().await?;
Ok((cookie_jar, mas_router::Index.go())) Ok((cookie_jar, post_auth_action.go_next()))
} }

View File

@ -22,7 +22,6 @@ use mas_axum_utils::{
FancyError, SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::user::{login, LoginError}; use mas_storage::user::{login, LoginError};
use mas_templates::{ use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
@ -160,10 +159,7 @@ async fn render(
} else { } else {
ctx ctx
}; };
let register_link = mas_router::Register::from(action.post_auth_action).relative_url(); let ctx = ctx.with_csrf(csrf_token.form_value());
let ctx = ctx
.with_register_link(register_link.to_string())
.with_csrf(csrf_token.form_value());
let content = templates.render_login(&ctx).await?; let content = templates.render_login(&ctx).await?;
Ok(content) Ok(content)

View File

@ -243,10 +243,7 @@ async fn render(
} else { } else {
ctx ctx
}; };
let login_link = mas_router::Login::from(action.post_auth_action).relative_url(); let ctx = ctx.with_csrf(csrf_token.form_value());
let ctx = ctx
.with_login_link(login_link.to_string())
.with_csrf(csrf_token.form_value());
let content = templates.render_register(&ctx).await?; let content = templates.render_register(&ctx).await?;
Ok(content) Ok(content)

View File

@ -16,7 +16,7 @@ use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
}; };
use mas_templates::PostAuthContext; use mas_templates::{PostAuthContext, PostAuthContextInner};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgConnection; use sqlx::PgConnection;
@ -41,23 +41,24 @@ impl OptionalPostAuthAction {
&self, &self,
conn: &mut PgConnection, conn: &mut PgConnection,
) -> anyhow::Result<Option<PostAuthContext>> { ) -> anyhow::Result<Option<PostAuthContext>> {
match &self.post_auth_action { let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
Some(PostAuthAction::ContinueAuthorizationGrant { data }) => { let ctx = match action {
let grant = get_grant_by_id(conn, *data).await?; PostAuthAction::ContinueAuthorizationGrant { data } => {
let grant = get_grant_by_id(conn, data).await?;
let grant = Box::new(grant.into()); let grant = Box::new(grant.into());
Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant })) PostAuthContextInner::ContinueAuthorizationGrant { grant }
} }
Some(PostAuthAction::ContinueCompatSsoLogin { data }) => { PostAuthAction::ContinueCompatSsoLogin { data } => {
let login = get_compat_sso_login_by_id(conn, *data).await?; let login = get_compat_sso_login_by_id(conn, data).await?;
let login = Box::new(login.into()); let login = Box::new(login.into());
Ok(Some(PostAuthContext::ContinueCompatSsoLogin { login })) PostAuthContextInner::ContinueCompatSsoLogin { login }
} }
Some(PostAuthAction::ChangePassword) => Ok(Some(PostAuthContext::ChangePassword)), PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
Some(PostAuthAction::LinkUpstream { id }) => { PostAuthAction::LinkUpstream { id } => {
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, *id).await?; let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id).await?;
let provider = let provider =
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
@ -65,10 +66,13 @@ impl OptionalPostAuthAction {
let provider = Box::new(provider); let provider = Box::new(provider);
let link = Box::new(link); let link = Box::new(link);
Ok(Some(PostAuthContext::LinkUpstream { provider, link })) PostAuthContextInner::LinkUpstream { provider, link }
} }
};
None => Ok(None), Ok(Some(PostAuthContext {
} params: action.clone(),
ctx,
}))
} }
} }

View File

@ -532,17 +532,27 @@ impl Route for CompatLoginSsoComplete {
/// `GET /upstream/authorize/:id` /// `GET /upstream/authorize/:id`
pub struct UpstreamOAuth2Authorize { pub struct UpstreamOAuth2Authorize {
id: Ulid, id: Ulid,
post_auth_action: Option<PostAuthAction>,
} }
impl UpstreamOAuth2Authorize { impl UpstreamOAuth2Authorize {
#[must_use] #[must_use]
pub const fn new(id: Ulid) -> Self { pub const fn new(id: Ulid) -> Self {
Self { id } Self {
id,
post_auth_action: None,
}
}
#[must_use]
pub fn and_then(mut self, action: PostAuthAction) -> Self {
self.post_auth_action = Some(action);
self
} }
} }
impl Route for UpstreamOAuth2Authorize { impl Route for UpstreamOAuth2Authorize {
type Query = (); type Query = PostAuthAction;
fn route() -> &'static str { fn route() -> &'static str {
"/upstream/authorize/:provider_id" "/upstream/authorize/:provider_id"
} }
@ -550,6 +560,10 @@ impl Route for UpstreamOAuth2Authorize {
fn path(&self) -> std::borrow::Cow<'static, str> { fn path(&self) -> std::borrow::Cow<'static, str> {
format!("/upstream/authorize/{}", self.id).into() format!("/upstream/authorize/{}", self.id).into()
} }
fn query(&self) -> Option<&Self::Query> {
self.post_auth_action.as_ref()
}
} }
/// `GET /upstream/callback/:id` /// `GET /upstream/callback/:id`

View File

@ -20,6 +20,7 @@ serde_urlencoded = "0.7.1"
camino = "1.1.1" camino = "1.1.1"
chrono = "0.4.23" chrono = "0.4.23"
url = "2.3.1" url = "2.3.1"
http = "0.2.8"
ulid = { version = "1.0.0", features = ["serde"] } ulid = { version = "1.0.0", features = ["serde"] }
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }

View File

@ -244,10 +244,10 @@ impl FormField for LoginFormField {
} }
} }
/// Context used in login and reauth screens, for the post-auth action to do /// Inner context used in login and reauth screens. See [`PostAuthContext`].
#[derive(Serialize)] #[derive(Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")] #[serde(tag = "kind", rename_all = "snake_case")]
pub enum PostAuthContext { pub enum PostAuthContextInner {
/// Continue an authorization grant /// Continue an authorization grant
ContinueAuthorizationGrant { ContinueAuthorizationGrant {
/// The authorization grant that will be continued after authentication /// The authorization grant that will be continued after authentication
@ -274,13 +274,23 @@ pub enum PostAuthContext {
}, },
} }
/// Context used in login and reauth screens, for the post-auth action to do
#[derive(Serialize)]
pub struct PostAuthContext {
/// The post auth action params from the URL
pub params: PostAuthAction,
/// The loaded post auth context
#[serde(flatten)]
pub ctx: PostAuthContextInner,
}
/// Context used by the `login.html` template /// Context used by the `login.html` template
#[derive(Serialize, Default)] #[derive(Serialize, Default)]
pub struct LoginContext { pub struct LoginContext {
form: FormState<LoginFormField>, form: FormState<LoginFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
providers: Vec<UpstreamOAuthProvider>, providers: Vec<UpstreamOAuthProvider>,
register_link: String,
} }
impl TemplateContext for LoginContext { impl TemplateContext for LoginContext {
@ -293,7 +303,6 @@ impl TemplateContext for LoginContext {
form: FormState::default(), form: FormState::default(),
next: None, next: None,
providers: Vec::new(), providers: Vec::new(),
register_link: "/register".to_owned(),
}] }]
} }
} }
@ -313,18 +322,9 @@ impl LoginContext {
/// Add a post authentication action to the context /// Add a post authentication action to the context
#[must_use] #[must_use]
pub fn with_post_action(self, next: PostAuthContext) -> Self { pub fn with_post_action(self, context: PostAuthContext) -> Self {
Self { Self {
next: Some(next), next: Some(context),
..self
}
}
/// Add a registration link to the context
#[must_use]
pub fn with_register_link(self, register_link: String) -> Self {
Self {
register_link,
..self ..self
} }
} }
@ -361,7 +361,6 @@ impl FormField for RegisterFormField {
pub struct RegisterContext { pub struct RegisterContext {
form: FormState<RegisterFormField>, form: FormState<RegisterFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
login_link: String,
} }
impl TemplateContext for RegisterContext { impl TemplateContext for RegisterContext {
@ -373,7 +372,6 @@ impl TemplateContext for RegisterContext {
vec![RegisterContext { vec![RegisterContext {
form: FormState::default(), form: FormState::default(),
next: None, next: None,
login_link: "/login".to_owned(),
}] }]
} }
} }
@ -393,12 +391,6 @@ impl RegisterContext {
..self ..self
} }
} }
/// Add a login link to the context
#[must_use]
pub fn with_login_link(self, login_link: String) -> Self {
Self { login_link, ..self }
}
} }
/// Context used by the `consent.html` template /// Context used by the `consent.html` template

View File

@ -22,7 +22,9 @@ use url::Url;
pub fn register(tera: &mut Tera, url_builder: UrlBuilder) { pub fn register(tera: &mut Tera, url_builder: UrlBuilder) {
tera.register_tester("empty", self::tester_empty); tera.register_tester("empty", self::tester_empty);
tera.register_function("add_params_to_uri", function_add_params_to_uri); tera.register_filter("to_params", filter_to_params);
tera.register_filter("safe_get", filter_safe_get);
tera.register_function("add_params_to_url", function_add_params_to_url);
tera.register_function("merge", function_merge); tera.register_function("merge", function_merge);
tera.register_function("dict", function_dict); tera.register_function("dict", function_dict);
tera.register_function("static_asset", make_static_asset(url_builder)); tera.register_function("static_asset", make_static_asset(url_builder));
@ -37,12 +39,42 @@ fn tester_empty(value: Option<&Value>, params: &[Value]) -> Result<bool, tera::E
} }
} }
fn filter_to_params(params: &Value, kv: &HashMap<String, Value>) -> Result<Value, tera::Error> {
let prefix = kv.get("prefix").and_then(Value::as_str).unwrap_or("");
let params = serde_urlencoded::to_string(params)
.map_err(|e| tera::Error::chain(e, "Could not serialize parameters"))?;
if params.is_empty() {
Ok(Value::String(String::new()))
} else {
Ok(Value::String(format!("{prefix}{params}")))
}
}
/// Alternative to `get` which does not crash on `None` and defaults to `None`
pub fn filter_safe_get(value: &Value, args: &HashMap<String, Value>) -> Result<Value, tera::Error> {
let default = args.get("default").unwrap_or(&Value::Null);
let key = args
.get("key")
.and_then(Value::as_str)
.ok_or_else(|| tera::Error::msg("Invalid parameter `uri`"))?;
match value.as_object() {
Some(o) => match o.get(key) {
Some(val) => Ok(val.clone()),
// If the value is not present, allow for an optional default value
None => Ok(default.clone()),
},
None => Ok(default.clone()),
}
}
enum ParamsWhere { enum ParamsWhere {
Fragment, Fragment,
Query, Query,
} }
fn function_add_params_to_uri(params: &HashMap<String, Value>) -> Result<Value, tera::Error> { fn function_add_params_to_url(params: &HashMap<String, Value>) -> Result<Value, tera::Error> {
use ParamsWhere::{Fragment, Query}; use ParamsWhere::{Fragment, Query};
// First, get the `uri`, `mode` and `params` parameters // First, get the `uri`, `mode` and `params` parameters
@ -77,12 +109,7 @@ fn function_add_params_to_uri(params: &HashMap<String, Value>) -> Result<Value,
.unwrap_or_default(); .unwrap_or_default();
// Merge the exising and the additional parameters together // Merge the exising and the additional parameters together
let params: HashMap<&String, &Value> = params let params: HashMap<&String, &Value> = params.iter().chain(existing.iter()).collect();
.iter()
// Filter out the `uri` and `mode` params
.filter(|(k, _v)| k != &"uri" && k != &"mode")
.chain(existing.iter())
.collect();
// Transform them back to urlencoded // Transform them back to urlencoded
let params = serde_urlencoded::to_string(params) let params = serde_urlencoded::to_string(params)

View File

@ -48,9 +48,9 @@ pub use self::{
AccountContext, AccountEmailsContext, CompatSsoContext, ConsentContext, EmailAddContext, AccountContext, AccountEmailsContext, CompatSsoContext, ConsentContext, EmailAddContext,
EmailVerificationContext, EmailVerificationPageContext, EmptyContext, ErrorContext, EmailVerificationContext, EmailVerificationPageContext, EmptyContext, ErrorContext,
FormPostContext, IndexContext, LoginContext, LoginFormField, PolicyViolationContext, FormPostContext, IndexContext, LoginContext, LoginFormField, PolicyViolationContext,
PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, PostAuthContext, PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext,
TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, RegisterFormField, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister,
WithCsrf, WithOptionalSession, WithSession, UpstreamSuggestLink, WithCsrf, WithOptionalSession, WithSession,
}, },
forms::{FieldError, FormError, FormField, FormState, ToFormState}, forms::{FieldError, FormError, FormField, FormState, ToFormState},
}; };

View File

@ -23,7 +23,7 @@ limitations under the License.
<button class="{{ class }}" type="submit">{{ text }}</button> <button class="{{ class }}" type="submit">{{ text }}</button>
</form> </form>
{% elif mode == "fragment" or mode == "query" %} {% elif mode == "fragment" or mode == "query" %}
<a class="{{ class }}" href="{{ add_params_to_uri(uri=uri, mode=mode, params=params) }}">{{ text }}</a> <a class="{{ class }}" href="{{ add_params_to_url(uri=uri, mode=mode, params=params) }}">{{ text }}</a>
{% else %} {% else %}
{{ throw(message="Invalid mode") }} {{ throw(message="Invalid mode") }}
{% endif %} {% endif %}

View File

@ -62,7 +62,8 @@ limitations under the License.
{% if not next or next.kind != "link_upstream" %} {% if not next or next.kind != "link_upstream" %}
<div class="text-center mt-4"> <div class="text-center mt-4">
Don't have an account yet? Don't have an account yet?
{{ button::link_text(text="Create an account", href=register_link) }} {% set params = next | safe_get(key="params") | to_params(prefix="?") %}
{{ button::link_text(text="Create an account", href="/register" ~ params) }}
</div> </div>
{% endif %} {% endif %}
@ -74,7 +75,8 @@ limitations under the License.
</div> </div>
{% for provider in providers %} {% for provider in providers %}
{{ button::link(text="Continue with " ~ provider.issuer, href="/upstream/authorize/" ~ provider.id) }} {% set params = next | safe_get(key="params") | to_params(prefix="?") %}
{{ button::link(text="Continue with " ~ provider.issuer, href="/upstream/authorize/" ~ provider.id ~ params) }}
{% endfor %} {% endfor %}
{% endif %} {% endif %}
</form> </form>

View File

@ -55,8 +55,8 @@ limitations under the License.
{% endif %} {% endif %}
<div class="text-center mt-4"> <div class="text-center mt-4">
Already have an account? Already have an account?
{# TODO: proper link #} {% set params = next | safe_get(key="params") | to_params(prefix="?") %}
{{ button::link_text(text="Sign in instead", href=login_link) }} {{ button::link_text(text="Sign in instead", href="/login" ~ params) }}
</div> </div>
</form> </form>
</section> </section>