1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

axum-utils: Accept-Language header encoder and decoder

This commit is contained in:
Quentin Gliech
2023-10-03 18:15:46 +02:00
parent e21d193942
commit 730ad4674b
4 changed files with 286 additions and 0 deletions

1
Cargo.lock generated
View File

@@ -2697,6 +2697,7 @@ dependencies = [
"headers",
"http",
"http-body",
"icu_locid",
"mas-data-model",
"mas-http",
"mas-iana",

View File

@@ -17,6 +17,7 @@ futures-util = "0.3.28"
headers = "0.3.9"
http.workspace = true
http-body = "0.4.5"
icu_locid = "1.3.0"
mime = "0.3.17"
rand.workspace = true
sentry = { version = "0.31.7", default-features = false }

View File

@@ -0,0 +1,283 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Reverse;
use headers::{Error, Header};
use http::{header::ACCEPT_LANGUAGE, HeaderName, HeaderValue};
use icu_locid::Locale;
#[derive(PartialEq, Eq, Debug)]
struct AcceptLanguagePart {
// None means *
locale: Option<Locale>,
// Quality is between 0 and 1 with 3 decimal places
// Which we map from 0 to 1000, e.g. 0.5 becomes 500
quality: u16,
}
impl PartialOrd for AcceptLanguagePart {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
// When comparing two AcceptLanguage structs, we only consider the
// quality, in reverse.
Reverse(self.quality).partial_cmp(&Reverse(other.quality))
}
}
impl Ord for AcceptLanguagePart {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// When comparing two AcceptLanguage structs, we only consider the
// quality, in reverse.
Reverse(self.quality).cmp(&Reverse(other.quality))
}
}
/// A header that represents the `Accept-Language` header.
#[derive(PartialEq, Eq, Debug)]
pub struct AcceptLanguage {
parts: Vec<AcceptLanguagePart>,
}
/// Utility to trim ASCII whitespace from the start and end of a byte slice
const fn trim_bytes(mut bytes: &[u8]) -> &[u8] {
// Trim leading and trailing whitespace
while let [first, rest @ ..] = bytes {
if first.is_ascii_whitespace() {
bytes = rest;
} else {
break;
}
}
while let [rest @ .., last] = bytes {
if last.is_ascii_whitespace() {
bytes = rest;
} else {
break;
}
}
bytes
}
impl Header for AcceptLanguage {
fn name() -> &'static HeaderName {
&ACCEPT_LANGUAGE
}
fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
where
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
let mut parts = Vec::new();
for value in values {
for part in value.as_bytes().split(|b| *b == b',') {
let mut it = part.split(|b| *b == b';');
let locale = it.next().ok_or(Error::invalid())?;
let locale = trim_bytes(locale);
let locale = match locale {
b"*" => None,
locale => {
let locale =
Locale::try_from_bytes(locale).map_err(|_e| Error::invalid())?;
Some(locale)
}
};
let quality = if let Some(quality) = it.next() {
let quality = trim_bytes(quality);
let quality = quality.strip_prefix(b"q=").ok_or(Error::invalid())?;
let quality = std::str::from_utf8(quality).map_err(|_e| Error::invalid())?;
let quality = quality.parse::<f64>().map_err(|_e| Error::invalid())?;
// Bound the quality between 0 and 1
let quality = quality.min(1_f64).max(0_f64);
// Make sure the iterator is empty
if it.next().is_some() {
return Err(Error::invalid());
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
{
f64::round(quality * 1000_f64) as u16
}
} else {
1000
};
parts.push(AcceptLanguagePart { locale, quality });
}
}
parts.sort();
Ok(AcceptLanguage { parts })
}
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
let mut value = String::new();
let mut first = true;
for part in &self.parts {
if first {
first = false;
} else {
value.push_str(", ");
}
if let Some(locale) = &part.locale {
value.push_str(&locale.to_string());
} else {
value.push('*');
}
if part.quality != 1000 {
value.push_str(";q=");
value.push_str(&(f64::from(part.quality) / 1000_f64).to_string());
}
}
// We know this is safe because we only use ASCII characters
values.extend(Some(HeaderValue::from_str(&value).unwrap()));
}
}
#[cfg(test)]
mod tests {
use headers::HeaderMapExt;
use http::{header::ACCEPT_LANGUAGE, HeaderMap, HeaderValue};
use icu_locid::locale;
use super::*;
#[test]
fn test_decode() {
let headers = HeaderMap::from_iter([(
ACCEPT_LANGUAGE,
HeaderValue::from_str("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5").unwrap(),
)]);
let accept_language: Option<AcceptLanguage> = headers.typed_get();
assert!(accept_language.is_some());
let accept_language = accept_language.unwrap();
assert_eq!(
accept_language,
AcceptLanguage {
parts: vec![
AcceptLanguagePart {
locale: Some(locale!("fr-CH")),
quality: 1000,
},
AcceptLanguagePart {
locale: Some(locale!("fr")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("en")),
quality: 800,
},
AcceptLanguagePart {
locale: Some(locale!("de")),
quality: 700,
},
AcceptLanguagePart {
locale: None,
quality: 500,
},
]
}
);
}
#[test]
/// Test that we can decode a header with multiple values unordered, and
/// that the output is ordered by quality
fn test_decode_order() {
let headers = HeaderMap::from_iter([(
ACCEPT_LANGUAGE,
HeaderValue::from_str("*;q=0.5, fr-CH, en;q=0.8, fr;q=0.9, de;q=0.9").unwrap(),
)]);
let accept_language: Option<AcceptLanguage> = headers.typed_get();
assert!(accept_language.is_some());
let accept_language = accept_language.unwrap();
assert_eq!(
accept_language,
AcceptLanguage {
parts: vec![
AcceptLanguagePart {
locale: Some(locale!("fr-CH")),
quality: 1000,
},
AcceptLanguagePart {
locale: Some(locale!("fr")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("de")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("en")),
quality: 800,
},
AcceptLanguagePart {
locale: None,
quality: 500,
},
]
}
);
}
#[test]
fn test_encode() {
let accept_language = AcceptLanguage {
parts: vec![
AcceptLanguagePart {
locale: Some(locale!("fr-CH")),
quality: 1000,
},
AcceptLanguagePart {
locale: Some(locale!("fr")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("de")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("en")),
quality: 800,
},
AcceptLanguagePart {
locale: None,
quality: 500,
},
],
};
let mut headers = HeaderMap::new();
headers.typed_insert(accept_language);
let header = headers.get(ACCEPT_LANGUAGE).unwrap();
assert_eq!(
header.to_str().unwrap(),
"fr-CH, fr;q=0.9, de;q=0.9, en;q=0.8, *;q=0.5"
);
}
}

View File

@@ -29,6 +29,7 @@ pub mod error_wrapper;
pub mod fancy_error;
pub mod http_client_factory;
pub mod jwt;
pub mod language_detection;
pub mod sentry;
pub mod session;
pub mod user_authorization;