mirror of
https://github.com/quay/quay.git
synced 2026-01-27 18:42:52 +03:00
145 lines
4.9 KiB
Python
145 lines
4.9 KiB
Python
import re
|
|
|
|
from calendar import timegm
|
|
from datetime import datetime, timedelta
|
|
from jwt import PyJWT
|
|
from jwt.exceptions import (
|
|
InvalidTokenError,
|
|
DecodeError,
|
|
InvalidAudienceError,
|
|
ExpiredSignatureError,
|
|
ImmatureSignatureError,
|
|
InvalidIssuedAtError,
|
|
InvalidIssuerError,
|
|
MissingRequiredClaimError,
|
|
InvalidAlgorithmError,
|
|
)
|
|
|
|
from authlib.jose import JsonWebKey, RSAKey, ECKey
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
|
|
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
|
|
|
|
|
|
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
|
|
TOKEN_REGEX = re.compile(r"\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z")
|
|
|
|
# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD
|
|
# `none` here!
|
|
ALGORITHM_WHITELIST = ["rs256", "hs256", "rs384"]
|
|
|
|
|
|
class _StrictJWT(PyJWT):
|
|
"""
|
|
_StrictJWT defines a JWT decoder with extra checks.
|
|
"""
|
|
|
|
@staticmethod
|
|
def _get_default_options():
|
|
# Weird syntax to call super on a staticmethod
|
|
defaults = super(_StrictJWT, _StrictJWT)._get_default_options()
|
|
defaults.update(
|
|
{
|
|
"require": ["nbf", "iat", "exp"],
|
|
"exp_max_s": None,
|
|
}
|
|
)
|
|
return defaults
|
|
|
|
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
|
|
if options.get("exp_max_s") is not None:
|
|
if "verify_expiration" in kwargs and not kwargs.get("verify_expiration"):
|
|
raise ValueError("exp_max_s option implies verify_expiration")
|
|
|
|
options["verify_exp"] = True
|
|
|
|
# Do all of the other checks
|
|
super(_StrictJWT, self)._validate_claims(
|
|
payload, options, audience, issuer, leeway, **kwargs
|
|
)
|
|
|
|
now = timegm(datetime.utcnow().utctimetuple())
|
|
self._reject_future_iat(payload, now, leeway)
|
|
|
|
if "exp" in payload and options.get("exp_max_s") is not None:
|
|
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
|
# or in the absence of an issue time, more than exp_max_s in the future from now
|
|
|
|
# This will work because the parent method already checked the type of exp
|
|
expiration = datetime.utcfromtimestamp(int(payload["exp"]))
|
|
max_signed_s = options.get("exp_max_s")
|
|
|
|
start_time = datetime.utcnow()
|
|
if "iat" in payload:
|
|
start_time = datetime.utcfromtimestamp(int(payload["iat"]))
|
|
|
|
if expiration > start_time + timedelta(seconds=max_signed_s):
|
|
raise InvalidTokenError(
|
|
"Token was signed for more than %s seconds from %s", max_signed_s, start_time
|
|
)
|
|
|
|
def _reject_future_iat(self, payload, now, leeway):
|
|
try:
|
|
iat = int(payload["iat"])
|
|
except ValueError:
|
|
raise DecodeError("Issued At claim (iat) must be an integer.")
|
|
|
|
if iat > (now + leeway):
|
|
raise InvalidIssuedAtError("Issued At claim (iat) cannot be in" " the future.")
|
|
|
|
|
|
def decode(jwt, key="", verify=True, algorithms=None, options=None, **kwargs):
|
|
"""
|
|
Decodes a JWT.
|
|
"""
|
|
if not algorithms:
|
|
raise InvalidAlgorithmError("algorithms must be specified")
|
|
|
|
normalized = set([a.lower() for a in algorithms])
|
|
if "none" in normalized:
|
|
raise InvalidAlgorithmError("`none` algorithm is not allowed")
|
|
|
|
if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized):
|
|
raise InvalidAlgorithmError(
|
|
"Algorithms `%s` are not whitelisted. Allowed: %s" % (algorithms, ALGORITHM_WHITELIST)
|
|
)
|
|
|
|
# verify is a legacy option in PyJWT, should be moved to options as verify_signature
|
|
if options is None:
|
|
options = {"verify_signature": verify}
|
|
elif "verify_signature" not in options:
|
|
options["verify_signature"] = verify
|
|
|
|
return _StrictJWT().decode(jwt, key, algorithms, options, **kwargs)
|
|
|
|
|
|
def exp_max_s_option(max_exp_s):
|
|
"""
|
|
Returns an options dictionary that sets the maximum expiration seconds for a JWT.
|
|
"""
|
|
return {
|
|
"exp_max_s": max_exp_s,
|
|
}
|
|
|
|
|
|
def jwk_dict_to_public_key(jwk_dict):
|
|
"""
|
|
Converts the specified JWK into a public key.
|
|
"""
|
|
jwk = JsonWebKey.import_key(jwk_dict)
|
|
|
|
if isinstance(jwk, RSAKey):
|
|
rsa_pk = jwk.as_key()
|
|
return RSAPublicNumbers(
|
|
e=rsa_pk.public_numbers().e, n=rsa_pk.public_numbers().n
|
|
).public_key(default_backend())
|
|
elif isinstance(jwk, ECKey):
|
|
ec_pk = jwk.as_key()
|
|
return EllipticCurvePublicNumbers(
|
|
x=ec_pk.public_numbers().x,
|
|
y=ec_pk.public_numbers().y,
|
|
curve=ec_pk.public_numbers().curve,
|
|
).public_key(default_backend())
|
|
|
|
raise Exception("Unsupported kind of JWK: %s", str(type(jwk)))
|