From 5c14611b963250a1e1ce41d7c28405e6398db3da Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 19 Apr 2022 12:23:01 +0200 Subject: [PATCH] Simple dynamic client registration --- Cargo.lock | 2 + crates/axum-utils/src/client_authorization.rs | 31 +++- crates/axum-utils/src/url_builder.rs | 8 +- crates/handlers/src/lib.rs | 4 + crates/handlers/src/oauth2/discovery.rs | 2 + crates/handlers/src/oauth2/mod.rs | 1 + crates/handlers/src/oauth2/registration.rs | 94 ++++++++++ crates/handlers/src/oauth2/token.rs | 34 ++-- crates/jose/Cargo.toml | 1 + crates/jose/src/jwk.rs | 15 +- crates/jose/src/keystore/jwks/static_store.rs | 105 +++++++---- crates/oauth2-types/Cargo.toml | 1 + crates/oauth2-types/src/lib.rs | 1 + crates/oauth2-types/src/oidc.rs | 8 +- crates/oauth2-types/src/registration.rs | 168 ++++++++++++++++++ crates/storage/src/oauth2/client.rs | 101 ++++++++++- 16 files changed, 509 insertions(+), 67 deletions(-) create mode 100644 crates/handlers/src/oauth2/registration.rs create mode 100644 crates/oauth2-types/src/registration.rs diff --git a/Cargo.lock b/Cargo.lock index f657813f..27707bef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2197,6 +2197,7 @@ dependencies = [ "thiserror", "tokio", "tower", + "tracing", "url", ] @@ -2466,6 +2467,7 @@ dependencies = [ "itertools", "language-tags", "mas-iana", + "mas-jose", "parse-display", "serde", "serde_json", diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index e9e4b8f4..7ec400d3 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -24,8 +24,9 @@ use axum::{ response::IntoResponse, }; use headers::{authorization::Basic, Authorization}; +use http::StatusCode; use mas_config::Encrypter; -use mas_data_model::{Client, StorageBackend}; +use mas_data_model::{Client, JwksOrJwksUri, StorageBackend}; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{ DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebTokenParts, JwtHeader, SharedSecret, @@ -38,6 +39,7 @@ use mas_storage::{ use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use sqlx::PgExecutor; +use thiserror::Error; static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; @@ -88,6 +90,7 @@ impl Credentials { lookup_client_by_client_id(executor, client_id).await } + #[tracing::instrument(skip_all, err)] pub async fn verify( &self, encrypter: &Encrypter, @@ -123,7 +126,7 @@ impl Credentials { ( Credentials::ClientAssertionJwtBearer { jwt, header, .. }, - OAuthClientAuthenticationMethod::ClientSecretJwt, + OAuthClientAuthenticationMethod::PrivateKeyJwt, ) => { // Get the client JWKS let jwks = client @@ -139,7 +142,7 @@ impl Credentials { ( Credentials::ClientAssertionJwtBearer { jwt, header, .. }, - OAuthClientAuthenticationMethod::PrivateKeyJwt, + OAuthClientAuthenticationMethod::ClientSecretJwt, ) => { // Decrypt the client_secret let encrypted_client_secret = client @@ -165,17 +168,28 @@ impl Credentials { } } -fn jwks_key_store( - _jwks: &mas_data_model::JwksOrJwksUri, -) -> Either { - todo!() +fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either { + match jwks { + JwksOrJwksUri::Jwks(key_set) => Either::Left(StaticJwksStore::new(key_set.clone())), + JwksOrJwksUri::JwksUri(_uri) => todo!(), + } } +#[derive(Debug, Error)] pub enum CredentialsVerificationError { + #[error("failed to decrypt client credentials")] DecryptionError, + + #[error("invalid client configuration")] InvalidClientConfig, + + #[error("client secret did not match")] ClientSecretMismatch, + + #[error("authentication method mismatch")] AuthenticationMethodMismatch, + + #[error("invalid assertion signature")] InvalidAssertionSignature, } @@ -199,7 +213,8 @@ pub enum ClientAuthorizationError { impl IntoResponse for ClientAuthorizationError { fn into_response(self) -> axum::response::Response { - todo!() + // TODO + StatusCode::INTERNAL_SERVER_ERROR.into_response() } } diff --git a/crates/axum-utils/src/url_builder.rs b/crates/axum-utils/src/url_builder.rs index 56bd061e..f31fa2e9 100644 --- a/crates/axum-utils/src/url_builder.rs +++ b/crates/axum-utils/src/url_builder.rs @@ -61,7 +61,13 @@ impl UrlBuilder { self.base.join("oauth2/introspect").expect("build URL") } - /// OAuth 2.0 introspection endpoint + /// OAuth 2.0 client registration endpoint + #[must_use] + pub fn oauth_registration_endpoint(&self) -> Url { + self.base.join("oauth2/registration").expect("build URL") + } + + /// OpenID Connect userinfo endpoint #[must_use] pub fn oidc_userinfo_endpoint(&self) -> Url { self.base.join("oauth2/userinfo").expect("build URL") diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 4b39d7da..c5c269db 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -75,6 +75,10 @@ where post(self::oauth2::introspection::post), ) .route("/oauth2/token", post(self::oauth2::token::post)) + .route( + "/oauth2/registration", + post(self::oauth2::registration::post), + ) .layer( CorsLayer::new() .allow_origin(Any) diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index bdc5252b..ef836253 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -68,6 +68,7 @@ pub(crate) async fn get( let jwks_uri = Some(url_builder.jwks_uri()); let introspection_endpoint = Some(url_builder.oauth_introspection_endpoint()); let userinfo_endpoint = Some(url_builder.oidc_userinfo_endpoint()); + let registration_endpoint = Some(url_builder.oauth_registration_endpoint()); let scopes_supported = Some(vec![scope::OPENID.to_string(), scope::EMAIL.to_string()]); @@ -133,6 +134,7 @@ pub(crate) async fn get( authorization_endpoint, token_endpoint, jwks_uri, + registration_endpoint, scopes_supported, response_types_supported, response_modes_supported, diff --git a/crates/handlers/src/oauth2/mod.rs b/crates/handlers/src/oauth2/mod.rs index d3dac044..865ec8e9 100644 --- a/crates/handlers/src/oauth2/mod.rs +++ b/crates/handlers/src/oauth2/mod.rs @@ -16,6 +16,7 @@ pub mod authorization; pub mod discovery; pub mod introspection; pub mod keys; +pub mod registration; pub mod token; pub mod userinfo; pub mod webfinger; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs new file mode 100644 index 00000000..9dd8e64b --- /dev/null +++ b/crates/handlers/src/oauth2/registration.rs @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::{response::IntoResponse, Extension, Json}; +use hyper::StatusCode; +use mas_storage::oauth2::client::insert_client; +use oauth2_types::{ + errors::SERVER_ERROR, + registration::{ClientMetadata, ClientRegistrationResponse}, +}; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use sqlx::PgPool; +use thiserror::Error; +use tracing::info; + +#[derive(Debug, Error)] +pub(crate) enum RouteError { + #[error(transparent)] + Internal(Box), +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::Internal(Box::new(e)) + } +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + (StatusCode::INTERNAL_SERVER_ERROR, Json(SERVER_ERROR)).into_response() + } +} + +#[tracing::instrument(skip_all, err)] +pub(crate) async fn post( + Extension(pool): Extension, + Json(body): Json, +) -> Result { + info!(?body, "Client registration"); + + // Grab a txn + let mut txn = pool.begin().await?; + + // Let's generate a random client ID + let client_id: String = thread_rng() + .sample_iter(&Alphanumeric) + .take(10) + .map(char::from) + .collect(); + + insert_client( + &mut txn, + &client_id, + &body.redirect_uris, + None, + &body.response_types, + &body.grant_types, + &body.contacts, + body.client_name.as_deref(), + body.logo_uri.as_ref(), + body.client_uri.as_ref(), + body.policy_uri.as_ref(), + body.tos_uri.as_ref(), + body.jwks_uri.as_ref(), + body.jwks.as_ref(), + body.id_token_signed_response_alg, + body.token_endpoint_auth_method, + body.token_endpoint_auth_signing_alg, + body.initiate_login_uri.as_ref(), + ) + .await?; + + txn.commit().await?; + + let response = ClientRegistrationResponse { + client_id, + client_secret: None, + client_id_issued_at: None, + client_secret_expires_at: None, + }; + + Ok((StatusCode::CREATED, Json(response))) +} diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 2fde3d52..3ce6e76c 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -56,6 +56,7 @@ use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; use sha2::{Digest, Sha256}; use sqlx::{PgPool, Postgres, Transaction}; +use thiserror::Error; use tracing::debug; use url::Url; @@ -76,14 +77,30 @@ struct CustomClaims { c_hash: String, } +#[derive(Debug, Error)] pub(crate) enum RouteError { + #[error(transparent)] Internal(Box), - Anyhow(anyhow::Error), + + #[error(transparent)] + Anyhow(#[from] anyhow::Error), + + #[error("bad request")] BadRequest, + + #[error("client not found")] ClientNotFound, + + #[error("client not allowed")] ClientNotAllowed, - ClientCredentialsVerification(CredentialsVerificationError), + + #[error("could not verify client credentials")] + ClientCredentialsVerification(#[from] CredentialsVerificationError), + + #[error("invalid grant")] InvalidGrant, + + #[error("unauthorized client")] UnauthorizedClient, } @@ -138,18 +155,7 @@ impl From for RouteError { } } -impl From for RouteError { - fn from(e: anyhow::Error) -> Self { - Self::Anyhow(e) - } -} - -impl From for RouteError { - fn from(e: CredentialsVerificationError) -> Self { - Self::ClientCredentialsVerification(e) - } -} - +#[tracing::instrument(skip_all, err)] pub(crate) async fn post( client_authorization: ClientAuthorization, Extension(key_store): Extension>, diff --git a/crates/jose/Cargo.toml b/crates/jose/Cargo.toml index dacf2a36..990876c7 100644 --- a/crates/jose/Cargo.toml +++ b/crates/jose/Cargo.toml @@ -32,6 +32,7 @@ signature = "1.4.0" thiserror = "1.0.30" tokio = { version = "1.17.0", features = ["macros", "rt", "sync"] } tower = { version = "0.4.12", features = ["util"] } +tracing = "0.1.34" url = { version = "2.2.2", features = ["serde"] } mas-iana = { path = "../iana" } diff --git a/crates/jose/src/jwk.rs b/crates/jose/src/jwk.rs index 3ead54b5..65e46ebb 100644 --- a/crates/jose/src/jwk.rs +++ b/crates/jose/src/jwk.rs @@ -71,7 +71,7 @@ pub struct JsonWebKey { impl JsonWebKey { #[must_use] - pub fn new(parameters: JsonWebKeyParameters) -> Self { + pub const fn new(parameters: JsonWebKeyParameters) -> Self { Self { parameters, r#use: None, @@ -86,7 +86,7 @@ impl JsonWebKey { } #[must_use] - pub fn with_use(mut self, value: JsonWebKeyUse) -> Self { + pub const fn with_use(mut self, value: JsonWebKeyUse) -> Self { self.r#use = Some(value); self } @@ -98,7 +98,7 @@ impl JsonWebKey { } #[must_use] - pub fn with_alg(mut self, alg: JsonWebSignatureAlg) -> Self { + pub const fn with_alg(mut self, alg: JsonWebSignatureAlg) -> Self { self.alg = Some(alg); self } @@ -110,7 +110,7 @@ impl JsonWebKey { } #[must_use] - pub fn kty(&self) -> JsonWebKeyType { + pub const fn kty(&self) -> JsonWebKeyType { match self.parameters { JsonWebKeyParameters::Ec { .. } => JsonWebKeyType::Ec, JsonWebKeyParameters::Rsa { .. } => JsonWebKeyType::Rsa, @@ -124,7 +124,12 @@ impl JsonWebKey { } #[must_use] - pub fn params(&self) -> &JsonWebKeyParameters { + pub const fn alg(&self) -> Option { + self.alg + } + + #[must_use] + pub const fn params(&self) -> &JsonWebKeyParameters { &self.parameters } } diff --git a/crates/jose/src/keystore/jwks/static_store.rs b/crates/jose/src/keystore/jwks/static_store.rs index d6adfcf0..8f30df1b 100644 --- a/crates/jose/src/keystore/jwks/static_store.rs +++ b/crates/jose/src/keystore/jwks/static_store.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, future::Ready}; +use std::future::Ready; use digest::Digest; use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg}; @@ -21,15 +21,15 @@ use sha2::{Sha256, Sha384, Sha512}; use signature::{Signature, Verifier}; use thiserror::Error; -use crate::{JsonWebKeySet, JwtHeader, VerifyingKeystore}; +use crate::{JsonWebKey, JsonWebKeySet, JwtHeader, VerifyingKeystore}; #[derive(Debug, Error)] pub enum Error { #[error("key not found")] KeyNotFound, - #[error("invalid index")] - InvalidIndex, + #[error("multiple key matched")] + MultipleKeyMatched, #[error(r#"missing "kid" field in header"#)] MissingKid, @@ -43,43 +43,77 @@ pub enum Error { #[error(transparent)] Signature(#[from] signature::Error), - #[error("invalid {kty} key {kid}")] + #[error("invalid {kty} key")] InvalidKey { kty: JsonWebKeyType, - kid: String, source: anyhow::Error, }, } +struct KeyConstraint<'a> { + kty: Option, + alg: Option, + kid: Option<&'a str>, +} + +impl<'a> KeyConstraint<'a> { + fn matches(&self, key: &'a JsonWebKey) -> bool { + // If a specific KID was asked, match the key only if it has a matching kid + // field + if let Some(kid) = self.kid { + if key.kid() != Some(kid) { + return false; + } + } + + if let Some(kty) = self.kty { + if key.kty() != kty { + return false; + } + } + + if let Some(alg) = self.alg { + if key.alg() != None && key.alg() != Some(alg) { + return false; + } + } + + true + } + + fn find_keys(&self, key_set: &'a JsonWebKeySet) -> Vec<&'a JsonWebKey> { + key_set.iter().filter(|k| self.matches(k)).collect() + } +} + pub struct StaticJwksStore { key_set: JsonWebKeySet, - index: HashMap<(JsonWebKeyType, String), usize>, } impl StaticJwksStore { #[must_use] pub fn new(key_set: JsonWebKeySet) -> Self { - let index = key_set - .iter() - .enumerate() - .filter_map(|(index, key)| { - let kid = key.kid()?.to_string(); - let kty = key.kty(); - - Some(((kty, kid), index)) - }) - .collect(); - - Self { key_set, index } + Self { key_set } } - fn find_rsa_key(&self, kid: String) -> Result { - let index = *self - .index - .get(&(JsonWebKeyType::Rsa, kid.clone())) - .ok_or(Error::KeyNotFound)?; + fn find_key<'a>(&'a self, constraint: &KeyConstraint<'a>) -> Result<&'a JsonWebKey, Error> { + let keys = constraint.find_keys(&self.key_set); - let key = self.key_set.get(index).ok_or(Error::InvalidIndex)?; + match &keys[..] { + [one] => Ok(one), + [] => Err(Error::KeyNotFound), + _ => Err(Error::MultipleKeyMatched), + } + } + + fn find_rsa_key(&self, kid: Option<&str>) -> Result { + let constraint = KeyConstraint { + kty: Some(JsonWebKeyType::Rsa), + kid, + alg: None, + }; + + let key = self.find_key(&constraint)?; let key = key .params() @@ -87,20 +121,23 @@ impl StaticJwksStore { .try_into() .map_err(|source| Error::InvalidKey { kty: JsonWebKeyType::Rsa, - kid, source, })?; Ok(key) } - fn find_ecdsa_key(&self, kid: String) -> Result, Error> { - let index = *self - .index - .get(&(JsonWebKeyType::Ec, kid.clone())) - .ok_or(Error::KeyNotFound)?; + fn find_ecdsa_key( + &self, + kid: Option<&str>, + ) -> Result, Error> { + let constraint = KeyConstraint { + kty: Some(JsonWebKeyType::Ec), + kid, + alg: None, + }; - let key = self.key_set.get(index).ok_or(Error::InvalidIndex)?; + let key = self.find_key(&constraint)?; let key = key .params() @@ -108,20 +145,20 @@ impl StaticJwksStore { .try_into() .map_err(|source| Error::InvalidKey { kty: JsonWebKeyType::Ec, - kid, source, })?; Ok(key) } + #[tracing::instrument(skip(self))] fn verify_sync( &self, header: &JwtHeader, payload: &[u8], signature: &[u8], ) -> Result<(), Error> { - let kid = header.kid().ok_or(Error::MissingKid)?.to_string(); + let kid = header.kid(); match header.alg() { JsonWebSignatureAlg::Rs256 => { let key = self.find_rsa_key(kid)?; diff --git a/crates/oauth2-types/Cargo.toml b/crates/oauth2-types/Cargo.toml index 70295247..eccb5639 100644 --- a/crates/oauth2-types/Cargo.toml +++ b/crates/oauth2-types/Cargo.toml @@ -21,3 +21,4 @@ thiserror = "1.0.30" itertools = "0.10.3" mas-iana = { path = "../iana" } +mas-jose = { path = "../jose" } diff --git a/crates/oauth2-types/src/lib.rs b/crates/oauth2-types/src/lib.rs index db44958d..fc7fd108 100644 --- a/crates/oauth2-types/src/lib.rs +++ b/crates/oauth2-types/src/lib.rs @@ -50,6 +50,7 @@ impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType { pub mod errors; pub mod oidc; pub mod pkce; +pub mod registration; pub mod requests; pub mod scope; pub mod webfinger; diff --git a/crates/oauth2-types/src/oidc.rs b/crates/oauth2-types/src/oidc.rs index f715d3a2..a06b7a52 100644 --- a/crates/oauth2-types/src/oidc.rs +++ b/crates/oauth2-types/src/oidc.rs @@ -19,27 +19,27 @@ use mas_iana::{ PkceCodeChallengeMethod, }, }; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; use url::Url; use crate::requests::{Display, GrantType, Prompt, ResponseMode}; -#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)] #[serde(rename_all = "lowercase")] pub enum ApplicationType { Web, Native, } -#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)] #[serde(rename_all = "lowercase")] pub enum SubjectType { Public, Pairwise, } -#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)] #[serde(rename_all = "lowercase")] pub enum ClaimType { Normal, diff --git a/crates/oauth2-types/src/registration.rs b/crates/oauth2-types/src/registration.rs new file mode 100644 index 00000000..cbb14dd5 --- /dev/null +++ b/crates/oauth2-types/src/registration.rs @@ -0,0 +1,168 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Duration, Utc}; +use mas_iana::{ + jose::{JsonWebEncryptionAlg, JsonWebSignatureAlg}, + oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, +}; +use mas_jose::JsonWebKeySet; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, skip_serializing_none, DurationSeconds, TimestampSeconds}; +use url::Url; + +use crate::{ + oidc::{ApplicationType, SubjectType}, + requests::GrantType, +}; + +fn default_response_types() -> Vec { + vec![OAuthAuthorizationEndpointResponseType::Code] +} + +fn default_grant_types() -> Vec { + vec![GrantType::AuthorizationCode] +} + +const fn default_application_type() -> ApplicationType { + ApplicationType::Web +} + +#[serde_as] +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +pub struct ClientMetadata { + pub redirect_uris: Vec, + + #[serde(default = "default_response_types")] + pub response_types: Vec, + + #[serde(default = "default_grant_types")] + pub grant_types: Vec, + + #[serde(default = "default_application_type")] + pub application_type: ApplicationType, + + #[serde(default)] + pub contacts: Vec, + + #[serde(default)] + pub client_name: Option, + + #[serde(default)] + pub logo_uri: Option, + + #[serde(default)] + pub client_uri: Option, + + #[serde(default)] + pub policy_uri: Option, + + #[serde(default)] + pub tos_uri: Option, + + #[serde(default)] + pub jwks_uri: Option, + + #[serde(default)] + pub jwks: Option, + + #[serde(default)] + pub sector_identifier_uri: Option, + + #[serde(default)] + pub subject_type: Option, + + #[serde(default)] + pub token_endpoint_auth_method: Option, + + #[serde(default)] + pub token_endpoint_auth_signing_alg: Option, + + #[serde(default)] + pub id_token_signed_response_alg: Option, + + #[serde(default)] + pub id_token_encrypted_response_alg: Option, + + #[serde(default)] + pub id_token_encrypted_response_enc: Option, + + #[serde(default)] + pub userinfo_signed_response_alg: Option, + + #[serde(default)] + pub userinfo_encrypted_response_alg: Option, + + #[serde(default)] + pub userinfo_encrypted_response_enc: Option, + + #[serde(default)] + pub request_object_signing_alg: Option, + + #[serde(default)] + pub request_object_encryption_alg: Option, + + #[serde(default)] + pub request_object_encryption_enc: Option, + + #[serde(default)] + #[serde_as(as = "Option>")] + pub default_max_age: Option, + + #[serde(default)] + pub require_auth_time: bool, + + #[serde(default)] + pub default_acr_values: Vec, + + #[serde(default)] + pub initiate_login_uri: Option, + + #[serde(default)] + pub request_uris: Option>, + + #[serde(default)] + pub require_signed_request_object: bool, + + #[serde(default)] + pub require_pushed_authorization_requests: bool, + + #[serde(default)] + pub introspection_signed_response_alg: Option, + + #[serde(default)] + pub introspection_encrypted_response_alg: Option, + + #[serde(default)] + pub introspection_encrypted_response_enc: Option, +} + +#[serde_as] +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +pub struct ClientRegistrationResponse { + pub client_id: String, + + #[serde(default)] + pub client_secret: Option, + + #[serde(default)] + #[serde_as(as = "Option>")] + pub client_id_issued_at: Option>, + + #[serde(default)] + #[serde_as(as = "Option>")] + pub client_secret_expires_at: Option>, +} diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 22f927f7..2468e4d7 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -15,7 +15,10 @@ use std::string::ToString; use mas_data_model::{Client, JwksOrJwksUri}; -use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}; +use mas_iana::{ + jose::JsonWebSignatureAlg, + oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, +}; use mas_jose::JsonWebKeySet; use oauth2_types::requests::GrantType; use sqlx::{PgConnection, PgExecutor}; @@ -300,6 +303,102 @@ pub async fn lookup_client_by_client_id( Ok(client) } +#[allow(clippy::too_many_arguments)] +pub async fn insert_client( + conn: &mut PgConnection, + client_id: &str, + redirect_uris: &[Url], + encrypted_client_secret: Option<&str>, + response_types: &[OAuthAuthorizationEndpointResponseType], + grant_types: &[GrantType], + contacts: &[String], + client_name: Option<&str>, + logo_uri: Option<&Url>, + client_uri: Option<&Url>, + policy_uri: Option<&Url>, + tos_uri: Option<&Url>, + jwks_uri: Option<&Url>, + jwks: Option<&JsonWebKeySet>, + id_token_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option<&Url>, +) -> Result<(), sqlx::Error> { + let response_types: Vec = response_types.iter().map(ToString::to_string).collect(); + let grant_type_authorization_code = grant_types.contains(&GrantType::AuthorizationCode); + let grant_type_refresh_token = grant_types.contains(&GrantType::RefreshToken); + let logo_uri = logo_uri.map(Url::as_str); + let client_uri = client_uri.map(Url::as_str); + let policy_uri = policy_uri.map(Url::as_str); + let tos_uri = tos_uri.map(Url::as_str); + let jwks = jwks.map(serde_json::to_value).transpose().unwrap(); // TODO + let jwks_uri = jwks_uri.map(Url::as_str); + let id_token_signed_response_alg = id_token_signed_response_alg.map(|v| v.to_string()); + let token_endpoint_auth_method = token_endpoint_auth_method.map(|v| v.to_string()); + let token_endpoint_auth_signing_alg = token_endpoint_auth_signing_alg.map(|v| v.to_string()); + let initiate_login_uri = initiate_login_uri.map(Url::as_str); + + let id = sqlx::query_scalar!( + r#" + INSERT INTO oauth2_clients + (client_id, + encrypted_client_secret, + response_types, + grant_type_authorization_code, + grant_type_refresh_token, + contacts, + client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks_uri, + jwks, + id_token_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) + RETURNING id + "#, + client_id, + encrypted_client_secret, + &response_types, + grant_type_authorization_code, + grant_type_refresh_token, + contacts, + client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks_uri, + jwks, + id_token_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, + ) + .fetch_one(&mut *conn) + .await?; + + let redirect_uris: Vec = redirect_uris.iter().map(ToString::to_string).collect(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri) + SELECT $1, uri FROM UNNEST($2::text[]) uri + "#, + id, + &redirect_uris, + ) + .execute(&mut *conn) + .await?; + + Ok(()) +} + pub async fn insert_client_from_config( conn: &mut PgConnection, client_id: &str,