diff --git a/Cargo.lock b/Cargo.lock index a3b4325e..d2c189a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2294,6 +2294,7 @@ name = "mas-templates" version = "0.1.0" dependencies = [ "anyhow", + "chrono", "mas-config", "mas-data-model", "oauth2-types", diff --git a/crates/data-model/src/compat.rs b/crates/data-model/src/compat.rs index 9d9cf46e..7f6e0b1e 100644 --- a/crates/data-model/src/compat.rs +++ b/crates/data-model/src/compat.rs @@ -90,6 +90,18 @@ pub struct CompatSession { pub deleted_at: Option>, } +impl From> for CompatSession<()> { + fn from(t: CompatSession) -> Self { + Self { + data: (), + user: t.user.into(), + device: t.device, + created_at: t.created_at, + deleted_at: t.deleted_at, + } + } +} + #[derive(Debug, Clone, PartialEq)] pub struct CompatAccessToken { pub data: T::CompatAccessTokenData, @@ -98,6 +110,17 @@ pub struct CompatAccessToken { pub expires_at: Option>, } +impl From> for CompatAccessToken<()> { + fn from(t: CompatAccessToken) -> Self { + Self { + data: (), + token: t.token, + created_at: t.created_at, + expires_at: t.expires_at, + } + } +} + #[derive(Debug, Clone, PartialEq)] pub struct CompatRefreshToken { pub data: T::RefreshTokenData, @@ -105,13 +128,12 @@ pub struct CompatRefreshToken { pub created_at: DateTime, } -impl From> for CompatAccessToken<()> { - fn from(t: CompatAccessToken) -> Self { - CompatAccessToken { +impl From> for CompatRefreshToken<()> { + fn from(t: CompatRefreshToken) -> Self { + Self { data: (), token: t.token, created_at: t.created_at, - expires_at: t.expires_at, } } } @@ -131,6 +153,30 @@ pub enum CompatSsoLoginState { }, } +impl From> for CompatSsoLoginState<()> { + fn from(t: CompatSsoLoginState) -> Self { + match t { + CompatSsoLoginState::Pending => Self::Pending, + CompatSsoLoginState::Fullfilled { + fullfilled_at, + session, + } => Self::Fullfilled { + fullfilled_at, + session: session.into(), + }, + CompatSsoLoginState::Exchanged { + fullfilled_at, + exchanged_at, + session, + } => Self::Exchanged { + fullfilled_at, + exchanged_at, + session: session.into(), + }, + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize)] #[serde(bound = "T: StorageBackend")] pub struct CompatSsoLogin { @@ -141,3 +187,15 @@ pub struct CompatSsoLogin { pub created_at: DateTime, pub state: CompatSsoLoginState, } + +impl From> for CompatSsoLogin<()> { + fn from(t: CompatSsoLogin) -> Self { + Self { + data: (), + redirect_uri: t.redirect_uri, + token: t.token, + created_at: t.created_at, + state: t.state.into(), + } + } +} diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 1513ac63..ee4a8536 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -40,6 +40,7 @@ enum LoginType { #[serde(rename = "m.login.sso")] Sso { + #[serde(skip_serializing_if = "Vec::is_empty")] identity_providers: Vec, }, } diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 6d0dc01d..54a03c98 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -16,16 +16,20 @@ use std::collections::HashMap; use axum::{ - extract::Path, - response::{IntoResponse, Redirect, Response}, + extract::{Form, Path}, + response::{Html, IntoResponse, Redirect, Response}, Extension, }; use axum_extra::extract::PrivateCookieJar; -use mas_axum_utils::{FancyError, SessionInfoExt}; +use mas_axum_utils::{ + csrf::{CsrfExt, ProtectedForm}, + FancyError, SessionInfoExt, +}; use mas_config::Encrypter; use mas_data_model::Device; use mas_router::Route; use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; +use mas_templates::{CompatSsoContext, TemplateContext, Templates}; use rand::thread_rng; use serde::Serialize; use sqlx::PgPool; @@ -41,12 +45,46 @@ struct AllParams<'s> { pub async fn get( Extension(pool): Extension, + Extension(templates): Extension, cookie_jar: PrivateCookieJar, Path(id): Path, +) -> Result { + let mut conn = pool.acquire().await?; + + let (session_info, cookie_jar) = cookie_jar.session_info(); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(); + + let maybe_session = session_info.load_session(&mut conn).await?; + + let session = if let Some(session) = maybe_session { + session + } else { + // If there is no session, redirect to the login screen + let login = mas_router::Login::and_continue_compat_sso_login(id); + return Ok((cookie_jar, login.go()).into_response()); + }; + + let login = get_compat_sso_login_by_id(&mut conn, id).await?; + + let ctx = CompatSsoContext::new(login) + .with_session(session) + .with_csrf(csrf_token.form_value()); + + let content = templates.render_sso_login(&ctx).await?; + + Ok((cookie_jar, Html(content)).into_response()) +} + +pub async fn post( + Extension(pool): Extension, + cookie_jar: PrivateCookieJar, + Path(id): Path, + Form(form): Form>, ) -> Result { let mut txn = pool.begin().await?; let (session_info, cookie_jar) = cookie_jar.session_info(); + cookie_jar.verify_form(form)?; let maybe_session = session_info.load_session(&mut txn).await?; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 4c82e2b4..635ed10d 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -193,7 +193,8 @@ where ) .route( mas_router::CompatLoginSsoComplete::route(), - get(self::compat::login_sso_complete::get), + get(self::compat::login_sso_complete::get) + .post(self::compat::login_sso_complete::post), ) .layer(ThenLayer::new( move |result: Result| async move { diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index bd54e831..f9274edc 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -13,7 +13,9 @@ // limitations under the License. use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::authorization_grant::get_grant_by_id; +use mas_storage::{ + compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, +}; use mas_templates::PostAuthContext; use serde::{Deserialize, Serialize}; use sqlx::PgConnection; @@ -41,8 +43,10 @@ impl OptionalPostAuthAction { let grant = Box::new(grant.into()); Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant })) } - Some(PostAuthAction::ContinueCompatSsoLogin { .. }) => { - Ok(Some(PostAuthContext::ContinueCompatSsoLogin)) + Some(PostAuthAction::ContinueCompatSsoLogin { data }) => { + let login = get_compat_sso_login_by_id(conn, *data).await?; + let login = Box::new(login.into()); + Ok(Some(PostAuthContext::ContinueCompatSsoLogin { login })) } Some(PostAuthAction::ChangePassword) => Ok(Some(PostAuthContext::ChangePassword)), None => Ok(None), diff --git a/crates/templates/Cargo.toml b/crates/templates/Cargo.toml index edce42d5..caa4d9cf 100644 --- a/crates/templates/Cargo.toml +++ b/crates/templates/Cargo.toml @@ -20,6 +20,7 @@ serde = { version = "1.0.137", features = ["derive"] } serde_json = "1.0.81" serde_urlencoded = "0.7.1" +chrono = "0.4.19" url = "2.2.2" oauth2-types = { path = "../oauth2-types" } diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 7d2d70cf..9a9632a9 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -16,7 +16,11 @@ #![allow(clippy::trait_duplication_in_bounds)] -use mas_data_model::{AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail}; +use chrono::Utc; +use mas_data_model::{ + AuthorizationGrant, BrowserSession, CompatSsoLogin, CompatSsoLoginState, StorageBackend, User, + UserEmail, +}; use serde::{ser::SerializeStruct, Deserialize, Serialize}; use url::Url; @@ -250,7 +254,10 @@ pub enum PostAuthContext { /// Continue legacy login /// TODO: add the login context in there - ContinueCompatSsoLogin, + ContinueCompatSsoLogin { + /// The compat SSO login request + login: Box>, + }, /// Change the account password ChangePassword, @@ -454,6 +461,42 @@ impl ReauthContext { } } +/// Context used by the `sso.html` template +#[derive(Serialize)] +pub struct CompatSsoContext { + login: CompatSsoLogin<()>, +} + +impl TemplateContext for CompatSsoContext { + fn sample() -> Vec + where + Self: Sized, + { + vec![CompatSsoContext { + login: CompatSsoLogin { + data: (), + redirect_uri: Url::parse("https://app.element.io/").unwrap(), + token: "abcdefghijklmnopqrstuvwxyz012345".into(), + created_at: Utc::now(), + state: CompatSsoLoginState::Pending, + }, + }] + } +} + +impl CompatSsoContext { + /// Constructs a context for the legacy SSO login page + #[must_use] + pub fn new(login: T) -> Self + where + T: Into>, + { + Self { + login: login.into(), + } + } +} + /// Context used by the `account/index.html` template #[derive(Serialize)] pub struct AccountContext { diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 232cf994..e755e749 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -45,10 +45,11 @@ mod macros; pub use self::{ context::{ - AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext, - EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, - PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, - TemplateContext, WithCsrf, WithOptionalSession, WithSession, + AccountContext, AccountEmailsContext, CompatSsoContext, ConsentContext, + EmailVerificationContext, EmptyContext, ErrorContext, FormPostContext, IndexContext, + LoginContext, LoginFormField, PostAuthContext, ReauthContext, ReauthFormField, + RegisterContext, RegisterFormField, TemplateContext, WithCsrf, WithOptionalSession, + WithSession, }, forms::{FieldError, FormError, FormField, FormState, ToFormState}, }; @@ -294,9 +295,12 @@ register_templates! { /// Render the registration page pub fn render_register(WithCsrf) { "pages/register.html" } - /// Render the registration page + /// Render the client consent page pub fn render_consent(WithCsrf>) { "pages/consent.html" } + /// Render the client consent page + pub fn render_sso_login(WithCsrf>) { "pages/sso.html" } + /// Render the home page pub fn render_index(WithCsrf>) { "pages/index.html" } diff --git a/crates/templates/src/res/pages/sso.html b/crates/templates/src/res/pages/sso.html new file mode 100644 index 00000000..33cd9aa6 --- /dev/null +++ b/crates/templates/src/res/pages/sso.html @@ -0,0 +1,45 @@ +{# +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +{% extends "base.html" %} + +{% block content %} +
+
+
+
+
+

{{ login.redirect_uri }}

+

wants to access your Matrix account

+
+
+ + + + {{ button::button(text="Allow") }} +
+
+
+ +
+ Not {{ current_session.user.username }}? + {{ button::button_text(text="Sign out", name="logout", type="submit") }} +
+
+
+
+
+{% endblock content %}