You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
axum-utils: Accept-Language header encoder and decoder
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2697,6 +2697,7 @@ dependencies = [
|
||||
"headers",
|
||||
"http",
|
||||
"http-body",
|
||||
"icu_locid",
|
||||
"mas-data-model",
|
||||
"mas-http",
|
||||
"mas-iana",
|
||||
|
@@ -17,6 +17,7 @@ futures-util = "0.3.28"
|
||||
headers = "0.3.9"
|
||||
http.workspace = true
|
||||
http-body = "0.4.5"
|
||||
icu_locid = "1.3.0"
|
||||
mime = "0.3.17"
|
||||
rand.workspace = true
|
||||
sentry = { version = "0.31.7", default-features = false }
|
||||
|
283
crates/axum-utils/src/language_detection.rs
Normal file
283
crates/axum-utils/src/language_detection.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::cmp::Reverse;
|
||||
|
||||
use headers::{Error, Header};
|
||||
use http::{header::ACCEPT_LANGUAGE, HeaderName, HeaderValue};
|
||||
use icu_locid::Locale;
|
||||
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
struct AcceptLanguagePart {
|
||||
// None means *
|
||||
locale: Option<Locale>,
|
||||
|
||||
// Quality is between 0 and 1 with 3 decimal places
|
||||
// Which we map from 0 to 1000, e.g. 0.5 becomes 500
|
||||
quality: u16,
|
||||
}
|
||||
|
||||
impl PartialOrd for AcceptLanguagePart {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
// When comparing two AcceptLanguage structs, we only consider the
|
||||
// quality, in reverse.
|
||||
Reverse(self.quality).partial_cmp(&Reverse(other.quality))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for AcceptLanguagePart {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
// When comparing two AcceptLanguage structs, we only consider the
|
||||
// quality, in reverse.
|
||||
Reverse(self.quality).cmp(&Reverse(other.quality))
|
||||
}
|
||||
}
|
||||
|
||||
/// A header that represents the `Accept-Language` header.
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub struct AcceptLanguage {
|
||||
parts: Vec<AcceptLanguagePart>,
|
||||
}
|
||||
|
||||
/// Utility to trim ASCII whitespace from the start and end of a byte slice
|
||||
const fn trim_bytes(mut bytes: &[u8]) -> &[u8] {
|
||||
// Trim leading and trailing whitespace
|
||||
while let [first, rest @ ..] = bytes {
|
||||
if first.is_ascii_whitespace() {
|
||||
bytes = rest;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
while let [rest @ .., last] = bytes {
|
||||
if last.is_ascii_whitespace() {
|
||||
bytes = rest;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
|
||||
impl Header for AcceptLanguage {
|
||||
fn name() -> &'static HeaderName {
|
||||
&ACCEPT_LANGUAGE
|
||||
}
|
||||
|
||||
fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
|
||||
where
|
||||
Self: Sized,
|
||||
I: Iterator<Item = &'i HeaderValue>,
|
||||
{
|
||||
let mut parts = Vec::new();
|
||||
for value in values {
|
||||
for part in value.as_bytes().split(|b| *b == b',') {
|
||||
let mut it = part.split(|b| *b == b';');
|
||||
let locale = it.next().ok_or(Error::invalid())?;
|
||||
let locale = trim_bytes(locale);
|
||||
|
||||
let locale = match locale {
|
||||
b"*" => None,
|
||||
locale => {
|
||||
let locale =
|
||||
Locale::try_from_bytes(locale).map_err(|_e| Error::invalid())?;
|
||||
Some(locale)
|
||||
}
|
||||
};
|
||||
|
||||
let quality = if let Some(quality) = it.next() {
|
||||
let quality = trim_bytes(quality);
|
||||
let quality = quality.strip_prefix(b"q=").ok_or(Error::invalid())?;
|
||||
let quality = std::str::from_utf8(quality).map_err(|_e| Error::invalid())?;
|
||||
let quality = quality.parse::<f64>().map_err(|_e| Error::invalid())?;
|
||||
// Bound the quality between 0 and 1
|
||||
let quality = quality.min(1_f64).max(0_f64);
|
||||
|
||||
// Make sure the iterator is empty
|
||||
if it.next().is_some() {
|
||||
return Err(Error::invalid());
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
{
|
||||
f64::round(quality * 1000_f64) as u16
|
||||
}
|
||||
} else {
|
||||
1000
|
||||
};
|
||||
|
||||
parts.push(AcceptLanguagePart { locale, quality });
|
||||
}
|
||||
}
|
||||
|
||||
parts.sort();
|
||||
|
||||
Ok(AcceptLanguage { parts })
|
||||
}
|
||||
|
||||
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
|
||||
let mut value = String::new();
|
||||
let mut first = true;
|
||||
for part in &self.parts {
|
||||
if first {
|
||||
first = false;
|
||||
} else {
|
||||
value.push_str(", ");
|
||||
}
|
||||
|
||||
if let Some(locale) = &part.locale {
|
||||
value.push_str(&locale.to_string());
|
||||
} else {
|
||||
value.push('*');
|
||||
}
|
||||
|
||||
if part.quality != 1000 {
|
||||
value.push_str(";q=");
|
||||
value.push_str(&(f64::from(part.quality) / 1000_f64).to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// We know this is safe because we only use ASCII characters
|
||||
values.extend(Some(HeaderValue::from_str(&value).unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use headers::HeaderMapExt;
|
||||
use http::{header::ACCEPT_LANGUAGE, HeaderMap, HeaderValue};
|
||||
use icu_locid::locale;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decode() {
|
||||
let headers = HeaderMap::from_iter([(
|
||||
ACCEPT_LANGUAGE,
|
||||
HeaderValue::from_str("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5").unwrap(),
|
||||
)]);
|
||||
|
||||
let accept_language: Option<AcceptLanguage> = headers.typed_get();
|
||||
assert!(accept_language.is_some());
|
||||
let accept_language = accept_language.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
accept_language,
|
||||
AcceptLanguage {
|
||||
parts: vec![
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("fr-CH")),
|
||||
quality: 1000,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("fr")),
|
||||
quality: 900,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("en")),
|
||||
quality: 800,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("de")),
|
||||
quality: 700,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: None,
|
||||
quality: 500,
|
||||
},
|
||||
]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Test that we can decode a header with multiple values unordered, and
|
||||
/// that the output is ordered by quality
|
||||
fn test_decode_order() {
|
||||
let headers = HeaderMap::from_iter([(
|
||||
ACCEPT_LANGUAGE,
|
||||
HeaderValue::from_str("*;q=0.5, fr-CH, en;q=0.8, fr;q=0.9, de;q=0.9").unwrap(),
|
||||
)]);
|
||||
|
||||
let accept_language: Option<AcceptLanguage> = headers.typed_get();
|
||||
assert!(accept_language.is_some());
|
||||
let accept_language = accept_language.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
accept_language,
|
||||
AcceptLanguage {
|
||||
parts: vec![
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("fr-CH")),
|
||||
quality: 1000,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("fr")),
|
||||
quality: 900,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("de")),
|
||||
quality: 900,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("en")),
|
||||
quality: 800,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: None,
|
||||
quality: 500,
|
||||
},
|
||||
]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode() {
|
||||
let accept_language = AcceptLanguage {
|
||||
parts: vec![
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("fr-CH")),
|
||||
quality: 1000,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("fr")),
|
||||
quality: 900,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("de")),
|
||||
quality: 900,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: Some(locale!("en")),
|
||||
quality: 800,
|
||||
},
|
||||
AcceptLanguagePart {
|
||||
locale: None,
|
||||
quality: 500,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.typed_insert(accept_language);
|
||||
let header = headers.get(ACCEPT_LANGUAGE).unwrap();
|
||||
assert_eq!(
|
||||
header.to_str().unwrap(),
|
||||
"fr-CH, fr;q=0.9, de;q=0.9, en;q=0.8, *;q=0.5"
|
||||
);
|
||||
}
|
||||
}
|
@@ -29,6 +29,7 @@ pub mod error_wrapper;
|
||||
pub mod fancy_error;
|
||||
pub mod http_client_factory;
|
||||
pub mod jwt;
|
||||
pub mod language_detection;
|
||||
pub mod sentry;
|
||||
pub mod session;
|
||||
pub mod user_authorization;
|
||||
|
Reference in New Issue
Block a user