diff --git a/crates/static-files/src/lib.rs b/crates/static-files/src/lib.rs index d62414d2..f5fd2c40 100644 --- a/crates/static-files/src/lib.rs +++ b/crates/static-files/src/lib.rs @@ -24,11 +24,11 @@ use std::{ }; use axum::{ - body::{boxed, Full}, response::{IntoResponse, Response}, + TypedHeader, }; -use headers::{ContentLength, ContentType, HeaderMapExt}; -use http::{Request, StatusCode}; +use headers::{ContentLength, ContentType, ETag, HeaderMapExt, IfNoneMatch}; +use http::{Method, Request, StatusCode}; use rust_embed::RustEmbed; use tower::Service; @@ -40,15 +40,49 @@ use tower::Service; pub struct Assets; impl Assets { - fn get_response(path: &str) -> Option { + fn get_response( + is_head: bool, + path: &str, + if_none_match: Option, + ) -> Option { let asset = Self::get(path)?; + let etag: String = asset + .metadata + .sha256_hash() + .iter() + .map(|x| format!("{:02x}", x)) + .collect(); + let etag: ETag = format!("\"{}\"", etag).parse().unwrap(); + + if let Some(if_none_match) = if_none_match { + if if_none_match.precondition_passes(&etag) { + return Some(StatusCode::NOT_MODIFIED.into_response()); + } + } + let len = asset.data.len().try_into().unwrap(); let mime = mime_guess::from_path(path).first_or_octet_stream(); - let mut res = Response::new(boxed(Full::from(asset.data))); - res.headers_mut().typed_insert(ContentType::from(mime)); - res.headers_mut().typed_insert(ContentLength(len)); + let res = if is_head { + ( + StatusCode::OK, + TypedHeader(ContentType::from(mime)), + TypedHeader(ContentLength(len)), + TypedHeader(etag), + ) + .into_response() + } else { + ( + StatusCode::OK, + TypedHeader(ContentType::from(mime)), + TypedHeader(ContentLength(len)), + TypedHeader(etag), + asset.data, + ) + .into_response() + }; + Some(res) } } @@ -67,11 +101,16 @@ impl Service> for Assets { fn call(&mut self, req: Request) -> Self::Future { let path = req.uri().path().trim_start_matches('/'); - // TODO: support HEAD requests - // TODO: support ETag + let if_none_match = req.headers().typed_get(); + let is_head = match *req.method() { + Method::GET => false, + Method::HEAD => true, + _ => return ready(Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())), + }; + // TODO: support range requests - let response = - Self::get_response(path).unwrap_or_else(|| StatusCode::NOT_FOUND.into_response()); + let response = Self::get_response(is_head, path, if_none_match) + .unwrap_or_else(|| StatusCode::NOT_FOUND.into_response()); ready(Ok(response)) } }