1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Embed templates in binary & add command to export them

This commit is contained in:
Quentin Gliech
2021-09-16 23:39:07 +02:00
parent e44197a2cc
commit 76c69485e9
15 changed files with 505 additions and 306 deletions

View File

@ -27,12 +27,14 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte
use self::{
config::ConfigCommand, database::DatabaseCommand, manage::ManageCommand, server::ServerCommand,
templates::TemplatesCommand,
};
mod config;
mod database;
mod manage;
mod server;
mod templates;
#[derive(Clap, Debug)]
enum Subcommand {
@ -47,6 +49,9 @@ enum Subcommand {
/// Manage the instance
Manage(ManageCommand),
/// Templates-related commands
Templates(TemplatesCommand),
}
#[derive(Clap, Debug)]
@ -67,6 +72,7 @@ impl RootCommand {
Some(S::Database(c)) => c.run(self).await,
Some(S::Server(c)) => c.run(self).await,
Some(S::Manage(c)) => c.run(self).await,
Some(S::Templates(c)) => c.run(self).await,
None => ServerCommand::default().run(self).await,
}
}

View File

@ -49,7 +49,8 @@ impl ServerCommand {
let pool = config.database.connect().await?;
// Load and compile the templates
let templates = Templates::load().context("could not load templates")?;
// TODO: custom template path from the config
let templates = Templates::load(None, true).context("could not load templates")?;
// Start the server
let root = mas_core::handlers::root(&pool, &templates, &config);

View File

@ -0,0 +1,64 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::PathBuf;
use clap::Clap;
use mas_core::templates::Templates;
use super::RootCommand;
#[derive(Clap, Debug)]
pub(super) struct TemplatesCommand {
#[clap(subcommand)]
subcommand: TemplatesSubcommand,
}
#[derive(Clap, Debug)]
enum TemplatesSubcommand {
/// Save the builtin templates to a folder
Save {
/// Where the templates should be saved
path: PathBuf,
/// Overwrite existing template files
#[clap(long)]
overwrite: bool,
},
/// Check for template validity at given path.
Check {
/// Path where the templates are
path: String,
},
}
impl TemplatesCommand {
pub async fn run(&self, _root: &RootCommand) -> anyhow::Result<()> {
use TemplatesSubcommand as SC;
match &self.subcommand {
SC::Save { path, overwrite } => {
Templates::save(path, *overwrite).await?;
Ok(())
}
SC::Check { path } => {
Templates::load(Some(path.clone()), false)?;
Ok(())
}
}
}
}

View File

@ -1,305 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, string::ToString, sync::Arc};
use oauth2_types::errors::OAuth2Error;
use serde::Serialize;
use tera::{Context, Error as TeraError, Tera};
use thiserror::Error;
use tracing::{debug, info};
use url::Url;
use warp::reject::Reject;
use crate::{errors::ErroredForm, filters::CsrfToken, storage::SessionInfo};
#[derive(Clone)]
pub struct Templates(Arc<Tera>);
#[derive(Error, Debug)]
pub enum TemplateLoadingError {
#[error("could not load and compile some templates")]
Compile(#[from] TeraError),
#[error("missing templates {missing:?}")]
MissingTemplates {
missing: HashSet<String>,
loaded: HashSet<String>,
},
}
impl Templates {
/// Load the templates and check all needed templates are properly loaded
pub fn load() -> Result<Self, TemplateLoadingError> {
let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR"));
info!(%path, "Loading templates");
let tera = Tera::new(&path)?;
let loaded: HashSet<_> = tera.get_template_names().collect();
let needed: HashSet<_> = std::array::IntoIter::new(TEMPLATES).collect();
debug!(?loaded, ?needed, "Templates loaded");
let missing: HashSet<_> = needed.difference(&loaded).collect();
if missing.is_empty() {
Ok(Self(Arc::new(tera)))
} else {
let missing = missing.into_iter().map(ToString::to_string).collect();
let loaded = loaded.into_iter().map(ToString::to_string).collect();
Err(TemplateLoadingError::MissingTemplates { missing, loaded })
}
}
}
#[derive(Error, Debug)]
pub enum TemplateError {
#[error("could not prepare context for template {template:?}")]
Context {
template: &'static str,
#[source]
source: TeraError,
},
#[error("could not render template {template:?}")]
Render {
template: &'static str,
#[source]
source: TeraError,
},
}
impl Reject for TemplateError {}
/// Count the number of tokens. Used to have a fixed-sized array for the
/// templates list.
macro_rules! count {
() => (0_usize);
( $x:tt $($xs:tt)* ) => (1_usize + count!($($xs)*));
}
/// Macro that helps generating helper function that renders a specific template
/// with a strongly-typed context. It also register the template in a static
/// array to help detecting missing templates at startup time.
///
/// The syntax looks almost like a function to confuse syntax highlighter as
/// little as possible.
macro_rules! register_templates {
{
$(
// Match any attribute on the function, such as #[doc], #[allow(dead_code)], etc.
$( #[ $attr:meta ] )*
// The function name
pub fn $name:ident
// Optional list of generics. Taken from
// https://newbedev.com/rust-macro-accepting-type-with-generic-parameters
$(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)?
// Type of context taken by the template
( $param:ty )
{
// The name of the template file
$template:expr
}
)*
} => {
/// List of registered templates
static TEMPLATES: [&'static str; count!( $( $template )* )] = [ $( $template ),* ];
impl Templates {
$(
$(#[$attr])?
pub fn $name
$(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)?
(&self, context: &$param)
-> Result<String, TemplateError> {
let ctx = Context::from_serialize(context)
.map_err(|source| TemplateError::Context { template: $template, source })?;
self.0.render($template, &ctx)
.map_err(|source| TemplateError::Render { template: $template, source })
}
)*
}
};
}
register_templates! {
/// Render the login page
pub fn render_login(WithCsrf<LoginContext>) { "login.html" }
/// Render the registration page
pub fn render_register(WithCsrf<()>) { "register.html" }
/// Render the home page
pub fn render_index(WithCsrf<WithOptionalSession<IndexContext>>) { "index.html" }
/// Render the re-authentication form
pub fn render_reauth(WithCsrf<WithSession<()>>) { "reauth.html" }
/// Render the form used by the form_post response mode
pub fn render_form_post<T: Serialize>(FormPostContext<T>) { "form_post.html" }
/// Render the HTML error page
pub fn render_error(ErrorContext) { "error.html" }
}
/// Helper trait to construct context wrappers
pub trait TemplateContext: Sized {
fn with_session(self, current_session: SessionInfo) -> WithSession<Self> {
WithSession {
current_session,
inner: self,
}
}
fn maybe_with_session(self, current_session: Option<SessionInfo>) -> WithOptionalSession<Self> {
WithOptionalSession {
current_session,
inner: self,
}
}
fn with_csrf(self, token: &CsrfToken) -> WithCsrf<Self> {
WithCsrf {
csrf_token: token.form_value(),
inner: self,
}
}
}
impl<T: Sized> TemplateContext for T {}
/// Context with a CSRF token in it
#[derive(Serialize)]
pub struct WithCsrf<T> {
csrf_token: String,
#[serde(flatten)]
inner: T,
}
/// Context with a user session in it
#[derive(Serialize)]
pub struct WithSession<T> {
current_session: SessionInfo,
#[serde(flatten)]
inner: T,
}
/// Context with an optional user session in it
#[derive(Serialize)]
pub struct WithOptionalSession<T> {
current_session: Option<SessionInfo>,
#[serde(flatten)]
inner: T,
}
// Context used by the `index.html` template
#[derive(Serialize)]
pub struct IndexContext {
discovery_url: Url,
}
impl IndexContext {
#[must_use]
pub fn new(discovery_url: Url) -> Self {
Self { discovery_url }
}
}
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum LoginFormField {
Username,
Password,
}
#[derive(Serialize)]
pub struct LoginContext {
form: ErroredForm<LoginFormField>,
}
impl LoginContext {
#[must_use]
pub fn with_form_error(form: ErroredForm<LoginFormField>) -> Self {
Self { form }
}
}
impl Default for LoginContext {
fn default() -> Self {
Self {
form: ErroredForm::new(),
}
}
}
/// Context used by the `form_post.html` template
#[derive(Serialize)]
pub struct FormPostContext<T> {
redirect_uri: Url,
params: T,
}
impl<T> FormPostContext<T> {
pub fn new(redirect_uri: Url, params: T) -> Self {
Self {
redirect_uri,
params,
}
}
}
#[derive(Default, Serialize)]
pub struct ErrorContext {
code: Option<&'static str>,
description: Option<String>,
details: Option<String>,
}
impl ErrorContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_code(mut self, code: &'static str) -> Self {
self.code = Some(code);
self
}
#[must_use]
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
#[allow(dead_code)]
#[must_use]
pub fn with_details(mut self, details: String) -> Self {
self.details = Some(details);
self
}
}
impl From<Box<dyn OAuth2Error>> for ErrorContext {
fn from(err: Box<dyn OAuth2Error>) -> Self {
let mut ctx = ErrorContext::new().with_code(err.error());
if let Some(desc) = err.description() {
ctx = ctx.with_description(desc);
}
ctx
}
}

View File

@ -0,0 +1,177 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use oauth2_types::errors::OAuth2Error;
use serde::Serialize;
use url::Url;
use crate::{errors::ErroredForm, filters::CsrfToken, storage::SessionInfo};
/// Helper trait to construct context wrappers
pub trait TemplateContext: Sized {
fn with_session(self, current_session: SessionInfo) -> WithSession<Self> {
WithSession {
current_session,
inner: self,
}
}
fn maybe_with_session(self, current_session: Option<SessionInfo>) -> WithOptionalSession<Self> {
WithOptionalSession {
current_session,
inner: self,
}
}
fn with_csrf(self, token: &CsrfToken) -> WithCsrf<Self> {
WithCsrf {
csrf_token: token.form_value(),
inner: self,
}
}
}
impl TemplateContext for () {}
impl TemplateContext for IndexContext {}
impl TemplateContext for LoginContext {}
impl<T: Sized> TemplateContext for FormPostContext<T> {}
impl<T: Sized> TemplateContext for WithSession<T> {}
impl<T: Sized> TemplateContext for WithOptionalSession<T> {}
impl<T: Sized> TemplateContext for WithCsrf<T> {}
/// Context with a CSRF token in it
#[derive(Serialize)]
pub struct WithCsrf<T> {
csrf_token: String,
#[serde(flatten)]
inner: T,
}
/// Context with a user session in it
#[derive(Serialize)]
pub struct WithSession<T> {
current_session: SessionInfo,
#[serde(flatten)]
inner: T,
}
/// Context with an optional user session in it
#[derive(Serialize)]
pub struct WithOptionalSession<T> {
current_session: Option<SessionInfo>,
#[serde(flatten)]
inner: T,
}
// Context used by the `index.html` template
#[derive(Serialize)]
pub struct IndexContext {
discovery_url: Url,
}
impl IndexContext {
#[must_use]
pub fn new(discovery_url: Url) -> Self {
Self { discovery_url }
}
}
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum LoginFormField {
Username,
Password,
}
#[derive(Serialize)]
pub struct LoginContext {
form: ErroredForm<LoginFormField>,
}
impl LoginContext {
#[must_use]
pub fn with_form_error(form: ErroredForm<LoginFormField>) -> Self {
Self { form }
}
}
impl Default for LoginContext {
fn default() -> Self {
Self {
form: ErroredForm::new(),
}
}
}
/// Context used by the `form_post.html` template
#[derive(Serialize)]
pub struct FormPostContext<T> {
redirect_uri: Url,
params: T,
}
impl<T> FormPostContext<T> {
pub fn new(redirect_uri: Url, params: T) -> Self {
Self {
redirect_uri,
params,
}
}
}
#[derive(Default, Serialize)]
pub struct ErrorContext {
code: Option<&'static str>,
description: Option<String>,
details: Option<String>,
}
impl ErrorContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_code(mut self, code: &'static str) -> Self {
self.code = Some(code);
self
}
#[must_use]
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
#[allow(dead_code)]
#[must_use]
pub fn with_details(mut self, details: String) -> Self {
self.details = Some(details);
self
}
}
impl From<Box<dyn OAuth2Error>> for ErrorContext {
fn from(err: Box<dyn OAuth2Error>) -> Self {
let mut ctx = ErrorContext::new().with_code(err.error());
if let Some(desc) = err.description() {
ctx = ctx.with_description(desc);
}
ctx
}
}

View File

@ -0,0 +1,77 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// Count the number of tokens. Used to have a fixed-sized array for the
/// templates list.
macro_rules! count {
() => (0_usize);
( $x:tt $($xs:tt)* ) => (1_usize + count!($($xs)*));
}
/// Macro that helps generating helper function that renders a specific template
/// with a strongly-typed context. It also register the template in a static
/// array to help detecting missing templates at startup time.
///
/// The syntax looks almost like a function to confuse syntax highlighter as
/// little as possible.
#[macro_export]
macro_rules! register_templates {
{
$(
extra = { $( $extra_template:expr ),* };
)?
$(
// Match any attribute on the function, such as #[doc], #[allow(dead_code)], etc.
$( #[ $attr:meta ] )*
// The function name
pub fn $name:ident
// Optional list of generics. Taken from
// https://newbedev.com/rust-macro-accepting-type-with-generic-parameters
$(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)?
// Type of context taken by the template
( $param:ty )
{
// The name of the template file
$template:expr
}
)*
} => {
/// List of registered templates
static TEMPLATES: [(&'static str, &'static str); count!( $( $template )* )] = [
$( ($template, include_str!(concat!("res/", $template))) ),*
];
/// List of extra templates used by other templates
static EXTRA_TEMPLATES: [(&'static str, &'static str); count!( $( $( $extra_template )* )? )] = [
$( $( ($extra_template, include_str!(concat!("res/", $extra_template))) ),* )?
];
impl Templates {
$(
$(#[$attr])?
pub fn $name
$(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)?
(&self, context: &$param)
-> Result<String, TemplateError> {
let ctx = Context::from_serialize(context)
.map_err(|source| TemplateError::Context { template: $template, source })?;
self.0.render($template, &ctx)
.map_err(|source| TemplateError::Render { template: $template, source })
}
)*
}
};
}

View File

@ -0,0 +1,179 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, io::Cursor, path::Path, string::ToString, sync::Arc};
use anyhow::Context as _;
use serde::Serialize;
use tera::{Context, Error as TeraError, Tera};
use thiserror::Error;
use tokio::{fs::OpenOptions, io::AsyncWriteExt};
use tracing::{debug, info, warn};
use warp::reject::Reject;
mod context;
#[macro_use]
mod macros;
pub use self::context::{
ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, TemplateContext,
WithCsrf, WithOptionalSession, WithSession,
};
#[derive(Clone)]
pub struct Templates(Arc<Tera>);
#[derive(Error, Debug)]
pub enum TemplateLoadingError {
#[error("could not load and compile some templates")]
Compile(#[from] TeraError),
#[error("missing templates {missing:?}")]
MissingTemplates {
missing: HashSet<String>,
loaded: HashSet<String>,
},
}
impl Templates {
/// Load the templates and check all needed templates are properly loaded
///
/// # Arguments
///
/// * `path` - An optional path to where templates should be loaded
/// * `builtin` - Set to `true` to load the builtin templates as well
pub fn load(path: Option<String>, builtin: bool) -> Result<Self, TemplateLoadingError> {
let tera = {
let mut tera = Tera::default();
if builtin {
info!("Loading builtin templates");
for (name, source) in EXTRA_TEMPLATES {
tera.add_raw_template(name, source)?;
}
for (name, source) in TEMPLATES {
tera.add_raw_template(name, source)?;
}
}
if let Some(path) = path {
let path = format!("{}/**/*.{{html,txt}}", path);
info!(%path, "Loading templates from filesystem");
tera.extend(&Tera::parse(&path)?)?;
}
tera.build_inheritance_chains()?;
tera.check_macro_files()?;
tera
};
let loaded: HashSet<_> = tera.get_template_names().collect();
let needed: HashSet<_> = std::array::IntoIter::new(TEMPLATES)
.map(|(name, _)| name)
.collect();
debug!(?loaded, ?needed, "Templates loaded");
let missing: HashSet<_> = needed.difference(&loaded).collect();
if missing.is_empty() {
Ok(Self(Arc::new(tera)))
} else {
let missing = missing.into_iter().map(ToString::to_string).collect();
let loaded = loaded.into_iter().map(ToString::to_string).collect();
Err(TemplateLoadingError::MissingTemplates { missing, loaded })
}
}
/// Save the builtin templates to a folder
pub async fn save(path: &Path, overwrite: bool) -> anyhow::Result<()> {
tokio::fs::create_dir_all(&path)
.await
.context("could not create destination folder")?;
let templates = std::array::IntoIter::new(TEMPLATES).chain(EXTRA_TEMPLATES);
let mut options = OpenOptions::new();
if overwrite {
options.create(true).truncate(true).write(true);
} else {
// With the `create_new` flag, `open` fails with an `AlreadyExists` error to
// avoid overwriting
options.create_new(true).write(true);
};
for (name, source) in templates {
let path = path.join(name);
let mut file = match options.open(&path).await {
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
// Not overwriting a template is a soft error
warn!(?path, "Not overwriting template");
continue;
}
x => x.context(format!("could not open file {:?}", path))?,
};
let mut buffer = Cursor::new(source);
file.write_all_buf(&mut buffer)
.await
.context(format!("could not write file {:?}", path))?;
info!(?path, "Wrote template");
}
Ok(())
}
}
#[derive(Error, Debug)]
pub enum TemplateError {
#[error("could not prepare context for template {template:?}")]
Context {
template: &'static str,
#[source]
source: TeraError,
},
#[error("could not render template {template:?}")]
Render {
template: &'static str,
#[source]
source: TeraError,
},
}
impl Reject for TemplateError {}
register_templates! {
extra = { "base.html" };
/// Render the login page
pub fn render_login(WithCsrf<LoginContext>) { "login.html" }
/// Render the registration page
pub fn render_register(WithCsrf<()>) { "register.html" }
/// Render the home page
pub fn render_index(WithCsrf<WithOptionalSession<IndexContext>>) { "index.html" }
/// Render the re-authentication form
pub fn render_reauth(WithCsrf<WithSession<()>>) { "reauth.html" }
/// Render the form used by the form_post response mode
pub fn render_form_post<T: Serialize>(FormPostContext<T>) { "form_post.html" }
/// Render the HTML error page
pub fn render_error(ErrorContext) { "error.html" }
}