import logging import re from calendar import timegm from datetime import datetime, timedelta from authlib.jose import ECKey, JsonWebKey, RSAKey from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers from jwt import PyJWT, get_unverified_header from jwt.exceptions import ( DecodeError, ExpiredSignatureError, ImmatureSignatureError, InvalidAlgorithmError, InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError, InvalidTokenError, MissingRequiredClaimError, ) logger = logging.getLogger(__name__) # TOKEN_REGEX defines a regular expression for matching JWT bearer tokens. TOKEN_REGEX = re.compile(r"\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z") # ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD # `none` here! ALGORITHM_WHITELIST = ["rs256", "hs256", "rs384"] class _StrictJWT(PyJWT): """ _StrictJWT defines a JWT decoder with extra checks. """ @staticmethod def _get_default_options(): # Weird syntax to call super on a staticmethod defaults = super(_StrictJWT, _StrictJWT)._get_default_options() defaults.update( { "require": ["nbf", "iat", "exp"], "exp_max_s": None, } ) return defaults def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs): if options.get("exp_max_s") is not None: if "verify_expiration" in kwargs and not kwargs.get("verify_expiration"): raise ValueError("exp_max_s option implies verify_expiration") options["verify_exp"] = True # Do all of the other checks super(_StrictJWT, self)._validate_claims( payload, options, audience, issuer, leeway, **kwargs ) now = timegm(datetime.utcnow().utctimetuple()) self._reject_future_iat(payload, now, leeway) if "exp" in payload and options.get("exp_max_s") is not None: # Validate that the expiration was not more than exp_max_s seconds after the issue time # or in the absence of an issue time, more than exp_max_s in the future from now # This will work because the parent method already checked the type of exp expiration = datetime.utcfromtimestamp(int(payload["exp"])) max_signed_s = options.get("exp_max_s") start_time = datetime.utcnow() if "iat" in payload: start_time = datetime.utcfromtimestamp(int(payload["iat"])) if expiration > start_time + timedelta(seconds=max_signed_s): raise InvalidTokenError( "Token was signed for more than %s seconds from %s", max_signed_s, start_time ) def _reject_future_iat(self, payload, now, leeway): try: iat = int(payload["iat"]) except ValueError: raise DecodeError("Issued At claim (iat) must be an integer.") if iat > (now + leeway): raise InvalidIssuedAtError("Issued At claim (iat) cannot be in" " the future.") def decode(jwt, key="", verify=True, algorithms=None, options=None, **kwargs): """ Decodes a JWT. """ if not algorithms: raise InvalidAlgorithmError("algorithms must be specified") normalized = set([a.lower() for a in algorithms]) if "none" in normalized: raise InvalidAlgorithmError("`none` algorithm is not allowed") if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized): raise InvalidAlgorithmError( "Algorithms `%s` are not whitelisted. Allowed: %s" % (algorithms, ALGORITHM_WHITELIST) ) # verify is a legacy option in PyJWT, should be moved to options as verify_signature if options is None: options = {"verify_signature": verify} elif "verify_signature" not in options: options["verify_signature"] = verify return _StrictJWT().decode(jwt, key, algorithms, options, **kwargs) def exp_max_s_option(max_exp_s): """ Returns an options dictionary that sets the maximum expiration seconds for a JWT. """ return { "exp_max_s": max_exp_s, } def jwk_dict_to_public_key(jwk_dict): """ Converts the specified JWK into a public key. """ jwk = JsonWebKey.import_key(jwk_dict) if isinstance(jwk, RSAKey): rsa_pk = jwk.as_key() return RSAPublicNumbers( e=rsa_pk.public_numbers().e, n=rsa_pk.public_numbers().n ).public_key(default_backend()) elif isinstance(jwk, ECKey): ec_pk = jwk.as_key() return EllipticCurvePublicNumbers( x=ec_pk.public_numbers().x, y=ec_pk.public_numbers().y, curve=ec_pk.public_numbers().curve, ).public_key(default_backend()) raise Exception("Unsupported kind of JWK: %s", str(type(jwk))) def is_jwt(token): try: headers = get_unverified_header(token) return headers.get("typ", "").lower() == "jwt" except DecodeError: pass return False