You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2026-01-03 17:02:28 +03:00
initial commit
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
target/
|
||||
305
Cargo.lock
generated
Normal file
305
Cargo.lock
generated
Normal file
@@ -0,0 +1,305 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040"
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5fc25a87fa4fd2094bffb06925852034d90a17f0d1e05197d4956d3555752191"
|
||||
dependencies = [
|
||||
"matches",
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
|
||||
dependencies = [
|
||||
"matches",
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5a75aeaaef0ce18b58056d306c27b07436fbb34b8816c53094b76dd81803136"
|
||||
dependencies = [
|
||||
"unindent",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
|
||||
|
||||
[[package]]
|
||||
name = "language-tags"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4345964bb142484797b161f473a503a434de77149dd8c7427788c6e13379388"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matches"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc"
|
||||
|
||||
[[package]]
|
||||
name = "oauth2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"http",
|
||||
"indoc",
|
||||
"language-tags",
|
||||
"parse-display",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af8b08b04175473088b46763e51ee54da5f9a164bc162f615b91bc179dbf15a3"
|
||||
|
||||
[[package]]
|
||||
name = "parse-display"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc7e98ea043e0880940ef455c6c6e5710b4f670b4f0aeff6edf320bb01143fe9"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"parse-display-derive",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parse-display-derive"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "962e8dc54ebea1392eb2f36a205f2efa9437bfe8e95d7a91f070044c363c9684"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"regex-syntax",
|
||||
"structmeta",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038"
|
||||
dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.126"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec7505abeacaec74ae4778d9d9328fe5a5d04253220a85c4ee022239fc996d03"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.126"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "963a7dbc9895aeac7ac90e74f34a5d5261828f79df35cbed41e10189d3804d43"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "structmeta"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b55b4052fd036e3d1fe74ea978426a3f87997ba803e7a8e69ff0cf99f35a720a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"structmeta-derive",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "structmeta-derive"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f55502dda4b5fd26b33f6810d7493b4f5d7859bca604bd07ff22a523cd257ee"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f71489ff30030d2ae598524f61326b902466f72a0fb1a8564c001cc63425bcc7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b5220f05bb7de7f3f53c7c065e1199b3172696fe2db9f9c4d8ad9b4ee74c342"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec_macros"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eeb8be209bb1c96b7c177c7420d26e04eccacb0eeae6b980e35fcb74678107e0"
|
||||
dependencies = [
|
||||
"matches",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9"
|
||||
dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
|
||||
|
||||
[[package]]
|
||||
name = "unindent"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"matches",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
]
|
||||
5
Cargo.toml
Normal file
5
Cargo.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
[workspace]
|
||||
|
||||
members = [
|
||||
"oauth2"
|
||||
]
|
||||
14
oauth2/Cargo.toml
Normal file
14
oauth2/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "oauth2"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
http = "0.2.4"
|
||||
serde = "1.0.123"
|
||||
serde_json = "1.0.64"
|
||||
language-tags = { version = "0.3.2", features = ["serde"] }
|
||||
url = { version = "2", features = ["serde"] }
|
||||
parse-display = "0.5.0"
|
||||
indoc = "1.0.3"
|
||||
237
oauth2/src/errors.rs
Normal file
237
oauth2/src/errors.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
use http::status::StatusCode;
|
||||
use serde::ser::{Serialize, SerializeMap};
|
||||
use url::Url;
|
||||
|
||||
trait OAuth2Error {
|
||||
/// A single ASCII error code.
|
||||
///
|
||||
/// Maps to the required "error" field.
|
||||
fn error(&self) -> &'static str;
|
||||
|
||||
/// Human-readable ASCII text providing additional information, used to assist the client
|
||||
/// developer in understanding the error that occurred.
|
||||
///
|
||||
/// Maps to the optional "error_description" field.
|
||||
fn description(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// A URI identifying a human-readable web page with information about the error, used to
|
||||
/// provide the client developer with additional information about the error.
|
||||
///
|
||||
/// Maps to the optional "error_uri" field.
|
||||
fn uri(&self) -> Option<Url> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Wraps the error with an ErrorResponse to help serializing.
|
||||
fn into_response(self) -> ErrorResponse<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
ErrorResponse(self)
|
||||
}
|
||||
}
|
||||
|
||||
trait OAuth2ErrorCode: OAuth2Error {
|
||||
/// The HTTP status code that must be returned by this error
|
||||
fn status(&self) -> StatusCode;
|
||||
}
|
||||
|
||||
struct ErrorResponse<T: OAuth2Error>(T);
|
||||
|
||||
impl<T: OAuth2ErrorCode> OAuth2ErrorCode for ErrorResponse<T> {
|
||||
fn status(&self) -> StatusCode {
|
||||
self.0.status()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: OAuth2Error> OAuth2Error for ErrorResponse<T> {
|
||||
fn error(&self) -> &'static str {
|
||||
self.0.error()
|
||||
}
|
||||
|
||||
fn description(&self) -> Option<String> {
|
||||
self.0.description()
|
||||
}
|
||||
|
||||
fn uri(&self) -> Option<Url> {
|
||||
self.0.uri()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: OAuth2Error> Serialize for ErrorResponse<T> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let error = self.0.error();
|
||||
let description = self.0.description();
|
||||
let uri = self.0.uri();
|
||||
|
||||
// Count the number of fields to serialize
|
||||
let len = {
|
||||
let mut x = 1;
|
||||
if description.is_some() {
|
||||
x += 1;
|
||||
}
|
||||
if uri.is_some() {
|
||||
x += 1;
|
||||
}
|
||||
x
|
||||
};
|
||||
|
||||
let mut map = serializer.serialize_map(Some(len))?;
|
||||
map.serialize_entry("error", error)?;
|
||||
if let Some(ref description) = description {
|
||||
map.serialize_entry("error_description", description)?;
|
||||
}
|
||||
if let Some(ref uri) = uri {
|
||||
map.serialize_entry("error_uri", uri)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_def {
|
||||
($name:ident) => {
|
||||
pub struct $name;
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_status {
|
||||
($name:ident, $code:ident) => {
|
||||
impl $crate::errors::OAuth2ErrorCode for $name {
|
||||
fn status(&self) -> ::http::status::StatusCode {
|
||||
::http::status::StatusCode::$code
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_error {
|
||||
($err:literal) => {
|
||||
fn error(&self) -> &'static str {
|
||||
$err
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_description {
|
||||
($description:expr) => {
|
||||
fn description(&self) -> Option<String> {
|
||||
Some(($description).to_string())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error {
|
||||
($name:ident, $err:literal => $description:expr) => {
|
||||
oauth2_error_def!($name);
|
||||
impl $crate::errors::OAuth2Error for $name {
|
||||
oauth2_error_error!($err);
|
||||
oauth2_error_description!(indoc::indoc! {$description});
|
||||
}
|
||||
};
|
||||
($name:ident, $err:literal) => {
|
||||
oauth2_error_def!($name);
|
||||
impl $crate::errors::OAuth2Error for $name {
|
||||
oauth2_error_error!($err);
|
||||
}
|
||||
};
|
||||
($name:ident, code: $code:ident, $err:literal => $description:expr) => {
|
||||
oauth2_error!($name, $err => $description);
|
||||
oauth2_error_status!($name, $code);
|
||||
};
|
||||
($name:ident, code: $code:ident, $err:literal) => {
|
||||
oauth2_error!($name, $err);
|
||||
oauth2_error_status!($name, $code);
|
||||
};
|
||||
}
|
||||
|
||||
pub mod rfc6749 {
|
||||
oauth2_error! {
|
||||
InvalidRequest,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_request" =>
|
||||
"The request is missing a required parameter, includes an invalid parameter value, \
|
||||
includes a parameter more than once, or is otherwise malformed."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
InvalidClient,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_client" =>
|
||||
"Client authentication failed."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
InvalidGrant,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_grant"
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
UnauthorizedClient,
|
||||
code: BAD_REQUEST,
|
||||
"unauthorized_client" =>
|
||||
"The client is not authorized to request an access token using this method."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
UnsupportedGrantType,
|
||||
code: BAD_REQUEST,
|
||||
"unsupported_grant_type" =>
|
||||
"The authorization grant type is not supported by the authorization server."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
AccessDenied,
|
||||
"access_denied" =>
|
||||
"The resource owner or authorization server denied the request."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
UnsupportedResponseType,
|
||||
"unsupported_response_type" =>
|
||||
"The authorization server does not support obtaining an access token using this method."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
InvalidScope,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_scope" =>
|
||||
"The requested scope is invalid, unknown, or malformed."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
ServerError,
|
||||
"server_error" =>
|
||||
"The authorization server encountered an unexpected \
|
||||
condition that prevented it from fulfilling the request."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
TemporarilyUnavailable,
|
||||
"temporarily_unavailable" =>
|
||||
"The authorization server is currently unable to handle \
|
||||
the request due to a temporary overloading or maintenance \
|
||||
of the server."
|
||||
}
|
||||
}
|
||||
|
||||
pub use rfc6749::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn serialize_error() {
|
||||
let expected = json!({"error": "invalid_grant"});
|
||||
let actual = serde_json::to_value(InvalidGrant.into_response()).unwrap();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
}
|
||||
6
oauth2/src/lib.rs
Normal file
6
oauth2/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod errors;
|
||||
pub mod requests;
|
||||
mod types;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
||||
153
oauth2/src/requests.rs
Normal file
153
oauth2/src/requests.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use std::hash::Hash;
|
||||
|
||||
use language_tags::LanguageTag;
|
||||
use parse_display::{Display, FromStr};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use crate::types::{Seconds, StringHashSet, StringVec};
|
||||
|
||||
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Display, FromStr)]
|
||||
#[display(style = "snake_case")]
|
||||
pub enum ResponseType {
|
||||
Code,
|
||||
IdToken,
|
||||
Token,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseMode {
|
||||
Query,
|
||||
Fragment,
|
||||
FormPost,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Display {
|
||||
Page,
|
||||
Popup,
|
||||
Touch,
|
||||
Wap,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, FromStr)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Prompt {
|
||||
None,
|
||||
Login,
|
||||
Consent,
|
||||
SelectAccount,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AuthorizationRequest {
|
||||
response_type: StringHashSet<ResponseType>,
|
||||
client_id: String,
|
||||
redirect_uri: Option<Url>,
|
||||
scope: StringHashSet<String>,
|
||||
state: Option<String>,
|
||||
response_mode: Option<ResponseMode>,
|
||||
nonce: Option<String>,
|
||||
display: Option<Display>,
|
||||
max_age: Option<Seconds>,
|
||||
ui_locales: Option<StringVec<LanguageTag>>,
|
||||
id_token_hint: Option<String>,
|
||||
login_hint: Option<String>,
|
||||
acr_values: Option<StringHashSet<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AuthorizationResponse {
|
||||
code: String,
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TokenType {
|
||||
Bearer,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct AuthorizationCodeGrant {
|
||||
code: String,
|
||||
redirect_uri: Option<Url>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct RefreshTokenGrant {
|
||||
refresh_token: String,
|
||||
scope: Option<StringHashSet<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[serde(tag = "grant_type", rename_all = "snake_case")]
|
||||
pub enum AccessTokenRequest {
|
||||
AuthorizationCode(AuthorizationCodeGrant),
|
||||
RefreshToken(RefreshTokenGrant),
|
||||
#[serde(skip_deserializing, other)]
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct AccessTokenResponse {
|
||||
access_token: String,
|
||||
token_type: TokenType,
|
||||
expires_in: Option<Seconds>,
|
||||
refresh_token: Option<String>,
|
||||
scope: Option<StringHashSet<String>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::assert_serde_json;
|
||||
|
||||
#[test]
|
||||
fn serde_refresh_token_grant() {
|
||||
let expected = json!({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": "abcd",
|
||||
"scope": "openid profile",
|
||||
});
|
||||
|
||||
let scope = {
|
||||
let mut s = HashSet::new();
|
||||
s.insert("openid".to_string());
|
||||
s.insert("profile".to_string());
|
||||
Some(s.into())
|
||||
};
|
||||
|
||||
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
|
||||
refresh_token: "abcd".into(),
|
||||
scope,
|
||||
});
|
||||
|
||||
assert_serde_json(&req, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_authorization_code_grant() {
|
||||
let expected = json!({
|
||||
"grant_type": "authorization_code",
|
||||
"code": "abcd",
|
||||
"redirect_uri": "https://example.com/redirect",
|
||||
});
|
||||
|
||||
let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
|
||||
code: "abcd".into(),
|
||||
redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
|
||||
});
|
||||
|
||||
assert_serde_json(&req, expected);
|
||||
}
|
||||
}
|
||||
16
oauth2/src/test_utils.rs
Normal file
16
oauth2/src/test_utils.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
#[track_caller]
|
||||
pub(crate) fn assert_serde_json<T: Serialize + DeserializeOwned + PartialEq + Debug>(
|
||||
got: &T,
|
||||
expected_value: serde_json::Value,
|
||||
) {
|
||||
let got_value = serde_json::to_value(&got).expect("could not serialize object as JSON value");
|
||||
assert_eq!(got_value, expected_value);
|
||||
|
||||
let expected: T =
|
||||
serde_json::from_value(expected_value).expect("could not serialize object as JSON value");
|
||||
assert_eq!(got, &expected);
|
||||
}
|
||||
129
oauth2/src/types.rs
Normal file
129
oauth2/src/types.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
//! Utilitary types for serde
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashSet, hash::Hash, time::Duration};
|
||||
|
||||
/// A HashSet that serializes to a space-separated string in alphanumerical order
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct StringHashSet<T: Eq + Hash>(HashSet<T>);
|
||||
|
||||
impl<T: Eq + Hash> From<HashSet<T>> for StringHashSet<T> {
|
||||
fn from(set: HashSet<T>) -> Self {
|
||||
Self(set)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq + Hash> From<StringHashSet<T>> for HashSet<T> {
|
||||
fn from(set: StringHashSet<T>) -> Self {
|
||||
set.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Serialize for StringHashSet<T>
|
||||
where
|
||||
T: ToString + PartialOrd + Eq + Hash,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
|
||||
items.sort();
|
||||
let s = items.join(" ");
|
||||
serializer.serialize_str(&s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T> Deserialize<'de> for StringHashSet<T>
|
||||
where
|
||||
T: std::str::FromStr + Eq + Hash,
|
||||
<T as std::str::FromStr>::Err: std::fmt::Display,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s: String = Deserialize::deserialize(deserializer)?;
|
||||
let items: Result<HashSet<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
|
||||
items.map(Into::into).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Vec that serializes to a space-separated string
|
||||
pub struct StringVec<T>(Vec<T>);
|
||||
|
||||
impl<T> From<Vec<T>> for StringVec<T> {
|
||||
fn from(set: Vec<T>) -> Self {
|
||||
Self(set)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<StringVec<T>> for Vec<T> {
|
||||
fn from(v: StringVec<T>) -> Self {
|
||||
v.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Serialize for StringVec<T>
|
||||
where
|
||||
T: ToString,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
|
||||
let s = items.join(" ");
|
||||
serializer.serialize_str(&s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T> Deserialize<'de> for StringVec<T>
|
||||
where
|
||||
T: std::str::FromStr + std::hash::Hash + Eq,
|
||||
<T as std::str::FromStr>::Err: std::fmt::Display,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s: String = Deserialize::deserialize(deserializer)?;
|
||||
let items: Result<Vec<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
|
||||
items.map(Into::into).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Duration that serializes to seconds
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct Seconds(Duration);
|
||||
|
||||
impl From<Duration> for Seconds {
|
||||
fn from(d: Duration) -> Self {
|
||||
Self(d)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Seconds> for Duration {
|
||||
fn from(val: Seconds) -> Self {
|
||||
val.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Seconds {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
self.0.as_secs().serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Seconds {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let secs = u64::deserialize(deserializer)?;
|
||||
Ok(Self(Duration::from_secs(secs)))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user