diff --git a/Cargo.lock b/Cargo.lock index 6cf91ff4..33aaed3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1626,6 +1626,7 @@ dependencies = [ "data-encoding", "http", "indoc", + "itertools", "language-tags", "parse-display", "serde", @@ -1633,6 +1634,7 @@ dependencies = [ "serde_with", "sha2", "sqlx", + "thiserror", "url", ] diff --git a/crates/oauth2-types/Cargo.toml b/crates/oauth2-types/Cargo.toml index 4f47417b..f550b747 100644 --- a/crates/oauth2-types/Cargo.toml +++ b/crates/oauth2-types/Cargo.toml @@ -18,6 +18,8 @@ sqlx = { version = "0.5.9", default-features = false, optional = true } chrono = "0.4.19" sha2 = "0.9.8" data-encoding = "2.3.2" +thiserror = "1.0.29" +itertools = "0.10.1" [features] sqlx_type = ["sqlx"] diff --git a/crates/oauth2-types/src/lib.rs b/crates/oauth2-types/src/lib.rs index 57e35926..09d52b4b 100644 --- a/crates/oauth2-types/src/lib.rs +++ b/crates/oauth2-types/src/lib.rs @@ -20,6 +20,7 @@ pub mod errors; pub mod oidc; pub mod pkce; pub mod requests; +pub mod scope; #[cfg(test)] mod test_utils; diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index 6a7c9763..5ab20a32 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -24,6 +24,8 @@ use serde_with::{ }; use url::Url; +use crate::scope::Scope; + // ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml #[derive( @@ -212,22 +214,18 @@ pub struct AuthorizationCodeGrant { pub code_verifier: Option, } -#[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct RefreshTokenGrant { pub refresh_token: String, #[serde(default)] - #[serde_as(as = "Option>")] - scope: Option>, + scope: Option, } -#[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct ClientCredentialsGrant { #[serde(default)] - #[serde_as(as = "Option>")] - scope: Option>, + scope: Option, } #[derive( @@ -275,8 +273,7 @@ pub struct AccessTokenResponse { #[serde_as(as = "Option>")] expires_in: Option, - #[serde_as(as = "Option>")] - scope: Option>, + scope: Option, } impl AccessTokenResponse { @@ -305,7 +302,7 @@ impl AccessTokenResponse { } #[must_use] - pub fn with_scopes(mut self, scope: HashSet) -> Self { + pub fn with_scope(mut self, scope: Scope) -> Self { self.scope = Some(scope); self } @@ -339,8 +336,7 @@ pub struct IntrospectionRequest { pub struct IntrospectionResponse { pub active: bool, - #[serde_as(as = "Option>")] - pub scope: Option>, + pub scope: Option, pub client_id: Option, @@ -368,12 +364,10 @@ pub struct IntrospectionResponse { #[cfg(test)] mod tests { - use std::collections::HashSet; - use serde_json::json; use super::*; - use crate::test_utils::assert_serde_json; + use crate::{scope::OPENID, test_utils::assert_serde_json}; #[test] fn serde_refresh_token_grant() { @@ -383,14 +377,10 @@ mod tests { "scope": "openid", }); - let scope = { - let mut s = HashSet::new(); - // TODO: insert multiple scopes and test it. It's a bit tricky to test since - // HashSet have no guarantees regarding the ordering of items, so right - // now the output is unstable. - s.insert("openid".to_string()); - Some(s) - }; + // TODO: insert multiple scopes and test it. It's a bit tricky to test since + // HashSet have no guarantees regarding the ordering of items, so right + // now the output is unstable. + let scope: Option = Some(vec![OPENID].into_iter().collect()); let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant { refresh_token: "abcd".into(), diff --git a/crates/oauth2-types/src/scope.rs b/crates/oauth2-types/src/scope.rs new file mode 100644 index 00000000..0352f73b --- /dev/null +++ b/crates/oauth2-types/src/scope.rs @@ -0,0 +1,196 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(clippy::module_name_repetitions)] + +use std::{borrow::Cow, collections::HashSet, iter::FromIterator, ops::Deref, str::FromStr}; + +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[error("Invalid scope format")] +pub struct InvalidScope; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ScopeToken(Cow<'static, str>); + +impl ScopeToken { + const fn well_known(token: &'static str) -> Self { + Self(Cow::Borrowed(token)) + } +} + +pub const OPENID: ScopeToken = ScopeToken::well_known("openid"); +pub const PROFILE: ScopeToken = ScopeToken::well_known("profile"); +pub const EMAIL: ScopeToken = ScopeToken::well_known("email"); +pub const ADDRESS: ScopeToken = ScopeToken::well_known("address"); +pub const PHONE: ScopeToken = ScopeToken::well_known("phone"); +pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::well_known("offline_access"); + +// As per RFC6749 appendix A: +// https://datatracker.ietf.org/doc/html/rfc6749#appendix-A +// +// NQCHAR = %x21 / %x23-5B / %x5D-7E +fn nqchar(c: char) -> bool { + '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c) +} + +impl FromStr for ScopeToken { + type Err = InvalidScope; + + fn from_str(s: &str) -> Result { + // As per RFC6749 appendix A.4: + // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4 + // + // scope-token = 1*NQCHAR + if !s.is_empty() && s.chars().all(nqchar) { + Ok(ScopeToken(Cow::Owned(s.into()))) + } else { + Err(InvalidScope) + } + } +} + +impl Deref for ScopeToken { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl ToString for ScopeToken { + fn to_string(&self) -> String { + self.0.to_string() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Scope(HashSet); + +impl FromStr for Scope { + type Err = InvalidScope; + + fn from_str(s: &str) -> Result { + // As per RFC6749 appendix A.4: + // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4 + // + // scope = scope-token *( SP scope-token ) + let scopes: Result, InvalidScope> = + s.split(' ').map(ScopeToken::from_str).collect(); + + Ok(Self(scopes?)) + } +} + +impl Scope { + #[must_use] + pub fn is_empty(&self) -> bool { + // This should never be the case? + self.0.is_empty() + } + + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + + #[must_use] + pub fn contains(&self, token: &str) -> bool { + ScopeToken::from_str(token) + .map(|token| self.0.contains(&token)) + .unwrap_or(false) + } +} + +impl ToString for Scope { + fn to_string(&self) -> String { + let it = self.0.iter().map(ScopeToken::to_string); + Itertools::intersperse(it, ' '.to_string()).collect() + } +} + +impl Serialize for Scope { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.to_string().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Scope { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // FIXME: seems like there is an unnecessary clone here? + let scope: String = Deserialize::deserialize(deserializer)?; + Scope::from_str(&scope).map_err(serde::de::Error::custom) + } +} + +impl FromIterator for Scope { + fn from_iter>(iter: T) -> Self { + Self(HashSet::from_iter(iter)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_scope_token() { + assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID)); + + assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope)); + } + + #[test] + fn parse_scope() { + let scope = Scope::from_str("openid profile address").unwrap(); + assert_eq!(scope.len(), 3); + assert!(scope.contains("openid")); + assert!(scope.contains("profile")); + assert!(scope.contains("address")); + assert!(!scope.contains("unknown")); + + assert!( + Scope::from_str("").is_err(), + "there should always be at least one token in the scope" + ); + + assert!(Scope::from_str("invalid\\scope").is_err()); + assert!(Scope::from_str("no double space").is_err()); + assert!(Scope::from_str(" no leading space").is_err()); + assert!(Scope::from_str("no trailing space ").is_err()); + + let scope = Scope::from_str("openid").unwrap(); + assert_eq!(scope.len(), 1); + assert!(scope.contains("openid")); + assert!(!scope.contains("profile")); + assert!(!scope.contains("address")); + + assert_eq!( + Scope::from_str("order does not matter"), + Scope::from_str("matter not order does"), + ); + + assert!(Scope::from_str("http://example.com").is_ok()); + assert!(Scope::from_str("urn:matrix:*").is_ok()); + } +} diff --git a/crates/oauth2-types/src/test_utils.rs b/crates/oauth2-types/src/test_utils.rs index 050fd839..09f453d2 100644 --- a/crates/oauth2-types/src/test_utils.rs +++ b/crates/oauth2-types/src/test_utils.rs @@ -24,7 +24,7 @@ pub(crate) fn assert_serde_json