1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Better frontend assets handling and move the react app to /account/ (#1324)

This makes the Vite assets handling better, namely:

 - make it possible to include any vite assets in the templates
 - include the right `<link rel="preload">` tags for assets
 - include Subresource Integrity hashes
 - pre-compress assets and remove on-the-fly compression by the Rust server
 - build the CSS used by templates through Vite

It also moves the React app from /app/ to /account/, and remove some of the old SSR account screens.
This commit is contained in:
Quentin Gliech
2023-07-06 15:30:26 +02:00
committed by GitHub
parent 6cae2adc08
commit 76653f9638
47 changed files with 1096 additions and 1011 deletions

View File

@ -25,7 +25,7 @@ serde_yaml = "0.9.22"
sqlx = { version = "0.6.3", features = ["runtime-tokio-rustls", "postgres"] }
tokio = { version = "1.29.1", features = ["full"] }
tower = { version = "0.4.13", features = ["full"] }
tower-http = { version = "0.4.1", features = ["fs", "compression-full"] }
tower-http = { version = "0.4.1", features = ["fs"] }
url = "2.4.0"
watchman_client = "0.8.0"
zeroize = "1.6.0"

View File

@ -83,8 +83,11 @@ impl Options {
let policy_factory = policy_factory_from_config(&config.policy).await?;
let policy_factory = Arc::new(policy_factory);
let url_builder =
UrlBuilder::new(config.http.public_base.clone(), config.http.issuer.clone());
let url_builder = UrlBuilder::new(
config.http.public_base.clone(),
config.http.issuer.clone(),
None,
);
// Load and compile the templates
let templates = templates_from_config(&config.templates, &url_builder).await?;

View File

@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use camino::Utf8PathBuf;
use clap::Parser;
use mas_config::TemplatesConfig;
use mas_storage::{Clock, SystemClock};
use mas_templates::Templates;
use rand::SeedableRng;
use tracing::info_span;
use crate::util::templates_from_config;
#[derive(Parser, Debug)]
pub(super) struct Options {
#[clap(subcommand)]
@ -27,26 +28,24 @@ pub(super) struct Options {
#[derive(Parser, Debug)]
enum Subcommand {
/// Check for template validity at given path.
Check {
/// Path where the templates are
path: Utf8PathBuf,
},
/// Check that the templates specified in the config are valid
Check,
}
impl Options {
pub async fn run(self, _root: &super::Options) -> anyhow::Result<()> {
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC;
match self.subcommand {
SC::Check { path } => {
SC::Check => {
let _span = info_span!("cli.templates.check").entered();
let config: TemplatesConfig = root.load_config()?;
let clock = SystemClock::default();
// XXX: we should disallow SeedableRng::from_entropy
let mut rng = rand_chacha::ChaChaRng::from_entropy();
let url_builder =
mas_router::UrlBuilder::new("https://example.com/".parse()?, None);
let templates = Templates::load(path, url_builder).await?;
mas_router::UrlBuilder::new("https://example.com/".parse()?, None, None);
let templates = templates_from_config(&config, &url_builder).await?;
templates.check_render(clock.now(), &mut rng).await?;
Ok(())

View File

@ -37,8 +37,11 @@ impl Options {
info!("Connecting to the database");
let pool = database_from_config(&config.database).await?;
let url_builder =
UrlBuilder::new(config.http.public_base.clone(), config.http.issuer.clone());
let url_builder = UrlBuilder::new(
config.http.public_base.clone(),
config.http.issuer.clone(),
None,
);
// Load and compile the templates
let templates = templates_from_config(&config.templates, &url_builder).await?;

View File

@ -26,13 +26,15 @@ use axum::{
extract::{FromRef, MatchedPath},
Extension, Router,
};
use hyper::{Method, Request, Response, StatusCode, Version};
use hyper::{
header::{HeaderValue, CACHE_CONTROL},
Method, Request, Response, StatusCode, Version,
};
use listenfd::ListenFd;
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
use mas_handlers::AppState;
use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo};
use mas_router::Route;
use mas_spa::ViteManifestService;
use mas_templates::Templates;
use mas_tower::{
make_span_fn, metrics_attributes_fn, DurationRecorderLayer, InFlightCounterLayer, TraceLayer,
@ -46,8 +48,8 @@ use opentelemetry_semantic_conventions::trace::{
use rustls::ServerConfig;
use sentry_tower::{NewSentryLayer, SentryHttpLayer};
use tower::Layer;
use tower_http::{compression::CompressionLayer, services::ServeDir};
use tracing::Span;
use tower_http::{services::ServeDir, set_header::SetResponseHeaderLayer};
use tracing::{warn, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
const NET_PROTOCOL_NAME: Key = Key::from_static_str("net.protocol.name");
@ -192,13 +194,23 @@ where
router.merge(mas_handlers::graphql_router::<AppState, B>(*playground))
}
mas_config::HttpResource::Assets { path } => {
let static_service = ServeDir::new(path).append_index_html_on_directories(false);
let static_service = ServeDir::new(path)
.append_index_html_on_directories(false)
.precompressed_br()
.precompressed_gzip()
.precompressed_deflate();
let error_layer =
HandleErrorLayer::new(|_e| ready(StatusCode::INTERNAL_SERVER_ERROR));
let cache_layer = SetResponseHeaderLayer::overriding(
CACHE_CONTROL,
HeaderValue::from_static("public, max-age=31536000, immutable"),
);
router.nest_service(
mas_router::StaticAsset::route(),
error_layer.layer(static_service),
(error_layer, cache_layer).layer(static_service),
)
}
mas_config::HttpResource::OAuth => {
@ -215,25 +227,10 @@ where
}),
),
mas_config::HttpResource::Spa { manifest } => {
let error_layer =
HandleErrorLayer::new(|_e| ready(StatusCode::INTERNAL_SERVER_ERROR));
// TODO: make those paths configurable
let app_base = "/app/";
// TODO: make that config typed and configurable
let config = serde_json::json!({
"root": app_base,
});
let index_service = ViteManifestService::new(
manifest.clone(),
mas_router::StaticAsset::route().into(),
config,
);
router.nest_service(app_base, error_layer.layer(index_service))
#[allow(deprecated)]
mas_config::HttpResource::Spa { .. } => {
warn!("The SPA HTTP resource is deprecated");
router
}
}
}
@ -266,7 +263,6 @@ where
)
.layer(SentryHttpLayer::new())
.layer(NewSentryLayer::new_from_top())
.layer(CompressionLayer::new())
.with_state(state)
}

View File

@ -111,7 +111,12 @@ pub async fn templates_from_config(
config: &TemplatesConfig,
url_builder: &UrlBuilder,
) -> Result<Templates, TemplateLoadingError> {
Templates::load(config.path.clone(), url_builder.clone()).await
Templates::load(
config.path.clone(),
url_builder.clone(),
config.assets_manifest.clone(),
)
.await
}
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(deprecated)]
use std::{borrow::Cow, io::Cursor, ops::Deref};
use anyhow::bail;
@ -43,21 +45,11 @@ fn http_address_example_4() -> &'static str {
"0.0.0.0:8080"
}
#[cfg(not(feature = "docker"))]
fn http_listener_spa_manifest_default() -> Utf8PathBuf {
"./frontend/dist/manifest.json".into()
}
#[cfg(not(feature = "docker"))]
fn http_listener_assets_path_default() -> Utf8PathBuf {
"./frontend/dist/".into()
}
#[cfg(feature = "docker")]
fn http_listener_spa_manifest_default() -> Utf8PathBuf {
"/usr/local/share/mas-cli/manifest.json".into()
}
#[cfg(feature = "docker")]
fn http_listener_assets_path_default() -> Utf8PathBuf {
"/usr/local/share/mas-cli/assets/".into()
@ -285,12 +277,10 @@ pub enum Resource {
ConnectionInfo,
/// Mount the single page app
Spa {
/// Path to the vite manifest.json
#[serde(default = "http_listener_spa_manifest_default")]
#[schemars(with = "String")]
manifest: Utf8PathBuf,
},
///
/// This is deprecated and will be removed in a future release.
#[deprecated = "This resource is deprecated and will be removed in a future release"]
Spa,
}
/// Configuration of a listener
@ -346,9 +336,6 @@ impl Default for HttpConfig {
Resource::Assets {
path: http_listener_assets_path_default(),
},
Resource::Spa {
manifest: http_listener_spa_manifest_default(),
},
],
tls: None,
proxy_protocol: false,

View File

@ -30,6 +30,16 @@ fn default_path() -> Utf8PathBuf {
"/usr/local/share/mas-cli/templates/".into()
}
#[cfg(not(feature = "docker"))]
fn default_assets_path() -> Utf8PathBuf {
"./frontend/dist/manifest.json".into()
}
#[cfg(feature = "docker")]
fn default_assets_path() -> Utf8PathBuf {
"/usr/local/share/mas-cli/manifest.json".into()
}
/// Configuration related to templates
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct TemplatesConfig {
@ -37,12 +47,18 @@ pub struct TemplatesConfig {
#[serde(default = "default_path")]
#[schemars(with = "Option<String>")]
pub path: Utf8PathBuf,
/// Path to the assets manifest
#[serde(default = "default_assets_path")]
#[schemars(with = "Option<String>")]
pub assets_manifest: Utf8PathBuf,
}
impl Default for TemplatesConfig {
fn default() -> Self {
Self {
path: default_path(),
assets_manifest: default_assets_path(),
}
}
}

View File

@ -68,6 +68,7 @@ mas-matrix = { path = "../matrix" }
mas-oidc-client = { path = "../oidc-client" }
mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-spa = { path = "../spa" }
mas-storage = { path = "../storage" }
mas-storage-pg = { path = "../storage-pg" }
mas-templates = { path = "../templates" }

View File

@ -268,6 +268,12 @@ where
BoxRng: FromRequestParts<S>,
{
Router::new()
// TODO: mount this route somewhere else?
.route(mas_router::Account::route(), get(self::views::app::get))
.route(
mas_router::AccountWildcard::route(),
get(self::views::app::get),
)
.route(
mas_router::ChangePasswordDiscovery::route(),
get(|| async { mas_router::AccountPassword.go() }),
@ -286,15 +292,10 @@ where
mas_router::Register::route(),
get(self::views::register::get).post(self::views::register::post),
)
.route(mas_router::Account::route(), get(self::views::account::get))
.route(
mas_router::AccountPassword::route(),
get(self::views::account::password::get).post(self::views::account::password::post),
)
.route(
mas_router::AccountEmails::route(),
get(self::views::account::emails::get).post(self::views::account::emails::post),
)
.route(
mas_router::AccountVerifyEmail::route(),
get(self::views::account::emails::verify::get)

View File

@ -110,10 +110,14 @@ impl TestState {
.join("..")
.join("..");
let url_builder = UrlBuilder::new("https://example.com/".parse()?, None);
let url_builder = UrlBuilder::new("https://example.com/".parse()?, None, None);
let templates =
Templates::load(workspace_root.join("templates"), url_builder.clone()).await?;
let templates = Templates::load(
workspace_root.join("templates"),
url_builder.clone(),
workspace_root.join("frontend/dist/manifest.json"),
)
.await?;
// TODO: add more test keys to the store
let rsa =

View File

@ -12,190 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::{anyhow, Context};
use axum::{
extract::{Form, State},
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
FancyError, SessionInfoExt,
};
use mas_data_model::BrowserSession;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob},
user::UserEmailRepository,
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::{AccountEmailsContext, TemplateContext, Templates};
use rand::Rng;
use serde::Deserialize;
pub mod add;
pub mod verify;
#[derive(Deserialize, Debug)]
#[serde(tag = "action", rename_all = "snake_case")]
pub enum ManagementForm {
Add { email: String },
ResendConfirmation { id: String },
SetPrimary { id: String },
Remove { id: String },
}
#[tracing::instrument(name = "handlers.views.account_email_list.get", skip_all, err)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session) = maybe_session {
render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await
} else {
let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response())
}
}
async fn render<E: std::error::Error>(
rng: impl Rng + Send,
clock: &impl Clock,
templates: Templates,
session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let emails = repo.user_email().all(&session.user).await?;
let ctx = AccountEmailsContext::new(emails)
.with_session(session)
.with_csrf(csrf_token.form_value());
let content = templates.render_account_emails(&ctx).await?;
Ok((cookie_jar, Html(content)).into_response())
}
#[tracing::instrument(name = "handlers.views.account_email_list.post", skip_all, err)]
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let Some(mut session) = maybe_session else {
let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response());
};
let form = cookie_jar.verify_form(&clock, form)?;
match form {
ManagementForm::Add { email } => {
let user_email = repo
.user_email()
.add(&mut rng, &clock, &session.user, email)
.await?;
let next = mas_router::AccountVerifyEmail::new(user_email.id);
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
repo.save().await?;
return Ok((cookie_jar, next.go()).into_response());
}
ManagementForm::ResendConfirmation { id } => {
let id = id.parse()?;
let user_email = repo
.user_email()
.lookup(id)
.await?
.context("Email not found")?;
if user_email.user_id != session.user.id {
return Err(anyhow!("Email not found").into());
}
let next = mas_router::AccountVerifyEmail::new(user_email.id);
repo.job()
.schedule_job(VerifyEmailJob::new(&user_email))
.await?;
repo.save().await?;
return Ok((cookie_jar, next.go()).into_response());
}
ManagementForm::Remove { id } => {
let id = id.parse()?;
let email = repo
.user_email()
.lookup(id)
.await?
.context("Email not found")?;
if email.user_id != session.user.id {
return Err(anyhow!("Email not found").into());
}
repo.user_email().remove(email).await?;
}
ManagementForm::SetPrimary { id } => {
let id = id.parse()?;
let email = repo
.user_email()
.lookup(id)
.await?
.context("Email not found")?;
if email.user_id != session.user.id {
return Err(anyhow!("Email not found").into());
}
repo.user_email().set_as_primary(&email).await?;
session.user.primary_user_email_id = Some(email.id);
}
};
// XXX: It shouldn't hurt to do this even if the user didn't change their emails
// in a meaningful way
repo.job()
.schedule_job(ProvisionUserJob::new(&session.user))
.await?;
let reply = render(
&mut rng,
&clock,
templates.clone(),
session,
cookie_jar,
&mut repo,
)
.await?;
repo.save().await?;
Ok(reply)
}

View File

@ -74,7 +74,7 @@ pub(crate) async fn get(
if user_email.confirmed_at.is_some() {
// This email was already verified, skip
let destination = query.go_next_or_default(&mas_router::AccountEmails);
let destination = query.go_next_or_default(&mas_router::Account);
return Ok((cookie_jar, destination).into_response());
}
@ -146,6 +146,6 @@ pub(crate) async fn post(
repo.save().await?;
let destination = query.go_next_or_default(&mas_router::AccountEmails);
let destination = query.go_next_or_default(&mas_router::Account);
Ok((cookie_jar, destination).into_response())
}

View File

@ -14,48 +14,3 @@
pub mod emails;
pub mod password;
use axum::{
extract::State,
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository},
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{AccountContext, TemplateContext, Templates};
#[tracing::instrument(name = "handlers.views.account.get", skip_all, err)]
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(templates): State<Templates>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let Some(session) = maybe_session else {
let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response());
};
let active_sessions = repo.browser_session().count_active(&session.user).await?;
let emails = repo.user_email().all(&session.user).await?;
let ctx = AccountContext::new(active_sessions, emails)
.with_session(session)
.with_csrf(csrf_token.form_value());
let content = templates.render_account_index(&ctx).await?;
Ok((cookie_jar, Html(content)).into_response())
}

View File

@ -0,0 +1,48 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::{
extract::State,
response::{Html, IntoResponse},
};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route};
use mas_storage::BoxRepository;
use mas_templates::{AppContext, Templates};
#[tracing::instrument(name = "handlers.views.app.get", skip_all, err)]
pub async fn get(
State(templates): State<Templates>,
mut repo: BoxRepository,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<impl IntoResponse, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut repo).await?;
// TODO: keep the full path
if session.is_none() {
return Ok((
cookie_jar,
mas_router::Login::and_then(PostAuthAction::ManageAccount).go(),
)
.into_response());
}
let ctx = AppContext::default();
let content = templates.render_app(&ctx).await?;
Ok((cookie_jar, Html(content)).into_response())
}

View File

@ -13,6 +13,7 @@
// limitations under the License.
pub mod account;
pub mod app;
pub mod index;
pub mod login;
pub mod logout;

View File

@ -87,6 +87,8 @@ impl OptionalPostAuthAction {
let link = Box::new(link);
PostAuthContextInner::LinkUpstream { provider, link }
}
PostAuthAction::ManageAccount => PostAuthContextInner::ManageAccount,
};
Ok(Some(PostAuthContext {

View File

@ -60,7 +60,7 @@ features = ["client", "http1", "http2", "stream", "runtime" ]
optional = true
[dependencies.tower-http]
version = "0.4.1"
features = ["follow-redirect", "decompression-full", "set-header", "timeout"]
features = ["follow-redirect", "set-header", "timeout", "map-request-body", "util"]
optional = true
[dev-dependencies]

View File

@ -19,17 +19,13 @@
use std::time::Duration;
use http::{header::USER_AGENT, HeaderValue};
use http_body::Full;
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use tower::{limit::ConcurrencyLimitLayer, BoxError, ServiceBuilder};
use tower_http::{
decompression::DecompressionLayer, follow_redirect::FollowRedirectLayer,
set_header::SetRequestHeaderLayer, timeout::TimeoutLayer,
};
use mas_http::BodyToBytesResponseLayer;
use tower::{BoxError, ServiceBuilder};
use tower_http::{timeout::TimeoutLayer, ServiceBuilderExt};
mod body_layer;
use self::body_layer::BodyLayer;
use super::HttpService;
static MAS_USER_AGENT: HeaderValue = HeaderValue::from_static("mas-oidc-client/0.0.1");
@ -60,14 +56,11 @@ pub fn hyper_service() -> HttpService {
let client = ServiceBuilder::new()
.map_err(BoxError::from)
.layer(BodyLayer::default())
.layer(DecompressionLayer::new())
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
.map_request_body(Full::new)
.layer(BodyToBytesResponseLayer::default())
.override_request_header(USER_AGENT, MAS_USER_AGENT.clone())
.concurrency_limit(10)
.follow_redirects()
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(client);

View File

@ -1,88 +0,0 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::task::Poll;
use bytes::Bytes;
use futures_util::future::BoxFuture;
use http::{Request, Response};
use http_body::{Body, Full};
use hyper::body::to_bytes;
use thiserror::Error;
use tower::{BoxError, Layer, Service};
#[derive(Debug, Error)]
#[error(transparent)]
pub enum BodyError<E> {
Decompression(BoxError),
Service(E),
}
#[derive(Clone)]
pub struct BodyService<S> {
inner: S,
}
impl<S> BodyService<S> {
pub const fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S, E, ResBody> Service<Request<Bytes>> for BodyService<S>
where
S: Service<Request<Full<Bytes>>, Response = Response<ResBody>, Error = E>,
ResBody: Body<Data = Bytes, Error = BoxError> + Send,
S::Future: Send + 'static,
{
type Error = BodyError<E>;
type Response = Response<Bytes>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(BodyError::Service)
}
fn call(&mut self, request: Request<Bytes>) -> Self::Future {
let (parts, body) = request.into_parts();
let body = Full::new(body);
let request = Request::from_parts(parts, body);
let fut = self.inner.call(request);
let fut = async {
let response = fut.await.map_err(BodyError::Service)?;
let (parts, body) = response.into_parts();
let body = to_bytes(body).await.map_err(BodyError::Decompression)?;
let response = Response::from_parts(parts, body);
Ok(response)
};
Box::pin(fut)
}
}
#[derive(Default, Clone, Copy)]
pub struct BodyLayer(());
impl<S> Layer<S> for BodyLayer {
type Service = BodyService<S>;
fn layer(&self, inner: S) -> Self::Service {
BodyService::new(inner)
}
}

View File

@ -24,6 +24,7 @@ pub enum PostAuthAction {
ContinueCompatSsoLogin { id: Ulid },
ChangePassword,
LinkUpstream { id: Ulid },
ManageAccount,
}
impl PostAuthAction {
@ -48,6 +49,7 @@ impl PostAuthAction {
Self::ContinueCompatSsoLogin { id } => CompatLoginSsoComplete::new(*id, None).go(),
Self::ChangePassword => AccountPassword.go(),
Self::LinkUpstream { id } => UpstreamOAuth2Link::new(*id).go(),
Self::ManageAccount => Account.go(),
}
}
}
@ -335,7 +337,7 @@ impl From<Option<PostAuthAction>> for Register {
}
}
/// `GET|POST /account/emails/verify/:id`
/// `GET|POST /verify-email/:id`
#[derive(Debug, Clone)]
pub struct AccountVerifyEmail {
id: Ulid,
@ -367,19 +369,19 @@ impl AccountVerifyEmail {
impl Route for AccountVerifyEmail {
type Query = PostAuthAction;
fn route() -> &'static str {
"/account/emails/verify/:id"
}
fn path(&self) -> std::borrow::Cow<'static, str> {
format!("/account/emails/verify/{}", self.id).into()
"/verify-email/:id"
}
fn query(&self) -> Option<&Self::Query> {
self.post_auth_action.as_ref()
}
fn path(&self) -> std::borrow::Cow<'static, str> {
format!("/verify-email/{}", self.id).into()
}
}
/// `GET /account/emails/add`
/// `GET /add-email`
#[derive(Default, Debug, Clone)]
pub struct AccountAddEmail {
post_auth_action: Option<PostAuthAction>,
@ -388,7 +390,7 @@ pub struct AccountAddEmail {
impl Route for AccountAddEmail {
type Query = PostAuthAction;
fn route() -> &'static str {
"/account/emails/add"
"/add-email"
}
fn query(&self) -> Option<&Self::Query> {
@ -404,28 +406,28 @@ impl AccountAddEmail {
}
}
/// `GET /account`
/// `GET /account/`
#[derive(Default, Debug, Clone)]
pub struct Account;
impl SimpleRoute for Account {
const PATH: &'static str = "/account";
const PATH: &'static str = "/account/";
}
/// `GET|POST /account/password`
/// `GET /account/*`
#[derive(Default, Debug, Clone)]
pub struct AccountWildcard;
impl SimpleRoute for AccountWildcard {
const PATH: &'static str = "/account/*rest";
}
/// `GET|POST /change-password`
#[derive(Default, Debug, Clone)]
pub struct AccountPassword;
impl SimpleRoute for AccountPassword {
const PATH: &'static str = "/account/password";
}
/// `GET|POST /account/emails`
#[derive(Default, Debug, Clone)]
pub struct AccountEmails;
impl SimpleRoute for AccountEmails {
const PATH: &'static str = "/account/emails";
const PATH: &'static str = "/change-password";
}
/// `GET /authorize/:grant_id`

View File

@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -14,6 +14,8 @@
//! Utility to build URLs
use std::borrow::Cow;
use ulid::Ulid;
use url::Url;
@ -21,7 +23,8 @@ use crate::traits::Route;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct UrlBuilder {
base: Url,
http_base: Url,
assets_base: Cow<'static, str>,
issuer: Url,
}
@ -30,21 +33,26 @@ impl UrlBuilder {
where
U: Route,
{
destination.absolute_url(&self.base)
destination.absolute_url(&self.http_base)
}
pub fn absolute_redirect<U>(&self, destination: &U) -> axum::response::Redirect
where
U: Route,
{
destination.go_absolute(&self.base)
destination.go_absolute(&self.http_base)
}
/// Create a new [`UrlBuilder`] from a base URL
#[must_use]
pub fn new(base: Url, issuer: Option<Url>) -> Self {
pub fn new(base: Url, issuer: Option<Url>, assets_base: Option<String>) -> Self {
let issuer = issuer.unwrap_or_else(|| base.clone());
Self { base, issuer }
let assets_base = assets_base.map_or(Cow::Borrowed("/assets/"), Cow::Owned);
Self {
http_base: base,
assets_base,
issuer,
}
}
/// OIDC issuer
@ -107,6 +115,12 @@ impl UrlBuilder {
self.url_for(&crate::endpoints::StaticAsset::new(path))
}
/// Static asset base
#[must_use]
pub fn assets_base(&self) -> &str {
&self.assets_base
}
/// Upstream redirect URI
#[must_use]
pub fn upstream_oauth_callback(&self, id: Ulid) -> Url {

View File

@ -7,14 +7,6 @@ license = "Apache-2.0"
[dependencies]
serde = { version = "1.0.166", features = ["derive"] }
serde_json = "1.0.100"
thiserror = "1.0.41"
camino = { version = "1.1.4", features = ["serde1"] }
headers = "0.3.8"
http = "0.2.9"
tower-service = "0.3.2"
tower-http = { version = "0.4.1", features = ["fs"] }
tokio = { version = "1.29.1", features = ["fs"] }
[[bin]]
name = "render"

View File

@ -1,32 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use camino::Utf8Path;
use mas_spa::ViteManifest;
fn main() {
let mut stdin = std::io::stdin();
let manifest: ViteManifest =
serde_json::from_reader(&mut stdin).expect("failed to read manifest from stdin");
let assets_base = Utf8Path::new("/assets/");
let config = serde_json::json!({
"root": "/app/",
});
let html = manifest
.render(assets_base, &config)
.expect("failed to render");
println!("{html}");
}

View File

@ -25,73 +25,4 @@
mod vite;
use std::{future::Future, pin::Pin};
use camino::Utf8PathBuf;
use headers::{ContentType, HeaderMapExt};
use http::Response;
use serde::Serialize;
use tower_service::Service;
pub use self::vite::Manifest as ViteManifest;
/// Service which renders an `index.html` based on the files in the manifest
#[derive(Debug, Clone)]
pub struct ViteManifestService<T> {
manifest: Utf8PathBuf,
assets_base: Utf8PathBuf,
config: T,
}
impl<T> ViteManifestService<T> {
#[must_use]
pub const fn new(manifest: Utf8PathBuf, assets_base: Utf8PathBuf, config: T) -> Self {
Self {
manifest,
assets_base,
config,
}
}
}
impl<T, R> Service<R> for ViteManifestService<T>
where
T: Clone + Serialize + Send + Sync + 'static,
{
type Error = std::io::Error;
type Response = Response<String>;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync + 'static>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, _req: R) -> Self::Future {
let manifest = self.manifest.clone();
let assets_base = self.assets_base.clone();
let config = self.config.clone();
Box::pin(async move {
// Read the manifest from disk
let manifest = tokio::fs::read(manifest).await?;
// Parse it
let manifest: ViteManifest = serde_json::from_slice(&manifest)
.map_err(|error| std::io::Error::new(std::io::ErrorKind::Other, error))?;
// Render the HTML out of the manifest
let html = manifest
.render(&assets_base, &config)
.map_err(|error| std::io::Error::new(std::io::ErrorKind::Other, error))?;
let mut response = Response::new(html);
response.headers_mut().typed_insert(ContentType::html());
Ok(response)
})
}
}

View File

@ -1,9 +1,23 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::{BTreeSet, HashMap};
use camino::{Utf8Path, Utf8PathBuf};
use thiserror::Error;
#[derive(serde::Deserialize, Debug)]
#[derive(serde::Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub struct ManifestEntry {
#[allow(dead_code)]
@ -13,7 +27,6 @@ pub struct ManifestEntry {
css: Option<Vec<Utf8PathBuf>>,
#[allow(dead_code)]
assets: Option<Vec<Utf8PathBuf>>,
#[allow(dead_code)]
@ -22,73 +35,14 @@ pub struct ManifestEntry {
#[allow(dead_code)]
is_dynamic_entry: Option<bool>,
#[allow(dead_code)]
imports: Option<Vec<Utf8PathBuf>>,
dynamic_imports: Option<Vec<Utf8PathBuf>>,
integrity: Option<String>,
}
/// Render the HTML template
fn template(head: impl Iterator<Item = String>, config: &impl serde::Serialize) -> String {
// This should be kept in sync with `../../../frontend/index.html`
// Render the items to insert in the <head>
let head: String = head.map(|f| format!(" {f}\n")).collect();
// Serialize the config
let config = serde_json::to_string(config).expect("failed to serialize config");
// Script in the <head> which manages the dark mode class on the <html> element
let dark_mode_script = r#"
(function () {
const query = window.matchMedia("(prefers-color-scheme: dark)");
function handleChange(e) {
if (e.matches) {
document.documentElement.classList.add("dark")
} else {
document.documentElement.classList.remove("dark")
}
}
query.addListener(handleChange);
handleChange(query);
})();
"#;
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>matrix-authentication-service</title>
<script>window.APP_CONFIG = {config};</script>
<script>{dark_mode_script}</script>
{head}</head>
<body>
<div id="root"></div>
</body>
</html>"#
)
}
impl ManifestEntry {
/// Get a list of items to insert in the `<head>`
fn head<'a>(&'a self, assets_base: &'a Utf8Path) -> impl Iterator<Item = String> + 'a {
let css = self.css.iter().flat_map(|css| {
css.iter().map(|href| {
let href = assets_base.join(href);
format!(r#"<link rel="stylesheet" href="{href}" />"#)
})
});
let script = assets_base.join(&self.file);
let script = format!(r#"<script type="module" crossorigin src="{script}"></script>"#);
css.chain(std::iter::once(script))
}
}
#[derive(serde::Deserialize, Debug)]
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Manifest {
#[serde(flatten)]
inner: HashMap<Utf8PathBuf, ManifestEntry>,
@ -98,6 +52,8 @@ pub struct Manifest {
enum FileType {
Script,
Stylesheet,
Woff,
Woff2,
}
impl FileType {
@ -105,6 +61,8 @@ impl FileType {
match name.extension() {
Some("css") => Some(Self::Stylesheet),
Some("js") => Some(Self::Script),
Some("woff") => Some(Self::Woff),
Some("woff2") => Some(Self::Woff2),
_ => None,
}
}
@ -112,104 +70,168 @@ impl FileType {
#[derive(Debug, Error)]
#[error("Invalid Vite manifest")]
pub enum InvalidManifest {
#[error("No index.html")]
NoIndex,
pub enum InvalidManifest<'a> {
#[error("Can't find asset for name {name:?}")]
CantFindAssetByName { name: &'a Utf8Path },
#[error("Can't find preloaded entry")]
CantFindPreload,
#[error("Can't find asset for file {file:?}")]
CantFindAssetByFile { file: &'a Utf8Path },
#[error("Invalid file type")]
InvalidFileType,
}
/// Represents an entry which should be preloaded
/// Represents an entry which should be preloaded and included
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct Preload<'name> {
name: &'name Utf8Path,
pub struct Asset<'a> {
file_type: FileType,
name: &'a Utf8Path,
integrity: Option<&'a str>,
}
impl<'a> Preload<'a> {
/// Generate a `<link>` tag for this entry
fn link(&self, assets_base: &Utf8Path) -> String {
let href = assets_base.join(self.name);
impl<'a> Asset<'a> {
fn new(entry: &'a ManifestEntry) -> Result<Self, InvalidManifest<'a>> {
let name = &entry.file;
let integrity = entry.integrity.as_deref();
let file_type = FileType::from_name(name).ok_or(InvalidManifest::InvalidFileType)?;
Ok(Self {
file_type,
name,
integrity,
})
}
fn src(&self, assets_base: &Utf8Path) -> Utf8PathBuf {
assets_base.join(self.name)
}
/// Generate a `<link rel="preload">` tag to preload this entry
pub fn preload_tag(&self, assets_base: &Utf8Path) -> String {
let href = self.src(assets_base);
let integrity = self
.integrity
.map(|i| format!(r#"integrity="{i}" "#))
.unwrap_or_default();
match self.file_type {
FileType::Stylesheet => {
format!(r#"<link rel="preload" href="{href}" as="style" />"#)
format!(r#"<link rel="preload" href="{href}" as="style" crossorigin {integrity}/>"#)
}
FileType::Script => format!(
r#"<link rel="preload" href="{href}" as="script" crossorigin="anonymous" />"#
),
FileType::Script => {
format!(r#"<link rel="modulepreload" href="{href}" crossorigin {integrity}/>"#)
}
FileType::Woff | FileType::Woff2 => {
format!(r#"<link rel="preload" href="{href}" as="font" crossorigin {integrity}/>"#,)
}
}
}
/// Generate a `<link>` or `<script>` tag to include this entry
pub fn include_tag(&self, assets_base: &Utf8Path) -> Option<String> {
let src = self.src(assets_base);
let integrity = self
.integrity
.map(|i| format!(r#"integrity="{i}" "#))
.unwrap_or_default();
match self.file_type {
FileType::Stylesheet => Some(format!(
r#"<link rel="stylesheet" href="{src}" crossorigin {integrity}/>"#
)),
FileType::Script => Some(format!(
r#"<script type="module" src="{src}" crossorigin {integrity}></script>"#
)),
FileType::Woff | FileType::Woff2 => None,
}
}
}
impl Manifest {
/// Render an `index.html` page
/// Find all assets which should be loaded for a given entrypoint
///
/// # Errors
///
/// Returns an error if the manifest is invalid.
pub fn render(
&self,
assets_base: &Utf8Path,
config: &impl serde::Serialize,
) -> Result<String, InvalidManifest> {
let entrypoint = Utf8Path::new("index.html");
let entry = self.inner.get(entrypoint).ok_or(InvalidManifest::NoIndex)?;
// Find the items that should be pre-loaded
let preload = self.find_preload(entrypoint)?;
let head = preload
/// Returns an error if the entrypoint is invalid for this manifest
pub fn assets_for<'a>(
&'a self,
entrypoint: &'a Utf8Path,
) -> Result<BTreeSet<Asset<'a>>, InvalidManifest<'a>> {
let entry = self.lookup_by_name(entrypoint)?;
let main_asset = Asset::new(entry)?;
entry
.css
.iter()
.map(|p| p.link(assets_base))
.chain(entry.head(assets_base));
let html = template(head, config);
Ok(html)
.flatten()
.map(|name| self.lookup_by_file(name).and_then(Asset::new))
.chain(std::iter::once(Ok(main_asset)))
.collect()
}
/// Find entries to preload
/// Find all assets which should be preloaded for a given entrypoint
///
/// # Errors
///
/// Returns an error if the entrypoint is invalid for this manifest
pub fn preload_for<'a>(
&'a self,
entrypoint: &'a Utf8Path,
) -> Result<BTreeSet<Asset<'a>>, InvalidManifest<'a>> {
let entry = self.lookup_by_name(entrypoint)?;
self.find_preload(entry)
}
/// Lookup an entry in the manifest by its original name
fn lookup_by_name<'a>(
&self,
name: &'a Utf8Path,
) -> Result<&ManifestEntry, InvalidManifest<'a>> {
self.inner
.get(name)
.ok_or(InvalidManifest::CantFindAssetByName { name })
}
/// Lookup an entry in the manifest by its output name
fn lookup_by_file<'a>(
&self,
file: &'a Utf8Path,
) -> Result<&ManifestEntry, InvalidManifest<'a>> {
self.inner
.values()
.find(|e| e.file == file)
.ok_or(InvalidManifest::CantFindAssetByFile { file })
}
/// Recursively find all the assets that should be preloaded
fn find_preload<'a>(
&'a self,
entrypoint: &Utf8Path,
) -> Result<BTreeSet<Preload<'a>>, InvalidManifest> {
// TODO: we're preoading the whole tree. We should instead guess which component
// should be loaded based on the route.
entry: &'a ManifestEntry,
) -> Result<BTreeSet<Asset<'a>>, InvalidManifest<'a>> {
let mut entries = BTreeSet::new();
self.find_preload_rec(entrypoint, &mut entries)?;
self.find_preload_rec(entry, &mut entries)?;
Ok(entries)
}
fn find_preload_rec<'a>(
&'a self,
entrypoint: &Utf8Path,
entries: &mut BTreeSet<Preload<'a>>,
) -> Result<(), InvalidManifest> {
let entry = self
.inner
.get(entrypoint)
.ok_or(InvalidManifest::CantFindPreload)?;
let name = &entry.file;
let file_type = FileType::from_name(name).ok_or(InvalidManifest::InvalidFileType)?;
let preload = Preload { name, file_type };
let inserted = entries.insert(preload);
current_entry: &'a ManifestEntry,
entries: &mut BTreeSet<Asset<'a>>,
) -> Result<(), InvalidManifest<'a>> {
let asset = Asset::new(current_entry)?;
let inserted = entries.insert(asset);
// If we inserted the entry, we need to find its dependencies
if inserted {
if let Some(css) = &entry.css {
let file_type = FileType::Stylesheet;
for name in css {
let preload = Preload { name, file_type };
entries.insert(preload);
}
let css = current_entry.css.iter().flatten();
let assets = current_entry.assets.iter().flatten();
for name in css.chain(assets) {
let entry = self.lookup_by_file(name)?;
self.find_preload_rec(entry, entries)?;
}
if let Some(dynamic_imports) = &entry.dynamic_imports {
for import in dynamic_imports {
self.find_preload_rec(import, entries)?;
}
let dynamic_imports = current_entry.dynamic_imports.iter().flatten();
let imports = current_entry.imports.iter().flatten();
for import in dynamic_imports.chain(imports) {
let entry = self.lookup_by_name(import)?;
self.find_preload_rec(entry, entries)?;
}
}

View File

@ -27,3 +27,4 @@ rand = "0.8.5"
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-router = { path = "../router" }
mas-spa = { path = "../spa" }

View File

@ -220,6 +220,45 @@ impl TemplateContext for IndexContext {
}
}
/// Config used by the frontend app
#[derive(Serialize)]
pub struct AppConfig {
root: String,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
root: "/account/".into(),
}
}
}
/// Context used by the `app.html` template
#[derive(Serialize, Default)]
pub struct AppContext {
app_config: AppConfig,
}
impl AppContext {
/// Constructs the context for the app page with the given app root
#[must_use]
pub fn with_app_root(root: String) -> Self {
Self {
app_config: AppConfig { root },
}
}
}
impl TemplateContext for AppContext {
fn sample(_now: chrono::DateTime<Utc>, _rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
vec![Self::default()]
}
}
/// Fields of the login form
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
@ -268,6 +307,9 @@ pub enum PostAuthContextInner {
/// The link
link: Box<UpstreamOAuthLink>,
},
/// Go to the account management page
ManageAccount,
}
/// Context used in login and reauth screens, for the post-auth action to do
@ -580,58 +622,6 @@ where {
}
}
/// Context used by the `account/index.html` template
#[derive(Serialize)]
pub struct AccountContext {
active_sessions: usize,
emails: Vec<UserEmail>,
}
impl AccountContext {
/// Constructs a context for the "my account" page
#[must_use]
pub fn new(active_sessions: usize, emails: Vec<UserEmail>) -> Self {
Self {
active_sessions,
emails,
}
}
}
impl TemplateContext for AccountContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let emails: Vec<UserEmail> = UserEmail::samples(now, rng);
vec![Self::new(5, emails)]
}
}
/// Context used by the `account/emails.html` template
#[derive(Serialize)]
pub struct AccountEmailsContext {
emails: Vec<UserEmail>,
}
impl AccountEmailsContext {
/// Constructs a context for the email management page
#[must_use]
pub fn new(emails: Vec<UserEmail>) -> Self {
Self { emails }
}
}
impl TemplateContext for AccountEmailsContext {
fn sample(now: chrono::DateTime<Utc>, rng: &mut impl Rng) -> Vec<Self>
where
Self: Sized,
{
let emails: Vec<UserEmail> = UserEmail::samples(now, rng);
vec![Self::new(emails)]
}
}
/// Context used by the `emails/verification.{txt,html,subject}` templates
#[derive(Serialize)]
pub struct EmailVerificationContext {

View File

@ -16,18 +16,26 @@
use std::{collections::HashMap, str::FromStr};
use mas_router::{Route, UrlBuilder};
use camino::Utf8Path;
use mas_router::UrlBuilder;
use mas_spa::ViteManifest;
use tera::{helpers::tests::number_args_allowed, Tera, Value};
use url::Url;
pub fn register(tera: &mut Tera, url_builder: UrlBuilder) {
pub fn register(tera: &mut Tera, url_builder: UrlBuilder, vite_manifest: ViteManifest) {
tera.register_tester("empty", self::tester_empty);
tera.register_filter("to_params", filter_to_params);
tera.register_filter("safe_get", filter_safe_get);
tera.register_function("add_params_to_url", function_add_params_to_url);
tera.register_function("merge", function_merge);
tera.register_function("dict", function_dict);
tera.register_function("static_asset", make_static_asset(url_builder));
tera.register_function(
"include_asset",
IncludeAsset {
url_builder,
vite_manifest,
},
);
}
fn tester_empty(value: Option<&Value>, params: &[Value]) -> Result<bool, tera::Error> {
@ -145,25 +153,53 @@ fn function_dict(params: &HashMap<String, Value>) -> Result<Value, tera::Error>
Ok(Value::Object(ret))
}
fn make_static_asset(url_builder: UrlBuilder) -> impl tera::Function {
Box::new(
move |args: &HashMap<String, Value>| -> Result<Value, tera::Error> {
if let Some(path) = args.get("path").and_then(Value::as_str) {
let absolute = args
.get("absolute")
.and_then(Value::as_bool)
.unwrap_or(false);
let path = path.to_owned();
let url = if absolute {
url_builder.static_asset(path).into()
} else {
let destination = mas_router::StaticAsset::new(path);
destination.relative_url().into_owned()
};
Ok(Value::String(url))
} else {
Err(tera::Error::msg("Invalid parameter 'path'"))
}
},
)
struct IncludeAsset {
url_builder: UrlBuilder,
vite_manifest: ViteManifest,
}
impl tera::Function for IncludeAsset {
fn call(&self, args: &HashMap<String, Value>) -> tera::Result<Value> {
let path = args.get("path").ok_or(tera::Error::msg(
"Function `include_asset` was missing parameter `path`",
))?;
let path: &Utf8Path = path
.as_str()
.ok_or_else(|| {
tera::Error::msg(
"Function `include_asset` received an incorrect type for arg `path`",
)
})?
.into();
let assets = self.vite_manifest.assets_for(path).map_err(|e| {
tera::Error::chain(
"Invalid assets manifest while calling function `include_asset`",
e.to_string(),
)
})?;
let preloads = self.vite_manifest.preload_for(path).map_err(|e| {
tera::Error::chain(
"Invalid assets manifest while calling function `include_asset`",
e.to_string(),
)
})?;
let tags: Vec<String> = preloads
.iter()
.map(|asset| asset.preload_tag(self.url_builder.assets_base().into()))
.chain(
assets
.iter()
.filter_map(|asset| asset.include_tag(self.url_builder.assets_base().into())),
)
.collect();
Ok(Value::String(tags.join("\n")))
}
fn is_safe(&self) -> bool {
true
}
}

View File

@ -29,6 +29,7 @@ use std::{collections::HashSet, string::ToString, sync::Arc};
use anyhow::Context as _;
use camino::{Utf8Path, Utf8PathBuf};
use mas_router::UrlBuilder;
use mas_spa::ViteManifest;
use rand::Rng;
use serde::Serialize;
pub use tera::escape_html;
@ -46,12 +47,12 @@ mod macros;
pub use self::{
context::{
AccountContext, AccountEmailsContext, CompatSsoContext, ConsentContext, EmailAddContext,
EmailVerificationContext, EmailVerificationPageContext, EmptyContext, ErrorContext,
FormPostContext, IndexContext, LoginContext, LoginFormField, PolicyViolationContext,
PostAuthContext, PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext,
RegisterFormField, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink, WithCsrf, WithOptionalSession, WithSession,
AppContext, CompatSsoContext, ConsentContext, EmailAddContext, EmailVerificationContext,
EmailVerificationPageContext, EmptyContext, ErrorContext, FormPostContext, IndexContext,
LoginContext, LoginFormField, PolicyViolationContext, PostAuthContext,
PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
WithCsrf, WithOptionalSession, WithSession,
},
forms::{FieldError, FormError, FormField, FormState, ToFormState},
};
@ -61,6 +62,7 @@ pub use self::{
pub struct Templates {
tera: Arc<RwLock<Tera>>,
url_builder: UrlBuilder,
vite_manifest_path: Utf8PathBuf,
path: Utf8PathBuf,
}
@ -71,6 +73,14 @@ pub enum TemplateLoadingError {
#[error(transparent)]
IO(#[from] std::io::Error),
/// Failed to read the assets manifest
#[error("failed to read the assets manifest")]
ViteManifestIO(#[source] std::io::Error),
/// Failed to deserialize the assets manifest
#[error("invalid assets manifest")]
ViteManifest(#[from] serde_json::Error),
/// Some templates failed to compile
#[error("could not load and compile some templates")]
Compile(#[from] TeraError),
@ -106,19 +116,34 @@ impl Templates {
pub async fn load(
path: Utf8PathBuf,
url_builder: UrlBuilder,
vite_manifest_path: Utf8PathBuf,
) -> Result<Self, TemplateLoadingError> {
let tera = Self::load_(&path, url_builder.clone()).await?;
let tera = Self::load_(&path, url_builder.clone(), &vite_manifest_path).await?;
Ok(Self {
tera: Arc::new(RwLock::new(tera)),
path,
url_builder,
vite_manifest_path,
})
}
async fn load_(path: &Utf8Path, url_builder: UrlBuilder) -> Result<Tera, TemplateLoadingError> {
async fn load_(
path: &Utf8Path,
url_builder: UrlBuilder,
vite_manifest_path: &Utf8Path,
) -> Result<Tera, TemplateLoadingError> {
let path = path.to_owned();
let span = tracing::Span::current();
// Read the assets manifest from disk
let vite_manifest = tokio::fs::read(vite_manifest_path)
.await
.map_err(TemplateLoadingError::ViteManifestIO)?;
// Parse it
let vite_manifest: ViteManifest =
serde_json::from_slice(&vite_manifest).map_err(TemplateLoadingError::ViteManifest)?;
// This uses blocking I/Os, do that in a blocking task
let mut tera = tokio::task::spawn_blocking(move || {
span.in_scope(move || {
@ -131,7 +156,7 @@ impl Templates {
})
.await??;
self::functions::register(&mut tera, url_builder);
self::functions::register(&mut tera, url_builder, vite_manifest);
let loaded: HashSet<_> = tera.get_template_names().collect();
let needed: HashSet<_> = TEMPLATES.into_iter().collect();
@ -156,7 +181,12 @@ impl Templates {
)]
pub async fn reload(&self) -> Result<(), TemplateLoadingError> {
// Prepare the new Tera instance
let new_tera = Self::load_(&self.path, self.url_builder.clone()).await?;
let new_tera = Self::load_(
&self.path,
self.url_builder.clone(),
&self.vite_manifest_path,
)
.await?;
// Swap it
*self.tera.write().await = new_tera;
@ -192,6 +222,9 @@ pub enum TemplateError {
}
register_templates! {
/// Render the frontend app
pub fn render_app(AppContext) { "app.html" }
/// Render the login page
pub fn render_login(WithCsrf<LoginContext>) { "pages/login.html" }
@ -210,15 +243,9 @@ register_templates! {
/// Render the home page
pub fn render_index(WithCsrf<WithOptionalSession<IndexContext>>) { "pages/index.html" }
/// Render the account management page
pub fn render_account_index(WithCsrf<WithSession<AccountContext>>) { "pages/account/index.html" }
/// Render the password change page
pub fn render_account_password(WithCsrf<WithSession<EmptyContext>>) { "pages/account/password.html" }
/// Render the emails management
pub fn render_account_emails(WithCsrf<WithSession<AccountEmailsContext>>) { "pages/account/emails/index.html" }
/// Render the email verification page
pub fn render_account_verify_email(WithCsrf<WithSession<EmailVerificationPageContext>>) { "pages/account/emails/verify.html" }
@ -267,15 +294,14 @@ impl Templates {
now: chrono::DateTime<chrono::Utc>,
rng: &mut impl Rng,
) -> anyhow::Result<()> {
check::render_app(self, now, rng).await?;
check::render_login(self, now, rng).await?;
check::render_register(self, now, rng).await?;
check::render_consent(self, now, rng).await?;
check::render_policy_violation(self, now, rng).await?;
check::render_sso_login(self, now, rng).await?;
check::render_index(self, now, rng).await?;
check::render_account_index(self, now, rng).await?;
check::render_account_password(self, now, rng).await?;
check::render_account_emails(self, now, rng).await?;
check::render_account_add_email(self, now, rng).await?;
check::render_account_verify_email(self, now, rng).await?;
check::render_reauth(self, now, rng).await?;
@ -305,8 +331,12 @@ mod tests {
let mut rng = rand::thread_rng();
let path = Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../templates/");
let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None);
let templates = Templates::load(path, url_builder).await.unwrap();
let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None);
let vite_manifest_path =
Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../frontend/dist/manifest.json");
let templates = Templates::load(path, url_builder, vite_manifest_path)
.await
.unwrap();
templates.check_render(now, &mut rng).await.unwrap();
}
}