diff --git a/Cargo.lock b/Cargo.lock index 73bf436b..c56eb7ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3061,7 +3061,6 @@ dependencies = [ "axum 0.6.1", "serde", "serde_urlencoded", - "serde_with", "ulid", "url", ] diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index ee51daf8..e0416cd1 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -107,7 +107,7 @@ pub async fn get( return Ok((cookie_jar, Html(content)).into_response()); } - let ctx = CompatSsoContext::new(login, PostAuthAction::continue_compat_sso_login(id)) + let ctx = CompatSsoContext::new(login) .with_session(session) .with_csrf(csrf_token.form_value()); diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 4daf43f3..45107783 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -101,7 +101,7 @@ pub(crate) async fn get( .await?; if res.valid() { - let ctx = ConsentContext::new(grant, PostAuthAction::continue_grant(grant_id)) + let ctx = ConsentContext::new(grant) .with_session(session) .with_csrf(csrf_token.form_value()); @@ -109,7 +109,7 @@ pub(crate) async fn get( Ok((cookie_jar, Html(content)).into_response()) } else { - let ctx = PolicyViolationContext::new(grant, PostAuthAction::continue_grant(grant_id)) + let ctx = PolicyViolationContext::new(grant) .with_session(session) .with_csrf(csrf_token.form_value()); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 97badffc..5aeb8978 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -163,7 +163,7 @@ pub(crate) async fn get( (Some(user_session), None) => { // Session not linked, but user logged in: suggest linking account - let ctx = UpstreamSuggestLink::new(link.id) + let ctx = UpstreamSuggestLink::new(&link) .with_session(user_session) .with_csrf(csrf_token.form_value()); @@ -182,7 +182,7 @@ pub(crate) async fn get( (None, None) => { // Session not linked and used not logged in: suggest creating an // account or logging in an existing user - let ctx = UpstreamRegister::new(link.id).with_csrf(csrf_token.form_value()); + let ctx = UpstreamRegister::new(&link).with_csrf(csrf_token.form_value()); templates.render_upstream_oauth2_do_register(&ctx).await? } diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index a7b9cb33..1b546ee1 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -61,10 +61,7 @@ pub(crate) async fn get( let ctx = ReauthContext::default(); let next = query.load_context(&mut conn).await?; let ctx = if let Some(next) = next { - // SAFETY: we should have an action only if we have a "next" context - // TODO: make that cleaner - let action = query.post_auth_action.unwrap(); - ctx.with_post_action(next, action) + ctx.with_post_action(next) } else { ctx }; diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 941040ce..fcdef3b4 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -38,22 +38,22 @@ impl OptionalPostAuthAction { self.go_next_or_default(&mas_router::Index) } - pub async fn load_context<'e>( + pub async fn load_context( &self, conn: &mut PgConnection, ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { - PostAuthAction::ContinueAuthorizationGrant { data } => { - let grant = get_grant_by_id(conn, data) + PostAuthAction::ContinueAuthorizationGrant { id } => { + let grant = get_grant_by_id(conn, id) .await? .context("Failed to load authorization grant")?; let grant = Box::new(grant); PostAuthContextInner::ContinueAuthorizationGrant { grant } } - PostAuthAction::ContinueCompatSsoLogin { data } => { - let login = get_compat_sso_login_by_id(conn, data) + PostAuthAction::ContinueCompatSsoLogin { id } => { + let login = get_compat_sso_login_by_id(conn, id) .await? .context("Failed to load compat SSO login")?; let login = Box::new(login); diff --git a/crates/router/Cargo.toml b/crates/router/Cargo.toml index a77c3921..3578d8e4 100644 --- a/crates/router/Cargo.toml +++ b/crates/router/Cargo.toml @@ -9,6 +9,5 @@ license = "Apache-2.0" axum = { version = "0.6.1", default-features = false } serde = { version = "1.0.149", features = ["derive"] } serde_urlencoded = "0.7.1" -serde_with = "2.1.0" url = "2.3.1" ulid = "1.0.0" diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 5d11f939..dbfe301a 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -13,45 +13,39 @@ // limitations under the License. use serde::{Deserialize, Serialize}; -use serde_with::{serde_as, DisplayFromStr}; use ulid::Ulid; pub use crate::traits::*; -#[serde_as] #[derive(Deserialize, Serialize, Clone, Debug)] #[serde(rename_all = "snake_case", tag = "next")] pub enum PostAuthAction { - ContinueAuthorizationGrant { - #[serde_as(as = "DisplayFromStr")] - data: Ulid, - }, - ContinueCompatSsoLogin { - #[serde_as(as = "DisplayFromStr")] - data: Ulid, - }, + ContinueAuthorizationGrant { id: Ulid }, + ContinueCompatSsoLogin { id: Ulid }, ChangePassword, - LinkUpstream { - #[serde_as(as = "DisplayFromStr")] - id: Ulid, - }, + LinkUpstream { id: Ulid }, } impl PostAuthAction { #[must_use] - pub fn continue_grant(data: Ulid) -> Self { - PostAuthAction::ContinueAuthorizationGrant { data } + pub const fn continue_grant(id: Ulid) -> Self { + PostAuthAction::ContinueAuthorizationGrant { id } } #[must_use] - pub fn continue_compat_sso_login(data: Ulid) -> Self { - PostAuthAction::ContinueCompatSsoLogin { data } + pub const fn continue_compat_sso_login(id: Ulid) -> Self { + PostAuthAction::ContinueCompatSsoLogin { id } + } + + #[must_use] + pub const fn link_upstream(id: Ulid) -> Self { + PostAuthAction::LinkUpstream { id } } pub fn go_next(&self) -> axum::response::Redirect { match self { - Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(), - Self::ContinueCompatSsoLogin { data } => CompatLoginSsoComplete::new(*data, None).go(), + Self::ContinueAuthorizationGrant { id } => ContinueAuthorizationGrant(*id).go(), + Self::ContinueCompatSsoLogin { id } => CompatLoginSsoComplete::new(*id, None).go(), Self::ChangePassword => AccountPassword.go(), Self::LinkUpstream { id } => UpstreamOAuth2Link::new(*id).go(), } @@ -165,23 +159,30 @@ impl Route for Login { impl Login { #[must_use] - pub fn and_then(action: PostAuthAction) -> Self { + pub const fn and_then(action: PostAuthAction) -> Self { Self { post_auth_action: Some(action), } } #[must_use] - pub fn and_continue_grant(data: Ulid) -> Self { + pub const fn and_continue_grant(id: Ulid) -> Self { Self { - post_auth_action: Some(PostAuthAction::continue_grant(data)), + post_auth_action: Some(PostAuthAction::continue_grant(id)), } } #[must_use] - pub fn and_continue_compat_sso_login(data: Ulid) -> Self { + pub const fn and_continue_compat_sso_login(id: Ulid) -> Self { Self { - post_auth_action: Some(PostAuthAction::continue_compat_sso_login(data)), + post_auth_action: Some(PostAuthAction::continue_compat_sso_login(id)), + } + } + + #[must_use] + pub const fn and_link_upstream(id: Ulid) -> Self { + Self { + post_auth_action: Some(PostAuthAction::link_upstream(id)), } } diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 1aa5d33c..447ad4ee 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -411,7 +411,8 @@ impl TemplateContext for ConsentContext { impl ConsentContext { /// Constructs a context for the client consent page #[must_use] - pub fn new(grant: AuthorizationGrant, action: PostAuthAction) -> Self { + pub fn new(grant: AuthorizationGrant) -> Self { + let action = PostAuthAction::continue_grant(grant.id); Self { grant, action } } } @@ -436,7 +437,8 @@ impl TemplateContext for PolicyViolationContext { impl PolicyViolationContext { /// Constructs a context for the policy violation page #[must_use] - pub fn new(grant: AuthorizationGrant, action: PostAuthAction) -> Self { + pub const fn new(grant: AuthorizationGrant) -> Self { + let action = PostAuthAction::continue_grant(grant.id); Self { grant, action } } } @@ -462,7 +464,6 @@ impl FormField for ReauthFormField { pub struct ReauthContext { form: FormState, next: Option, - action: Option, } impl TemplateContext for ReauthContext { @@ -474,7 +475,6 @@ impl TemplateContext for ReauthContext { vec![ReauthContext { form: FormState::default(), next: None, - action: None, }] } } @@ -488,10 +488,9 @@ impl ReauthContext { /// Add a post authentication action to the context #[must_use] - pub fn with_post_action(self, next: PostAuthContext, action: PostAuthAction) -> Self { + pub fn with_post_action(self, next: PostAuthContext) -> Self { Self { next: Some(next), - action: Some(action), ..self } } @@ -510,24 +509,22 @@ impl TemplateContext for CompatSsoContext { Self: Sized, { let id = Ulid::from_datetime_with_source(now.into(), rng); - vec![CompatSsoContext { - login: CompatSsoLogin { - id, - redirect_uri: Url::parse("https://app.element.io/").unwrap(), - login_token: "abcdefghijklmnopqrstuvwxyz012345".into(), - created_at: now, - state: CompatSsoLoginState::Pending, - }, - action: PostAuthAction::ContinueCompatSsoLogin { data: id }, - }] + vec![CompatSsoContext::new(CompatSsoLogin { + id, + redirect_uri: Url::parse("https://app.element.io/").unwrap(), + login_token: "abcdefghijklmnopqrstuvwxyz012345".into(), + created_at: now, + state: CompatSsoLoginState::Pending, + })] } } impl CompatSsoContext { /// Constructs a context for the legacy SSO login page #[must_use] - pub fn new(login: CompatSsoLogin, action: PostAuthAction) -> Self + pub fn new(login: CompatSsoLogin) -> Self where { + let action = PostAuthAction::continue_compat_sso_login(login.id); Self { login, action } } } @@ -654,13 +651,10 @@ pub struct EmailVerificationPageContext { impl EmailVerificationPageContext { /// Constructs a context for the email verification page #[must_use] - pub fn new(email: T) -> Self - where - T: Into, - { + pub fn new(email: UserEmail) -> Self { Self { form: FormState::default(), - email: email.into(), + email, } } @@ -744,13 +738,9 @@ pub struct UpstreamExistingLinkContext { impl UpstreamExistingLinkContext { /// Constructs a new context with an existing linked user - pub fn new(linked_user: T) -> Self - where - T: Into, - { - Self { - linked_user: linked_user.into(), - } + #[must_use] + pub fn new(linked_user: User) -> Self { + Self { linked_user } } } @@ -776,18 +766,23 @@ pub struct UpstreamSuggestLink { impl UpstreamSuggestLink { /// Constructs a new context with an existing linked user #[must_use] - pub fn new(link_id: Ulid) -> Self { - let post_logout_action = PostAuthAction::LinkUpstream { id: link_id }; + pub fn new(link: &UpstreamOAuthLink) -> Self { + Self::for_link_id(link.id) + } + + fn for_link_id(id: Ulid) -> Self { + let post_logout_action = PostAuthAction::link_upstream(id); Self { post_logout_action } } } impl TemplateContext for UpstreamSuggestLink { - fn sample(_now: chrono::DateTime, _rng: &mut impl Rng) -> Vec + fn sample(now: chrono::DateTime, rng: &mut impl Rng) -> Vec where Self: Sized, { - vec![Self::new(Ulid::nil())] + let id = Ulid::from_datetime_with_source(now.into(), rng); + vec![Self::for_link_id(id)] } } @@ -801,19 +796,26 @@ pub struct UpstreamRegister { impl UpstreamRegister { /// Constructs a new context with an existing linked user #[must_use] - pub fn new(link_id: Ulid) -> Self { - let action = PostAuthAction::LinkUpstream { id: link_id }; - let login_link = mas_router::Login::and_then(action).relative_url().into(); + pub fn new(link: &UpstreamOAuthLink) -> Self { + Self::for_link_id(link.id) + } + + fn for_link_id(id: Ulid) -> Self { + let login_link = mas_router::Login::and_link_upstream(id) + .relative_url() + .into(); + Self { login_link } } } impl TemplateContext for UpstreamRegister { - fn sample(_now: chrono::DateTime, _rng: &mut impl Rng) -> Vec + fn sample(now: chrono::DateTime, rng: &mut impl Rng) -> Vec where Self: Sized, { - vec![Self::new(Ulid::nil())] + let id = Ulid::from_datetime_with_source(now.into(), rng); + vec![Self::for_link_id(id)] } } diff --git a/templates/pages/reauth.html b/templates/pages/reauth.html index 746f792b..c919af1e 100644 --- a/templates/pages/reauth.html +++ b/templates/pages/reauth.html @@ -46,22 +46,10 @@ limitations under the License.
Not {{ current_session.user.username }}? - {{ logout::button(text="Sign out", class=button::text_class(), csrf_token=csrf_token, post_logout_action=action) }} + {% set post_logout_action = next | safe_get(key="params") %} + {{ logout::button(text="Sign out", class=button::text_class(), csrf_token=csrf_token, post_logout_action=post_logout_action) }}
- - {% endblock content %}