1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Create mas-oidc-client crate

Methods to interact as an RP with an OIDC OP.
This commit is contained in:
Kévin Commaille
2022-11-07 11:15:22 +01:00
committed by Quentin Gliech
parent c590e8df92
commit 90d0e12b7f
35 changed files with 6200 additions and 40 deletions

View File

@ -0,0 +1,714 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The error types used in this crate.
use std::{str::Utf8Error, sync::Arc};
use headers::authorization::InvalidBearerToken;
use http::{header::ToStrError, StatusCode};
use mas_http::{catch_http_codes, form_urlencoded_request, json_request, json_response};
use mas_jose::{
claims::ClaimError,
jwa::InvalidAlgorithm,
jwt::{JwtDecodeError, JwtSignatureError, NoKeyWorked},
};
use mas_keystore::WrongAlgorithmError;
use oauth2_types::{
errors::ClientErrorCode, oidc::ProviderMetadataVerificationError, pkce::CodeChallengeError,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub use tower::BoxError;
/// All possible errors when using this crate.
#[derive(Debug, Error)]
#[error(transparent)]
pub enum Error {
/// An error occurred fetching provider metadata.
Discovery(#[from] DiscoveryError),
/// An error occurred fetching the provider JWKS.
Jwks(#[from] JwksError),
/// An error occurred during client registration.
Registration(#[from] RegistrationError),
/// An error occurred building the authorization URL.
Authorization(#[from] AuthorizationError),
/// An error occurred exchanging an authorization code for an access token.
TokenAuthorizationCode(#[from] TokenAuthorizationCodeError),
/// An error occurred requesting an access token with client credentials.
TokenClientCredentials(#[from] TokenRequestError),
/// An error occurred refreshing an access token.
TokenRefresh(#[from] TokenRefreshError),
/// An error occurred revoking a token.
TokenRevoke(#[from] TokenRevokeError),
/// An error occurred requesting user info.
UserInfo(#[from] UserInfoError),
/// An error occurred introspecting a token.
Introspection(#[from] IntrospectionError),
}
/// All possible errors when fetching provider metadata.
#[derive(Debug, Error)]
pub enum DiscoveryError {
/// An error occurred building the request's URL.
#[error(transparent)]
IntoUrl(#[from] url::ParseError),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
FromJson(#[from] serde_json::Error),
/// An error occurred validating the metadata.
#[error(transparent)]
Validation(#[from] ProviderMetadataVerificationError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<json_response::Error<S>> for DiscoveryError
where
S: Into<DiscoveryError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for DiscoveryError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
Self::Http(HttpError::new(status_code, inner))
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when registering the client.
#[derive(Debug, Error)]
pub enum RegistrationError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred serializing the request or deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// No client secret was received although one was expected because of the
/// authentication method.
#[error("missing client secret in response")]
MissingClientSecret,
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<json_request::Error<S>> for RegistrationError
where
S: Into<RegistrationError>,
{
fn from(err: json_request::Error<S>) -> Self {
match err {
json_request::Error::Serialize { inner } => inner.into(),
json_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for RegistrationError
where
S: Into<RegistrationError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for RegistrationError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when making a pushed authorization request.
#[derive(Debug, Error)]
pub enum PushedAuthorizationError {
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for PushedAuthorizationError
where
S: Into<PushedAuthorizationError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for PushedAuthorizationError
where
S: Into<PushedAuthorizationError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for PushedAuthorizationError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when authorizing the client.
#[derive(Debug, Error)]
pub enum AuthorizationError {
/// An error occurred constructing the PKCE code challenge.
#[error(transparent)]
Pkce(#[from] CodeChallengeError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred making the PAR request.
#[error(transparent)]
PushedAuthorization(#[from] PushedAuthorizationError),
}
/// All possible errors when requesting an access token.
#[derive(Debug, Error)]
pub enum TokenRequestError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for TokenRequestError
where
S: Into<TokenRequestError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for TokenRequestError
where
S: Into<TokenRequestError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRequestError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when exchanging a code for an access token.
#[derive(Debug, Error)]
pub enum TokenAuthorizationCodeError {
/// The nonce doesn't match the one that was sent.
#[error("wrong nonce")]
WrongNonce,
/// An error occurred requesting the access token.
#[error(transparent)]
Token(#[from] TokenRequestError),
/// An error occurred validating the ID Token.
#[error(transparent)]
IdToken(#[from] IdTokenError),
}
/// All possible errors when refreshing an access token.
#[derive(Debug, Error)]
pub enum TokenRefreshError {
/// An error occurred requesting the access token.
#[error(transparent)]
Token(#[from] TokenRequestError),
/// An error occurred validating the ID Token.
#[error(transparent)]
IdToken(#[from] IdTokenError),
}
/// All possible errors when revoking a token.
#[derive(Debug, Error)]
pub enum TokenRevokeError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred deserializing the error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for TokenRevokeError
where
S: Into<TokenRevokeError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRevokeError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when requesting user info.
#[derive(Debug, Error)]
pub enum UserInfoError {
/// An error occurred getting the provider metadata.
#[error(transparent)]
Discovery(#[from] Arc<DiscoveryError>),
/// The provider doesn't support requesting user info.
#[error("missing UserInfo support")]
MissingUserInfoSupport,
/// No token is available to get info from.
#[error("missing token")]
MissingToken,
/// No client metadata is available.
#[error("missing client metadata")]
MissingClientMetadata,
/// The access token is invalid.
#[error(transparent)]
Token(#[from] InvalidBearerToken),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// The content-type header is missing from the response.
#[error("missing response content-type")]
MissingResponseContentType,
/// The content-type header could not be decoded.
#[error("could not decoded response content-type: {0}")]
DecodeResponseContentType(#[from] ToStrError),
/// The content-type is not the one that was expected.
#[error("invalid response content-type {got:?}, expected {expected:?}")]
InvalidResponseContentType {
/// The expected content-type.
expected: String,
/// The returned content-type.
got: String,
},
/// An error occurred reading the response.
#[error(transparent)]
FromUtf8(#[from] Utf8Error),
/// An error occurred deserializing the JSON or error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred verifying the Id Token.
#[error(transparent)]
IdToken(#[from] IdTokenError),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for UserInfoError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when introspecting a token.
#[derive(Debug, Error)]
pub enum IntrospectionError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// The access token is invalid.
#[error(transparent)]
Token(#[from] InvalidBearerToken),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred deserializing the JSON or error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for IntrospectionError
where
S: Into<IntrospectionError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for IntrospectionError
where
S: Into<IntrospectionError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for IntrospectionError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when requesting a JWKS.
#[derive(Debug, Error)]
pub enum JwksError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<json_response::Error<S>> for JwksError
where
S: Into<BoxError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Service { inner } => Self::Service(inner.into()),
json_response::Error::Deserialize { inner } => Self::Json(inner),
}
}
}
/// All possible errors when verifying a JWT.
#[derive(Debug, Error)]
pub enum JwtVerificationError {
/// An error occured decoding the JWT.
#[error(transparent)]
JwtDecode(#[from] JwtDecodeError),
/// No key worked for verifying the JWT's signature.
#[error(transparent)]
JwtSignature(#[from] NoKeyWorked),
/// An error occurred extracting a claim.
#[error(transparent)]
Claim(#[from] ClaimError),
/// The issuer is not the one that sent the JWT.
#[error("wrong issuer claim")]
WrongIssuer,
/// The audience of the JWT is not this client.
#[error("wrong aud claim")]
WrongAudience,
/// The algorithm used for signing the JWT is not the one that was
/// requested.
#[error("wrong signature alg")]
WrongSignatureAlg,
}
/// All possible errors when verifying an ID token.
#[derive(Debug, Error)]
pub enum IdTokenError {
/// No ID Token was found in the response although one was expected.
#[error("ID token is missing")]
MissingIdToken,
/// The ID Token from the latest Authorization was not provided although
/// this request expects to be verified against one.
#[error("Authorization ID token is missing")]
MissingAuthIdToken,
/// An error occurred validating the ID Token's signature and basic claims.
#[error(transparent)]
Jwt(#[from] JwtVerificationError),
/// An error occurred extracting a claim.
#[error(transparent)]
Claim(#[from] ClaimError),
/// The subject identifier returned by the issuer is not the same as the one
/// we got before.
#[error("wrong subject identifier")]
WrongSubjectIdentifier,
/// The authentication time returned by the issuer is not the same as the
/// one we got before.
#[error("wrong authentication time")]
WrongAuthTime,
}
/// An error that can be returned by an OpenID Provider.
#[derive(Debug, Clone, Error)]
#[error("{status}: {body:?}")]
pub struct HttpError {
/// The status code of the error.
pub status: StatusCode,
/// The body of the error, if any.
pub body: Option<ErrorBody>,
}
impl HttpError {
/// Creates a new `HttpError` with the given status code and optional body.
#[must_use]
pub fn new(status: StatusCode, body: Option<ErrorBody>) -> Self {
Self { status, body }
}
}
/// The body of an error that can be returned by an OpenID Provider.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorBody {
/// The error code.
pub error: ClientErrorCode,
/// Additional text description of the error for debugging.
pub error_description: Option<String>,
}
/// All errors that can occur when adding client credentials to the request.
#[derive(Debug, Error)]
pub enum CredentialsError {
/// Trying to use an unsupported authentication method.
#[error("unsupported authentication method")]
UnsupportedMethod,
/// When authenticationg with `private_key_jwt`, no private key was found
/// for the given algorithm.
#[error("no private key was found for the given algorithm")]
NoPrivateKeyFound,
/// The signing algorithm is invalid for this authentication method.
#[error("invalid algorithm: {0}")]
InvalidSigningAlgorithm(#[from] InvalidAlgorithm),
/// An error occurred when building the claims of the JWT.
#[error(transparent)]
JwtClaims(#[from] ClaimError),
/// The key found cannot be used with the algorithm.
#[error(transparent)]
JwtWrongAlgorithm(#[from] WrongAlgorithmError),
/// An error occurred when signing the JWT.
#[error(transparent)]
JwtSignature(#[from] JwtSignatureError),
/// An error occurred with a custom signing method.
#[error(transparent)]
Custom(BoxError),
}

View File

@ -0,0 +1,88 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::task::Poll;
use bytes::Bytes;
use futures_util::future::BoxFuture;
use http::{Request, Response};
use http_body::{Body, Full};
use hyper::body::to_bytes;
use thiserror::Error;
use tower::{BoxError, Layer, Service};
#[derive(Debug, Error)]
#[error(transparent)]
pub enum BodyError<E> {
Decompression(BoxError),
Service(E),
}
#[derive(Clone)]
pub struct BodyService<S> {
inner: S,
}
impl<S> BodyService<S> {
pub const fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S, E, ResBody> Service<Request<Bytes>> for BodyService<S>
where
S: Service<Request<Full<Bytes>>, Response = Response<ResBody>, Error = E>,
ResBody: Body<Data = Bytes, Error = BoxError> + Send,
S::Future: Send + 'static,
{
type Error = BodyError<E>;
type Response = Response<Bytes>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(BodyError::Service)
}
fn call(&mut self, request: Request<Bytes>) -> Self::Future {
let (parts, body) = request.into_parts();
let body = Full::new(body);
let request = Request::from_parts(parts, body);
let fut = self.inner.call(request);
let fut = async {
let response = fut.await.map_err(BodyError::Service)?;
let (parts, body) = response.into_parts();
let body = to_bytes(body).await.map_err(BodyError::Decompression)?;
let response = Response::from_parts(parts, body);
Ok(response)
};
Box::pin(fut)
}
}
#[derive(Default, Clone, Copy)]
pub struct BodyLayer(());
impl<S> Layer<S> for BodyLayer {
type Service = BodyService<S>;
fn layer(&self, inner: S) -> Self::Service {
BodyService::new(inner)
}
}

View File

@ -0,0 +1,75 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A [`HttpService`] that uses [hyper] as a backend.
//!
//! [hyper]: https://crates.io/crates/hyper
use std::time::Duration;
use http::{header::USER_AGENT, HeaderValue};
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use tower::{limit::ConcurrencyLimitLayer, BoxError, ServiceBuilder};
use tower_http::{
decompression::DecompressionLayer, follow_redirect::FollowRedirectLayer,
set_header::SetRequestHeaderLayer, timeout::TimeoutLayer,
};
mod body_layer;
use self::body_layer::BodyLayer;
use super::HttpService;
static MAS_USER_AGENT: HeaderValue = HeaderValue::from_static("mas-oidc-client/0.0.1");
/// Constructs a [`HttpService`] using [hyper] as a backend.
///
/// [hyper]: https://crates.io/crates/hyper
#[must_use]
pub fn hyper_service() -> HttpService {
let resolver = ServiceBuilder::new().service(GaiResolver::new());
let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false);
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth();
let https = HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http1()
.enable_http2()
.wrap_connector(http);
let client = hyper::Client::builder().build(https);
let client = ServiceBuilder::new()
.map_err(BoxError::from)
.layer(BodyLayer::default())
.layer(DecompressionLayer::new())
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(client);
HttpService::new(client)
}

View File

@ -0,0 +1,109 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Traits to implement to provide a custom HTTP service for `Client`.
use std::{
fmt,
task::{Context, Poll},
};
use bytes::Bytes;
use futures::future::BoxFuture;
use tower::{BoxError, Service, ServiceExt};
#[cfg(feature = "hyper")]
pub mod hyper;
/// Type for the underlying HTTP service.
///
/// Allows implementors to use different libraries that provide a [`Service`]
/// that implements [`Clone`] + [`Send`] + [`Sync`].
pub type HttpService = BoxCloneSyncService<http::Request<Bytes>, http::Response<Bytes>, BoxError>;
impl fmt::Debug for HttpService {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("HttpService").finish()
}
}
/// A [`Clone`] + [`Send`] + [`Sync`] boxed [`Service`].
///
/// [`BoxCloneSyncService`] turns a service into a trait object, allowing the
/// response future type to be dynamic, and allowing the service to be cloned.
#[allow(clippy::type_complexity)]
pub struct BoxCloneSyncService<T, U, E>(
Box<
dyn CloneSyncService<T, Response = U, Error = E, Future = BoxFuture<'static, Result<U, E>>>,
>,
);
impl<T, U, E> BoxCloneSyncService<T, U, E> {
/// Create a new `BoxCloneSyncService`.
pub fn new<S>(inner: S) -> Self
where
S: Service<T, Response = U, Error = E> + Clone + Send + Sync + 'static,
S::Future: Send + 'static,
{
let inner = inner.map_future(|f| Box::pin(f) as _);
Self(Box::new(inner))
}
}
impl<T, U, E> Service<T> for BoxCloneSyncService<T, U, E> {
type Response = U;
type Error = E;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
#[inline]
fn call(&mut self, request: T) -> Self::Future {
self.0.call(request)
}
}
impl<T, U, E> Clone for BoxCloneSyncService<T, U, E> {
fn clone(&self) -> Self {
Self(self.0.clone_sync_box())
}
}
trait CloneSyncService<R>: Service<R> + Send + Sync {
fn clone_sync_box(
&self,
) -> Box<
dyn CloneSyncService<
R,
Response = Self::Response,
Error = Self::Error,
Future = Self::Future,
>,
>;
}
impl<R, T> CloneSyncService<R> for T
where
T: Service<R> + Send + Sync + Clone + 'static,
{
fn clone_sync_box(
&self,
) -> Box<dyn CloneSyncService<R, Response = T::Response, Error = T::Error, Future = T::Future>>
{
Box::new(self.clone())
}
}

View File

@ -0,0 +1,88 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! An [OpenID Connect] client library for the [Matrix] specification.
//!
//! This is part of the [Matrix Authentication Service] project.
//!
//! # Scope
//!
//! The scope of this crate is to support OIDC features required by the
//! Matrix specification according to [MSC3861] and its sub-proposals.
//!
//! As such, it is compatible with the OpenID Connect 1.0 specification, but
//! also enforces Matrix-specific requirements or adds compatibility with new
//! [OAuth 2.0] features.
//!
//! # OpenID Connect and OAuth 2.0 Features
//!
//! - Grant Types:
//! - [Authorization Code](https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth)
//! - [Client Credentials](https://www.rfc-editor.org/rfc/rfc6749#section-4.4)
//! - [Device Code](https://www.rfc-editor.org/rfc/rfc8628) (TBD)
//! - [User Info](https://openid.net/specs/openid-connect-core-1_0.html#UserInfo)
//! - Token:
//! - [Refresh Token](https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens)
//! - [Introspection](https://www.rfc-editor.org/rfc/rfc7662)
//! - [Revocation](https://www.rfc-editor.org/rfc/rfc7009)
//! - [Dynamic Client Registration](https://openid.net/specs/openid-connect-registration-1_0.html)
//! - [PKCE](https://www.rfc-editor.org/rfc/rfc7636)
//! - [Pushed Authorization Requests](https://www.rfc-editor.org/rfc/rfc9126)
//!
//! # Matrix features
//!
//! - Client registration
//! - Login
//! - Matrix API Scopes
//! - Logout
//!
//! [OpenID Connect]: https://openid.net/connect/
//! [Matrix]: https://matrix.org/
//! [Matrix Authentication Service]: https://github.com/matrix-org/matrix-authentication-service
//! [MSC3861]: https://github.com/matrix-org/matrix-spec-proposals/pull/3861
//! [OAuth 2.0]: https://oauth.net/2/
#![forbid(unsafe_code)]
#![deny(
clippy::all,
clippy::str_to_string,
rustdoc::broken_intra_doc_links,
missing_docs
)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions, clippy::implicit_hasher)]
pub mod error;
pub mod http_service;
pub mod requests;
pub mod types;
mod utils;
use std::fmt;
#[doc(inline)]
pub use mas_jose as jose;
// Wrapper around `String` that cannot be used in a meaningful way outside of
// this crate. Used for string enums that only allow certain characters because
// their variant can't be private.
#[doc(hidden)]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PrivString(String);
impl fmt::Debug for PrivString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

View File

@ -0,0 +1,460 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for the [Authorization Code flow].
//!
//! [Authorization Code flow]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Utc};
use http::header::CONTENT_TYPE;
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
use mas_jose::claims::{self, TokenHash};
use oauth2_types::{
pkce,
prelude::CodeChallengeMethodExt,
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
Prompt, PushedAuthorizationResponse,
},
scope::Scope,
};
use rand::{
distributions::{Alphanumeric, DistString},
Rng,
};
use serde::Serialize;
use serde_with::skip_serializing_none;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
},
http_service::HttpService,
requests::{jose::verify_id_token, token::request_access_token},
types::{
client_credentials::ClientCredentials,
scope::{ScopeExt, ScopeToken},
IdToken,
},
utils::{http_all_error_status_codes, http_error_mapper},
};
/// The data necessary to build an authorization request.
#[derive(Debug, Clone, Copy)]
pub struct AuthorizationRequestData<'a> {
/// The ID obtained when registering the client.
pub client_id: &'a str,
/// The PKCE methods supported by the issuer, from its metadata.
pub code_challenge_methods_supported: Option<&'a [PkceCodeChallengeMethod]>,
/// The scope to authorize.
///
/// If the OpenID Connect scope token (`openid`) is not included, it will be
/// added.
pub scope: &'a Scope,
/// The URI to redirect the end-user to after the authorization.
///
/// It must be one of the redirect URIs provided during registration.
pub redirect_uri: &'a Url,
/// Optional hints for the action to be performed.
pub prompt: Option<&'a [Prompt]>,
}
/// The data necessary to validate a response from the Token endpoint in the
/// Authorization Code flow.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthorizationValidationData {
/// A unique identifier for the request.
pub state: String,
/// A string to mitigate replay attacks.
pub nonce: String,
/// The URI where the end-user will be redirected after authorization.
pub redirect_uri: Url,
/// A string to correlate the authorization request to the token request.
pub code_challenge_verifier: Option<String>,
}
#[skip_serializing_none]
#[derive(Clone, Serialize)]
struct FullAuthorizationRequest {
#[serde(flatten)]
inner: AuthorizationRequest,
#[serde(flatten)]
pkce: Option<pkce::AuthorizationRequest>,
}
/// Build the authorization request.
fn build_authorization_request(
authorization_data: AuthorizationRequestData<'_>,
rng: &mut impl Rng,
) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> {
let AuthorizationRequestData {
client_id,
code_challenge_methods_supported,
scope,
redirect_uri,
prompt,
} = authorization_data;
let mut scope = scope.clone();
// Generate a random CSRF "state" token and a nonce.
let state = Alphanumeric.sample_string(rng, 16);
let nonce = Alphanumeric.sample_string(rng, 16);
// Use PKCE, whenever possible.
let (pkce, code_challenge_verifier) = if code_challenge_methods_supported
.iter()
.any(|methods| methods.contains(&PkceCodeChallengeMethod::S256))
{
let mut verifier = [0u8; 32];
rng.fill(&mut verifier);
let method = PkceCodeChallengeMethod::S256;
let verifier = Base64UrlUnpadded::encode_string(&verifier);
let code_challenge = method.compute_challenge(&verifier)?.into();
let pkce = pkce::AuthorizationRequest {
code_challenge_method: method,
code_challenge,
};
(Some(pkce), Some(verifier))
} else {
(None, None)
};
scope.insert_token(ScopeToken::Openid);
let auth_request = FullAuthorizationRequest {
inner: AuthorizationRequest {
response_type: OAuthAuthorizationEndpointResponseType::Code.into(),
client_id: client_id.to_owned(),
redirect_uri: Some(redirect_uri.clone()),
scope,
state: Some(state.clone()),
response_mode: None,
nonce: Some(nonce.clone()),
display: None,
prompt: prompt.map(ToOwned::to_owned),
max_age: None,
ui_locales: None,
id_token_hint: None,
login_hint: None,
acr_values: None,
request: None,
request_uri: None,
registration: None,
},
pkce,
};
let auth_data = AuthorizationValidationData {
state,
nonce,
redirect_uri: redirect_uri.clone(),
code_challenge_verifier,
};
Ok((auth_request, auth_data))
}
/// Build the URL for authenticating at the Authorization endpoint.
///
/// # Arguments
///
/// * `authorization_endpoint` - The URL of the issuer's authorization endpoint.
///
/// * `authorization_data` - The data necessary to build the authorization
/// request.
///
/// * `rng` - A random number generator.
///
/// # Returns
///
/// A URL to be opened in a web browser where the end-user will be able to
/// authorize the given scope, and the [`AuthorizationValidationData`] to
/// validate this request.
///
/// The redirect URI will receive parameters in its query:
///
/// * A successful response will receive a `code` and a `state`.
///
/// * If the authorization fails, it should receive an `error` parameter with a
/// [`ClientErrorCode`] and optionally an `error_description`.
///
/// # Errors
///
/// Returns an error if preparing the URL fails.
///
/// [`VerifiedClientMetadata`]: oauth2_types::registration::VerifiedClientMetadata
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
#[allow(clippy::too_many_lines)]
pub fn build_authorization_url(
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData<'_>,
rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!(
scope = ?authorization_data.scope,
"Authorizing..."
);
let (authorization_request, validation_data) =
build_authorization_request(authorization_data, rng)?;
let authorization_query = serde_urlencoded::to_string(authorization_request)?;
let mut authorization_url = authorization_endpoint;
// Add our parameters to the query, because the URL might already have one.
let mut full_query = authorization_url
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&authorization_query);
authorization_url.set_query(Some(&full_query));
Ok((authorization_url, validation_data))
}
/// Make a [Pushed Authorization Request] and build the URL for authenticating
/// at the Authorization endpoint.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `par_endpoint` - The URL of the issuer's Pushed Authorization Request
/// endpoint.
///
/// * `authorization_endpoint` - The URL of the issuer's Authorization endpoint.
///
/// * `authorization_data` - The data necessary to build the authorization
/// request.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Returns
///
/// A URL to be opened in a web browser where the end-user will be able to
/// authorize the given scope, and the [`AuthorizationValidationData`] to
/// validate this request.
///
/// The redirect URI will receive parameters in its query:
///
/// * A successful response will receive a `code` and a `state`.
///
/// * If the authorization fails, it should receive an `error` parameter with a
/// [`ClientErrorCode`] and optionally an `error_description`.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or building
/// the URL fails.
///
/// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(par_endpoint))]
pub async fn build_par_authorization_url(
http_service: &HttpService,
client_credentials: ClientCredentials,
par_endpoint: &Url,
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData<'_>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!(
scope = ?authorization_data.scope,
"Authorizing with a PAR..."
);
let client_id = client_credentials.client_id().to_owned();
let (authorization_request, validation_data) =
build_authorization_request(authorization_data, rng)?;
let par_request = http::Request::post(par_endpoint.as_str())
.header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())
.body(authorization_request)
.map_err(PushedAuthorizationError::from)?;
let par_request = client_credentials
.apply_to_request(par_request, now, rng)
.map_err(PushedAuthorizationError::from)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<PushedAuthorizationResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let par_response = service
.ready_oneshot()
.await
.map_err(PushedAuthorizationError::from)?
.call(par_request)
.await
.map_err(PushedAuthorizationError::from)?
.into_body();
let authorization_query = serde_urlencoded::to_string([
("request_uri", par_response.request_uri.as_str()),
("client_id", &client_id),
])?;
let mut authorization_url = authorization_endpoint;
// Add our parameters to the query, because the URL might already have one.
let mut full_query = authorization_url
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&authorization_query);
authorization_url.set_query(Some(&full_query));
Ok((authorization_url, validation_data))
}
/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a
/// token with a new scope.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `code` - The authorization code returned at the Authorization endpoint.
///
/// * `validation_data` - The validation data that was returned when building
/// the Authorization URL, for the state returned at the Authorization
/// endpoint.
///
/// * `id_token_verification_data` - The data required to verify the ID Token in
/// the response.
///
/// The signing algorithm corresponds to the `id_token_signed_response_alg`
/// field in the client metadata.
///
/// If it is not provided, the ID Token won't be verified. Note that in the
/// OpenID Connect specification, this verification is required.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// verification of the ID Token fails.
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn access_token_with_authorization_code(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
code: String,
validation_data: AuthorizationValidationData,
id_token_verification_data: Option<JwtVerificationData<'_>>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenAuthorizationCodeError> {
tracing::debug!("Exchanging authorization code for access token...");
let token_response = request_access_token(
http_service,
client_credentials,
token_endpoint,
AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: code.clone(),
redirect_uri: Some(validation_data.redirect_uri),
code_verifier: validation_data.code_challenge_verifier,
}),
now,
rng,
)
.await?;
let id_token = if let Some(verification_data) = id_token_verification_data {
let signing_alg = verification_data.signing_algorithm;
let id_token = token_response
.id_token
.as_deref()
.ok_or(IdTokenError::MissingIdToken)?;
let id_token = verify_id_token(id_token, verification_data, None, now)?;
let mut claims = id_token.payload().clone();
// Access token hash must match.
claims::AT_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(signing_alg, &token_response.access_token),
)
.map_err(IdTokenError::from)?;
// Code hash must match.
claims::C_HASH
.extract_optional_with_options(&mut claims, TokenHash::new(signing_alg, &code))
.map_err(IdTokenError::from)?;
// Nonce must match.
let token_nonce = claims::NONCE
.extract_required(&mut claims)
.map_err(IdTokenError::from)?;
if token_nonce != validation_data.nonce {
return Err(TokenAuthorizationCodeError::WrongNonce);
}
Some(id_token.into_owned())
} else {
None
};
Ok((token_response, id_token))
}

View File

@ -0,0 +1,75 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for the [Client Credentials flow].
//!
//! [Client Credentials flow]: https://www.rfc-editor.org/rfc/rfc6749#section-4.4
use chrono::{DateTime, Utc};
use oauth2_types::{
requests::{AccessTokenRequest, AccessTokenResponse, ClientCredentialsGrant},
scope::Scope,
};
use rand::Rng;
use url::Url;
use crate::{
error::TokenRequestError, http_service::HttpService, requests::token::request_access_token,
types::client_credentials::ClientCredentials,
};
/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a
/// token with a new scope.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `scope` - The scope to authorize.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn access_token_with_client_credentials(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
scope: Option<Scope>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<AccessTokenResponse, TokenRequestError> {
tracing::debug!("Requesting access token with client credentials...");
request_access_token(
http_service,
client_credentials,
token_endpoint,
AccessTokenRequest::ClientCredentials(ClientCredentialsGrant { scope }),
now,
rng,
)
.await
}

View File

@ -0,0 +1,109 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for OpenID Connect Provider [Discovery].
//!
//! [Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
use bytes::Bytes;
use mas_http::{CatchHttpCodesLayer, JsonResponseLayer};
use oauth2_types::oidc::{ProviderMetadata, VerifiedProviderMetadata};
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::DiscoveryError,
http_service::HttpService,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Fetch the provider metadata.
async fn discover_inner(
http_service: &HttpService,
issuer: &Url,
) -> Result<ProviderMetadata, DiscoveryError> {
tracing::debug!("Fetching provider metadata...");
let mut config_url = issuer.clone();
// If the path doesn't end with a slash, the last segment is removed when
// using `join`.
if !config_url.path().ends_with('/') {
let mut path = config_url.path().to_owned();
path.push('/');
config_url.set_path(&path);
}
let config_url = config_url.join(".well-known/openid-configuration")?;
let config_req = http::Request::get(config_url.as_str()).body(Bytes::new())?;
let service = (
JsonResponseLayer::<ProviderMetadata>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let response = service.ready_oneshot().await?.call(config_req).await?;
tracing::debug!(?response);
Ok(response.into_body())
}
/// Fetch the provider metadata and validate it.
///
/// # Errors
///
/// Returns an error if the request fails or if the data is invalid.
#[tracing::instrument(skip_all, fields(issuer))]
pub async fn discover(
http_service: &HttpService,
issuer: &Url,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let provider_metadata = discover_inner(http_service, issuer).await?;
Ok(provider_metadata.validate(issuer)?)
}
/// Fetch the [provider metadata] and make basic checks.
///
/// Contrary to [`discover()`], this uses
/// [`ProviderMetadata::insecure_verify_metadata()`] to check the received
/// metadata instead of validating it according to the specification.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `issuer` - The URL of the OpenID Connect Provider to fetch metadata for.
///
/// # Errors
///
/// Returns an error if the request fails or if the data is invalid.
///
/// # Warning
///
/// It is not recommended to use this method in production as it doesn't
/// ensure that the issuer implements the proper security practices.
///
/// [provider metadata]: https://openid.net/specs/openid-connect-discovery-1_0.html
#[tracing::instrument(skip_all, fields(issuer))]
pub async fn insecure_discover(
http_service: &HttpService,
issuer: &Url,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let provider_metadata = discover_inner(http_service, issuer).await?;
Ok(provider_metadata.insecure_verify_metadata()?)
}

View File

@ -0,0 +1,153 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for [Token Introspection].
//!
//! [Token Introspection]: https://www.rfc-editor.org/rfc/rfc7662
use chrono::{DateTime, Utc};
use headers::{Authorization, HeaderMapExt};
use http::Request;
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_iana::oauth::OAuthTokenTypeHint;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
use rand::Rng;
use serde::Serialize;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::IntrospectionError,
http_service::HttpService,
types::client_credentials::{ClientCredentials, RequestWithClientCredentials},
utils::{http_all_error_status_codes, http_error_mapper},
};
/// The method used to authenticate at the introspection endpoint.
pub enum IntrospectionAuthentication<'a> {
/// Using client authentication.
Credentials(ClientCredentials),
/// Using a bearer token.
BearerToken(&'a str),
}
impl<'a> IntrospectionAuthentication<'a> {
/// Constructs an `IntrospectionAuthentication` from the given client
/// credentials.
#[must_use]
pub fn with_client_credentials(credentials: ClientCredentials) -> Self {
Self::Credentials(credentials)
}
/// Constructs an `IntrospectionAuthentication` from the given bearer token.
#[must_use]
pub fn with_bearer_token(token: &'a str) -> Self {
Self::BearerToken(token)
}
fn apply_to_request<T: Serialize>(
self,
request: Request<T>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Request<RequestWithClientCredentials<T>>, IntrospectionError> {
let res = match self {
IntrospectionAuthentication::Credentials(client_credentials) => {
client_credentials.apply_to_request(request, now, rng)?
}
IntrospectionAuthentication::BearerToken(access_token) => {
let (mut parts, body) = request.into_parts();
parts
.headers
.typed_insert(Authorization::bearer(access_token)?);
let body = RequestWithClientCredentials {
body,
credentials: None,
};
http::Request::from_parts(parts, body)
}
};
Ok(res)
}
}
impl<'a> From<ClientCredentials> for IntrospectionAuthentication<'a> {
fn from(credentials: ClientCredentials) -> Self {
Self::with_client_credentials(credentials)
}
}
/// Obtain information about a token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `authentication` - The method used to authenticate the request.
///
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
///
/// * `token` - The token to introspect.
///
/// * `token_type_hint` - Hint about the type of the token.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(introspection_endpoint))]
pub async fn introspect_token(
http_service: &HttpService,
authentication: IntrospectionAuthentication<'_>,
introspection_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<IntrospectionResponse, IntrospectionError> {
tracing::debug!("Introspecting token…");
let introspection_request = IntrospectionRequest {
token,
token_type_hint,
};
let introspection_request =
http::Request::post(introspection_endpoint.as_str()).body(introspection_request)?;
let introspection_request = authentication.apply_to_request(introspection_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<IntrospectionResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let introspection_response = service
.ready_oneshot()
.await?
.call(introspection_request)
.await?
.into_body();
Ok(introspection_response)
}

View File

@ -0,0 +1,223 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests and method related to JSON Object Signing and Encryption.
use std::collections::HashMap;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use mas_http::JsonResponseLayer;
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims::{self, TimeOptions},
jwk::PublicJsonWebKeySet,
jwt::Jwt,
};
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::{IdTokenError, JwksError, JwtVerificationError},
http_service::HttpService,
types::IdToken,
};
/// Fetch a JWKS at the given URL.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `jwks_uri` - The URL where the JWKS can be retrieved.
///
/// # Errors
///
/// Returns an error if the request fails or if the data is invalid.
#[tracing::instrument(skip_all, fields(jwks_uri))]
pub async fn fetch_jwks(
http_service: &HttpService,
jwks_uri: &Url,
) -> Result<PublicJsonWebKeySet, JwksError> {
tracing::debug!("Fetching JWKS...");
let jwks_request = http::Request::get(jwks_uri.as_str()).body(Bytes::new())?;
let service = JsonResponseLayer::<PublicJsonWebKeySet>::default().layer(http_service.clone());
let response = service.ready_oneshot().await?.call(jwks_request).await?;
Ok(response.into_body())
}
/// The data required to verify a JWT.
#[derive(Clone, Copy)]
pub struct JwtVerificationData<'a> {
/// The URL of the issuer that generated the ID Token.
pub issuer: &'a Url,
/// The issuer's JWKS.
pub jwks: &'a PublicJsonWebKeySet,
/// The ID obtained when registering the client.
pub client_id: &'a String,
/// The JWA that should have been used to sign the JWT, as set during
/// client registration.
pub signing_algorithm: &'a JsonWebSignatureAlg,
}
/// Decode and verify a signed JWT.
///
/// The following checks are performed:
///
/// * The signature is verified with the given JWKS.
///
/// * The `iss` claim must be present and match the issuer.
///
/// * The `aud` claim must be present and match the client ID.
///
/// * The `alg` in the header must match the signing algorithm.
///
/// # Arguments
///
/// * `jwt` - The serialized JWT to decode and verify.
///
/// * `jwks` - The JWKS that should contain the public key to verify the JWT's
/// signature.
///
/// * `issuer` - The issuer of the JWT.
///
/// * `audience` - The audience that the JWT is intended for.
///
/// * `signing_algorithm` - The JWA that should have been used to sign the JWT.
///
/// # Errors
///
/// Returns an error if the data is invalid or verification fails.
pub fn verify_signed_jwt<'a>(
jwt: &'a str,
verification_data: JwtVerificationData<'_>,
) -> Result<Jwt<'a, HashMap<String, Value>>, JwtVerificationError> {
tracing::debug!("Validating JWT...");
let JwtVerificationData {
issuer,
jwks,
client_id,
signing_algorithm,
} = verification_data;
let jwt: Jwt<HashMap<String, Value>> = jwt.try_into()?;
jwt.verify_with_jwks(jwks)?;
let (header, mut claims) = jwt.clone().into_parts();
// Must have the proper issuer.
let iss = claims::ISS.extract_required(&mut claims)?;
if iss != issuer.as_str() {
return Err(JwtVerificationError::WrongIssuer);
}
// Must have the proper audience.
let aud = claims::AUD.extract_required(&mut claims)?;
if !aud.contains(client_id) {
return Err(JwtVerificationError::WrongAudience);
}
// Must use the proper algorithm.
if header.alg() != signing_algorithm {
return Err(JwtVerificationError::WrongSignatureAlg);
}
Ok(jwt)
}
/// Decode and verify an ID Token.
///
/// Besides the checks of [`verify_signed_jwt()`], the following checks are
/// performed:
///
/// * The `exp` claim must be present and the token must not have expired.
///
/// * The `iat` claim must be present must be in the past.
///
/// * The `sub` claim must be present.
///
/// If an authorization ID token is provided, these extra checks are performed:
///
/// * The `sub` claims must match.
///
/// * The `auth_time` claims must match.
///
/// # Arguments
///
/// * `id_token` - The serialized ID Token to decode and verify.
///
/// * `verification_data` - The data necessary to verify the ID Token.
///
/// * `auth_id_token` - If the ID Token is not verified during an authorization
/// request, the ID token that was returned from the latest authorization
/// request.
///
/// # Errors
///
/// Returns an error if the data is invalid or verification fails.
pub fn verify_id_token<'a>(
id_token: &'a str,
verification_data: JwtVerificationData<'_>,
auth_id_token: Option<&IdToken<'_>>,
now: DateTime<Utc>,
) -> Result<IdToken<'a>, IdTokenError> {
let id_token = verify_signed_jwt(id_token, verification_data)?;
let mut claims = id_token.payload().clone();
let time_options = TimeOptions::new(now);
// Must not have expired.
claims::EXP.extract_required_with_options(&mut claims, &time_options)?;
// `iat` claim must be present.
claims::IAT.extract_required_with_options(&mut claims, time_options)?;
// Subject identifier must be present.
let sub = claims::SUB.extract_required(&mut claims)?;
// No more checks if there is no previous ID token.
let auth_id_token = match auth_id_token {
Some(id_token) => id_token,
None => return Ok(id_token),
};
let mut auth_claims = auth_id_token.payload().clone();
// Subject identifier must always be the same.
let auth_sub = claims::SUB.extract_required(&mut auth_claims)?;
if sub != auth_sub {
return Err(IdTokenError::WrongSubjectIdentifier);
}
// If the authentication time is present, it must be unchanged.
if let Some(auth_time) = claims::AUTH_TIME.extract_optional(&mut claims)? {
let prev_auth_time = claims::AUTH_TIME.extract_required(&mut auth_claims)?;
if prev_auth_time != auth_time {
return Err(IdTokenError::WrongAuthTime);
}
}
Ok(id_token)
}

View File

@ -0,0 +1,26 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Methods to interact with OpenID Connect and OAuth2.0 endpoints.
pub mod authorization_code;
pub mod client_credentials;
pub mod discovery;
pub mod introspection;
pub mod jose;
pub mod refresh_token;
pub mod registration;
pub mod revocation;
pub mod token;
pub mod userinfo;

View File

@ -0,0 +1,128 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for using [Refresh Tokens].
//!
//! [Refresh Tokens]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
use chrono::{DateTime, Utc};
use mas_jose::claims::{self, TokenHash};
use oauth2_types::{
requests::{AccessTokenRequest, AccessTokenResponse, RefreshTokenGrant},
scope::Scope,
};
use rand::Rng;
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{IdTokenError, TokenRefreshError},
http_service::HttpService,
requests::{jose::verify_id_token, token::request_access_token},
types::{client_credentials::ClientCredentials, IdToken},
};
/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a
/// token with a new scope.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `refresh_token` - The token used to refresh the access token returned at
/// the Token endpoint.
///
/// * `scope` - The scope of the access token. The requested scope must not
/// include any scope not originally granted to the access token, and if
/// omitted is treated as equal to the scope originally granted by the issuer.
///
/// * `id_token_verification_data` - The data required to verify the ID Token in
/// the response.
///
/// The signing algorithm corresponds to the `id_token_signed_response_alg`
/// field in the client metadata.
///
/// If it is not provided, the ID Token won't be verified.
///
/// * `auth_id_token` - If an ID Token is expected in the response, the ID token
/// that was returned from the latest authorization request.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// verification of the ID Token fails.
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn refresh_access_token(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
refresh_token: String,
scope: Option<Scope>,
id_token_verification_data: Option<JwtVerificationData<'_>>,
auth_id_token: Option<&IdToken<'_>>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenRefreshError> {
tracing::debug!("Refreshing access token…");
let token_response = request_access_token(
http_service,
client_credentials,
token_endpoint,
AccessTokenRequest::RefreshToken(RefreshTokenGrant {
refresh_token,
scope,
}),
now,
rng,
)
.await?;
let id_token = if let Some((verification_data, id_token)) =
id_token_verification_data.zip(token_response.id_token.as_ref())
{
let auth_id_token = auth_id_token.ok_or(IdTokenError::MissingAuthIdToken)?;
let signing_alg = verification_data.signing_algorithm;
let id_token = verify_id_token(id_token, verification_data, Some(auth_id_token), now)?;
let mut claims = id_token.payload().clone();
// Access token hash must match.
claims::AT_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(signing_alg, &token_response.access_token),
)
.map_err(IdTokenError::from)?;
Some(id_token.into_owned())
} else {
None
};
Ok((token_response, id_token))
}

View File

@ -0,0 +1,82 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for [Dynamic Registration].
//!
//! [Dynamic Registration]: https://openid.net/specs/openid-connect-registration-1_0.html
use mas_http::{CatchHttpCodesLayer, JsonRequestLayer, JsonResponseLayer};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use oauth2_types::registration::{ClientRegistrationResponse, VerifiedClientMetadata};
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::RegistrationError,
http_service::HttpService,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Register a client with an OpenID Provider.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `registration_endpoint` - The URL of the issuer's Registration endpoint.
///
/// * `client_metadata` - The metadata to register with the issuer.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(registration_endpoint))]
pub async fn register_client(
http_service: &HttpService,
registration_endpoint: &Url,
client_metadata: VerifiedClientMetadata,
) -> Result<ClientRegistrationResponse, RegistrationError> {
tracing::debug!("Registering client...");
let registration_req =
http::Request::post(registration_endpoint.as_str()).body(client_metadata.clone())?;
let service = (
JsonRequestLayer::default(),
JsonResponseLayer::<ClientRegistrationResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let response = service
.ready_oneshot()
.await?
.call(registration_req)
.await?
.into_body();
match client_metadata.token_endpoint_auth_method() {
OAuthClientAuthenticationMethod::ClientSecretPost
| OAuthClientAuthenticationMethod::ClientSecretBasic
| OAuthClientAuthenticationMethod::ClientSecretJwt => {
response
.client_secret
.as_ref()
.ok_or(RegistrationError::MissingClientSecret)?;
}
_ => {}
}
Ok(response)
}

View File

@ -0,0 +1,90 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for [Token Revocation].
//!
//! [Token Revocation]: https://www.rfc-editor.org/rfc/rfc7009.html
use chrono::{DateTime, Utc};
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer};
use mas_iana::oauth::OAuthTokenTypeHint;
use oauth2_types::requests::IntrospectionRequest;
use rand::Rng;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::TokenRevokeError,
http_service::HttpService,
types::client_credentials::ClientCredentials,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Revoke a token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
///
/// * `token` - The token to revoke.
///
/// * `token_type_hint` - Hint about the type of the token.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(revocation_endpoint))]
pub async fn revoke_token(
http_service: &HttpService,
client_credentials: ClientCredentials,
revocation_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(), TokenRevokeError> {
tracing::debug!("Revoking token…");
let request = IntrospectionRequest {
token,
token_type_hint,
};
let revocation_request = http::Request::post(revocation_endpoint.as_str()).body(request)?;
let revocation_request = client_credentials.apply_to_request(revocation_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
service
.ready_oneshot()
.await?
.call(revocation_request)
.await?;
Ok(())
}

View File

@ -0,0 +1,78 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for the Token endpoint.
use chrono::{DateTime, Utc};
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use oauth2_types::requests::{AccessTokenRequest, AccessTokenResponse};
use rand::Rng;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::TokenRequestError,
http_service::HttpService,
types::client_credentials::ClientCredentials,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Request an access token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `request` - The request to make at the Token endpoint.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(token_endpoint, request))]
pub async fn request_access_token(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
request: AccessTokenRequest,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<AccessTokenResponse, TokenRequestError> {
tracing::debug!(?request, "Requesting access token...");
let token_request = http::Request::post(token_endpoint.as_str()).body(request)?;
let token_request = client_credentials.apply_to_request(token_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<AccessTokenResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let res = service.ready_oneshot().await?.call(token_request).await?;
let token_response = res.into_body();
Ok(token_response)
}

View File

@ -0,0 +1,139 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for obtaining [Claims] about an end-user.
//!
//! [Claims]: https://openid.net/specs/openid-connect-core-1_0.html#Claims
use std::collections::HashMap;
use bytes::Bytes;
use headers::{Authorization, HeaderMapExt, HeaderValue};
use http::header::{ACCEPT, CONTENT_TYPE};
use mas_http::CatchHttpCodesLayer;
use mas_jose::claims;
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{IdTokenError, UserInfoError},
http_service::HttpService,
requests::jose::verify_signed_jwt,
types::IdToken,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Obtain information about an authenticated end-user.
///
/// Returns a map of claims with their value, that should be extracted with
/// one of the [`Claim`] methods.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `userinfo_endpoint` - The URL of the issuer's User Info endpoint.
///
/// * `access_token` - The access token of the end-user.
///
/// * `jwt_verification_data` - The data required to verify the response if a
/// signed response was requested during client registration.
///
/// The signing algorithm corresponds to the `userinfo_signed_response_alg`
/// field in the client metadata.
///
/// * `auth_id_token` - The ID token that was returned from the latest
/// authorization request.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// validation of the signed response fails.
///
/// [`Claim`]: mas_jose::claims::Claim
#[tracing::instrument(skip_all, fields(userinfo_endpoint))]
pub async fn fetch_userinfo(
http_service: &HttpService,
userinfo_endpoint: &Url,
access_token: &str,
jwt_verification_data: Option<JwtVerificationData<'_>>,
auth_id_token: &IdToken<'_>,
) -> Result<HashMap<String, Value>, UserInfoError> {
tracing::debug!("Obtaining user info…");
let mut userinfo_request = http::Request::get(userinfo_endpoint.as_str());
let expected_content_type = if jwt_verification_data.is_some() {
"application/jwt"
} else {
mime::APPLICATION_JSON.as_ref()
};
if let Some(headers) = userinfo_request.headers_mut() {
headers.typed_insert(Authorization::bearer(access_token)?);
headers.insert(ACCEPT, HeaderValue::from_static(expected_content_type));
}
let userinfo_request = userinfo_request.body(Bytes::new())?;
let service = CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper)
.layer(http_service.clone());
let userinfo_response = service
.ready_oneshot()
.await?
.call(userinfo_request)
.await?;
let content_type = userinfo_response
.headers()
.get(CONTENT_TYPE)
.ok_or(UserInfoError::MissingResponseContentType)?
.to_str()?;
if content_type != expected_content_type {
return Err(UserInfoError::InvalidResponseContentType {
expected: expected_content_type.to_owned(),
got: content_type.to_owned(),
});
}
let response_body = std::str::from_utf8(userinfo_response.body())?;
let mut claims = if let Some(verification_data) = jwt_verification_data {
verify_signed_jwt(response_body, verification_data)
.map_err(IdTokenError::from)?
.into_parts()
.1
} else {
serde_json::from_str(response_body)?
};
let mut auth_claims = auth_id_token.payload().clone();
// Subject identifier must always be the same.
let sub = claims::SUB
.extract_required(&mut claims)
.map_err(IdTokenError::from)?;
let auth_sub = claims::SUB
.extract_required(&mut auth_claims)
.map_err(IdTokenError::from)?;
if sub != auth_sub {
return Err(IdTokenError::WrongSubjectIdentifier.into());
}
Ok(claims)
}

View File

@ -0,0 +1,669 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Types and methods for client credentials.
use std::{collections::HashMap, fmt};
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use headers::{Authorization, HeaderMapExt};
use http::Request;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{
claims::{self, ClaimError},
jwa::SymmetricKey,
jwt::{JsonWebSignatureHeader, Jwt},
};
#[cfg(feature = "keystore")]
use mas_keystore::Keystore;
use rand::Rng;
use serde::Serialize;
use serde_json::Value;
use serde_with::skip_serializing_none;
use tower::BoxError;
use url::Url;
use crate::error::CredentialsError;
/// The supported authentication methods of this library.
///
/// During client registration, make sure that you only use one of the values
/// defined here.
pub const CLIENT_SUPPORTED_AUTH_METHODS: &[OAuthClientAuthenticationMethod] = &[
OAuthClientAuthenticationMethod::None,
OAuthClientAuthenticationMethod::ClientSecretBasic,
OAuthClientAuthenticationMethod::ClientSecretPost,
OAuthClientAuthenticationMethod::ClientSecretJwt,
OAuthClientAuthenticationMethod::PrivateKeyJwt,
];
/// A function that takes a map of claims and a signing algorithm and returns a
/// signed JWT.
pub type JwtSigningFn =
dyn Fn(HashMap<String, Value>, JsonWebSignatureAlg) -> Result<String, BoxError> + Send + Sync;
/// The method used to sign JWTs with a private key.
pub enum JwtSigningMethod {
/// Sign the JWTs with this library, by providing the signing keys.
#[cfg(feature = "keystore")]
Keystore(Keystore),
/// Sign the JWTs in a callback.
Custom(Box<JwtSigningFn>),
}
impl JwtSigningMethod {
/// Creates a new [`JwtSigningMethod`] from a [`Keystore`].
#[cfg(feature = "keystore")]
#[must_use]
pub fn with_keystore(keystore: Keystore) -> Self {
Self::Keystore(keystore)
}
/// Creates a new [`JwtSigningMethod`] from a [`JwtSigningFn`].
#[must_use]
pub fn with_custom_signing_method<F>(signing_fn: F) -> Self
where
F: Fn(HashMap<String, Value>, JsonWebSignatureAlg) -> Result<String, BoxError>
+ Send
+ Sync
+ 'static,
{
Self::Custom(Box::new(signing_fn))
}
/// Get the [`Keystore`] from this [`JwtSigningMethod`].
#[cfg(feature = "keystore")]
#[must_use]
pub fn keystore(&self) -> Option<&Keystore> {
match self {
JwtSigningMethod::Keystore(k) => Some(k),
JwtSigningMethod::Custom(_) => None,
}
}
/// Get the [`JwtSigningFn`] from this [`JwtSigningMethod`].
#[must_use]
pub fn jwt_custom(&self) -> Option<&JwtSigningFn> {
match self {
JwtSigningMethod::Custom(s) => Some(s),
JwtSigningMethod::Keystore(_) => None,
}
}
}
/// The credentials obtained during registration, to authenticate a client on
/// endpoints that require it.
pub enum ClientCredentials {
/// No client authentication is used.
///
/// This is used if the client is public.
None {
/// The unique ID for the client.
client_id: String,
},
/// The client authentication is sent via the Authorization HTTP header.
ClientSecretBasic {
/// The unique ID for the client.
client_id: String,
/// The secret of the client.
client_secret: String,
},
/// The client authentication is sent with the body of the request.
ClientSecretPost {
/// The unique ID for the client.
client_id: String,
/// The secret of the client.
client_secret: String,
},
/// The client authentication uses a JWT signed with a key derived from the
/// client secret.
ClientSecretJwt {
/// The unique ID for the client.
client_id: String,
/// The secret of the client.
client_secret: String,
/// The algorithm used to sign the JWT.
signing_algorithm: JsonWebSignatureAlg,
/// The URL of the issuer's Token endpoint.
token_endpoint: Url,
},
/// The client authentication uses a JWT signed with a private key.
PrivateKeyJwt {
/// The unique ID for the client.
client_id: String,
/// The method used to sign the JWT.
jwt_signing_method: JwtSigningMethod,
/// The algorithm used to sign the JWT.
signing_algorithm: JsonWebSignatureAlg,
/// The URL of the issuer's Token endpoint.
token_endpoint: Url,
},
}
impl ClientCredentials {
/// Get the client ID of these `ClientCredentials`.
#[must_use]
pub fn client_id(&self) -> &str {
match self {
ClientCredentials::None { client_id }
| ClientCredentials::ClientSecretBasic { client_id, .. }
| ClientCredentials::ClientSecretPost { client_id, .. }
| ClientCredentials::ClientSecretJwt { client_id, .. }
| ClientCredentials::PrivateKeyJwt { client_id, .. } => client_id,
}
}
/// Apply these `ClientCredentials` to the given request.
pub(crate) fn apply_to_request<T: Serialize>(
self,
request: Request<T>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Request<RequestWithClientCredentials<T>>, CredentialsError> {
let credentials = RequestClientCredentials::try_from_credentials(self, now, rng)?;
let (parts, body) = request.into_parts();
let mut body = RequestWithClientCredentials {
body,
credentials: None,
};
let request = match credentials {
RequestClientCredentials::Body(credentials) => {
body.credentials = Some(credentials);
Request::from_parts(parts, body)
}
RequestClientCredentials::Header(credentials) => {
let HeaderClientCredentials {
client_id,
client_secret,
} = credentials;
let mut request = Request::from_parts(parts, body);
// Encode the values with `application/x-www-form-urlencoded`.
let client_id =
form_urlencoded::byte_serialize(client_id.as_bytes()).collect::<String>();
let client_secret =
form_urlencoded::byte_serialize(client_secret.as_bytes()).collect::<String>();
let auth = Authorization::basic(&client_id, &client_secret);
request.headers_mut().typed_insert(auth);
request
}
};
Ok(request)
}
}
impl fmt::Debug for ClientCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None { client_id } => f
.debug_struct("None")
.field("client_id", client_id)
.finish(),
Self::ClientSecretBasic { client_id, .. } => f
.debug_struct("ClientSecretBasic")
.field("client_id", client_id)
.finish_non_exhaustive(),
Self::ClientSecretPost { client_id, .. } => f
.debug_struct("ClientSecretPost")
.field("client_id", client_id)
.finish_non_exhaustive(),
Self::ClientSecretJwt {
client_id,
signing_algorithm,
token_endpoint,
..
} => f
.debug_struct("ClientSecretJwt")
.field("client_id", client_id)
.field("signing_algorithm", signing_algorithm)
.field("token_endpoint", token_endpoint)
.finish_non_exhaustive(),
Self::PrivateKeyJwt {
client_id,
signing_algorithm,
token_endpoint,
..
} => f
.debug_struct("PrivateKeyJwt")
.field("client_id", client_id)
.field("signing_algorithm", signing_algorithm)
.field("token_endpoint", token_endpoint)
.finish_non_exhaustive(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
pub(crate) struct JwtBearerClientAssertionType;
enum RequestClientCredentials {
Body(BodyClientCredentials),
Header(HeaderClientCredentials),
}
impl RequestClientCredentials {
fn try_from_credentials(
credentials: ClientCredentials,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Self, CredentialsError> {
let res = match credentials {
ClientCredentials::None { client_id } => Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: None,
client_assertion_type: None,
}),
ClientCredentials::ClientSecretBasic {
client_id,
client_secret,
} => Self::Header(HeaderClientCredentials {
client_id,
client_secret,
}),
ClientCredentials::ClientSecretPost {
client_id,
client_secret,
} => Self::Body(BodyClientCredentials {
client_id,
client_secret: Some(client_secret),
client_assertion: None,
client_assertion_type: None,
}),
ClientCredentials::ClientSecretJwt {
client_id,
client_secret,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let key = SymmetricKey::new_for_alg(client_secret.into(), &signing_algorithm)?;
let header = JsonWebSignatureHeader::new(signing_algorithm);
let jwt = Jwt::sign(header, claims, &key)?;
Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: Some(jwt.to_string()),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
ClientCredentials::PrivateKeyJwt {
client_id,
jwt_signing_method,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let client_assertion = match jwt_signing_method {
#[cfg(feature = "keystore")]
JwtSigningMethod::Keystore(keystore) => {
let key = keystore
.signing_key_for_algorithm(&signing_algorithm)
.ok_or(CredentialsError::NoPrivateKeyFound)?;
let signer = key.params().signing_key_for_alg(&signing_algorithm)?;
let header = JsonWebSignatureHeader::new(signing_algorithm);
Jwt::sign(header, claims, &signer)?.to_string()
}
JwtSigningMethod::Custom(jwt_signing_fn) => {
jwt_signing_fn(claims, signing_algorithm)
.map_err(CredentialsError::Custom)?
}
};
Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: Some(client_assertion),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
};
Ok(res)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub(crate) struct BodyClientCredentials {
client_id: String,
client_secret: Option<String>,
client_assertion: Option<String>,
client_assertion_type: Option<JwtBearerClientAssertionType>,
}
#[derive(Debug, Clone)]
struct HeaderClientCredentials {
client_id: String,
client_secret: String,
}
fn prepare_claims(
iss: String,
aud: String,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<HashMap<String, Value>, ClaimError> {
let mut claims = HashMap::new();
claims::ISS.insert(&mut claims, iss.clone())?;
claims::SUB.insert(&mut claims, iss)?;
claims::AUD.insert(&mut claims, aud)?;
claims::IAT.insert(&mut claims, now)?;
claims::EXP.insert(&mut claims, now + Duration::minutes(5))?;
let mut jti = [0u8; 16];
rng.fill(&mut jti);
let jti = Base64UrlUnpadded::encode_string(&jti);
claims::JTI.insert(&mut claims, jti)?;
Ok(claims)
}
/// A request with client credentials added to it.
#[derive(Clone, Serialize)]
#[skip_serializing_none]
pub struct RequestWithClientCredentials<T: Serialize> {
#[serde(flatten)]
pub(crate) body: T,
#[serde(flatten)]
pub(crate) credentials: Option<BodyClientCredentials>,
}
#[cfg(test)]
mod test {
use assert_matches::assert_matches;
use headers::authorization::Basic;
#[cfg(feature = "keystore")]
use mas_keystore::{JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use super::*;
const CLIENT_ID: &str = "abcd$++";
const CLIENT_SECRET: &str = "xyz!;?";
const REQUEST_BODY: &str = "some_body";
#[derive(Serialize)]
struct Body {
body: &'static str,
}
fn now() -> DateTime<Utc> {
#[allow(clippy::disallowed_methods)]
Utc::now()
}
#[test]
fn serialize_credentials() {
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: None,
client_assertion_type: None,
})
.unwrap(),
"client_id=abcd%24%2B%2B"
);
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_assertion: None,
client_assertion_type: None,
})
.unwrap(),
"client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
);
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: Some(CLIENT_SECRET.to_owned()),
client_assertion_type: Some(JwtBearerClientAssertionType)
})
.unwrap(),
"client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
);
}
#[test]
fn serialize_request_with_credentials() {
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: None,
};
assert_eq!(serde_urlencoded::to_string(req).unwrap(), "body=some_body");
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: None,
client_assertion_type: None,
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B"
);
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_assertion: None,
client_assertion_type: None,
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
);
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: Some(CLIENT_SECRET.to_owned()),
client_assertion_type: Some(JwtBearerClientAssertionType),
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
);
}
#[tokio::test]
async fn build_request_none() {
let credentials = ClientCredentials::None {
client_id: CLIENT_ID.to_owned(),
};
let request = Request::new(Body { body: REQUEST_BODY });
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret, None);
assert_eq!(credentials.client_assertion, None);
assert_eq!(credentials.client_assertion_type, None);
}
#[tokio::test]
async fn build_request_client_secret_basic() {
let credentials = ClientCredentials::ClientSecretBasic {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
};
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
let auth = assert_matches!(
request.headers().typed_get::<Authorization<Basic>>(),
Some(auth) => auth
);
assert_eq!(
form_urlencoded::parse(auth.username().as_bytes())
.next()
.unwrap()
.0,
CLIENT_ID
);
assert_eq!(
form_urlencoded::parse(auth.password().as_bytes())
.next()
.unwrap()
.0,
CLIENT_SECRET
);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
assert_eq!(body.credentials, None);
}
#[tokio::test]
async fn build_request_client_secret_post() {
let credentials = ClientCredentials::ClientSecretPost {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
};
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret.unwrap(), CLIENT_SECRET);
assert_eq!(credentials.client_assertion, None);
assert_eq!(credentials.client_assertion_type, None);
}
#[tokio::test]
async fn build_request_client_secret_jwt() {
let credentials = ClientCredentials::ClientSecretJwt {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
signing_algorithm: JsonWebSignatureAlg::Hs256,
token_endpoint: Url::parse("http://localhost").unwrap(),
};
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret, None);
credentials.client_assertion.unwrap();
credentials.client_assertion_type.unwrap();
}
#[tokio::test]
#[cfg(feature = "keystore")]
async fn build_request_private_key_jwt() {
let rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let key = PrivateKey::generate_rsa(rng).unwrap();
let keystore = Keystore::new(JsonWebKeySet::<PrivateKey>::new(vec![JsonWebKey::new(key)]));
let jwt_signing_method = JwtSigningMethod::with_keystore(keystore);
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let credentials = ClientCredentials::PrivateKeyJwt {
client_id: CLIENT_ID.to_owned(),
jwt_signing_method,
signing_algorithm: JsonWebSignatureAlg::Rs256,
token_endpoint: Url::parse("http://localhost").unwrap(),
};
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret, None);
credentials.client_assertion.unwrap();
credentials.client_assertion_type.unwrap();
}
}

View File

@ -0,0 +1,31 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! OAuth 2.0 and OpenID Connect types.
pub mod client_credentials;
pub mod scope;
use std::collections::HashMap;
#[doc(inline)]
pub use mas_iana as iana;
use mas_jose::jwt::Jwt;
pub use oauth2_types::*;
use serde_json::Value;
/// An OpenID Connect [ID Token].
///
/// [ID Token]: https://openid.net/specs/openid-connect-core-1_0.html#IDToken
pub type IdToken<'a> = Jwt<'a, HashMap<String, Value>>;

View File

@ -0,0 +1,226 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Helpers types to use scopes.
use std::{fmt, str::FromStr};
use oauth2_types::scope::{InvalidScope, Scope, ScopeToken as StrScopeToken};
use crate::PrivString;
/// Tokens to define the scope of an access token or to request specific claims.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ScopeToken {
/// `openid`
///
/// Required for OpenID Connect requests.
Openid,
/// `profile`
///
/// Requests access to the end-user's profile.
Profile,
/// `email`
///
/// Requests access to the end-user's email address.
Email,
/// `address`
///
/// Requests access to the end-user's address.
Address,
/// `phone`
///
/// Requests access to the end-user's phone number.
Phone,
/// `offline_access`
///
/// Requests that an OAuth 2.0 refresh token be issued that can be used to
/// obtain an access token that grants access to the end-user's UserInfo
/// Endpoint even when the end-user is not present (not logged in).
OfflineAccess,
/// `urn:matrix:org.matrix.msc2967.client:api:*`
///
/// Requests access to the Matrix Client API.
MatrixApi,
/// `urn:matrix:org.matrix.msc2967.client:device:{device_id}`
///
/// Requests access to the Matrix device with the given `device_id`.
///
/// To access the device ID, use [`ScopeToken::matrix_device_id`].
MatrixDevice(PrivString),
/// Another scope token.
///
/// To access it's value use this type's `Display` implementation.
Custom(PrivString),
}
impl ScopeToken {
/// Creates a Matrix device scope token with the given device ID.
///
/// # Errors
///
/// Returns an error if the device ID string is not compatible with the
/// scope syntax.
pub fn try_with_matrix_device(device_id: String) -> Result<Self, InvalidScope> {
// Check that the device ID is compatible with the scope format.
StrScopeToken::from_str(&device_id)?;
Ok(Self::MatrixDevice(PrivString(device_id)))
}
/// Get the device ID of this scope token, if it is a
/// [`ScopeToken::MatrixDevice`].
#[must_use]
pub fn matrix_device_id(&self) -> Option<&str> {
match &self {
Self::MatrixDevice(id) => Some(&id.0),
_ => None,
}
}
}
impl fmt::Display for ScopeToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScopeToken::Openid => write!(f, "openid"),
ScopeToken::Profile => write!(f, "profile"),
ScopeToken::Email => write!(f, "email"),
ScopeToken::Address => write!(f, "address"),
ScopeToken::Phone => write!(f, "phone"),
ScopeToken::OfflineAccess => write!(f, "offline_access"),
ScopeToken::MatrixApi => write!(f, "urn:matrix:org.matrix.msc2967.client:api:*"),
ScopeToken::MatrixDevice(s) => {
write!(f, "urn:matrix:org.matrix.msc2967.client:device:{}", s.0)
}
ScopeToken::Custom(s) => f.write_str(&s.0),
}
}
}
impl From<StrScopeToken> for ScopeToken {
fn from(t: StrScopeToken) -> Self {
match &*t {
"openid" => Self::Openid,
"profile" => Self::Profile,
"email" => Self::Email,
"address" => Self::Address,
"phone" => Self::Phone,
"offline_access" => Self::OfflineAccess,
"urn:matrix:org.matrix.msc2967.client:api:*" => Self::MatrixApi,
s => {
if let Some(device_id) =
s.strip_prefix("urn:matrix:org.matrix.msc2967.client:device:")
{
Self::MatrixDevice(PrivString(device_id.to_owned()))
} else {
Self::Custom(PrivString(s.to_owned()))
}
}
}
}
}
impl From<ScopeToken> for StrScopeToken {
fn from(t: ScopeToken) -> Self {
let s = t.to_string();
match StrScopeToken::from_str(&s) {
Ok(t) => t,
Err(_) => unreachable!(),
}
}
}
impl FromStr for ScopeToken {
type Err = InvalidScope;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let t = StrScopeToken::from_str(s)?;
Ok(t.into())
}
}
/// Helpers for [`Scope`] to work with [`ScopeToken`].
pub trait ScopeExt {
/// Insert the given `ScopeToken` into this `Scope`.
fn insert_token(&mut self, token: ScopeToken) -> bool;
/// Whether this `Scope` contains the given `ScopeToken`.
fn contains_token(&self, token: &ScopeToken) -> bool;
}
impl ScopeExt for Scope {
fn insert_token(&mut self, token: ScopeToken) -> bool {
self.insert(token.into())
}
fn contains_token(&self, token: &ScopeToken) -> bool {
self.contains(&token.to_string())
}
}
impl FromIterator<ScopeToken> for Scope {
fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
iter.into_iter().map(Into::<StrScopeToken>::into).collect()
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::*;
#[test]
fn parse_scope_token() {
assert_eq!(ScopeToken::from_str("openid"), Ok(ScopeToken::Openid));
let scope =
ScopeToken::from_str("urn:matrix:org.matrix.msc2967.client:device:ABCDEFGHIJKL")
.unwrap();
assert_matches!(scope, ScopeToken::MatrixDevice(_));
assert_eq!(scope.matrix_device_id(), Some("ABCDEFGHIJKL"));
assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
}
#[test]
fn parse_scope() {
let scope = Scope::from_str("openid profile address").unwrap();
assert_eq!(scope.len(), 3);
assert!(scope.contains_token(&ScopeToken::Openid));
assert!(scope.contains_token(&ScopeToken::Profile));
assert!(scope.contains_token(&ScopeToken::Address));
assert!(!scope.contains_token(&ScopeToken::OfflineAccess));
}
#[test]
fn display_scope() {
let mut scope: Scope = [ScopeToken::Profile].into_iter().collect();
assert_eq!(scope.to_string(), "profile");
scope.insert_token(ScopeToken::MatrixApi);
assert_eq!(
scope.to_string(),
"profile urn:matrix:org.matrix.msc2967.client:api:*"
);
}
}

View File

@ -0,0 +1,41 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::RangeBounds;
use bytes::Buf;
use http::{Response, StatusCode};
use crate::error::ErrorBody;
pub fn http_error_mapper<T>(response: Response<T>) -> Option<ErrorBody>
where
T: Buf,
{
let body = response.into_body();
serde_json::from_reader(body.reader()).ok()
}
pub fn http_all_error_status_codes() -> impl RangeBounds<StatusCode> {
let client_errors_start_code = match StatusCode::from_u16(400) {
Ok(code) => code,
Err(_) => unreachable!(),
};
let server_errors_end_code = match StatusCode::from_u16(599) {
Ok(code) => code,
Err(_) => unreachable!(),
};
client_errors_start_code..=server_errors_end_code
}