1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-21 23:00:50 +03:00

Refactor the provider client credentials extraction

This commit is contained in:
Quentin Gliech
2022-11-23 11:28:12 +01:00
parent bedcf44741
commit 16088fc11c
2 changed files with 94 additions and 56 deletions

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use axum::{
extract::{Path, Query, State},
response::IntoResponse,
@@ -21,12 +20,10 @@ use axum::{
use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode;
use mas_http::ClientInitError;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_keystore::{Encrypter, Keystore};
use mas_oidc_client::{
error::{DiscoveryError, JwksError, TokenAuthorizationCodeError},
requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData},
types::client_credentials::ClientCredentials,
};
use mas_router::UrlBuilder;
use mas_storage::{upstream_oauth2::lookup_session, LookupResultExt};
@@ -36,7 +33,7 @@ use sqlx::PgPool;
use thiserror::Error;
use ulid::Ulid;
use super::http_service;
use super::{client_credentials_for_provider, http_service, ProviderCredentialsError};
#[derive(Deserialize)]
pub struct QueryParams {
@@ -77,9 +74,6 @@ pub(crate) enum RouteError {
error_description: Option<String>,
},
#[error("Provider is missing a client secret")]
MissingClientSecret,
#[error("Missing session cookie")]
MissingCookie,
@@ -129,6 +123,12 @@ impl From<ClientInitError> for RouteError {
}
}
impl From<ProviderCredentialsError> for RouteError {
fn from(e: ProviderCredentialsError) -> Self {
Self::InternalError(Box::new(e))
}
}
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
match self {
@@ -207,54 +207,12 @@ pub(crate) async fn get(
mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?;
// Figure out the client credentials
let client_id = provider.client_id.clone();
// Decrypt the client secret
let client_secret = provider
.encrypted_client_secret
.map(|encrypted_client_secret| {
encrypter
.decrypt_string(&encrypted_client_secret)
.and_then(|client_secret| {
String::from_utf8(client_secret)
.context("Client secret contains non-UTF8 bytes")
})
})
.transpose()?;
let token_endpoint = metadata.token_endpoint();
let client_credentials = match provider.token_endpoint_auth_method {
OAuthClientAuthenticationMethod::None => ClientCredentials::None { client_id },
OAuthClientAuthenticationMethod::ClientSecretPost => ClientCredentials::ClientSecretPost {
client_id,
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
},
OAuthClientAuthenticationMethod::ClientSecretBasic => {
ClientCredentials::ClientSecretBasic {
client_id,
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
}
}
OAuthClientAuthenticationMethod::ClientSecretJwt => ClientCredentials::ClientSecretJwt {
client_id,
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
signing_algorithm: provider
.token_endpoint_signing_alg
.unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256),
token_endpoint: token_endpoint.clone(),
},
OAuthClientAuthenticationMethod::PrivateKeyJwt => ClientCredentials::PrivateKeyJwt {
client_id,
jwt_signing_method:
mas_oidc_client::types::client_credentials::JwtSigningMethod::Keystore(keystore),
signing_algorithm: provider
.token_endpoint_signing_alg
.unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256),
token_endpoint: token_endpoint.clone(),
},
// XXX: The database should never have an unsupported method in it
_ => unreachable!(),
};
let client_credentials = client_credentials_for_provider(
&provider,
metadata.token_endpoint(),
&keystore,
&encrypter,
)?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
@@ -277,7 +235,7 @@ pub(crate) async fn get(
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
&http_service,
client_credentials,
token_endpoint,
metadata.token_endpoint(),
code,
validation_data,
Some(id_token_verification_data),