diff --git a/Cargo.lock b/Cargo.lock index 31efe91f..f7e44ccc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2697,6 +2697,7 @@ dependencies = [ "headers", "http", "http-body", + "icu_locid", "mas-data-model", "mas-http", "mas-iana", diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 18754121..6a3d6dee 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -17,6 +17,7 @@ futures-util = "0.3.28" headers = "0.3.9" http.workspace = true http-body = "0.4.5" +icu_locid = "1.3.0" mime = "0.3.17" rand.workspace = true sentry = { version = "0.31.7", default-features = false } diff --git a/crates/axum-utils/src/language_detection.rs b/crates/axum-utils/src/language_detection.rs new file mode 100644 index 00000000..8b907f60 --- /dev/null +++ b/crates/axum-utils/src/language_detection.rs @@ -0,0 +1,283 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Reverse; + +use headers::{Error, Header}; +use http::{header::ACCEPT_LANGUAGE, HeaderName, HeaderValue}; +use icu_locid::Locale; + +#[derive(PartialEq, Eq, Debug)] +struct AcceptLanguagePart { + // None means * + locale: Option, + + // Quality is between 0 and 1 with 3 decimal places + // Which we map from 0 to 1000, e.g. 0.5 becomes 500 + quality: u16, +} + +impl PartialOrd for AcceptLanguagePart { + fn partial_cmp(&self, other: &Self) -> Option { + // When comparing two AcceptLanguage structs, we only consider the + // quality, in reverse. + Reverse(self.quality).partial_cmp(&Reverse(other.quality)) + } +} + +impl Ord for AcceptLanguagePart { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // When comparing two AcceptLanguage structs, we only consider the + // quality, in reverse. + Reverse(self.quality).cmp(&Reverse(other.quality)) + } +} + +/// A header that represents the `Accept-Language` header. +#[derive(PartialEq, Eq, Debug)] +pub struct AcceptLanguage { + parts: Vec, +} + +/// Utility to trim ASCII whitespace from the start and end of a byte slice +const fn trim_bytes(mut bytes: &[u8]) -> &[u8] { + // Trim leading and trailing whitespace + while let [first, rest @ ..] = bytes { + if first.is_ascii_whitespace() { + bytes = rest; + } else { + break; + } + } + + while let [rest @ .., last] = bytes { + if last.is_ascii_whitespace() { + bytes = rest; + } else { + break; + } + } + + bytes +} + +impl Header for AcceptLanguage { + fn name() -> &'static HeaderName { + &ACCEPT_LANGUAGE + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let mut parts = Vec::new(); + for value in values { + for part in value.as_bytes().split(|b| *b == b',') { + let mut it = part.split(|b| *b == b';'); + let locale = it.next().ok_or(Error::invalid())?; + let locale = trim_bytes(locale); + + let locale = match locale { + b"*" => None, + locale => { + let locale = + Locale::try_from_bytes(locale).map_err(|_e| Error::invalid())?; + Some(locale) + } + }; + + let quality = if let Some(quality) = it.next() { + let quality = trim_bytes(quality); + let quality = quality.strip_prefix(b"q=").ok_or(Error::invalid())?; + let quality = std::str::from_utf8(quality).map_err(|_e| Error::invalid())?; + let quality = quality.parse::().map_err(|_e| Error::invalid())?; + // Bound the quality between 0 and 1 + let quality = quality.min(1_f64).max(0_f64); + + // Make sure the iterator is empty + if it.next().is_some() { + return Err(Error::invalid()); + } + + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + { + f64::round(quality * 1000_f64) as u16 + } + } else { + 1000 + }; + + parts.push(AcceptLanguagePart { locale, quality }); + } + } + + parts.sort(); + + Ok(AcceptLanguage { parts }) + } + + fn encode>(&self, values: &mut E) { + let mut value = String::new(); + let mut first = true; + for part in &self.parts { + if first { + first = false; + } else { + value.push_str(", "); + } + + if let Some(locale) = &part.locale { + value.push_str(&locale.to_string()); + } else { + value.push('*'); + } + + if part.quality != 1000 { + value.push_str(";q="); + value.push_str(&(f64::from(part.quality) / 1000_f64).to_string()); + } + } + + // We know this is safe because we only use ASCII characters + values.extend(Some(HeaderValue::from_str(&value).unwrap())); + } +} + +#[cfg(test)] +mod tests { + use headers::HeaderMapExt; + use http::{header::ACCEPT_LANGUAGE, HeaderMap, HeaderValue}; + use icu_locid::locale; + + use super::*; + + #[test] + fn test_decode() { + let headers = HeaderMap::from_iter([( + ACCEPT_LANGUAGE, + HeaderValue::from_str("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5").unwrap(), + )]); + + let accept_language: Option = headers.typed_get(); + assert!(accept_language.is_some()); + let accept_language = accept_language.unwrap(); + + assert_eq!( + accept_language, + AcceptLanguage { + parts: vec![ + AcceptLanguagePart { + locale: Some(locale!("fr-CH")), + quality: 1000, + }, + AcceptLanguagePart { + locale: Some(locale!("fr")), + quality: 900, + }, + AcceptLanguagePart { + locale: Some(locale!("en")), + quality: 800, + }, + AcceptLanguagePart { + locale: Some(locale!("de")), + quality: 700, + }, + AcceptLanguagePart { + locale: None, + quality: 500, + }, + ] + } + ); + } + + #[test] + /// Test that we can decode a header with multiple values unordered, and + /// that the output is ordered by quality + fn test_decode_order() { + let headers = HeaderMap::from_iter([( + ACCEPT_LANGUAGE, + HeaderValue::from_str("*;q=0.5, fr-CH, en;q=0.8, fr;q=0.9, de;q=0.9").unwrap(), + )]); + + let accept_language: Option = headers.typed_get(); + assert!(accept_language.is_some()); + let accept_language = accept_language.unwrap(); + + assert_eq!( + accept_language, + AcceptLanguage { + parts: vec![ + AcceptLanguagePart { + locale: Some(locale!("fr-CH")), + quality: 1000, + }, + AcceptLanguagePart { + locale: Some(locale!("fr")), + quality: 900, + }, + AcceptLanguagePart { + locale: Some(locale!("de")), + quality: 900, + }, + AcceptLanguagePart { + locale: Some(locale!("en")), + quality: 800, + }, + AcceptLanguagePart { + locale: None, + quality: 500, + }, + ] + } + ); + } + + #[test] + fn test_encode() { + let accept_language = AcceptLanguage { + parts: vec![ + AcceptLanguagePart { + locale: Some(locale!("fr-CH")), + quality: 1000, + }, + AcceptLanguagePart { + locale: Some(locale!("fr")), + quality: 900, + }, + AcceptLanguagePart { + locale: Some(locale!("de")), + quality: 900, + }, + AcceptLanguagePart { + locale: Some(locale!("en")), + quality: 800, + }, + AcceptLanguagePart { + locale: None, + quality: 500, + }, + ], + }; + + let mut headers = HeaderMap::new(); + headers.typed_insert(accept_language); + let header = headers.get(ACCEPT_LANGUAGE).unwrap(); + assert_eq!( + header.to_str().unwrap(), + "fr-CH, fr;q=0.9, de;q=0.9, en;q=0.8, *;q=0.5" + ); + } +} diff --git a/crates/axum-utils/src/lib.rs b/crates/axum-utils/src/lib.rs index 9336eb08..82764f09 100644 --- a/crates/axum-utils/src/lib.rs +++ b/crates/axum-utils/src/lib.rs @@ -29,6 +29,7 @@ pub mod error_wrapper; pub mod fancy_error; pub mod http_client_factory; pub mod jwt; +pub mod language_detection; pub mod sentry; pub mod session; pub mod user_authorization;