You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
axum-utils: Accept-Language header encoder and decoder
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2697,6 +2697,7 @@ dependencies = [
|
|||||||
"headers",
|
"headers",
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
|
"icu_locid",
|
||||||
"mas-data-model",
|
"mas-data-model",
|
||||||
"mas-http",
|
"mas-http",
|
||||||
"mas-iana",
|
"mas-iana",
|
||||||
|
@@ -17,6 +17,7 @@ futures-util = "0.3.28"
|
|||||||
headers = "0.3.9"
|
headers = "0.3.9"
|
||||||
http.workspace = true
|
http.workspace = true
|
||||||
http-body = "0.4.5"
|
http-body = "0.4.5"
|
||||||
|
icu_locid = "1.3.0"
|
||||||
mime = "0.3.17"
|
mime = "0.3.17"
|
||||||
rand.workspace = true
|
rand.workspace = true
|
||||||
sentry = { version = "0.31.7", default-features = false }
|
sentry = { version = "0.31.7", default-features = false }
|
||||||
|
283
crates/axum-utils/src/language_detection.rs
Normal file
283
crates/axum-utils/src/language_detection.rs
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::cmp::Reverse;
|
||||||
|
|
||||||
|
use headers::{Error, Header};
|
||||||
|
use http::{header::ACCEPT_LANGUAGE, HeaderName, HeaderValue};
|
||||||
|
use icu_locid::Locale;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Debug)]
|
||||||
|
struct AcceptLanguagePart {
|
||||||
|
// None means *
|
||||||
|
locale: Option<Locale>,
|
||||||
|
|
||||||
|
// Quality is between 0 and 1 with 3 decimal places
|
||||||
|
// Which we map from 0 to 1000, e.g. 0.5 becomes 500
|
||||||
|
quality: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for AcceptLanguagePart {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
// When comparing two AcceptLanguage structs, we only consider the
|
||||||
|
// quality, in reverse.
|
||||||
|
Reverse(self.quality).partial_cmp(&Reverse(other.quality))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ord for AcceptLanguagePart {
|
||||||
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||||
|
// When comparing two AcceptLanguage structs, we only consider the
|
||||||
|
// quality, in reverse.
|
||||||
|
Reverse(self.quality).cmp(&Reverse(other.quality))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A header that represents the `Accept-Language` header.
|
||||||
|
#[derive(PartialEq, Eq, Debug)]
|
||||||
|
pub struct AcceptLanguage {
|
||||||
|
parts: Vec<AcceptLanguagePart>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utility to trim ASCII whitespace from the start and end of a byte slice
|
||||||
|
const fn trim_bytes(mut bytes: &[u8]) -> &[u8] {
|
||||||
|
// Trim leading and trailing whitespace
|
||||||
|
while let [first, rest @ ..] = bytes {
|
||||||
|
if first.is_ascii_whitespace() {
|
||||||
|
bytes = rest;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let [rest @ .., last] = bytes {
|
||||||
|
if last.is_ascii_whitespace() {
|
||||||
|
bytes = rest;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Header for AcceptLanguage {
|
||||||
|
fn name() -> &'static HeaderName {
|
||||||
|
&ACCEPT_LANGUAGE
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
I: Iterator<Item = &'i HeaderValue>,
|
||||||
|
{
|
||||||
|
let mut parts = Vec::new();
|
||||||
|
for value in values {
|
||||||
|
for part in value.as_bytes().split(|b| *b == b',') {
|
||||||
|
let mut it = part.split(|b| *b == b';');
|
||||||
|
let locale = it.next().ok_or(Error::invalid())?;
|
||||||
|
let locale = trim_bytes(locale);
|
||||||
|
|
||||||
|
let locale = match locale {
|
||||||
|
b"*" => None,
|
||||||
|
locale => {
|
||||||
|
let locale =
|
||||||
|
Locale::try_from_bytes(locale).map_err(|_e| Error::invalid())?;
|
||||||
|
Some(locale)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let quality = if let Some(quality) = it.next() {
|
||||||
|
let quality = trim_bytes(quality);
|
||||||
|
let quality = quality.strip_prefix(b"q=").ok_or(Error::invalid())?;
|
||||||
|
let quality = std::str::from_utf8(quality).map_err(|_e| Error::invalid())?;
|
||||||
|
let quality = quality.parse::<f64>().map_err(|_e| Error::invalid())?;
|
||||||
|
// Bound the quality between 0 and 1
|
||||||
|
let quality = quality.min(1_f64).max(0_f64);
|
||||||
|
|
||||||
|
// Make sure the iterator is empty
|
||||||
|
if it.next().is_some() {
|
||||||
|
return Err(Error::invalid());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||||
|
{
|
||||||
|
f64::round(quality * 1000_f64) as u16
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
1000
|
||||||
|
};
|
||||||
|
|
||||||
|
parts.push(AcceptLanguagePart { locale, quality });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parts.sort();
|
||||||
|
|
||||||
|
Ok(AcceptLanguage { parts })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
|
||||||
|
let mut value = String::new();
|
||||||
|
let mut first = true;
|
||||||
|
for part in &self.parts {
|
||||||
|
if first {
|
||||||
|
first = false;
|
||||||
|
} else {
|
||||||
|
value.push_str(", ");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(locale) = &part.locale {
|
||||||
|
value.push_str(&locale.to_string());
|
||||||
|
} else {
|
||||||
|
value.push('*');
|
||||||
|
}
|
||||||
|
|
||||||
|
if part.quality != 1000 {
|
||||||
|
value.push_str(";q=");
|
||||||
|
value.push_str(&(f64::from(part.quality) / 1000_f64).to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We know this is safe because we only use ASCII characters
|
||||||
|
values.extend(Some(HeaderValue::from_str(&value).unwrap()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use headers::HeaderMapExt;
|
||||||
|
use http::{header::ACCEPT_LANGUAGE, HeaderMap, HeaderValue};
|
||||||
|
use icu_locid::locale;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode() {
|
||||||
|
let headers = HeaderMap::from_iter([(
|
||||||
|
ACCEPT_LANGUAGE,
|
||||||
|
HeaderValue::from_str("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5").unwrap(),
|
||||||
|
)]);
|
||||||
|
|
||||||
|
let accept_language: Option<AcceptLanguage> = headers.typed_get();
|
||||||
|
assert!(accept_language.is_some());
|
||||||
|
let accept_language = accept_language.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
accept_language,
|
||||||
|
AcceptLanguage {
|
||||||
|
parts: vec![
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("fr-CH")),
|
||||||
|
quality: 1000,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("fr")),
|
||||||
|
quality: 900,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("en")),
|
||||||
|
quality: 800,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("de")),
|
||||||
|
quality: 700,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: None,
|
||||||
|
quality: 500,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
/// Test that we can decode a header with multiple values unordered, and
|
||||||
|
/// that the output is ordered by quality
|
||||||
|
fn test_decode_order() {
|
||||||
|
let headers = HeaderMap::from_iter([(
|
||||||
|
ACCEPT_LANGUAGE,
|
||||||
|
HeaderValue::from_str("*;q=0.5, fr-CH, en;q=0.8, fr;q=0.9, de;q=0.9").unwrap(),
|
||||||
|
)]);
|
||||||
|
|
||||||
|
let accept_language: Option<AcceptLanguage> = headers.typed_get();
|
||||||
|
assert!(accept_language.is_some());
|
||||||
|
let accept_language = accept_language.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
accept_language,
|
||||||
|
AcceptLanguage {
|
||||||
|
parts: vec![
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("fr-CH")),
|
||||||
|
quality: 1000,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("fr")),
|
||||||
|
quality: 900,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("de")),
|
||||||
|
quality: 900,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("en")),
|
||||||
|
quality: 800,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: None,
|
||||||
|
quality: 500,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode() {
|
||||||
|
let accept_language = AcceptLanguage {
|
||||||
|
parts: vec![
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("fr-CH")),
|
||||||
|
quality: 1000,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("fr")),
|
||||||
|
quality: 900,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("de")),
|
||||||
|
quality: 900,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: Some(locale!("en")),
|
||||||
|
quality: 800,
|
||||||
|
},
|
||||||
|
AcceptLanguagePart {
|
||||||
|
locale: None,
|
||||||
|
quality: 500,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.typed_insert(accept_language);
|
||||||
|
let header = headers.get(ACCEPT_LANGUAGE).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
header.to_str().unwrap(),
|
||||||
|
"fr-CH, fr;q=0.9, de;q=0.9, en;q=0.8, *;q=0.5"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@@ -29,6 +29,7 @@ pub mod error_wrapper;
|
|||||||
pub mod fancy_error;
|
pub mod fancy_error;
|
||||||
pub mod http_client_factory;
|
pub mod http_client_factory;
|
||||||
pub mod jwt;
|
pub mod jwt;
|
||||||
|
pub mod language_detection;
|
||||||
pub mod sentry;
|
pub mod sentry;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod user_authorization;
|
pub mod user_authorization;
|
||||||
|
Reference in New Issue
Block a user