You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Better data structure to handle scopes
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1626,6 +1626,7 @@ dependencies = [
|
|||||||
"data-encoding",
|
"data-encoding",
|
||||||
"http",
|
"http",
|
||||||
"indoc",
|
"indoc",
|
||||||
|
"itertools",
|
||||||
"language-tags",
|
"language-tags",
|
||||||
"parse-display",
|
"parse-display",
|
||||||
"serde",
|
"serde",
|
||||||
@ -1633,6 +1634,7 @@ dependencies = [
|
|||||||
"serde_with",
|
"serde_with",
|
||||||
"sha2",
|
"sha2",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
|
"thiserror",
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ sqlx = { version = "0.5.9", default-features = false, optional = true }
|
|||||||
chrono = "0.4.19"
|
chrono = "0.4.19"
|
||||||
sha2 = "0.9.8"
|
sha2 = "0.9.8"
|
||||||
data-encoding = "2.3.2"
|
data-encoding = "2.3.2"
|
||||||
|
thiserror = "1.0.29"
|
||||||
|
itertools = "0.10.1"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
sqlx_type = ["sqlx"]
|
sqlx_type = ["sqlx"]
|
||||||
|
@ -20,6 +20,7 @@ pub mod errors;
|
|||||||
pub mod oidc;
|
pub mod oidc;
|
||||||
pub mod pkce;
|
pub mod pkce;
|
||||||
pub mod requests;
|
pub mod requests;
|
||||||
|
pub mod scope;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test_utils;
|
mod test_utils;
|
||||||
|
@ -24,6 +24,8 @@ use serde_with::{
|
|||||||
};
|
};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::scope::Scope;
|
||||||
|
|
||||||
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
|
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
@ -212,22 +214,18 @@ pub struct AuthorizationCodeGrant {
|
|||||||
pub code_verifier: Option<String>,
|
pub code_verifier: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[serde_as]
|
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||||
pub struct RefreshTokenGrant {
|
pub struct RefreshTokenGrant {
|
||||||
pub refresh_token: String,
|
pub refresh_token: String,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
scope: Option<Scope>,
|
||||||
scope: Option<HashSet<String>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[serde_as]
|
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||||
pub struct ClientCredentialsGrant {
|
pub struct ClientCredentialsGrant {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
scope: Option<Scope>,
|
||||||
scope: Option<HashSet<String>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
@ -275,8 +273,7 @@ pub struct AccessTokenResponse {
|
|||||||
#[serde_as(as = "Option<DurationSeconds<i64>>")]
|
#[serde_as(as = "Option<DurationSeconds<i64>>")]
|
||||||
expires_in: Option<Duration>,
|
expires_in: Option<Duration>,
|
||||||
|
|
||||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
scope: Option<Scope>,
|
||||||
scope: Option<HashSet<String>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AccessTokenResponse {
|
impl AccessTokenResponse {
|
||||||
@ -305,7 +302,7 @@ impl AccessTokenResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_scopes(mut self, scope: HashSet<String>) -> Self {
|
pub fn with_scope(mut self, scope: Scope) -> Self {
|
||||||
self.scope = Some(scope);
|
self.scope = Some(scope);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -339,8 +336,7 @@ pub struct IntrospectionRequest {
|
|||||||
pub struct IntrospectionResponse {
|
pub struct IntrospectionResponse {
|
||||||
pub active: bool,
|
pub active: bool,
|
||||||
|
|
||||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
pub scope: Option<Scope>,
|
||||||
pub scope: Option<HashSet<String>>,
|
|
||||||
|
|
||||||
pub client_id: Option<String>,
|
pub client_id: Option<String>,
|
||||||
|
|
||||||
@ -368,12 +364,10 @@ pub struct IntrospectionResponse {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::collections::HashSet;
|
|
||||||
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::test_utils::assert_serde_json;
|
use crate::{scope::OPENID, test_utils::assert_serde_json};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serde_refresh_token_grant() {
|
fn serde_refresh_token_grant() {
|
||||||
@ -383,14 +377,10 @@ mod tests {
|
|||||||
"scope": "openid",
|
"scope": "openid",
|
||||||
});
|
});
|
||||||
|
|
||||||
let scope = {
|
|
||||||
let mut s = HashSet::new();
|
|
||||||
// TODO: insert multiple scopes and test it. It's a bit tricky to test since
|
// TODO: insert multiple scopes and test it. It's a bit tricky to test since
|
||||||
// HashSet have no guarantees regarding the ordering of items, so right
|
// HashSet have no guarantees regarding the ordering of items, so right
|
||||||
// now the output is unstable.
|
// now the output is unstable.
|
||||||
s.insert("openid".to_string());
|
let scope: Option<Scope> = Some(vec![OPENID].into_iter().collect());
|
||||||
Some(s)
|
|
||||||
};
|
|
||||||
|
|
||||||
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
|
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
|
||||||
refresh_token: "abcd".into(),
|
refresh_token: "abcd".into(),
|
||||||
|
196
crates/oauth2-types/src/scope.rs
Normal file
196
crates/oauth2-types/src/scope.rs
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
|
use std::{borrow::Cow, collections::HashSet, iter::FromIterator, ops::Deref, str::FromStr};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
|
#[error("Invalid scope format")]
|
||||||
|
pub struct InvalidScope;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
|
pub struct ScopeToken(Cow<'static, str>);
|
||||||
|
|
||||||
|
impl ScopeToken {
|
||||||
|
const fn well_known(token: &'static str) -> Self {
|
||||||
|
Self(Cow::Borrowed(token))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const OPENID: ScopeToken = ScopeToken::well_known("openid");
|
||||||
|
pub const PROFILE: ScopeToken = ScopeToken::well_known("profile");
|
||||||
|
pub const EMAIL: ScopeToken = ScopeToken::well_known("email");
|
||||||
|
pub const ADDRESS: ScopeToken = ScopeToken::well_known("address");
|
||||||
|
pub const PHONE: ScopeToken = ScopeToken::well_known("phone");
|
||||||
|
pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::well_known("offline_access");
|
||||||
|
|
||||||
|
// As per RFC6749 appendix A:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc6749#appendix-A
|
||||||
|
//
|
||||||
|
// NQCHAR = %x21 / %x23-5B / %x5D-7E
|
||||||
|
fn nqchar(c: char) -> bool {
|
||||||
|
'\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for ScopeToken {
|
||||||
|
type Err = InvalidScope;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
// As per RFC6749 appendix A.4:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
|
||||||
|
//
|
||||||
|
// scope-token = 1*NQCHAR
|
||||||
|
if !s.is_empty() && s.chars().all(nqchar) {
|
||||||
|
Ok(ScopeToken(Cow::Owned(s.into())))
|
||||||
|
} else {
|
||||||
|
Err(InvalidScope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for ScopeToken {
|
||||||
|
type Target = str;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToString for ScopeToken {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
self.0.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct Scope(HashSet<ScopeToken>);
|
||||||
|
|
||||||
|
impl FromStr for Scope {
|
||||||
|
type Err = InvalidScope;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
// As per RFC6749 appendix A.4:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
|
||||||
|
//
|
||||||
|
// scope = scope-token *( SP scope-token )
|
||||||
|
let scopes: Result<HashSet<ScopeToken>, InvalidScope> =
|
||||||
|
s.split(' ').map(ScopeToken::from_str).collect();
|
||||||
|
|
||||||
|
Ok(Self(scopes?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scope {
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
// This should never be the case?
|
||||||
|
self.0.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.0.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn contains(&self, token: &str) -> bool {
|
||||||
|
ScopeToken::from_str(token)
|
||||||
|
.map(|token| self.0.contains(&token))
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToString for Scope {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
let it = self.0.iter().map(ScopeToken::to_string);
|
||||||
|
Itertools::intersperse(it, ' '.to_string()).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for Scope {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
self.to_string().serialize(serializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Scope {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
// FIXME: seems like there is an unnecessary clone here?
|
||||||
|
let scope: String = Deserialize::deserialize(deserializer)?;
|
||||||
|
Scope::from_str(&scope).map_err(serde::de::Error::custom)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromIterator<ScopeToken> for Scope {
|
||||||
|
fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
|
||||||
|
Self(HashSet::from_iter(iter))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_scope_token() {
|
||||||
|
assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
|
||||||
|
|
||||||
|
assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_scope() {
|
||||||
|
let scope = Scope::from_str("openid profile address").unwrap();
|
||||||
|
assert_eq!(scope.len(), 3);
|
||||||
|
assert!(scope.contains("openid"));
|
||||||
|
assert!(scope.contains("profile"));
|
||||||
|
assert!(scope.contains("address"));
|
||||||
|
assert!(!scope.contains("unknown"));
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
Scope::from_str("").is_err(),
|
||||||
|
"there should always be at least one token in the scope"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(Scope::from_str("invalid\\scope").is_err());
|
||||||
|
assert!(Scope::from_str("no double space").is_err());
|
||||||
|
assert!(Scope::from_str(" no leading space").is_err());
|
||||||
|
assert!(Scope::from_str("no trailing space ").is_err());
|
||||||
|
|
||||||
|
let scope = Scope::from_str("openid").unwrap();
|
||||||
|
assert_eq!(scope.len(), 1);
|
||||||
|
assert!(scope.contains("openid"));
|
||||||
|
assert!(!scope.contains("profile"));
|
||||||
|
assert!(!scope.contains("address"));
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
Scope::from_str("order does not matter"),
|
||||||
|
Scope::from_str("matter not order does"),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(Scope::from_str("http://example.com").is_ok());
|
||||||
|
assert!(Scope::from_str("urn:matrix:*").is_ok());
|
||||||
|
}
|
||||||
|
}
|
@ -24,7 +24,7 @@ pub(crate) fn assert_serde_json<T: Serialize + DeserializeOwned + PartialEq + De
|
|||||||
let got_value = serde_json::to_value(&got).expect("could not serialize object as JSON value");
|
let got_value = serde_json::to_value(&got).expect("could not serialize object as JSON value");
|
||||||
assert_eq!(got_value, expected_value);
|
assert_eq!(got_value, expected_value);
|
||||||
|
|
||||||
let expected: T =
|
let expected: T = serde_json::from_value(expected_value)
|
||||||
serde_json::from_value(expected_value).expect("could not serialize object as JSON value");
|
.expect("could not deserialize object from JSON value");
|
||||||
assert_eq!(got, &expected);
|
assert_eq!(got, &expected);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user