1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

template: more cleanups

This commit is contained in:
Quentin Gliech
2022-12-08 14:43:46 +01:00
parent 13b1ac7c83
commit 0ea9089f7f
10 changed files with 79 additions and 93 deletions

1
Cargo.lock generated
View File

@ -3061,7 +3061,6 @@ dependencies = [
"axum 0.6.1", "axum 0.6.1",
"serde", "serde",
"serde_urlencoded", "serde_urlencoded",
"serde_with",
"ulid", "ulid",
"url", "url",
] ]

View File

@ -107,7 +107,7 @@ pub async fn get(
return Ok((cookie_jar, Html(content)).into_response()); return Ok((cookie_jar, Html(content)).into_response());
} }
let ctx = CompatSsoContext::new(login, PostAuthAction::continue_compat_sso_login(id)) let ctx = CompatSsoContext::new(login)
.with_session(session) .with_session(session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());

View File

@ -101,7 +101,7 @@ pub(crate) async fn get(
.await?; .await?;
if res.valid() { if res.valid() {
let ctx = ConsentContext::new(grant, PostAuthAction::continue_grant(grant_id)) let ctx = ConsentContext::new(grant)
.with_session(session) .with_session(session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
@ -109,7 +109,7 @@ pub(crate) async fn get(
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} else { } else {
let ctx = PolicyViolationContext::new(grant, PostAuthAction::continue_grant(grant_id)) let ctx = PolicyViolationContext::new(grant)
.with_session(session) .with_session(session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());

View File

@ -163,7 +163,7 @@ pub(crate) async fn get(
(Some(user_session), None) => { (Some(user_session), None) => {
// Session not linked, but user logged in: suggest linking account // Session not linked, but user logged in: suggest linking account
let ctx = UpstreamSuggestLink::new(link.id) let ctx = UpstreamSuggestLink::new(&link)
.with_session(user_session) .with_session(user_session)
.with_csrf(csrf_token.form_value()); .with_csrf(csrf_token.form_value());
@ -182,7 +182,7 @@ pub(crate) async fn get(
(None, None) => { (None, None) => {
// Session not linked and used not logged in: suggest creating an // Session not linked and used not logged in: suggest creating an
// account or logging in an existing user // account or logging in an existing user
let ctx = UpstreamRegister::new(link.id).with_csrf(csrf_token.form_value()); let ctx = UpstreamRegister::new(&link).with_csrf(csrf_token.form_value());
templates.render_upstream_oauth2_do_register(&ctx).await? templates.render_upstream_oauth2_do_register(&ctx).await?
} }

View File

@ -61,10 +61,7 @@ pub(crate) async fn get(
let ctx = ReauthContext::default(); let ctx = ReauthContext::default();
let next = query.load_context(&mut conn).await?; let next = query.load_context(&mut conn).await?;
let ctx = if let Some(next) = next { let ctx = if let Some(next) = next {
// SAFETY: we should have an action only if we have a "next" context ctx.with_post_action(next)
// TODO: make that cleaner
let action = query.post_auth_action.unwrap();
ctx.with_post_action(next, action)
} else { } else {
ctx ctx
}; };

View File

@ -38,22 +38,22 @@ impl OptionalPostAuthAction {
self.go_next_or_default(&mas_router::Index) self.go_next_or_default(&mas_router::Index)
} }
pub async fn load_context<'e>( pub async fn load_context(
&self, &self,
conn: &mut PgConnection, conn: &mut PgConnection,
) -> anyhow::Result<Option<PostAuthContext>> { ) -> anyhow::Result<Option<PostAuthContext>> {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action { let ctx = match action {
PostAuthAction::ContinueAuthorizationGrant { data } => { PostAuthAction::ContinueAuthorizationGrant { id } => {
let grant = get_grant_by_id(conn, data) let grant = get_grant_by_id(conn, id)
.await? .await?
.context("Failed to load authorization grant")?; .context("Failed to load authorization grant")?;
let grant = Box::new(grant); let grant = Box::new(grant);
PostAuthContextInner::ContinueAuthorizationGrant { grant } PostAuthContextInner::ContinueAuthorizationGrant { grant }
} }
PostAuthAction::ContinueCompatSsoLogin { data } => { PostAuthAction::ContinueCompatSsoLogin { id } => {
let login = get_compat_sso_login_by_id(conn, data) let login = get_compat_sso_login_by_id(conn, id)
.await? .await?
.context("Failed to load compat SSO login")?; .context("Failed to load compat SSO login")?;
let login = Box::new(login); let login = Box::new(login);

View File

@ -9,6 +9,5 @@ license = "Apache-2.0"
axum = { version = "0.6.1", default-features = false } axum = { version = "0.6.1", default-features = false }
serde = { version = "1.0.149", features = ["derive"] } serde = { version = "1.0.149", features = ["derive"] }
serde_urlencoded = "0.7.1" serde_urlencoded = "0.7.1"
serde_with = "2.1.0"
url = "2.3.1" url = "2.3.1"
ulid = "1.0.0" ulid = "1.0.0"

View File

@ -13,45 +13,39 @@
// limitations under the License. // limitations under the License.
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use ulid::Ulid; use ulid::Ulid;
pub use crate::traits::*; pub use crate::traits::*;
#[serde_as]
#[derive(Deserialize, Serialize, Clone, Debug)] #[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all = "snake_case", tag = "next")] #[serde(rename_all = "snake_case", tag = "next")]
pub enum PostAuthAction { pub enum PostAuthAction {
ContinueAuthorizationGrant { ContinueAuthorizationGrant { id: Ulid },
#[serde_as(as = "DisplayFromStr")] ContinueCompatSsoLogin { id: Ulid },
data: Ulid,
},
ContinueCompatSsoLogin {
#[serde_as(as = "DisplayFromStr")]
data: Ulid,
},
ChangePassword, ChangePassword,
LinkUpstream { LinkUpstream { id: Ulid },
#[serde_as(as = "DisplayFromStr")]
id: Ulid,
},
} }
impl PostAuthAction { impl PostAuthAction {
#[must_use] #[must_use]
pub fn continue_grant(data: Ulid) -> Self { pub const fn continue_grant(id: Ulid) -> Self {
PostAuthAction::ContinueAuthorizationGrant { data } PostAuthAction::ContinueAuthorizationGrant { id }
} }
#[must_use] #[must_use]
pub fn continue_compat_sso_login(data: Ulid) -> Self { pub const fn continue_compat_sso_login(id: Ulid) -> Self {
PostAuthAction::ContinueCompatSsoLogin { data } PostAuthAction::ContinueCompatSsoLogin { id }
}
#[must_use]
pub const fn link_upstream(id: Ulid) -> Self {
PostAuthAction::LinkUpstream { id }
} }
pub fn go_next(&self) -> axum::response::Redirect { pub fn go_next(&self) -> axum::response::Redirect {
match self { match self {
Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(), Self::ContinueAuthorizationGrant { id } => ContinueAuthorizationGrant(*id).go(),
Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete::new(*data, None).go(), Self::ContinueCompatSsoLogin { id } => CompatLoginSsoComplete::new(*id, None).go(),
Self::ChangePassword => AccountPassword.go(), Self::ChangePassword => AccountPassword.go(),
Self::LinkUpstream { id } => UpstreamOAuth2Link::new(*id).go(), Self::LinkUpstream { id } => UpstreamOAuth2Link::new(*id).go(),
} }
@ -165,23 +159,30 @@ impl Route for Login {
impl Login { impl Login {
#[must_use] #[must_use]
pub fn and_then(action: PostAuthAction) -> Self { pub const fn and_then(action: PostAuthAction) -> Self {
Self { Self {
post_auth_action: Some(action), post_auth_action: Some(action),
} }
} }
#[must_use] #[must_use]
pub fn and_continue_grant(data: Ulid) -> Self { pub const fn and_continue_grant(id: Ulid) -> Self {
Self { Self {
post_auth_action: Some(PostAuthAction::continue_grant(data)), post_auth_action: Some(PostAuthAction::continue_grant(id)),
} }
} }
#[must_use] #[must_use]
pub fn and_continue_compat_sso_login(data: Ulid) -> Self { pub const fn and_continue_compat_sso_login(id: Ulid) -> Self {
Self { Self {
post_auth_action: Some(PostAuthAction::continue_compat_sso_login(data)), post_auth_action: Some(PostAuthAction::continue_compat_sso_login(id)),
}
}
#[must_use]
pub const fn and_link_upstream(id: Ulid) -> Self {
Self {
post_auth_action: Some(PostAuthAction::link_upstream(id)),
} }
} }

View File

@ -411,7 +411,8 @@ impl TemplateContext for ConsentContext {
impl ConsentContext { impl ConsentContext {
/// Constructs a context for the client consent page /// Constructs a context for the client consent page
#[must_use] #[must_use]
pub fn new(grant: AuthorizationGrant, action: PostAuthAction) -> Self { pub fn new(grant: AuthorizationGrant) -> Self {
let action = PostAuthAction::continue_grant(grant.id);
Self { grant, action } Self { grant, action }
} }
} }
@ -436,7 +437,8 @@ impl TemplateContext for PolicyViolationContext {
impl PolicyViolationContext { impl PolicyViolationContext {
/// Constructs a context for the policy violation page /// Constructs a context for the policy violation page
#[must_use] #[must_use]
pub fn new(grant: AuthorizationGrant, action: PostAuthAction) -> Self { pub const fn new(grant: AuthorizationGrant) -> Self {
let action = PostAuthAction::continue_grant(grant.id);
Self { grant, action } Self { grant, action }
} }
} }
@ -462,7 +464,6 @@ impl FormField for ReauthFormField {
pub struct ReauthContext { pub struct ReauthContext {
form: FormState<ReauthFormField>, form: FormState<ReauthFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
action: Option<PostAuthAction>,
} }
impl TemplateContext for ReauthContext { impl TemplateContext for ReauthContext {
@ -474,7 +475,6 @@ impl TemplateContext for ReauthContext {
vec![ReauthContext { vec![ReauthContext {
form: FormState::default(), form: FormState::default(),
next: None, next: None,
action: None,
}] }]
} }
} }
@ -488,10 +488,9 @@ impl ReauthContext {
/// Add a post authentication action to the context /// Add a post authentication action to the context
#[must_use] #[must_use]
pub fn with_post_action(self, next: PostAuthContext, action: PostAuthAction) -> Self { pub fn with_post_action(self, next: PostAuthContext) -> Self {
Self { Self {
next: Some(next), next: Some(next),
action: Some(action),
..self ..self
} }
} }
@ -510,24 +509,22 @@ impl TemplateContext for CompatSsoContext {
Self: Sized, Self: Sized,
{ {
let id = Ulid::from_datetime_with_source(now.into(), rng); let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![CompatSsoContext { vec![CompatSsoContext::new(CompatSsoLogin {
login: CompatSsoLogin { id,
id, redirect_uri: Url::parse("https://app.element.io/").unwrap(),
redirect_uri: Url::parse("https://app.element.io/").unwrap(), login_token: "abcdefghijklmnopqrstuvwxyz012345".into(),
login_token: "abcdefghijklmnopqrstuvwxyz012345".into(), created_at: now,
created_at: now, state: CompatSsoLoginState::Pending,
state: CompatSsoLoginState::Pending, })]
},
action: PostAuthAction::ContinueCompatSsoLogin { data: id },
}]
} }
} }
impl CompatSsoContext { impl CompatSsoContext {
/// Constructs a context for the legacy SSO login page /// Constructs a context for the legacy SSO login page
#[must_use] #[must_use]
pub fn new(login: CompatSsoLogin, action: PostAuthAction) -> Self pub fn new(login: CompatSsoLogin) -> Self
where { where {
let action = PostAuthAction::continue_compat_sso_login(login.id);
Self { login, action } Self { login, action }
} }
} }
@ -654,13 +651,10 @@ pub struct EmailVerificationPageContext {
impl EmailVerificationPageContext { impl EmailVerificationPageContext {
/// Constructs a context for the email verification page /// Constructs a context for the email verification page
#[must_use] #[must_use]
pub fn new<T>(email: T) -> Self pub fn new(email: UserEmail) -> Self {
where
T: Into<UserEmail>,
{
Self { Self {
form: FormState::default(), form: FormState::default(),
email: email.into(), email,
} }
} }
@ -744,13 +738,9 @@ pub struct UpstreamExistingLinkContext {
impl UpstreamExistingLinkContext { impl UpstreamExistingLinkContext {
/// Constructs a new context with an existing linked user /// Constructs a new context with an existing linked user
pub fn new<T>(linked_user: T) -> Self #[must_use]
where pub fn new(linked_user: User) -> Self {
T: Into<User>, Self { linked_user }
{
Self {
linked_user: linked_user.into(),
}
} }
} }
@ -776,18 +766,23 @@ pub struct UpstreamSuggestLink {
impl UpstreamSuggestLink { impl UpstreamSuggestLink {
/// Constructs a new context with an existing linked user /// Constructs a new context with an existing linked user
#[must_use] #[must_use]
pub fn new(link_id: Ulid) -> Self { pub fn new(link: &UpstreamOAuthLink) -> Self {
let post_logout_action = PostAuthAction::LinkUpstream { id: link_id }; Self::for_link_id(link.id)
}
fn for_link_id(id: Ulid) -> Self {
let post_logout_action = PostAuthAction::link_upstream(id);
Self { post_logout_action } Self { post_logout_action }
} }
} }
impl TemplateContext for UpstreamSuggestLink { impl TemplateContext for UpstreamSuggestLink {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self> fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where where
Self: Sized, Self: Sized,
{ {
vec![Self::new(Ulid::nil())] let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![Self::for_link_id(id)]
} }
} }
@ -801,19 +796,26 @@ pub struct UpstreamRegister {
impl UpstreamRegister { impl UpstreamRegister {
/// Constructs a new context with an existing linked user /// Constructs a new context with an existing linked user
#[must_use] #[must_use]
pub fn new(link_id: Ulid) -> Self { pub fn new(link: &UpstreamOAuthLink) -> Self {
let action = PostAuthAction::LinkUpstream { id: link_id }; Self::for_link_id(link.id)
let login_link = mas_router::Login::and_then(action).relative_url().into(); }
fn for_link_id(id: Ulid) -> Self {
let login_link = mas_router::Login::and_link_upstream(id)
.relative_url()
.into();
Self { login_link } Self { login_link }
} }
} }
impl TemplateContext for UpstreamRegister { impl TemplateContext for UpstreamRegister {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self> fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where where
Self: Sized, Self: Sized,
{ {
vec![Self::new(Ulid::nil())] let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![Self::for_link_id(id)]
} }
} }

View File

@ -46,22 +46,10 @@ limitations under the License.
</form> </form>
<div class="text-center mt-4"> <div class="text-center mt-4">
Not {{ current_session.user.username }}? Not {{ current_session.user.username }}?
{{ logout::button(text="Sign out", class=button::text_class(), csrf_token=csrf_token, post_logout_action=action) }} {% set post_logout_action = next | safe_get(key="params") %}
{{ logout::button(text="Sign out", class=button::text_class(), csrf_token=csrf_token, post_logout_action=post_logout_action) }}
</div> </div>
</div> </div>
</section> </section>
<!-- <div class="flex justify-center">
<div class="w-96 m-2">
<h3 class="title is-3">Current session data:</h3>
<pre class="text-sm whitespace-pre-wrap"><code>{{ current_session | json_encode(pretty=True) | safe }}</code></pre>
</div>
{% if next %}
<div class="w-96 m-2">
<h3 class="title is-3">Next action:</h3>
<pre class="text-sm whitespace-pre-wrap"><code>{{ next | json_encode(pretty=True) | safe }}</code></pre>
</div>
{% endif %}
</div> -->
{% endblock content %} {% endblock content %}