1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-12-23 16:42:08 +03:00

initial commit

This commit is contained in:
Quentin Gliech
2021-06-10 13:40:54 +02:00
commit c9eb6ca1b4
9 changed files with 866 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
target/

305
Cargo.lock generated Normal file
View File

@@ -0,0 +1,305 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "aho-corasick"
version = "0.7.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
dependencies = [
"memchr",
]
[[package]]
name = "bytes"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040"
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "form_urlencoded"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fc25a87fa4fd2094bffb06925852034d90a17f0d1e05197d4956d3555752191"
dependencies = [
"matches",
"percent-encoding",
]
[[package]]
name = "http"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "idna"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
dependencies = [
"matches",
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "indoc"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5a75aeaaef0ce18b58056d306c27b07436fbb34b8816c53094b76dd81803136"
dependencies = [
"unindent",
]
[[package]]
name = "itoa"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
[[package]]
name = "language-tags"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4345964bb142484797b161f473a503a434de77149dd8c7427788c6e13379388"
dependencies = [
"serde",
]
[[package]]
name = "matches"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
[[package]]
name = "memchr"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc"
[[package]]
name = "oauth2"
version = "0.1.0"
dependencies = [
"http",
"indoc",
"language-tags",
"parse-display",
"serde",
"serde_json",
"url",
]
[[package]]
name = "once_cell"
version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af8b08b04175473088b46763e51ee54da5f9a164bc162f615b91bc179dbf15a3"
[[package]]
name = "parse-display"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7e98ea043e0880940ef455c6c6e5710b4f670b4f0aeff6edf320bb01143fe9"
dependencies = [
"once_cell",
"parse-display-derive",
"regex",
]
[[package]]
name = "parse-display-derive"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "962e8dc54ebea1392eb2f36a205f2efa9437bfe8e95d7a91f070044c363c9684"
dependencies = [
"once_cell",
"proc-macro2",
"quote",
"regex",
"regex-syntax",
"structmeta",
"syn",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e"
[[package]]
name = "proc-macro2"
version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038"
dependencies = [
"unicode-xid",
]
[[package]]
name = "quote"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
dependencies = [
"proc-macro2",
]
[[package]]
name = "regex"
version = "1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "ryu"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
[[package]]
name = "serde"
version = "1.0.126"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec7505abeacaec74ae4778d9d9328fe5a5d04253220a85c4ee022239fc996d03"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.126"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "963a7dbc9895aeac7ac90e74f34a5d5261828f79df35cbed41e10189d3804d43"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "structmeta"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b55b4052fd036e3d1fe74ea978426a3f87997ba803e7a8e69ff0cf99f35a720a"
dependencies = [
"proc-macro2",
"quote",
"structmeta-derive",
"syn",
]
[[package]]
name = "structmeta-derive"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f55502dda4b5fd26b33f6810d7493b4f5d7859bca604bd07ff22a523cd257ee"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "syn"
version = "1.0.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f71489ff30030d2ae598524f61326b902466f72a0fb1a8564c001cc63425bcc7"
dependencies = [
"proc-macro2",
"quote",
"unicode-xid",
]
[[package]]
name = "tinyvec"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b5220f05bb7de7f3f53c7c065e1199b3172696fe2db9f9c4d8ad9b4ee74c342"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "unicode-bidi"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeb8be209bb1c96b7c177c7420d26e04eccacb0eeae6b980e35fcb74678107e0"
dependencies = [
"matches",
]
[[package]]
name = "unicode-normalization"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9"
dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-xid"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "unindent"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7"
[[package]]
name = "url"
version = "2.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c"
dependencies = [
"form_urlencoded",
"idna",
"matches",
"percent-encoding",
"serde",
]

5
Cargo.toml Normal file
View File

@@ -0,0 +1,5 @@
[workspace]
members = [
"oauth2"
]

14
oauth2/Cargo.toml Normal file
View File

@@ -0,0 +1,14 @@
[package]
name = "oauth2"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2018"
[dependencies]
http = "0.2.4"
serde = "1.0.123"
serde_json = "1.0.64"
language-tags = { version = "0.3.2", features = ["serde"] }
url = { version = "2", features = ["serde"] }
parse-display = "0.5.0"
indoc = "1.0.3"

237
oauth2/src/errors.rs Normal file
View File

@@ -0,0 +1,237 @@
use http::status::StatusCode;
use serde::ser::{Serialize, SerializeMap};
use url::Url;
trait OAuth2Error {
/// A single ASCII error code.
///
/// Maps to the required "error" field.
fn error(&self) -> &'static str;
/// Human-readable ASCII text providing additional information, used to assist the client
/// developer in understanding the error that occurred.
///
/// Maps to the optional "error_description" field.
fn description(&self) -> Option<String> {
None
}
/// A URI identifying a human-readable web page with information about the error, used to
/// provide the client developer with additional information about the error.
///
/// Maps to the optional "error_uri" field.
fn uri(&self) -> Option<Url> {
None
}
/// Wraps the error with an ErrorResponse to help serializing.
fn into_response(self) -> ErrorResponse<Self>
where
Self: Sized,
{
ErrorResponse(self)
}
}
trait OAuth2ErrorCode: OAuth2Error {
/// The HTTP status code that must be returned by this error
fn status(&self) -> StatusCode;
}
struct ErrorResponse<T: OAuth2Error>(T);
impl<T: OAuth2ErrorCode> OAuth2ErrorCode for ErrorResponse<T> {
fn status(&self) -> StatusCode {
self.0.status()
}
}
impl<T: OAuth2Error> OAuth2Error for ErrorResponse<T> {
fn error(&self) -> &'static str {
self.0.error()
}
fn description(&self) -> Option<String> {
self.0.description()
}
fn uri(&self) -> Option<Url> {
self.0.uri()
}
}
impl<T: OAuth2Error> Serialize for ErrorResponse<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let error = self.0.error();
let description = self.0.description();
let uri = self.0.uri();
// Count the number of fields to serialize
let len = {
let mut x = 1;
if description.is_some() {
x += 1;
}
if uri.is_some() {
x += 1;
}
x
};
let mut map = serializer.serialize_map(Some(len))?;
map.serialize_entry("error", error)?;
if let Some(ref description) = description {
map.serialize_entry("error_description", description)?;
}
if let Some(ref uri) = uri {
map.serialize_entry("error_uri", uri)?;
}
map.end()
}
}
macro_rules! oauth2_error_def {
($name:ident) => {
pub struct $name;
};
}
macro_rules! oauth2_error_status {
($name:ident, $code:ident) => {
impl $crate::errors::OAuth2ErrorCode for $name {
fn status(&self) -> ::http::status::StatusCode {
::http::status::StatusCode::$code
}
}
};
}
macro_rules! oauth2_error_error {
($err:literal) => {
fn error(&self) -> &'static str {
$err
}
};
}
macro_rules! oauth2_error_description {
($description:expr) => {
fn description(&self) -> Option<String> {
Some(($description).to_string())
}
};
}
macro_rules! oauth2_error {
($name:ident, $err:literal => $description:expr) => {
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
oauth2_error_description!(indoc::indoc! {$description});
}
};
($name:ident, $err:literal) => {
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
}
};
($name:ident, code: $code:ident, $err:literal => $description:expr) => {
oauth2_error!($name, $err => $description);
oauth2_error_status!($name, $code);
};
($name:ident, code: $code:ident, $err:literal) => {
oauth2_error!($name, $err);
oauth2_error_status!($name, $code);
};
}
pub mod rfc6749 {
oauth2_error! {
InvalidRequest,
code: BAD_REQUEST,
"invalid_request" =>
"The request is missing a required parameter, includes an invalid parameter value, \
includes a parameter more than once, or is otherwise malformed."
}
oauth2_error! {
InvalidClient,
code: BAD_REQUEST,
"invalid_client" =>
"Client authentication failed."
}
oauth2_error! {
InvalidGrant,
code: BAD_REQUEST,
"invalid_grant"
}
oauth2_error! {
UnauthorizedClient,
code: BAD_REQUEST,
"unauthorized_client" =>
"The client is not authorized to request an access token using this method."
}
oauth2_error! {
UnsupportedGrantType,
code: BAD_REQUEST,
"unsupported_grant_type" =>
"The authorization grant type is not supported by the authorization server."
}
oauth2_error! {
AccessDenied,
"access_denied" =>
"The resource owner or authorization server denied the request."
}
oauth2_error! {
UnsupportedResponseType,
"unsupported_response_type" =>
"The authorization server does not support obtaining an access token using this method."
}
oauth2_error! {
InvalidScope,
code: BAD_REQUEST,
"invalid_scope" =>
"The requested scope is invalid, unknown, or malformed."
}
oauth2_error! {
ServerError,
"server_error" =>
"The authorization server encountered an unexpected \
condition that prevented it from fulfilling the request."
}
oauth2_error! {
TemporarilyUnavailable,
"temporarily_unavailable" =>
"The authorization server is currently unable to handle \
the request due to a temporary overloading or maintenance \
of the server."
}
}
pub use rfc6749::*;
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn serialize_error() {
let expected = json!({"error": "invalid_grant"});
let actual = serde_json::to_value(InvalidGrant.into_response()).unwrap();
assert_eq!(expected, actual);
}
}

6
oauth2/src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod errors;
pub mod requests;
mod types;
#[cfg(test)]
mod test_utils;

153
oauth2/src/requests.rs Normal file
View File

@@ -0,0 +1,153 @@
use std::hash::Hash;
use language_tags::LanguageTag;
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
use url::Url;
use crate::types::{Seconds, StringHashSet, StringVec};
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Display, FromStr)]
#[display(style = "snake_case")]
pub enum ResponseType {
Code,
IdToken,
Token,
None,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseMode {
Query,
Fragment,
FormPost,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Display {
Page,
Popup,
Touch,
Wap,
}
#[derive(Serialize, Deserialize, FromStr)]
#[serde(rename_all = "snake_case")]
pub enum Prompt {
None,
Login,
Consent,
SelectAccount,
}
#[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest {
response_type: StringHashSet<ResponseType>,
client_id: String,
redirect_uri: Option<Url>,
scope: StringHashSet<String>,
state: Option<String>,
response_mode: Option<ResponseMode>,
nonce: Option<String>,
display: Option<Display>,
max_age: Option<Seconds>,
ui_locales: Option<StringVec<LanguageTag>>,
id_token_hint: Option<String>,
login_hint: Option<String>,
acr_values: Option<StringHashSet<String>>,
}
#[derive(Serialize, Deserialize)]
pub struct AuthorizationResponse {
code: String,
state: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
Bearer,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AuthorizationCodeGrant {
code: String,
redirect_uri: Option<Url>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct RefreshTokenGrant {
refresh_token: String,
scope: Option<StringHashSet<String>>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
pub enum AccessTokenRequest {
AuthorizationCode(AuthorizationCodeGrant),
RefreshToken(RefreshTokenGrant),
#[serde(skip_deserializing, other)]
Unsupported,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AccessTokenResponse {
access_token: String,
token_type: TokenType,
expires_in: Option<Seconds>,
refresh_token: Option<String>,
scope: Option<StringHashSet<String>>,
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use serde_json::json;
use super::*;
use crate::test_utils::assert_serde_json;
#[test]
fn serde_refresh_token_grant() {
let expected = json!({
"grant_type": "refresh_token",
"refresh_token": "abcd",
"scope": "openid profile",
});
let scope = {
let mut s = HashSet::new();
s.insert("openid".to_string());
s.insert("profile".to_string());
Some(s.into())
};
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
refresh_token: "abcd".into(),
scope,
});
assert_serde_json(&req, expected);
}
#[test]
fn serde_authorization_code_grant() {
let expected = json!({
"grant_type": "authorization_code",
"code": "abcd",
"redirect_uri": "https://example.com/redirect",
});
let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: "abcd".into(),
redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
});
assert_serde_json(&req, expected);
}
}

16
oauth2/src/test_utils.rs Normal file
View File

@@ -0,0 +1,16 @@
use std::fmt::Debug;
use serde::{de::DeserializeOwned, Serialize};
#[track_caller]
pub(crate) fn assert_serde_json<T: Serialize + DeserializeOwned + PartialEq + Debug>(
got: &T,
expected_value: serde_json::Value,
) {
let got_value = serde_json::to_value(&got).expect("could not serialize object as JSON value");
assert_eq!(got_value, expected_value);
let expected: T =
serde_json::from_value(expected_value).expect("could not serialize object as JSON value");
assert_eq!(got, &expected);
}

129
oauth2/src/types.rs Normal file
View File

@@ -0,0 +1,129 @@
//! Utilitary types for serde
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, hash::Hash, time::Duration};
/// A HashSet that serializes to a space-separated string in alphanumerical order
#[derive(Debug, PartialEq)]
pub struct StringHashSet<T: Eq + Hash>(HashSet<T>);
impl<T: Eq + Hash> From<HashSet<T>> for StringHashSet<T> {
fn from(set: HashSet<T>) -> Self {
Self(set)
}
}
impl<T: Eq + Hash> From<StringHashSet<T>> for HashSet<T> {
fn from(set: StringHashSet<T>) -> Self {
set.0
}
}
impl<T> Serialize for StringHashSet<T>
where
T: ToString + PartialOrd + Eq + Hash,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
items.sort();
let s = items.join(" ");
serializer.serialize_str(&s)
}
}
impl<'de, T> Deserialize<'de> for StringHashSet<T>
where
T: std::str::FromStr + Eq + Hash,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
let items: Result<HashSet<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
items.map(Into::into).map_err(serde::de::Error::custom)
}
}
/// A Vec that serializes to a space-separated string
pub struct StringVec<T>(Vec<T>);
impl<T> From<Vec<T>> for StringVec<T> {
fn from(set: Vec<T>) -> Self {
Self(set)
}
}
impl<T> From<StringVec<T>> for Vec<T> {
fn from(v: StringVec<T>) -> Self {
v.0
}
}
impl<T> Serialize for StringVec<T>
where
T: ToString,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
let s = items.join(" ");
serializer.serialize_str(&s)
}
}
impl<'de, T> Deserialize<'de> for StringVec<T>
where
T: std::str::FromStr + std::hash::Hash + Eq,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
let items: Result<Vec<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
items.map(Into::into).map_err(serde::de::Error::custom)
}
}
/// A Duration that serializes to seconds
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct Seconds(Duration);
impl From<Duration> for Seconds {
fn from(d: Duration) -> Self {
Self(d)
}
}
impl From<Seconds> for Duration {
fn from(val: Seconds) -> Self {
val.0
}
}
impl Serialize for Seconds {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.as_secs().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Seconds {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let secs = u64::deserialize(deserializer)?;
Ok(Self(Duration::from_secs(secs)))
}
}