From fd39479810db4bcd73604253ec6e4dff955f9afb Mon Sep 17 00:00:00 2001 From: Jakub Warmuz Date: Thu, 11 Jun 2015 10:11:49 +0000 Subject: [PATCH] Add an anti-replay nonce facility (fixes: #488). --- acme/jose/__init__.py | 6 ++- acme/jose/json_util.py | 12 ++++-- acme/jose/json_util_test.py | 41 ++++++++++++++++++ acme/jose/jws.py | 14 +++--- acme/jws.py | 59 +++++++++++++++++++++++++ acme/jws_test.py | 58 +++++++++++++++++++++++++ acme/messages2.py | 1 + letsencrypt/network2.py | 47 +++++++++++++++----- letsencrypt/tests/network2_test.py | 69 ++++++++++++++++++++++++------ 9 files changed, 275 insertions(+), 32 deletions(-) create mode 100644 acme/jws.py create mode 100644 acme/jws_test.py diff --git a/acme/jose/__init__.py b/acme/jose/__init__.py index db3258a3d..a4fe7008b 100644 --- a/acme/jose/__init__.py +++ b/acme/jose/__init__.py @@ -66,7 +66,11 @@ from acme.jose.jwk import ( JWKRSA, ) -from acme.jose.jws import JWS +from acme.jose.jws import ( + Header, + JWS, + Signature, +) from acme.jose.util import ( ComparableX509, diff --git a/acme/jose/json_util.py b/acme/jose/json_util.py index 0c91c3412..c7698ed8d 100644 --- a/acme/jose/json_util.py +++ b/acme/jose/json_util.py @@ -129,7 +129,8 @@ class JSONObjectWithFieldsMeta(abc.ABCMeta): keys are field attribute names and values are fields themselves. 2. ``cls.__slots__`` is extended by all field attribute names - (i.e. not :attr:`Field.json_name`). + (i.e. not :attr:`Field.json_name`). Original ``cls.__slots__`` + are stored in ``cls._orig_slots``. In a consequence, for a field attribute name ``some_field``, ``cls.some_field`` will be a slot descriptor and not an instance @@ -143,6 +144,7 @@ class JSONObjectWithFieldsMeta(abc.ABCMeta): some_field = some_field assert Foo.__slots__ == ('some_field', 'baz') + assert Foo._orig_slots == () assert Foo.some_field is not Field assert Foo._fields.keys() == ['some_field'] @@ -158,12 +160,16 @@ class JSONObjectWithFieldsMeta(abc.ABCMeta): def __new__(mcs, name, bases, dikt): fields = {} + + for base in bases: + fields.update(getattr(base, '_fields', {})) + # Do not reorder, this class might override fields from base classes! for key, value in dikt.items(): # not iterkeys() (in-place edit!) if isinstance(value, Field): fields[key] = dikt.pop(key) - dikt['__slots__'] = tuple( - list(dikt.get('__slots__', ())) + fields.keys()) + dikt['_orig_slots'] = dikt.get('__slots__', ()) + dikt['__slots__'] = tuple(list(dikt['_orig_slots']) + fields.keys()) dikt['_fields'] = fields return abc.ABCMeta.__new__(mcs, name, bases, dikt) diff --git a/acme/jose/json_util_test.py b/acme/jose/json_util_test.py index 5726ef2a8..a37ac08de 100644 --- a/acme/jose/json_util_test.py +++ b/acme/jose/json_util_test.py @@ -77,6 +77,47 @@ class FieldTest(unittest.TestCase): self.assertTrue(Field.default_decoder(mock_value) is mock_value) +class JSONObjectWithFieldsMetaTest(unittest.TestCase): + """Tests for acme.jose.json_util.JSONObjectWithFieldsMeta.""" + + def setUp(self): + from acme.jose.json_util import Field + from acme.jose.json_util import JSONObjectWithFieldsMeta + self.field = Field('Baz') + self.field2 = Field('Baz2') + # pylint: disable=invalid-name,missing-docstring,too-few-public-methods + # pylint: disable=blacklisted-name + class A(object): + __metaclass__ = JSONObjectWithFieldsMeta + __slots__ = ('bar',) + baz = self.field + class B(A): + pass + class C(A): + baz = self.field2 + self.a_cls = A + self.b_cls = B + self.c_cls = C + + def test_fields(self): + # pylint: disable=protected-access,no-member + self.assertEqual({'baz': self.field}, self.a_cls._fields) + self.assertEqual({'baz': self.field}, self.b_cls._fields) + + def test_fields_inheritance(self): + # pylint: disable=protected-access,no-member + self.assertEqual({'baz': self.field2}, self.c_cls._fields) + + def test_slots(self): + self.assertEqual(('bar', 'baz'), self.a_cls.__slots__) + self.assertEqual(('baz',), self.b_cls.__slots__) + + def test_orig_slots(self): + # pylint: disable=protected-access,no-member + self.assertEqual(('bar',), self.a_cls._orig_slots) + self.assertEqual((), self.b_cls._orig_slots) + + class JSONObjectWithFieldsTest(unittest.TestCase): """Tests for acme.jose.json_util.JSONObjectWithFields.""" # pylint: disable=protected-access diff --git a/acme/jose/jws.py b/acme/jose/jws.py index 06923e145..3ba60d40c 100644 --- a/acme/jose/jws.py +++ b/acme/jose/jws.py @@ -247,6 +247,8 @@ class JWS(json_util.JSONObjectWithFields): """ __slots__ = ('payload', 'signatures') + signature_cls = Signature + def verify(self, key=None): """Verify.""" return all(sig.verify(self.payload, key) for sig in self.signatures) @@ -255,13 +257,13 @@ class JWS(json_util.JSONObjectWithFields): def sign(cls, payload, **kwargs): """Sign.""" return cls(payload=payload, signatures=( - Signature.sign(payload=payload, **kwargs),)) + cls.signature_cls.sign(payload=payload, **kwargs),)) @property def signature(self): """Get a singleton signature. - :rtype: :class:`Signature` + :rtype: `signature_cls` """ assert len(self.signatures) == 1 @@ -288,8 +290,8 @@ class JWS(json_util.JSONObjectWithFields): raise errors.DeserializationError( 'Compact JWS serialization should comprise of exactly' ' 3 dot-separated components') - sig = Signature(protected=json_util.decode_b64jose(protected), - signature=json_util.decode_b64jose(signature)) + sig = cls.signature_cls(protected=json_util.decode_b64jose(protected), + signature=json_util.decode_b64jose(signature)) return cls(payload=json_util.decode_b64jose(payload), signatures=(sig,)) def to_partial_json(self, flat=True): # pylint: disable=arguments-differ @@ -312,10 +314,10 @@ class JWS(json_util.JSONObjectWithFields): raise errors.DeserializationError('Flat mixed with non-flat') elif 'signature' in jobj: # flat return cls(payload=json_util.decode_b64jose(jobj.pop('payload')), - signatures=(Signature.from_json(jobj),)) + signatures=(cls.signature_cls.from_json(jobj),)) else: return cls(payload=json_util.decode_b64jose(jobj['payload']), - signatures=tuple(Signature.from_json(sig) + signatures=tuple(cls.signature_cls.from_json(sig) for sig in jobj['signatures'])) class CLI(object): diff --git a/acme/jws.py b/acme/jws.py new file mode 100644 index 000000000..a23015d93 --- /dev/null +++ b/acme/jws.py @@ -0,0 +1,59 @@ +"""ACME JOSE JWS.""" +from acme import errors +from acme import jose + + +class Header(jose.Header): + """ACME JOSE Header. + + .. todo:: Implement ``acmePath``. + + """ + nonce = jose.Field('nonce', omitempty=True) + + @classmethod + def validate_nonce(cls, nonce): + """Validate nonce. + + :returns: ``None`` if ``nonce`` is valid, decoding errors otherwise. + + """ + try: + jose.b64decode(nonce) + except (ValueError, TypeError) as error: + return error + else: + return None + + @nonce.decoder + def nonce(value): # pylint: disable=missing-docstring,no-self-argument + error = Header.validate_nonce(value) + if error is not None: + # TODO: custom error + raise errors.Error("Invalid nonce: {0}".format(error)) + return value + + +class Signature(jose.Signature): + """ACME Signature.""" + __slots__ = jose.Signature._orig_slots # pylint: disable=no-member + + # TODO: decoder/encoder should accept cls? Otherwise, subclassing + # JSONObjectWithFields is tricky... + header_cls = Header + header = jose.Field( + 'header', omitempty=True, default=header_cls(), + decoder=header_cls.from_json) + + # TODO: decoder should check that nonce is in the protected header + + +class JWS(jose.JWS): + """ACME JWS.""" + signature_cls = Signature + __slots__ = jose.JWS._orig_slots # pylint: disable=no-member + + @classmethod + def sign(cls, payload, key, alg, nonce): # pylint: disable=arguments-differ + return super(JWS, cls).sign(payload, key=key, alg=alg, + protect=frozenset(['nonce']), nonce=nonce) diff --git a/acme/jws_test.py b/acme/jws_test.py new file mode 100644 index 000000000..f4a03f70d --- /dev/null +++ b/acme/jws_test.py @@ -0,0 +1,58 @@ +"""Tests for acme.jws.""" +import os +import pkg_resources +import unittest + +import Crypto.PublicKey.RSA + +from acme import errors +from acme import jose + + +RSA512_KEY = Crypto.PublicKey.RSA.importKey(pkg_resources.resource_string( + 'acme.jose', os.path.join('testdata', 'rsa512_key.pem'))) + + +class HeaderTest(unittest.TestCase): + """Tests for acme.jws.Header.""" + + good_nonce = jose.b64encode('foo') + wrong_nonce = 'F' + # Following just makes sure wrong_nonce is wrong + try: + jose.b64decode(wrong_nonce) + except (ValueError, TypeError): + assert True + else: + assert False # pragma: no cover + + def test_validate_nonce(self): + from acme.jws import Header + self.assertTrue(Header.validate_nonce(self.good_nonce) is None) + self.assertFalse(Header.validate_nonce(self.wrong_nonce) is None) + + def test_nonce_decoder(self): + from acme.jws import Header + nonce_field = Header._fields['nonce'] + + self.assertRaises(errors.Error, nonce_field.decode, self.wrong_nonce) + self.assertEqual(self.good_nonce, nonce_field.decode(self.good_nonce)) + + +class JWSTest(unittest.TestCase): + """Tests for acme.jws.JWS.""" + + def setUp(self): + self.privkey = jose.JWKRSA(key=RSA512_KEY) + self.pubkey = self.privkey.public() + self.nonce = jose.b64encode('Nonce') + + def test_it(self): + from acme.jws import JWS + jws = JWS.sign(payload='foo', key=self.privkey, + alg=jose.RS256, nonce=self.nonce) + JWS.from_json(jws.to_json()) + + +if __name__ == '__main__': + unittest.main() # pragma: no cover diff --git a/acme/messages2.py b/acme/messages2.py index 253aaa95b..15b4521de 100644 --- a/acme/messages2.py +++ b/acme/messages2.py @@ -16,6 +16,7 @@ class Error(jose.JSONObjectWithFields, Exception): 'unauthorized': 'The client lacks sufficient authorization', 'serverInternal': 'The server experienced an internal error', 'badCSR': 'The CSR is unacceptable (e.g., due to a short key)', + 'badNonce': 'The client sent an unacceptable anti-replay nonce', } typ = jose.Field('type') diff --git a/letsencrypt/network2.py b/letsencrypt/network2.py index faf23f414..ae8aa43af 100644 --- a/letsencrypt/network2.py +++ b/letsencrypt/network2.py @@ -10,6 +10,7 @@ import requests import werkzeug from acme import jose +from acme import jws as acme_jws from acme import messages2 from letsencrypt import errors @@ -33,26 +34,32 @@ class Network(object): """ + # TODO: Move below to acme module? DER_CONTENT_TYPE = 'application/pkix-cert' JSON_CONTENT_TYPE = 'application/json' JSON_ERROR_CONTENT_TYPE = 'application/problem+json' + REPLAY_NONCE_HEADER = 'Replay-Nonce' def __init__(self, new_reg_uri, key, alg=jose.RS256, verify_ssl=True): self.new_reg_uri = new_reg_uri self.key = key self.alg = alg self.verify_ssl = verify_ssl + self._nonces = set() - def _wrap_in_jws(self, obj): + def _wrap_in_jws(self, obj, nonce): """Wrap `JSONDeSerializable` object in JWS. + .. todo:: Implement ``acmePath``. + + :param JSONDeSerializable obj: :rtype: `.JWS` """ dumps = obj.json_dumps() logging.debug('Serialized JSON: %s', dumps) - return jose.JWS.sign( - payload=dumps, key=self.key, alg=self.alg).json_dumps() + return acme_jws.JWS.sign( + payload=dumps, key=self.key, alg=self.alg, nonce=nonce).json_dumps() @classmethod def _check_response(cls, response, content_type=None): @@ -126,9 +133,27 @@ class Network(object): self._check_response(response, content_type=content_type) return response - def _post(self, uri, data, content_type=JSON_CONTENT_TYPE, **kwargs): + def _add_nonce(self, response): + if self.REPLAY_NONCE_HEADER in response.headers: + nonce = response.headers[self.REPLAY_NONCE_HEADER] + error = acme_jws.Header.validate_nonce(nonce) + if error is None: + logging.debug('Storing nonce: %r', nonce) + self._nonces.add(nonce) + else: + raise errors.NetworkError('Invalid nonce ({0}): {1}'.format( + nonce, error)) + + def _get_nonce(self, uri): + if not self._nonces: + logging.debug('Requesting fresh nonce by sending HEAD to %s', uri) + self._add_nonce(requests.head(uri)) + return self._nonces.pop() + + def _post(self, uri, obj, content_type=JSON_CONTENT_TYPE, **kwargs): """Send POST data. + :param JSONDeSerializable obj: Will be wrapped in JWS. :param str content_type: Expected ``Content-Type``, fails if not set. :raises acme.messages2.NetworkError: @@ -137,6 +162,7 @@ class Network(object): :rtype: `requests.Response` """ + data = self._wrap_in_jws(obj, self._get_nonce(uri)) logging.debug('Sending POST data to %s: %s', uri, data) kwargs.setdefault('verify', self.verify_ssl) try: @@ -145,6 +171,7 @@ class Network(object): raise errors.NetworkError(error) logging.debug('Received response %s: %r', response, response.text) + self._add_nonce(response) self._check_response(response, content_type=content_type) return response @@ -182,7 +209,7 @@ class Network(object): """ new_reg = messages2.Registration(contact=contact) - response = self._post(self.new_reg_uri, self._wrap_in_jws(new_reg)) + response = self._post(self.new_reg_uri, new_reg) assert response.status_code == httplib.CREATED # TODO: handle errors regr = self._regr_from_response(response) @@ -219,7 +246,7 @@ class Network(object): :rtype: `.RegistrationResource` """ - response = self._post(regr.uri, self._wrap_in_jws(regr.body)) + response = self._post(regr.uri, regr.body) # TODO: Boulder returns httplib.ACCEPTED #assert response.status_code == httplib.OK @@ -280,7 +307,7 @@ class Network(object): """ new_authz = messages2.Authorization(identifier=identifier) - response = self._post(new_authzr_uri, self._wrap_in_jws(new_authz)) + response = self._post(new_authzr_uri, new_authz) assert response.status_code == httplib.CREATED # TODO: handle errors return self._authzr_from_response(response, identifier) @@ -316,7 +343,7 @@ class Network(object): :raises errors.UnexpectedUpdate: """ - response = self._post(challb.uri, self._wrap_in_jws(response)) + response = self._post(challb.uri, response) try: authzr_uri = response.links['up']['url'] except KeyError: @@ -395,7 +422,7 @@ class Network(object): content_type = self.DER_CONTENT_TYPE # TODO: add 'cert_type 'argument response = self._post( authzrs[0].new_cert_uri, # TODO: acme-spec #90 - self._wrap_in_jws(req), + req, content_type=content_type, headers={'Accept': content_type}) @@ -546,7 +573,7 @@ class Network(object): """ rev = messages2.Revocation(revoke=when, authorizations=tuple( authzr.uri for authzr in certr.authzrs)) - response = self._post(certr.uri, self._wrap_in_jws(rev)) + response = self._post(certr.uri, rev) if response.status_code != httplib.OK: raise errors.NetworkError( 'Successful revocation must return HTTP OK status') diff --git a/letsencrypt/tests/network2_test.py b/letsencrypt/tests/network2_test.py index 7bffcf0f4..ed155df2e 100644 --- a/letsencrypt/tests/network2_test.py +++ b/letsencrypt/tests/network2_test.py @@ -13,6 +13,7 @@ import requests from acme import challenges from acme import jose +from acme import jws as acme_jws from acme import messages2 from letsencrypt import account @@ -40,15 +41,23 @@ class NetworkTest(unittest.TestCase): # pylint: disable=too-many-instance-attributes,too-many-public-methods def setUp(self): - from letsencrypt.network2 import Network self.verify_ssl = mock.MagicMock() + self.wrap_in_jws = mock.MagicMock(return_value=mock.sentinel.wrapped) + + from letsencrypt.network2 import Network self.net = Network( new_reg_uri='https://www.letsencrypt-demo.org/acme/new-reg', key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl) + self.nonce = jose.b64encode('Nonce') + self.net._nonces.add(self.nonce) # pylint: disable=protected-access + self.response = mock.MagicMock(ok=True, status_code=httplib.OK) self.response.headers = {} self.response.links = {} + self.post = mock.MagicMock(return_value=self.response) + self.get = mock.MagicMock(return_value=self.response) + self.identifier = messages2.Identifier( typ=messages2.IDENTIFIER_FQDN, value='example.com') @@ -89,8 +98,8 @@ class NetworkTest(unittest.TestCase): def _mock_post_get(self): # pylint: disable=protected-access - self.net._post = mock.MagicMock(return_value=self.response) - self.net._get = mock.MagicMock(return_value=self.response) + self.net._post = self.post + self.net._get = self.get def test_init(self): self.assertTrue(self.net.verify_ssl is self.verify_ssl) @@ -106,8 +115,12 @@ class NetworkTest(unittest.TestCase): def from_json(cls, value): pass # pragma: no cover # pylint: disable=protected-access - jws = self.net._wrap_in_jws(MockJSONDeSerializable('foo')) - self.assertEqual(jose.JWS.json_loads(jws).payload, '"foo"') + jws_dump = self.net._wrap_in_jws( + MockJSONDeSerializable('foo'), nonce='Tg') + jws = acme_jws.JWS.json_loads(jws_dump) + self.assertEqual(jws.payload, '"foo"') + self.assertEqual(jws.signature.combined.nonce, 'Tg') + # TODO: check that nonce is in protected header def test_check_response_not_ok_jobj_no_error(self): self.response.ok = False @@ -169,33 +182,66 @@ class NetworkTest(unittest.TestCase): self.net._check_response.assert_called_once_with( requests_mock.get('uri'), content_type='ct') + def _mock_wrap_in_jws(self): + # pylint: disable=protected-access + self.net._wrap_in_jws = self.wrap_in_jws + @mock.patch('letsencrypt.network2.requests') def test_post_requests_error_passthrough(self, requests_mock): requests_mock.exceptions = requests.exceptions requests_mock.post.side_effect = requests.exceptions.RequestException # pylint: disable=protected-access - self.assertRaises(errors.NetworkError, self.net._post, 'uri', 'data') + self._mock_wrap_in_jws() + self.assertRaises( + errors.NetworkError, self.net._post, 'uri', mock.sentinel.obj) @mock.patch('letsencrypt.network2.requests') def test_post(self, requests_mock): # pylint: disable=protected-access self.net._check_response = mock.MagicMock() - self.net._post('uri', 'data', content_type='ct') + self._mock_wrap_in_jws() + self.net._post('uri', mock.sentinel.obj, content_type='ct') self.net._check_response.assert_called_once_with( - requests_mock.post('uri', 'data'), content_type='ct') + requests_mock.post('uri', mock.sentinel.wrapped), content_type='ct') + + @mock.patch('letsencrypt.network2.requests') + def test_post_reply_nonce_handling(self, requests_mock): + # pylint: disable=protected-access + self.net._check_response = mock.MagicMock() + self._mock_wrap_in_jws() + + self.net._nonces.clear() + nonce2 = jose.b64encode('Nonce2') + requests_mock.head('uri').headers = { + self.net.REPLAY_NONCE_HEADER: nonce2} + requests_mock.post('uri').headers = { + self.net.REPLAY_NONCE_HEADER: self.nonce} + + self.net._post('uri', mock.sentinel.obj) + + requests_mock.head.assert_called_with('uri') + self.wrap_in_jws.assert_called_once_with(mock.sentinel.obj, nonce2) + self.assertEqual(self.net._nonces, set([self.nonce])) + + # wrong nonce + requests_mock.post('uri').headers = {self.net.REPLAY_NONCE_HEADER: 'F'} + self.assertRaises( + errors.NetworkError, self.net._post, 'uri', mock.sentinel.obj) @mock.patch('letsencrypt.client.network2.requests') def test_get_post_verify_ssl(self, requests_mock): # pylint: disable=protected-access + self._mock_wrap_in_jws() self.net._check_response = mock.MagicMock() for verify_ssl in [True, False]: self.net.verify_ssl = verify_ssl self.net._get('uri') - self.net._post('uri', 'data') + self.net._nonces.add('N') + self.net._post('uri', mock.sentinel.obj) requests_mock.get.assert_called_once_with('uri', verify=verify_ssl) requests_mock.post.assert_called_once_with( - 'uri', data='data', verify=verify_ssl) + 'uri', data=mock.sentinel.wrapped, verify=verify_ssl) requests_mock.reset_mock() def test_register(self): @@ -498,8 +544,7 @@ class NetworkTest(unittest.TestCase): def test_revoke(self): self._mock_post_get() self.net.revoke(self.certr, when=messages2.Revocation.NOW) - # pylint: disable=protected-access - self.net._post.assert_called_once_with(self.certr.uri, mock.ANY) + self.post.assert_called_once_with(self.certr.uri, mock.ANY) def test_revoke_bad_status_raises_error(self): self.response.status_code = httplib.METHOD_NOT_ALLOWED