You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Simple dynamic client registration
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2197,6 +2197,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
@ -2466,6 +2467,7 @@ dependencies = [
|
||||
"itertools",
|
||||
"language-tags",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"parse-display",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -24,8 +24,9 @@ use axum::{
|
||||
response::IntoResponse,
|
||||
};
|
||||
use headers::{authorization::Basic, Authorization};
|
||||
use http::StatusCode;
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::{Client, StorageBackend};
|
||||
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::{
|
||||
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebTokenParts, JwtHeader, SharedSecret,
|
||||
@ -38,6 +39,7 @@ use mas_storage::{
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use serde_json::Value;
|
||||
use sqlx::PgExecutor;
|
||||
use thiserror::Error;
|
||||
|
||||
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
|
||||
|
||||
@ -88,6 +90,7 @@ impl Credentials {
|
||||
lookup_client_by_client_id(executor, client_id).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn verify<S: StorageBackend>(
|
||||
&self,
|
||||
encrypter: &Encrypter,
|
||||
@ -123,7 +126,7 @@ impl Credentials {
|
||||
|
||||
(
|
||||
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt,
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
) => {
|
||||
// Get the client JWKS
|
||||
let jwks = client
|
||||
@ -139,7 +142,7 @@ impl Credentials {
|
||||
|
||||
(
|
||||
Credentials::ClientAssertionJwtBearer { jwt, header, .. },
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt,
|
||||
) => {
|
||||
// Decrypt the client_secret
|
||||
let encrypted_client_secret = client
|
||||
@ -165,17 +168,28 @@ impl Credentials {
|
||||
}
|
||||
}
|
||||
|
||||
fn jwks_key_store(
|
||||
_jwks: &mas_data_model::JwksOrJwksUri,
|
||||
) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||
todo!()
|
||||
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||
match jwks {
|
||||
JwksOrJwksUri::Jwks(key_set) => Either::Left(StaticJwksStore::new(key_set.clone())),
|
||||
JwksOrJwksUri::JwksUri(_uri) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CredentialsVerificationError {
|
||||
#[error("failed to decrypt client credentials")]
|
||||
DecryptionError,
|
||||
|
||||
#[error("invalid client configuration")]
|
||||
InvalidClientConfig,
|
||||
|
||||
#[error("client secret did not match")]
|
||||
ClientSecretMismatch,
|
||||
|
||||
#[error("authentication method mismatch")]
|
||||
AuthenticationMethodMismatch,
|
||||
|
||||
#[error("invalid assertion signature")]
|
||||
InvalidAssertionSignature,
|
||||
}
|
||||
|
||||
@ -199,7 +213,8 @@ pub enum ClientAuthorizationError {
|
||||
|
||||
impl IntoResponse for ClientAuthorizationError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
todo!()
|
||||
// TODO
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,13 @@ impl UrlBuilder {
|
||||
self.base.join("oauth2/introspect").expect("build URL")
|
||||
}
|
||||
|
||||
/// OAuth 2.0 introspection endpoint
|
||||
/// OAuth 2.0 client registration endpoint
|
||||
#[must_use]
|
||||
pub fn oauth_registration_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/registration").expect("build URL")
|
||||
}
|
||||
|
||||
/// OpenID Connect userinfo endpoint
|
||||
#[must_use]
|
||||
pub fn oidc_userinfo_endpoint(&self) -> Url {
|
||||
self.base.join("oauth2/userinfo").expect("build URL")
|
||||
|
@ -75,6 +75,10 @@ where
|
||||
post(self::oauth2::introspection::post),
|
||||
)
|
||||
.route("/oauth2/token", post(self::oauth2::token::post))
|
||||
.route(
|
||||
"/oauth2/registration",
|
||||
post(self::oauth2::registration::post),
|
||||
)
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
|
@ -68,6 +68,7 @@ pub(crate) async fn get(
|
||||
let jwks_uri = Some(url_builder.jwks_uri());
|
||||
let introspection_endpoint = Some(url_builder.oauth_introspection_endpoint());
|
||||
let userinfo_endpoint = Some(url_builder.oidc_userinfo_endpoint());
|
||||
let registration_endpoint = Some(url_builder.oauth_registration_endpoint());
|
||||
|
||||
let scopes_supported = Some(vec![scope::OPENID.to_string(), scope::EMAIL.to_string()]);
|
||||
|
||||
@ -133,6 +134,7 @@ pub(crate) async fn get(
|
||||
authorization_endpoint,
|
||||
token_endpoint,
|
||||
jwks_uri,
|
||||
registration_endpoint,
|
||||
scopes_supported,
|
||||
response_types_supported,
|
||||
response_modes_supported,
|
||||
|
@ -16,6 +16,7 @@ pub mod authorization;
|
||||
pub mod discovery;
|
||||
pub mod introspection;
|
||||
pub mod keys;
|
||||
pub mod registration;
|
||||
pub mod token;
|
||||
pub mod userinfo;
|
||||
pub mod webfinger;
|
||||
|
94
crates/handlers/src/oauth2/registration.rs
Normal file
94
crates/handlers/src/oauth2/registration.rs
Normal file
@ -0,0 +1,94 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::{response::IntoResponse, Extension, Json};
|
||||
use hyper::StatusCode;
|
||||
use mas_storage::oauth2::client::insert_client;
|
||||
use oauth2_types::{
|
||||
errors::SERVER_ERROR,
|
||||
registration::{ClientMetadata, ClientRegistrationResponse},
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for RouteError {
|
||||
fn from(e: sqlx::Error) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub(crate) async fn post(
|
||||
Extension(pool): Extension<PgPool>,
|
||||
Json(body): Json<ClientMetadata>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
info!(?body, "Client registration");
|
||||
|
||||
// Grab a txn
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
// Let's generate a random client ID
|
||||
let client_id: String = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(10)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
insert_client(
|
||||
&mut txn,
|
||||
&client_id,
|
||||
&body.redirect_uris,
|
||||
None,
|
||||
&body.response_types,
|
||||
&body.grant_types,
|
||||
&body.contacts,
|
||||
body.client_name.as_deref(),
|
||||
body.logo_uri.as_ref(),
|
||||
body.client_uri.as_ref(),
|
||||
body.policy_uri.as_ref(),
|
||||
body.tos_uri.as_ref(),
|
||||
body.jwks_uri.as_ref(),
|
||||
body.jwks.as_ref(),
|
||||
body.id_token_signed_response_alg,
|
||||
body.token_endpoint_auth_method,
|
||||
body.token_endpoint_auth_signing_alg,
|
||||
body.initiate_login_uri.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
let response = ClientRegistrationResponse {
|
||||
client_id,
|
||||
client_secret: None,
|
||||
client_id_issued_at: None,
|
||||
client_secret_expires_at: None,
|
||||
};
|
||||
|
||||
Ok((StatusCode::CREATED, Json(response)))
|
||||
}
|
@ -56,6 +56,7 @@ use serde::Serialize;
|
||||
use serde_with::{serde_as, skip_serializing_none};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
use url::Url;
|
||||
|
||||
@ -76,14 +77,30 @@ struct CustomClaims {
|
||||
c_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
Anyhow(anyhow::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("bad request")]
|
||||
BadRequest,
|
||||
|
||||
#[error("client not found")]
|
||||
ClientNotFound,
|
||||
|
||||
#[error("client not allowed")]
|
||||
ClientNotAllowed,
|
||||
ClientCredentialsVerification(CredentialsVerificationError),
|
||||
|
||||
#[error("could not verify client credentials")]
|
||||
ClientCredentialsVerification(#[from] CredentialsVerificationError),
|
||||
|
||||
#[error("invalid grant")]
|
||||
InvalidGrant,
|
||||
|
||||
#[error("unauthorized client")]
|
||||
UnauthorizedClient,
|
||||
}
|
||||
|
||||
@ -138,18 +155,7 @@ impl From<ClaimError> for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for RouteError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self::Anyhow(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CredentialsVerificationError> for RouteError {
|
||||
fn from(e: CredentialsVerificationError) -> Self {
|
||||
Self::ClientCredentialsVerification(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub(crate) async fn post(
|
||||
client_authorization: ClientAuthorization<AccessTokenRequest>,
|
||||
Extension(key_store): Extension<Arc<StaticKeystore>>,
|
||||
|
@ -32,6 +32,7 @@ signature = "1.4.0"
|
||||
thiserror = "1.0.30"
|
||||
tokio = { version = "1.17.0", features = ["macros", "rt", "sync"] }
|
||||
tower = { version = "0.4.12", features = ["util"] }
|
||||
tracing = "0.1.34"
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
|
||||
mas-iana = { path = "../iana" }
|
||||
|
@ -71,7 +71,7 @@ pub struct JsonWebKey {
|
||||
|
||||
impl JsonWebKey {
|
||||
#[must_use]
|
||||
pub fn new(parameters: JsonWebKeyParameters) -> Self {
|
||||
pub const fn new(parameters: JsonWebKeyParameters) -> Self {
|
||||
Self {
|
||||
parameters,
|
||||
r#use: None,
|
||||
@ -86,7 +86,7 @@ impl JsonWebKey {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_use(mut self, value: JsonWebKeyUse) -> Self {
|
||||
pub const fn with_use(mut self, value: JsonWebKeyUse) -> Self {
|
||||
self.r#use = Some(value);
|
||||
self
|
||||
}
|
||||
@ -98,7 +98,7 @@ impl JsonWebKey {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_alg(mut self, alg: JsonWebSignatureAlg) -> Self {
|
||||
pub const fn with_alg(mut self, alg: JsonWebSignatureAlg) -> Self {
|
||||
self.alg = Some(alg);
|
||||
self
|
||||
}
|
||||
@ -110,7 +110,7 @@ impl JsonWebKey {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn kty(&self) -> JsonWebKeyType {
|
||||
pub const fn kty(&self) -> JsonWebKeyType {
|
||||
match self.parameters {
|
||||
JsonWebKeyParameters::Ec { .. } => JsonWebKeyType::Ec,
|
||||
JsonWebKeyParameters::Rsa { .. } => JsonWebKeyType::Rsa,
|
||||
@ -124,7 +124,12 @@ impl JsonWebKey {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn params(&self) -> &JsonWebKeyParameters {
|
||||
pub const fn alg(&self) -> Option<JsonWebSignatureAlg> {
|
||||
self.alg
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn params(&self) -> &JsonWebKeyParameters {
|
||||
&self.parameters
|
||||
}
|
||||
}
|
||||
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, future::Ready};
|
||||
use std::future::Ready;
|
||||
|
||||
use digest::Digest;
|
||||
use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg};
|
||||
@ -21,15 +21,15 @@ use sha2::{Sha256, Sha384, Sha512};
|
||||
use signature::{Signature, Verifier};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{JsonWebKeySet, JwtHeader, VerifyingKeystore};
|
||||
use crate::{JsonWebKey, JsonWebKeySet, JwtHeader, VerifyingKeystore};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error("key not found")]
|
||||
KeyNotFound,
|
||||
|
||||
#[error("invalid index")]
|
||||
InvalidIndex,
|
||||
#[error("multiple key matched")]
|
||||
MultipleKeyMatched,
|
||||
|
||||
#[error(r#"missing "kid" field in header"#)]
|
||||
MissingKid,
|
||||
@ -43,43 +43,77 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Signature(#[from] signature::Error),
|
||||
|
||||
#[error("invalid {kty} key {kid}")]
|
||||
#[error("invalid {kty} key")]
|
||||
InvalidKey {
|
||||
kty: JsonWebKeyType,
|
||||
kid: String,
|
||||
source: anyhow::Error,
|
||||
},
|
||||
}
|
||||
|
||||
struct KeyConstraint<'a> {
|
||||
kty: Option<JsonWebKeyType>,
|
||||
alg: Option<JsonWebSignatureAlg>,
|
||||
kid: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> KeyConstraint<'a> {
|
||||
fn matches(&self, key: &'a JsonWebKey) -> bool {
|
||||
// If a specific KID was asked, match the key only if it has a matching kid
|
||||
// field
|
||||
if let Some(kid) = self.kid {
|
||||
if key.kid() != Some(kid) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(kty) = self.kty {
|
||||
if key.kty() != kty {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(alg) = self.alg {
|
||||
if key.alg() != None && key.alg() != Some(alg) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn find_keys(&self, key_set: &'a JsonWebKeySet) -> Vec<&'a JsonWebKey> {
|
||||
key_set.iter().filter(|k| self.matches(k)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StaticJwksStore {
|
||||
key_set: JsonWebKeySet,
|
||||
index: HashMap<(JsonWebKeyType, String), usize>,
|
||||
}
|
||||
|
||||
impl StaticJwksStore {
|
||||
#[must_use]
|
||||
pub fn new(key_set: JsonWebKeySet) -> Self {
|
||||
let index = key_set
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, key)| {
|
||||
let kid = key.kid()?.to_string();
|
||||
let kty = key.kty();
|
||||
|
||||
Some(((kty, kid), index))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self { key_set, index }
|
||||
Self { key_set }
|
||||
}
|
||||
|
||||
fn find_rsa_key(&self, kid: String) -> Result<RsaPublicKey, Error> {
|
||||
let index = *self
|
||||
.index
|
||||
.get(&(JsonWebKeyType::Rsa, kid.clone()))
|
||||
.ok_or(Error::KeyNotFound)?;
|
||||
fn find_key<'a>(&'a self, constraint: &KeyConstraint<'a>) -> Result<&'a JsonWebKey, Error> {
|
||||
let keys = constraint.find_keys(&self.key_set);
|
||||
|
||||
let key = self.key_set.get(index).ok_or(Error::InvalidIndex)?;
|
||||
match &keys[..] {
|
||||
[one] => Ok(one),
|
||||
[] => Err(Error::KeyNotFound),
|
||||
_ => Err(Error::MultipleKeyMatched),
|
||||
}
|
||||
}
|
||||
|
||||
fn find_rsa_key(&self, kid: Option<&str>) -> Result<RsaPublicKey, Error> {
|
||||
let constraint = KeyConstraint {
|
||||
kty: Some(JsonWebKeyType::Rsa),
|
||||
kid,
|
||||
alg: None,
|
||||
};
|
||||
|
||||
let key = self.find_key(&constraint)?;
|
||||
|
||||
let key = key
|
||||
.params()
|
||||
@ -87,20 +121,23 @@ impl StaticJwksStore {
|
||||
.try_into()
|
||||
.map_err(|source| Error::InvalidKey {
|
||||
kty: JsonWebKeyType::Rsa,
|
||||
kid,
|
||||
source,
|
||||
})?;
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
fn find_ecdsa_key(&self, kid: String) -> Result<ecdsa::VerifyingKey<p256::NistP256>, Error> {
|
||||
let index = *self
|
||||
.index
|
||||
.get(&(JsonWebKeyType::Ec, kid.clone()))
|
||||
.ok_or(Error::KeyNotFound)?;
|
||||
fn find_ecdsa_key(
|
||||
&self,
|
||||
kid: Option<&str>,
|
||||
) -> Result<ecdsa::VerifyingKey<p256::NistP256>, Error> {
|
||||
let constraint = KeyConstraint {
|
||||
kty: Some(JsonWebKeyType::Ec),
|
||||
kid,
|
||||
alg: None,
|
||||
};
|
||||
|
||||
let key = self.key_set.get(index).ok_or(Error::InvalidIndex)?;
|
||||
let key = self.find_key(&constraint)?;
|
||||
|
||||
let key = key
|
||||
.params()
|
||||
@ -108,20 +145,20 @@ impl StaticJwksStore {
|
||||
.try_into()
|
||||
.map_err(|source| Error::InvalidKey {
|
||||
kty: JsonWebKeyType::Ec,
|
||||
kid,
|
||||
source,
|
||||
})?;
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn verify_sync(
|
||||
&self,
|
||||
header: &JwtHeader,
|
||||
payload: &[u8],
|
||||
signature: &[u8],
|
||||
) -> Result<(), Error> {
|
||||
let kid = header.kid().ok_or(Error::MissingKid)?.to_string();
|
||||
let kid = header.kid();
|
||||
match header.alg() {
|
||||
JsonWebSignatureAlg::Rs256 => {
|
||||
let key = self.find_rsa_key(kid)?;
|
||||
|
@ -21,3 +21,4 @@ thiserror = "1.0.30"
|
||||
itertools = "0.10.3"
|
||||
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-jose = { path = "../jose" }
|
||||
|
@ -50,6 +50,7 @@ impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType {
|
||||
pub mod errors;
|
||||
pub mod oidc;
|
||||
pub mod pkce;
|
||||
pub mod registration;
|
||||
pub mod requests;
|
||||
pub mod scope;
|
||||
pub mod webfinger;
|
||||
|
@ -19,27 +19,27 @@ use mas_iana::{
|
||||
PkceCodeChallengeMethod,
|
||||
},
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use url::Url;
|
||||
|
||||
use crate::requests::{Display, GrantType, Prompt, ResponseMode};
|
||||
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ApplicationType {
|
||||
Web,
|
||||
Native,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SubjectType {
|
||||
Public,
|
||||
Pairwise,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ClaimType {
|
||||
Normal,
|
||||
|
168
crates/oauth2-types/src/registration.rs
Normal file
168
crates/oauth2-types/src/registration.rs
Normal file
@ -0,0 +1,168 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_iana::{
|
||||
jose::{JsonWebEncryptionAlg, JsonWebSignatureAlg},
|
||||
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
|
||||
};
|
||||
use mas_jose::JsonWebKeySet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, skip_serializing_none, DurationSeconds, TimestampSeconds};
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
oidc::{ApplicationType, SubjectType},
|
||||
requests::GrantType,
|
||||
};
|
||||
|
||||
fn default_response_types() -> Vec<OAuthAuthorizationEndpointResponseType> {
|
||||
vec![OAuthAuthorizationEndpointResponseType::Code]
|
||||
}
|
||||
|
||||
fn default_grant_types() -> Vec<GrantType> {
|
||||
vec![GrantType::AuthorizationCode]
|
||||
}
|
||||
|
||||
const fn default_application_type() -> ApplicationType {
|
||||
ApplicationType::Web
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub struct ClientMetadata {
|
||||
pub redirect_uris: Vec<Url>,
|
||||
|
||||
#[serde(default = "default_response_types")]
|
||||
pub response_types: Vec<OAuthAuthorizationEndpointResponseType>,
|
||||
|
||||
#[serde(default = "default_grant_types")]
|
||||
pub grant_types: Vec<GrantType>,
|
||||
|
||||
#[serde(default = "default_application_type")]
|
||||
pub application_type: ApplicationType,
|
||||
|
||||
#[serde(default)]
|
||||
pub contacts: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub client_name: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub logo_uri: Option<Url>,
|
||||
|
||||
#[serde(default)]
|
||||
pub client_uri: Option<Url>,
|
||||
|
||||
#[serde(default)]
|
||||
pub policy_uri: Option<Url>,
|
||||
|
||||
#[serde(default)]
|
||||
pub tos_uri: Option<Url>,
|
||||
|
||||
#[serde(default)]
|
||||
pub jwks_uri: Option<Url>,
|
||||
|
||||
#[serde(default)]
|
||||
pub jwks: Option<JsonWebKeySet>,
|
||||
|
||||
#[serde(default)]
|
||||
pub sector_identifier_uri: Option<Url>,
|
||||
|
||||
#[serde(default)]
|
||||
pub subject_type: Option<SubjectType>,
|
||||
|
||||
#[serde(default)]
|
||||
pub token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
|
||||
|
||||
#[serde(default)]
|
||||
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub id_token_encrypted_response_alg: Option<JsonWebEncryptionAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub id_token_encrypted_response_enc: Option<JsonWebEncryptionAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub userinfo_encrypted_response_alg: Option<JsonWebEncryptionAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub userinfo_encrypted_response_enc: Option<JsonWebEncryptionAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub request_object_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub request_object_encryption_alg: Option<JsonWebEncryptionAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub request_object_encryption_enc: Option<JsonWebEncryptionAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DurationSeconds<i64>>")]
|
||||
pub default_max_age: Option<Duration>,
|
||||
|
||||
#[serde(default)]
|
||||
pub require_auth_time: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub default_acr_values: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub initiate_login_uri: Option<Url>,
|
||||
|
||||
#[serde(default)]
|
||||
pub request_uris: Option<Vec<Url>>,
|
||||
|
||||
#[serde(default)]
|
||||
pub require_signed_request_object: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub require_pushed_authorization_requests: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub introspection_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub introspection_encrypted_response_alg: Option<JsonWebEncryptionAlg>,
|
||||
|
||||
#[serde(default)]
|
||||
pub introspection_encrypted_response_enc: Option<JsonWebEncryptionAlg>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub struct ClientRegistrationResponse {
|
||||
pub client_id: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<TimestampSeconds<i64>>")]
|
||||
pub client_id_issued_at: Option<DateTime<Utc>>,
|
||||
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<TimestampSeconds<i64>>")]
|
||||
pub client_secret_expires_at: Option<DateTime<Utc>>,
|
||||
}
|
@ -15,7 +15,10 @@
|
||||
use std::string::ToString;
|
||||
|
||||
use mas_data_model::{Client, JwksOrJwksUri};
|
||||
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod};
|
||||
use mas_iana::{
|
||||
jose::JsonWebSignatureAlg,
|
||||
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
|
||||
};
|
||||
use mas_jose::JsonWebKeySet;
|
||||
use oauth2_types::requests::GrantType;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
@ -300,6 +303,102 @@ pub async fn lookup_client_by_client_id(
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn insert_client(
|
||||
conn: &mut PgConnection,
|
||||
client_id: &str,
|
||||
redirect_uris: &[Url],
|
||||
encrypted_client_secret: Option<&str>,
|
||||
response_types: &[OAuthAuthorizationEndpointResponseType],
|
||||
grant_types: &[GrantType],
|
||||
contacts: &[String],
|
||||
client_name: Option<&str>,
|
||||
logo_uri: Option<&Url>,
|
||||
client_uri: Option<&Url>,
|
||||
policy_uri: Option<&Url>,
|
||||
tos_uri: Option<&Url>,
|
||||
jwks_uri: Option<&Url>,
|
||||
jwks: Option<&JsonWebKeySet>,
|
||||
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
|
||||
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
initiate_login_uri: Option<&Url>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let response_types: Vec<String> = response_types.iter().map(ToString::to_string).collect();
|
||||
let grant_type_authorization_code = grant_types.contains(&GrantType::AuthorizationCode);
|
||||
let grant_type_refresh_token = grant_types.contains(&GrantType::RefreshToken);
|
||||
let logo_uri = logo_uri.map(Url::as_str);
|
||||
let client_uri = client_uri.map(Url::as_str);
|
||||
let policy_uri = policy_uri.map(Url::as_str);
|
||||
let tos_uri = tos_uri.map(Url::as_str);
|
||||
let jwks = jwks.map(serde_json::to_value).transpose().unwrap(); // TODO
|
||||
let jwks_uri = jwks_uri.map(Url::as_str);
|
||||
let id_token_signed_response_alg = id_token_signed_response_alg.map(|v| v.to_string());
|
||||
let token_endpoint_auth_method = token_endpoint_auth_method.map(|v| v.to_string());
|
||||
let token_endpoint_auth_signing_alg = token_endpoint_auth_signing_alg.map(|v| v.to_string());
|
||||
let initiate_login_uri = initiate_login_uri.map(Url::as_str);
|
||||
|
||||
let id = sqlx::query_scalar!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
(client_id,
|
||||
encrypted_client_secret,
|
||||
response_types,
|
||||
grant_type_authorization_code,
|
||||
grant_type_refresh_token,
|
||||
contacts,
|
||||
client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
|
||||
RETURNING id
|
||||
"#,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
&response_types,
|
||||
grant_type_authorization_code,
|
||||
grant_type_refresh_token,
|
||||
contacts,
|
||||
client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let redirect_uris: Vec<String> = redirect_uris.iter().map(ToString::to_string).collect();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri)
|
||||
SELECT $1, uri FROM UNNEST($2::text[]) uri
|
||||
"#,
|
||||
id,
|
||||
&redirect_uris,
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_client_from_config(
|
||||
conn: &mut PgConnection,
|
||||
client_id: &str,
|
||||
|
Reference in New Issue
Block a user