diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 7da9154c..a2a93a60 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -27,12 +27,14 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte use self::{ config::ConfigCommand, database::DatabaseCommand, manage::ManageCommand, server::ServerCommand, + templates::TemplatesCommand, }; mod config; mod database; mod manage; mod server; +mod templates; #[derive(Clap, Debug)] enum Subcommand { @@ -47,6 +49,9 @@ enum Subcommand { /// Manage the instance Manage(ManageCommand), + + /// Templates-related commands + Templates(TemplatesCommand), } #[derive(Clap, Debug)] @@ -67,6 +72,7 @@ impl RootCommand { Some(S::Database(c)) => c.run(self).await, Some(S::Server(c)) => c.run(self).await, Some(S::Manage(c)) => c.run(self).await, + Some(S::Templates(c)) => c.run(self).await, None => ServerCommand::default().run(self).await, } } diff --git a/crates/cli/src/server.rs b/crates/cli/src/server.rs index f23fdad9..aef413bb 100644 --- a/crates/cli/src/server.rs +++ b/crates/cli/src/server.rs @@ -49,7 +49,8 @@ impl ServerCommand { let pool = config.database.connect().await?; // Load and compile the templates - let templates = Templates::load().context("could not load templates")?; + // TODO: custom template path from the config + let templates = Templates::load(None, true).context("could not load templates")?; // Start the server let root = mas_core::handlers::root(&pool, &templates, &config); diff --git a/crates/cli/src/templates.rs b/crates/cli/src/templates.rs new file mode 100644 index 00000000..c9a3e06e --- /dev/null +++ b/crates/cli/src/templates.rs @@ -0,0 +1,64 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::PathBuf; + +use clap::Clap; +use mas_core::templates::Templates; + +use super::RootCommand; + +#[derive(Clap, Debug)] +pub(super) struct TemplatesCommand { + #[clap(subcommand)] + subcommand: TemplatesSubcommand, +} + +#[derive(Clap, Debug)] +enum TemplatesSubcommand { + /// Save the builtin templates to a folder + Save { + /// Where the templates should be saved + path: PathBuf, + + /// Overwrite existing template files + #[clap(long)] + overwrite: bool, + }, + + /// Check for template validity at given path. + Check { + /// Path where the templates are + path: String, + }, +} + +impl TemplatesCommand { + pub async fn run(&self, _root: &RootCommand) -> anyhow::Result<()> { + use TemplatesSubcommand as SC; + match &self.subcommand { + SC::Save { path, overwrite } => { + Templates::save(path, *overwrite).await?; + + Ok(()) + } + + SC::Check { path } => { + Templates::load(Some(path.clone()), false)?; + + Ok(()) + } + } + } +} diff --git a/crates/core/src/templates.rs b/crates/core/src/templates.rs deleted file mode 100644 index ff439ae9..00000000 --- a/crates/core/src/templates.rs +++ /dev/null @@ -1,305 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{collections::HashSet, string::ToString, sync::Arc}; - -use oauth2_types::errors::OAuth2Error; -use serde::Serialize; -use tera::{Context, Error as TeraError, Tera}; -use thiserror::Error; -use tracing::{debug, info}; -use url::Url; -use warp::reject::Reject; - -use crate::{errors::ErroredForm, filters::CsrfToken, storage::SessionInfo}; - -#[derive(Clone)] -pub struct Templates(Arc); - -#[derive(Error, Debug)] -pub enum TemplateLoadingError { - #[error("could not load and compile some templates")] - Compile(#[from] TeraError), - - #[error("missing templates {missing:?}")] - MissingTemplates { - missing: HashSet, - loaded: HashSet, - }, -} - -impl Templates { - /// Load the templates and check all needed templates are properly loaded - pub fn load() -> Result { - let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR")); - info!(%path, "Loading templates"); - let tera = Tera::new(&path)?; - - let loaded: HashSet<_> = tera.get_template_names().collect(); - let needed: HashSet<_> = std::array::IntoIter::new(TEMPLATES).collect(); - debug!(?loaded, ?needed, "Templates loaded"); - let missing: HashSet<_> = needed.difference(&loaded).collect(); - - if missing.is_empty() { - Ok(Self(Arc::new(tera))) - } else { - let missing = missing.into_iter().map(ToString::to_string).collect(); - let loaded = loaded.into_iter().map(ToString::to_string).collect(); - Err(TemplateLoadingError::MissingTemplates { missing, loaded }) - } - } -} - -#[derive(Error, Debug)] -pub enum TemplateError { - #[error("could not prepare context for template {template:?}")] - Context { - template: &'static str, - #[source] - source: TeraError, - }, - - #[error("could not render template {template:?}")] - Render { - template: &'static str, - #[source] - source: TeraError, - }, -} - -impl Reject for TemplateError {} - -/// Count the number of tokens. Used to have a fixed-sized array for the -/// templates list. -macro_rules! count { - () => (0_usize); - ( $x:tt $($xs:tt)* ) => (1_usize + count!($($xs)*)); -} - -/// Macro that helps generating helper function that renders a specific template -/// with a strongly-typed context. It also register the template in a static -/// array to help detecting missing templates at startup time. -/// -/// The syntax looks almost like a function to confuse syntax highlighter as -/// little as possible. -macro_rules! register_templates { - { - $( - // Match any attribute on the function, such as #[doc], #[allow(dead_code)], etc. - $( #[ $attr:meta ] )* - // The function name - pub fn $name:ident - // Optional list of generics. Taken from - // https://newbedev.com/rust-macro-accepting-type-with-generic-parameters - $(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)? - // Type of context taken by the template - ( $param:ty ) - { - // The name of the template file - $template:expr - } - )* - } => { - /// List of registered templates - static TEMPLATES: [&'static str; count!( $( $template )* )] = [ $( $template ),* ]; - - impl Templates { - $( - $(#[$attr])? - pub fn $name - $(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)? - (&self, context: &$param) - -> Result { - let ctx = Context::from_serialize(context) - .map_err(|source| TemplateError::Context { template: $template, source })?; - - self.0.render($template, &ctx) - .map_err(|source| TemplateError::Render { template: $template, source }) - } - )* - } - }; -} - -register_templates! { - /// Render the login page - pub fn render_login(WithCsrf) { "login.html" } - - /// Render the registration page - pub fn render_register(WithCsrf<()>) { "register.html" } - - /// Render the home page - pub fn render_index(WithCsrf>) { "index.html" } - - /// Render the re-authentication form - pub fn render_reauth(WithCsrf>) { "reauth.html" } - - /// Render the form used by the form_post response mode - pub fn render_form_post(FormPostContext) { "form_post.html" } - - /// Render the HTML error page - pub fn render_error(ErrorContext) { "error.html" } -} - -/// Helper trait to construct context wrappers -pub trait TemplateContext: Sized { - fn with_session(self, current_session: SessionInfo) -> WithSession { - WithSession { - current_session, - inner: self, - } - } - - fn maybe_with_session(self, current_session: Option) -> WithOptionalSession { - WithOptionalSession { - current_session, - inner: self, - } - } - - fn with_csrf(self, token: &CsrfToken) -> WithCsrf { - WithCsrf { - csrf_token: token.form_value(), - inner: self, - } - } -} - -impl TemplateContext for T {} - -/// Context with a CSRF token in it -#[derive(Serialize)] -pub struct WithCsrf { - csrf_token: String, - - #[serde(flatten)] - inner: T, -} - -/// Context with a user session in it -#[derive(Serialize)] -pub struct WithSession { - current_session: SessionInfo, - - #[serde(flatten)] - inner: T, -} - -/// Context with an optional user session in it -#[derive(Serialize)] -pub struct WithOptionalSession { - current_session: Option, - - #[serde(flatten)] - inner: T, -} - -// Context used by the `index.html` template -#[derive(Serialize)] -pub struct IndexContext { - discovery_url: Url, -} - -impl IndexContext { - #[must_use] - pub fn new(discovery_url: Url) -> Self { - Self { discovery_url } - } -} - -#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] -#[serde(rename_all = "kebab-case")] -pub enum LoginFormField { - Username, - Password, -} - -#[derive(Serialize)] -pub struct LoginContext { - form: ErroredForm, -} - -impl LoginContext { - #[must_use] - pub fn with_form_error(form: ErroredForm) -> Self { - Self { form } - } -} - -impl Default for LoginContext { - fn default() -> Self { - Self { - form: ErroredForm::new(), - } - } -} - -/// Context used by the `form_post.html` template -#[derive(Serialize)] -pub struct FormPostContext { - redirect_uri: Url, - params: T, -} - -impl FormPostContext { - pub fn new(redirect_uri: Url, params: T) -> Self { - Self { - redirect_uri, - params, - } - } -} - -#[derive(Default, Serialize)] -pub struct ErrorContext { - code: Option<&'static str>, - description: Option, - details: Option, -} - -impl ErrorContext { - #[must_use] - pub fn new() -> Self { - Self::default() - } - - #[must_use] - pub fn with_code(mut self, code: &'static str) -> Self { - self.code = Some(code); - self - } - - #[must_use] - pub fn with_description(mut self, description: String) -> Self { - self.description = Some(description); - self - } - - #[allow(dead_code)] - #[must_use] - pub fn with_details(mut self, details: String) -> Self { - self.details = Some(details); - self - } -} - -impl From> for ErrorContext { - fn from(err: Box) -> Self { - let mut ctx = ErrorContext::new().with_code(err.error()); - if let Some(desc) = err.description() { - ctx = ctx.with_description(desc); - } - ctx - } -} diff --git a/crates/core/src/templates/context.rs b/crates/core/src/templates/context.rs new file mode 100644 index 00000000..ce38bb9e --- /dev/null +++ b/crates/core/src/templates/context.rs @@ -0,0 +1,177 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use oauth2_types::errors::OAuth2Error; +use serde::Serialize; +use url::Url; + +use crate::{errors::ErroredForm, filters::CsrfToken, storage::SessionInfo}; + +/// Helper trait to construct context wrappers +pub trait TemplateContext: Sized { + fn with_session(self, current_session: SessionInfo) -> WithSession { + WithSession { + current_session, + inner: self, + } + } + + fn maybe_with_session(self, current_session: Option) -> WithOptionalSession { + WithOptionalSession { + current_session, + inner: self, + } + } + + fn with_csrf(self, token: &CsrfToken) -> WithCsrf { + WithCsrf { + csrf_token: token.form_value(), + inner: self, + } + } +} + +impl TemplateContext for () {} +impl TemplateContext for IndexContext {} +impl TemplateContext for LoginContext {} +impl TemplateContext for FormPostContext {} +impl TemplateContext for WithSession {} +impl TemplateContext for WithOptionalSession {} +impl TemplateContext for WithCsrf {} + +/// Context with a CSRF token in it +#[derive(Serialize)] +pub struct WithCsrf { + csrf_token: String, + + #[serde(flatten)] + inner: T, +} + +/// Context with a user session in it +#[derive(Serialize)] +pub struct WithSession { + current_session: SessionInfo, + + #[serde(flatten)] + inner: T, +} + +/// Context with an optional user session in it +#[derive(Serialize)] +pub struct WithOptionalSession { + current_session: Option, + + #[serde(flatten)] + inner: T, +} + +// Context used by the `index.html` template +#[derive(Serialize)] +pub struct IndexContext { + discovery_url: Url, +} + +impl IndexContext { + #[must_use] + pub fn new(discovery_url: Url) -> Self { + Self { discovery_url } + } +} + +#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[serde(rename_all = "kebab-case")] +pub enum LoginFormField { + Username, + Password, +} + +#[derive(Serialize)] +pub struct LoginContext { + form: ErroredForm, +} + +impl LoginContext { + #[must_use] + pub fn with_form_error(form: ErroredForm) -> Self { + Self { form } + } +} + +impl Default for LoginContext { + fn default() -> Self { + Self { + form: ErroredForm::new(), + } + } +} + +/// Context used by the `form_post.html` template +#[derive(Serialize)] +pub struct FormPostContext { + redirect_uri: Url, + params: T, +} + +impl FormPostContext { + pub fn new(redirect_uri: Url, params: T) -> Self { + Self { + redirect_uri, + params, + } + } +} + +#[derive(Default, Serialize)] +pub struct ErrorContext { + code: Option<&'static str>, + description: Option, + details: Option, +} + +impl ErrorContext { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn with_code(mut self, code: &'static str) -> Self { + self.code = Some(code); + self + } + + #[must_use] + pub fn with_description(mut self, description: String) -> Self { + self.description = Some(description); + self + } + + #[allow(dead_code)] + #[must_use] + pub fn with_details(mut self, details: String) -> Self { + self.details = Some(details); + self + } +} + +impl From> for ErrorContext { + fn from(err: Box) -> Self { + let mut ctx = ErrorContext::new().with_code(err.error()); + if let Some(desc) = err.description() { + ctx = ctx.with_description(desc); + } + ctx + } +} diff --git a/crates/core/src/templates/macros.rs b/crates/core/src/templates/macros.rs new file mode 100644 index 00000000..0905ca53 --- /dev/null +++ b/crates/core/src/templates/macros.rs @@ -0,0 +1,77 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// Count the number of tokens. Used to have a fixed-sized array for the +/// templates list. +macro_rules! count { + () => (0_usize); + ( $x:tt $($xs:tt)* ) => (1_usize + count!($($xs)*)); +} + +/// Macro that helps generating helper function that renders a specific template +/// with a strongly-typed context. It also register the template in a static +/// array to help detecting missing templates at startup time. +/// +/// The syntax looks almost like a function to confuse syntax highlighter as +/// little as possible. +#[macro_export] +macro_rules! register_templates { + { + $( + extra = { $( $extra_template:expr ),* }; + )? + + $( + // Match any attribute on the function, such as #[doc], #[allow(dead_code)], etc. + $( #[ $attr:meta ] )* + // The function name + pub fn $name:ident + // Optional list of generics. Taken from + // https://newbedev.com/rust-macro-accepting-type-with-generic-parameters + $(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)? + // Type of context taken by the template + ( $param:ty ) + { + // The name of the template file + $template:expr + } + )* + } => { + /// List of registered templates + static TEMPLATES: [(&'static str, &'static str); count!( $( $template )* )] = [ + $( ($template, include_str!(concat!("res/", $template))) ),* + ]; + + /// List of extra templates used by other templates + static EXTRA_TEMPLATES: [(&'static str, &'static str); count!( $( $( $extra_template )* )? )] = [ + $( $( ($extra_template, include_str!(concat!("res/", $extra_template))) ),* )? + ]; + + impl Templates { + $( + $(#[$attr])? + pub fn $name + $(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)? + (&self, context: &$param) + -> Result { + let ctx = Context::from_serialize(context) + .map_err(|source| TemplateError::Context { template: $template, source })?; + + self.0.render($template, &ctx) + .map_err(|source| TemplateError::Render { template: $template, source }) + } + )* + } + }; +} diff --git a/crates/core/src/templates/mod.rs b/crates/core/src/templates/mod.rs new file mode 100644 index 00000000..cb35821b --- /dev/null +++ b/crates/core/src/templates/mod.rs @@ -0,0 +1,179 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashSet, io::Cursor, path::Path, string::ToString, sync::Arc}; + +use anyhow::Context as _; +use serde::Serialize; +use tera::{Context, Error as TeraError, Tera}; +use thiserror::Error; +use tokio::{fs::OpenOptions, io::AsyncWriteExt}; +use tracing::{debug, info, warn}; +use warp::reject::Reject; + +mod context; +#[macro_use] +mod macros; + +pub use self::context::{ + ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, TemplateContext, + WithCsrf, WithOptionalSession, WithSession, +}; + +#[derive(Clone)] +pub struct Templates(Arc); + +#[derive(Error, Debug)] +pub enum TemplateLoadingError { + #[error("could not load and compile some templates")] + Compile(#[from] TeraError), + + #[error("missing templates {missing:?}")] + MissingTemplates { + missing: HashSet, + loaded: HashSet, + }, +} + +impl Templates { + /// Load the templates and check all needed templates are properly loaded + /// + /// # Arguments + /// + /// * `path` - An optional path to where templates should be loaded + /// * `builtin` - Set to `true` to load the builtin templates as well + pub fn load(path: Option, builtin: bool) -> Result { + let tera = { + let mut tera = Tera::default(); + + if builtin { + info!("Loading builtin templates"); + + for (name, source) in EXTRA_TEMPLATES { + tera.add_raw_template(name, source)?; + } + + for (name, source) in TEMPLATES { + tera.add_raw_template(name, source)?; + } + } + + if let Some(path) = path { + let path = format!("{}/**/*.{{html,txt}}", path); + info!(%path, "Loading templates from filesystem"); + tera.extend(&Tera::parse(&path)?)?; + } + + tera.build_inheritance_chains()?; + tera.check_macro_files()?; + + tera + }; + + let loaded: HashSet<_> = tera.get_template_names().collect(); + let needed: HashSet<_> = std::array::IntoIter::new(TEMPLATES) + .map(|(name, _)| name) + .collect(); + debug!(?loaded, ?needed, "Templates loaded"); + let missing: HashSet<_> = needed.difference(&loaded).collect(); + + if missing.is_empty() { + Ok(Self(Arc::new(tera))) + } else { + let missing = missing.into_iter().map(ToString::to_string).collect(); + let loaded = loaded.into_iter().map(ToString::to_string).collect(); + Err(TemplateLoadingError::MissingTemplates { missing, loaded }) + } + } + + /// Save the builtin templates to a folder + pub async fn save(path: &Path, overwrite: bool) -> anyhow::Result<()> { + tokio::fs::create_dir_all(&path) + .await + .context("could not create destination folder")?; + + let templates = std::array::IntoIter::new(TEMPLATES).chain(EXTRA_TEMPLATES); + + let mut options = OpenOptions::new(); + if overwrite { + options.create(true).truncate(true).write(true); + } else { + // With the `create_new` flag, `open` fails with an `AlreadyExists` error to + // avoid overwriting + options.create_new(true).write(true); + }; + + for (name, source) in templates { + let path = path.join(name); + + let mut file = match options.open(&path).await { + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { + // Not overwriting a template is a soft error + warn!(?path, "Not overwriting template"); + continue; + } + x => x.context(format!("could not open file {:?}", path))?, + }; + + let mut buffer = Cursor::new(source); + file.write_all_buf(&mut buffer) + .await + .context(format!("could not write file {:?}", path))?; + info!(?path, "Wrote template"); + } + + Ok(()) + } +} + +#[derive(Error, Debug)] +pub enum TemplateError { + #[error("could not prepare context for template {template:?}")] + Context { + template: &'static str, + #[source] + source: TeraError, + }, + + #[error("could not render template {template:?}")] + Render { + template: &'static str, + #[source] + source: TeraError, + }, +} + +impl Reject for TemplateError {} + +register_templates! { + extra = { "base.html" }; + + /// Render the login page + pub fn render_login(WithCsrf) { "login.html" } + + /// Render the registration page + pub fn render_register(WithCsrf<()>) { "register.html" } + + /// Render the home page + pub fn render_index(WithCsrf>) { "index.html" } + + /// Render the re-authentication form + pub fn render_reauth(WithCsrf>) { "reauth.html" } + + /// Render the form used by the form_post response mode + pub fn render_form_post(FormPostContext) { "form_post.html" } + + /// Render the HTML error page + pub fn render_error(ErrorContext) { "error.html" } +} diff --git a/crates/core/templates/base.html b/crates/core/src/templates/res/base.html similarity index 100% rename from crates/core/templates/base.html rename to crates/core/src/templates/res/base.html diff --git a/crates/core/templates/error.html b/crates/core/src/templates/res/error.html similarity index 100% rename from crates/core/templates/error.html rename to crates/core/src/templates/res/error.html diff --git a/crates/core/templates/error.txt b/crates/core/src/templates/res/error.txt similarity index 100% rename from crates/core/templates/error.txt rename to crates/core/src/templates/res/error.txt diff --git a/crates/core/templates/form_post.html b/crates/core/src/templates/res/form_post.html similarity index 100% rename from crates/core/templates/form_post.html rename to crates/core/src/templates/res/form_post.html diff --git a/crates/core/templates/index.html b/crates/core/src/templates/res/index.html similarity index 100% rename from crates/core/templates/index.html rename to crates/core/src/templates/res/index.html diff --git a/crates/core/templates/login.html b/crates/core/src/templates/res/login.html similarity index 100% rename from crates/core/templates/login.html rename to crates/core/src/templates/res/login.html diff --git a/crates/core/templates/reauth.html b/crates/core/src/templates/res/reauth.html similarity index 100% rename from crates/core/templates/reauth.html rename to crates/core/src/templates/res/reauth.html diff --git a/crates/core/templates/register.html b/crates/core/src/templates/res/register.html similarity index 100% rename from crates/core/templates/register.html rename to crates/core/src/templates/res/register.html