1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Split the service in multiple crates

This commit is contained in:
Quentin Gliech
2021-09-16 14:43:56 +02:00
parent da91564bf9
commit a44e33931c
83 changed files with 311 additions and 174 deletions

View File

@ -0,0 +1,21 @@
[package]
name = "oauth2-types"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2018"
license = "Apache-2.0"
[dependencies]
http = "0.2.4"
serde = "1.0.130"
serde_json = "1.0.68"
language-tags = { version = "0.3.2", features = ["serde"] }
url = { version = "2.2.2", features = ["serde"] }
parse-display = "0.5.1"
indoc = "1.0.3"
serde_with = { version = "1.10.0", features = ["chrono"] }
sqlx = { version = "0.5.7", default-features = false, optional = true }
chrono = "0.4.19"
[features]
sqlx_type = ["sqlx"]

View File

@ -0,0 +1,268 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use http::status::StatusCode;
use serde::ser::{Serialize, SerializeMap};
use url::Url;
pub trait OAuth2Error: std::fmt::Debug + Send + Sync {
/// A single ASCII error code.
///
/// Maps to the required "error" field.
fn error(&self) -> &'static str;
/// Human-readable ASCII text providing additional information, used to
/// assist the client developer in understanding the error that
/// occurred.
///
/// Maps to the optional `error_description` field.
fn description(&self) -> Option<String> {
None
}
/// A URI identifying a human-readable web page with information about the
/// error, used to provide the client developer with additional
/// information about the error.
///
/// Maps to the optional `error_uri` field.
fn uri(&self) -> Option<Url> {
None
}
/// Wraps the error with an `ErrorResponse` to help serializing.
fn into_response(self) -> ErrorResponse
where
Self: Sized + 'static,
{
ErrorResponse(Box::new(self))
}
}
pub trait OAuth2ErrorCode: OAuth2Error + 'static {
/// The HTTP status code that must be returned by this error
fn status(&self) -> StatusCode;
}
impl OAuth2Error for &Box<dyn OAuth2ErrorCode> {
fn error(&self) -> &'static str {
self.as_ref().error()
}
fn description(&self) -> Option<String> {
self.as_ref().description()
}
fn uri(&self) -> Option<Url> {
self.as_ref().uri()
}
}
#[derive(Debug)]
pub struct ErrorResponse(Box<dyn OAuth2Error>);
impl From<Box<dyn OAuth2Error>> for ErrorResponse {
fn from(b: Box<dyn OAuth2Error>) -> Self {
Self(b)
}
}
impl OAuth2Error for ErrorResponse {
fn error(&self) -> &'static str {
self.0.error()
}
fn description(&self) -> Option<String> {
self.0.description()
}
fn uri(&self) -> Option<Url> {
self.0.uri()
}
}
impl Serialize for ErrorResponse {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let error = self.0.error();
let description = self.0.description();
let uri = self.0.uri();
// Count the number of fields to serialize
let len = {
let mut x = 1;
if description.is_some() {
x += 1;
}
if uri.is_some() {
x += 1;
}
x
};
let mut map = serializer.serialize_map(Some(len))?;
map.serialize_entry("error", error)?;
if let Some(ref description) = description {
map.serialize_entry("error_description", description)?;
}
if let Some(ref uri) = uri {
map.serialize_entry("error_uri", uri)?;
}
map.end()
}
}
macro_rules! oauth2_error_def {
($name:ident) => {
#[derive(Debug, Clone)]
pub struct $name;
};
}
macro_rules! oauth2_error_status {
($name:ident, $code:ident) => {
impl $crate::errors::OAuth2ErrorCode for $name {
fn status(&self) -> ::http::status::StatusCode {
::http::status::StatusCode::$code
}
}
};
}
macro_rules! oauth2_error_error {
($err:literal) => {
fn error(&self) -> &'static str {
$err
}
};
}
macro_rules! oauth2_error_description {
($description:expr) => {
fn description(&self) -> Option<String> {
Some(($description).to_string())
}
};
}
macro_rules! oauth2_error {
($name:ident, $err:literal => $description:expr) => {
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
oauth2_error_description!(indoc::indoc! {$description});
}
};
($name:ident, $err:literal) => {
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
}
};
($name:ident, code: $code:ident, $err:literal => $description:expr) => {
oauth2_error!($name, $err => $description);
oauth2_error_status!($name, $code);
};
($name:ident, code: $code:ident, $err:literal) => {
oauth2_error!($name, $err);
oauth2_error_status!($name, $code);
};
}
pub mod rfc6749 {
oauth2_error! {
InvalidRequest,
code: BAD_REQUEST,
"invalid_request" =>
"The request is missing a required parameter, includes an invalid parameter value, \
includes a parameter more than once, or is otherwise malformed."
}
oauth2_error! {
InvalidClient,
code: BAD_REQUEST,
"invalid_client" =>
"Client authentication failed."
}
oauth2_error! {
InvalidGrant,
code: BAD_REQUEST,
"invalid_grant"
}
oauth2_error! {
UnauthorizedClient,
code: BAD_REQUEST,
"unauthorized_client" =>
"The client is not authorized to request an access token using this method."
}
oauth2_error! {
UnsupportedGrantType,
code: BAD_REQUEST,
"unsupported_grant_type" =>
"The authorization grant type is not supported by the authorization server."
}
oauth2_error! {
AccessDenied,
"access_denied" =>
"The resource owner or authorization server denied the request."
}
oauth2_error! {
UnsupportedResponseType,
"unsupported_response_type" =>
"The authorization server does not support obtaining an access token using this method."
}
oauth2_error! {
InvalidScope,
code: BAD_REQUEST,
"invalid_scope" =>
"The requested scope is invalid, unknown, or malformed."
}
oauth2_error! {
ServerError,
"server_error" =>
"The authorization server encountered an unexpected \
condition that prevented it from fulfilling the request."
}
oauth2_error! {
TemporarilyUnavailable,
"temporarily_unavailable" =>
"The authorization server is currently unable to handle \
the request due to a temporary overloading or maintenance \
of the server."
}
}
pub use rfc6749::*;
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn serialize_error() {
let expected = json!({"error": "invalid_grant"});
let actual = serde_json::to_value(InvalidGrant.into_response()).unwrap();
assert_eq!(expected, actual);
}
}

View File

@ -0,0 +1,25 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(clippy::all)]
#![warn(clippy::pedantic)]
pub mod errors;
pub mod oidc;
pub mod pkce;
pub mod requests;
#[cfg(test)]
mod test_utils;

View File

@ -0,0 +1,75 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use serde::Serialize;
use serde_with::skip_serializing_none;
use url::Url;
use crate::{
pkce::CodeChallengeMethod,
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
};
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
#[skip_serializing_none]
#[derive(Serialize, Clone)]
pub struct Metadata {
/// The authorization server's issuer identifier, which is a URL that uses
/// the "https" scheme and has no query or fragment components.
pub issuer: Url,
/// URL of the authorization server's authorization endpoint.
pub authorization_endpoint: Option<Url>,
/// URL of the authorization server's token endpoint.
pub token_endpoint: Option<Url>,
/// URL of the authorization server's JWK Set document.
pub jwks_uri: Option<Url>,
/// URL of the authorization server's OAuth 2.0 Dynamic Client Registration
/// endpoint.
pub registration_endpoint: Option<Url>,
/// JSON array containing a list of the OAuth 2.0 "scope" values that this
/// authorization server supports.
pub scopes_supported: Option<HashSet<String>>,
/// JSON array containing a list of the OAuth 2.0 "response_type" values
/// that this authorization server supports.
pub response_types_supported: Option<HashSet<String>>,
/// JSON array containing a list of the OAuth 2.0 "response_mode" values
/// that this authorization server supports, as specified in "OAuth 2.0
/// Multiple Response Type Encoding Practices".
pub response_modes_supported: Option<HashSet<ResponseMode>>,
/// JSON array containing a list of the OAuth 2.0 grant type values that
/// this authorization server supports.
pub grant_types_supported: Option<HashSet<GrantType>>,
/// JSON array containing a list of client authentication methods supported
/// by this token endpoint.
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
/// PKCE code challenge methods supported by this authorization server
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
/// URL of the authorization server's OAuth 2.0 introspection endpoint.
pub introspection_endpoint: Option<Url>,
pub userinfo_endpoint: Option<Url>,
}

View File

@ -0,0 +1,48 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[cfg_attr(feature = "sqlx_type", derive(sqlx::Type))]
#[repr(i8)]
pub enum CodeChallengeMethod {
#[serde(rename = "plain")]
#[display("plain")]
Plain = 0,
#[serde(rename = "S256")]
#[display("S256")]
S256 = 1,
}
#[derive(Serialize, Deserialize)]
pub struct Request {
pub code_challenge_method: CodeChallengeMethod,
pub code_challenge: String,
}

View File

@ -0,0 +1,414 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, hash::Hash};
use chrono::{DateTime, Duration, Utc};
use language_tags::LanguageTag;
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
use serde_with::{
rust::StringWithSeparator, serde_as, skip_serializing_none, DurationSeconds, SpaceSeparator,
TimestampSeconds,
};
use url::Url;
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[display(style = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum ResponseType {
Code,
IdToken,
Token,
None,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum ResponseMode {
Query,
Fragment,
FormPost,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthenticationMethod {
None,
ClientSecretPost,
ClientSecretBasic,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum Display {
Page,
Popup,
Touch,
Wap,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[display(style = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum Prompt {
None,
Login,
Consent,
SelectAccount,
}
#[serde_as]
#[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest {
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
pub response_type: HashSet<ResponseType>,
pub client_id: String,
pub redirect_uri: Option<Url>,
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
pub scope: HashSet<String>,
pub state: Option<String>,
pub response_mode: Option<ResponseMode>,
pub nonce: Option<String>,
display: Option<Display>,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
#[serde(default)]
pub max_age: Option<Duration>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, LanguageTag>>")]
#[serde(default)]
ui_locales: Option<Vec<LanguageTag>>,
id_token_hint: Option<String>,
login_hint: Option<String>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
#[serde(default)]
acr_values: Option<HashSet<String>>,
}
#[derive(Serialize, Deserialize, Default)]
pub struct AuthorizationResponse<R> {
pub code: Option<String>,
pub state: Option<String>,
#[serde(flatten)]
pub response: R,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
Bearer,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AuthorizationCodeGrant {
pub code: String,
#[serde(default)]
pub redirect_uri: Option<Url>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct RefreshTokenGrant {
pub refresh_token: String,
#[serde(default)]
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct ClientCredentialsGrant {
#[serde(default)]
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum GrantType {
AuthorizationCode,
RefreshToken,
ClientCredentials,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
pub enum AccessTokenRequest {
AuthorizationCode(AuthorizationCodeGrant),
RefreshToken(RefreshTokenGrant),
ClientCredentials(ClientCredentialsGrant),
#[serde(skip_deserializing, other)]
Unsupported,
}
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AccessTokenResponse {
access_token: String,
refresh_token: Option<String>,
// TODO: this should be somewhere else
id_token: Option<String>,
token_type: TokenType,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
expires_in: Option<Duration>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
impl AccessTokenResponse {
#[must_use]
pub fn new(access_token: String) -> AccessTokenResponse {
AccessTokenResponse {
access_token,
refresh_token: None,
id_token: None,
token_type: TokenType::Bearer,
expires_in: None,
scope: None,
}
}
#[must_use]
pub fn with_refresh_token(mut self, refresh_token: String) -> Self {
self.refresh_token = Some(refresh_token);
self
}
#[must_use]
pub fn with_id_token(mut self, id_token: String) -> Self {
self.id_token = Some(id_token);
self
}
#[must_use]
pub fn with_scopes(mut self, scope: HashSet<String>) -> Self {
self.scope = Some(scope);
self
}
#[must_use]
pub fn with_expires_in(mut self, expires_in: Duration) -> Self {
self.expires_in = Some(expires_in);
self
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TokenTypeHint {
AccessToken,
RefreshToken,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct IntrospectionRequest {
pub token: String,
#[serde(default)]
pub token_type_hint: Option<TokenTypeHint>,
}
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq, Default)]
pub struct IntrospectionResponse {
pub active: bool,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
pub scope: Option<HashSet<String>>,
pub client_id: Option<String>,
pub username: Option<String>,
pub token_type: Option<TokenTypeHint>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub exp: Option<DateTime<Utc>>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub iat: Option<DateTime<Utc>>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub nbf: Option<DateTime<Utc>>,
pub sub: Option<String>,
pub aud: Option<String>,
pub iss: Option<String>,
pub jti: Option<String>,
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use serde_json::json;
use super::*;
use crate::test_utils::assert_serde_json;
#[test]
fn serde_refresh_token_grant() {
let expected = json!({
"grant_type": "refresh_token",
"refresh_token": "abcd",
"scope": "openid",
});
let scope = {
let mut s = HashSet::new();
// TODO: insert multiple scopes and test it. It's a bit tricky to test since
// HashSet have no guarantees regarding the ordering of items, so right
// now the output is unstable.
s.insert("openid".to_string());
Some(s)
};
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
refresh_token: "abcd".into(),
scope,
});
assert_serde_json(&req, expected);
}
#[test]
fn serde_authorization_code_grant() {
let expected = json!({
"grant_type": "authorization_code",
"code": "abcd",
"redirect_uri": "https://example.com/redirect",
});
let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: "abcd".into(),
redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
});
assert_serde_json(&req, expected);
}
}

View File

@ -0,0 +1,30 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Debug;
use serde::{de::DeserializeOwned, Serialize};
#[track_caller]
pub(crate) fn assert_serde_json<T: Serialize + DeserializeOwned + PartialEq + Debug>(
got: &T,
expected_value: serde_json::Value,
) {
let got_value = serde_json::to_value(&got).expect("could not serialize object as JSON value");
assert_eq!(got_value, expected_value);
let expected: T =
serde_json::from_value(expected_value).expect("could not serialize object as JSON value");
assert_eq!(got, &expected);
}