1
0
mirror of https://github.com/quay/quay.git synced 2026-01-26 06:21:37 +03:00
Files
quay/util/security/test/test_jwtutil.py
dependabot[bot] d2e5a69b26 build(deps): bump pyjwt from 2.4.0 to 2.8.0 (#2166)
* build(deps): bump pyjwt from 2.4.0 to 2.8.0

Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.4.0 to 2.8.0.
- [Release notes](https://github.com/jpadilla/pyjwt/releases)
- [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst)
- [Commits](https://github.com/jpadilla/pyjwt/compare/2.4.0...2.8.0)

---
updated-dependencies:
- dependency-name: pyjwt
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Expect new messages

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Oleg Bulatov <oleg@bulatov.me>
2023-10-03 14:08:00 +02:00

250 lines
7.0 KiB
Python

import time
import jwt
import pytest
from authlib.jose import jwk
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from util.security.jwtutil import (
InvalidAlgorithmError,
InvalidTokenError,
decode,
exp_max_s_option,
jwk_dict_to_public_key,
)
@pytest.fixture(scope="session")
def private_key():
return rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
@pytest.fixture(scope="session")
def private_key_pem(private_key):
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
@pytest.fixture(scope="session")
def public_key(private_key):
return private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def _token_data(audience, subject, iss, iat=None, exp=None, nbf=None):
return {
"iss": iss,
"aud": audience,
"nbf": nbf() if nbf is not None else int(time.time()),
"iat": iat() if iat is not None else int(time.time()),
"exp": exp() if exp is not None else int(time.time() + 3600),
"sub": subject,
}
@pytest.mark.parametrize("algorithm", [pytest.param("RS256"), pytest.param("RS384")])
@pytest.mark.parametrize(
"aud, iss, nbf, iat, exp, expected_exception",
[
pytest.param(
"invalidaudience",
"someissuer",
None,
None,
None,
"Audience doesn't match",
id="invalid audience",
),
pytest.param(
"someaudience", "invalidissuer", None, None, None, "Invalid issuer", id="invalid issuer"
),
pytest.param(
"someaudience",
"someissuer",
lambda: time.time() + 120,
None,
None,
"The token is not yet valid",
id="invalid not before",
),
pytest.param(
"someaudience",
"someissuer",
None,
lambda: time.time() + 120,
None,
"The token is not yet valid \\(iat\\)",
id="issued at in future",
),
pytest.param(
"someaudience",
"someissuer",
None,
None,
lambda: time.time() - 100,
"Signature has expired",
id="already expired",
),
pytest.param(
"someaudience",
"someissuer",
None,
None,
lambda: time.time() + 10000,
"Token was signed for more than",
id="expiration too far in future",
),
pytest.param(
"someaudience",
"someissuer",
lambda: time.time() + 10,
None,
None,
None,
id="not before in future by within leeway",
),
pytest.param(
"someaudience",
"someissuer",
None,
lambda: time.time() + 10,
None,
None,
id="issued at in future but within leeway",
),
pytest.param(
"someaudience",
"someissuer",
None,
None,
lambda: time.time() - 10,
None,
id="expiration in past but within leeway",
),
],
)
def test_decode_jwt_validation(
aud,
iss,
nbf,
iat,
exp,
expected_exception,
private_key_pem,
public_key,
algorithm,
):
token = jwt.encode(_token_data(aud, "subject", iss, iat, exp, nbf), private_key_pem, algorithm)
if expected_exception is not None:
with pytest.raises(InvalidTokenError) as ite:
max_exp = exp_max_s_option(3600)
decode(
token,
public_key,
algorithms=[algorithm],
audience="someaudience",
issuer="someissuer",
options=max_exp,
leeway=60,
)
assert ite.match(expected_exception)
else:
max_exp = exp_max_s_option(3600)
decode(
token,
public_key,
algorithms=[algorithm],
audience="someaudience",
issuer="someissuer",
options=max_exp,
leeway=60,
)
@pytest.mark.parametrize("algorithm", [pytest.param("RS256"), pytest.param("RS384")])
def test_decode_jwt_invalid_key(private_key_pem, algorithm):
# Encode with the test private key.
token = jwt.encode(_token_data("aud", "subject", "someissuer"), private_key_pem, algorithm)
# Try to decode with a different public key.
another_public_key = (
rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
)
with pytest.raises(InvalidTokenError) as ite:
max_exp = exp_max_s_option(3600)
decode(
token,
another_public_key,
algorithms=[algorithm],
audience="aud",
issuer="someissuer",
options=max_exp,
leeway=60,
)
assert ite.match("Signature verification failed")
@pytest.mark.parametrize("algorithm", [pytest.param("RS256"), pytest.param("RS384")])
def test_decode_jwt_invalid_algorithm(private_key_pem, public_key, algorithm):
# Encode with the test private key.
token = jwt.encode(_token_data("aud", "subject", "someissuer"), private_key_pem, algorithm)
# Attempt to decode but only with a different algorithm than that used.
with pytest.raises(InvalidAlgorithmError) as ite:
max_exp = exp_max_s_option(3600)
decode(
token,
public_key,
algorithms=["ES256"],
audience="aud",
issuer="someissuer",
options=max_exp,
leeway=60,
)
assert ite.match("are not whitelisted")
@pytest.mark.parametrize("algorithm", [pytest.param("RS256"), pytest.param("RS384")])
def test_jwk_dict_to_public_key(private_key, private_key_pem, algorithm):
public_key = private_key.public_key()
key_dict = jwk.dumps(
public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
)
converted = jwk_dict_to_public_key(key_dict)
# Encode with the test private key.
token = jwt.encode(_token_data("aud", "subject", "someissuer"), private_key_pem, algorithm)
# Decode with the converted key.
max_exp = exp_max_s_option(3600)
decode(
token,
converted,
algorithms=[algorithm],
audience="aud",
issuer="someissuer",
options=max_exp,
leeway=60,
)