diff --git a/acme/client.py b/acme/client.py index d1bf84a21..4a4192528 100644 --- a/acme/client.py +++ b/acme/client.py @@ -32,150 +32,18 @@ class Client(object): # pylint: disable=too-many-instance-attributes :ivar key: `.JWK` (private) :ivar alg: `.JWASignature` :ivar bool verify_ssl: Verify SSL certificates? + :ivar .ClientNetwork net: Client network. Useful for testing. If not + supplied, it will be initialized using `key`, `alg` and + `verify_ssl`. """ DER_CONTENT_TYPE = 'application/pkix-cert' - JSON_CONTENT_TYPE = 'application/json' - JSON_ERROR_CONTENT_TYPE = 'application/problem+json' - REPLAY_NONCE_HEADER = 'Replay-Nonce' - def __init__(self, new_reg_uri, key, alg=jose.RS256, verify_ssl=True): + def __init__(self, new_reg_uri, key, alg=jose.RS256, + verify_ssl=True, net=None): self.new_reg_uri = new_reg_uri self.key = key - self.alg = alg - self.verify_ssl = verify_ssl - self._nonces = set() - - def _wrap_in_jws(self, obj, nonce): - """Wrap `JSONDeSerializable` object in JWS. - - .. todo:: Implement ``acmePath``. - - :param JSONDeSerializable obj: - :rtype: `.JWS` - - """ - dumps = obj.json_dumps() - logger.debug('Serialized JSON: %s', dumps) - return jws.JWS.sign( - payload=dumps, key=self.key, alg=self.alg, nonce=nonce).json_dumps() - - @classmethod - def _check_response(cls, response, content_type=None): - """Check response content and its type. - - .. note:: - Checking is not strict: wrong server response ``Content-Type`` - HTTP header is ignored if response is an expected JSON object - (c.f. Boulder #56). - - :param str content_type: Expected Content-Type response header. - If JSON is expected and not present in server response, this - function will raise an error. Otherwise, wrong Content-Type - is ignored, but logged. - - :raises .messages.Error: If server response body - carries HTTP Problem (draft-ietf-appsawg-http-problem-00). - :raises .ClientError: In case of other networking errors. - - """ - logger.debug('Received response %s (headers: %s): %r', - response, response.headers, response.content) - - response_ct = response.headers.get('Content-Type') - try: - # TODO: response.json() is called twice, once here, and - # once in _get and _post clients - jobj = response.json() - except ValueError as error: - jobj = None - - if not response.ok: - if jobj is not None: - if response_ct != cls.JSON_ERROR_CONTENT_TYPE: - logger.debug( - 'Ignoring wrong Content-Type (%r) for JSON Error', - response_ct) - try: - raise messages.Error.from_json(jobj) - except jose.DeserializationError as error: - # Couldn't deserialize JSON object - raise errors.ClientError((response, error)) - else: - # response is not JSON object - raise errors.ClientError(response) - else: - if jobj is not None and response_ct != cls.JSON_CONTENT_TYPE: - logger.debug( - 'Ignoring wrong Content-Type (%r) for JSON decodable ' - 'response', response_ct) - - if content_type == cls.JSON_CONTENT_TYPE and jobj is None: - raise errors.ClientError( - 'Unexpected response Content-Type: {0}'.format(response_ct)) - - def _get(self, uri, content_type=JSON_CONTENT_TYPE, **kwargs): - """Send GET request. - - :raises .ClientError: - - :returns: HTTP Response - :rtype: `requests.Response` - - """ - logger.debug('Sending GET request to %s', uri) - kwargs.setdefault('verify', self.verify_ssl) - try: - response = requests.get(uri, **kwargs) - except requests.exceptions.RequestException as error: - raise errors.ClientError(error) - self._check_response(response, content_type=content_type) - return response - - def _add_nonce(self, response): - if self.REPLAY_NONCE_HEADER in response.headers: - nonce = response.headers[self.REPLAY_NONCE_HEADER] - error = jws.Header.validate_nonce(nonce) - if error is None: - logger.debug('Storing nonce: %r', nonce) - self._nonces.add(nonce) - else: - raise errors.ClientError('Invalid nonce ({0}): {1}'.format( - nonce, error)) - else: - raise errors.ClientError( - 'Server {0} response did not include a replay nonce'.format( - response.request.method)) - - def _get_nonce(self, uri): - if not self._nonces: - logger.debug('Requesting fresh nonce by sending HEAD to %s', uri) - self._add_nonce(requests.head(uri)) - return self._nonces.pop() - - def _post(self, uri, obj, content_type=JSON_CONTENT_TYPE, **kwargs): - """Send POST data. - - :param JSONDeSerializable obj: Will be wrapped in JWS. - :param str content_type: Expected ``Content-Type``, fails if not set. - - :raises acme.messages.ClientError: - - :returns: HTTP Response - :rtype: `requests.Response` - - """ - data = self._wrap_in_jws(obj, self._get_nonce(uri)) - logger.debug('Sending POST data to %s: %s', uri, data) - kwargs.setdefault('verify', self.verify_ssl) - try: - response = requests.post(uri, data=data, **kwargs) - except requests.exceptions.RequestException as error: - raise errors.ClientError(error) - - self._add_nonce(response) - self._check_response(response, content_type=content_type) - return response + self.net = ClientNetwork(key, alg, verify_ssl) if net is None else net @classmethod def _regr_from_response(cls, response, uri=None, new_authzr_uri=None, @@ -211,7 +79,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes """ new_reg = messages.Registration(contact=contact) - response = self._post(self.new_reg_uri, new_reg) + response = self.net.post(self.new_reg_uri, new_reg) assert response.status_code == httplib.CREATED # TODO: handle errors regr = self._regr_from_response(response) @@ -230,7 +98,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes :rtype: `.RegistrationResource` """ - response = self._post(regr.uri, regr.body) + response = self.net.post(regr.uri, regr.body) # TODO: Boulder returns httplib.ACCEPTED #assert response.status_code == httplib.OK @@ -290,7 +158,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes """ new_authz = messages.Authorization(identifier=identifier) - response = self._post(new_authzr_uri, new_authz) + response = self.net.post(new_authzr_uri, new_authz) assert response.status_code == httplib.CREATED # TODO: handle errors return self._authzr_from_response(response, identifier) @@ -326,7 +194,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes :raises .UnexpectedUpdate: """ - response = self._post(challb.uri, response) + response = self.net.post(challb.uri, response) try: authzr_uri = response.links['up']['url'] except KeyError: @@ -377,7 +245,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes :rtype: (`.AuthorizationResource`, `requests.Response`) """ - response = self._get(authzr.uri) + response = self.net.get(authzr.uri) updated_authzr = self._authzr_from_response( response, authzr.body.identifier, authzr.uri, authzr.new_cert_uri) # TODO: check and raise UnexpectedUpdate @@ -403,7 +271,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes csr=csr, authorizations=tuple(authzr.uri for authzr in authzrs)) content_type = self.DER_CONTENT_TYPE # TODO: add 'cert_type 'argument - response = self._post( + response = self.net.post( authzrs[0].new_cert_uri, # TODO: acme-spec #90 req, content_type=content_type, @@ -488,8 +356,8 @@ class Client(object): # pylint: disable=too-many-instance-attributes """ content_type = self.DER_CONTENT_TYPE # TODO: make it a param - response = self._get(uri, headers={'Accept': content_type}, - content_type=content_type) + response = self.net.get(uri, headers={'Accept': content_type}, + content_type=content_type) return response, jose.ComparableX509( M2Crypto.X509.load_cert_der_string(response.content)) @@ -551,8 +419,155 @@ class Client(object): # pylint: disable=too-many-instance-attributes :raises .ClientError: If revocation is unsuccessful. """ - response = self._post(messages.Revocation.url(self.new_reg_uri), - messages.Revocation(certificate=cert)) + response = self.net.post(messages.Revocation.url(self.new_reg_uri), + messages.Revocation(certificate=cert)) if response.status_code != httplib.OK: raise errors.ClientError( 'Successful revocation must return HTTP OK status') + + +class ClientNetwork(object): + """Client network.""" + JSON_CONTENT_TYPE = 'application/json' + JSON_ERROR_CONTENT_TYPE = 'application/problem+json' + REPLAY_NONCE_HEADER = 'Replay-Nonce' + + def __init__(self, key, alg=jose.RS256, verify_ssl=True): + self.key = key + self.alg = alg + self.verify_ssl = verify_ssl + self._nonces = set() + + def _wrap_in_jws(self, obj, nonce): + """Wrap `JSONDeSerializable` object in JWS. + + .. todo:: Implement ``acmePath``. + + :param JSONDeSerializable obj: + :rtype: `.JWS` + + """ + dumps = obj.json_dumps() + logger.debug('Serialized JSON: %s', dumps) + return jws.JWS.sign( + payload=dumps, key=self.key, alg=self.alg, nonce=nonce).json_dumps() + + @classmethod + def _check_response(cls, response, content_type=None): + """Check response content and its type. + + .. note:: + Checking is not strict: wrong server response ``Content-Type`` + HTTP header is ignored if response is an expected JSON object + (c.f. Boulder #56). + + :param str content_type: Expected Content-Type response header. + If JSON is expected and not present in server response, this + function will raise an error. Otherwise, wrong Content-Type + is ignored, but logged. + + :raises .messages.Error: If server response body + carries HTTP Problem (draft-ietf-appsawg-http-problem-00). + :raises .ClientError: In case of other networking errors. + + """ + logger.debug('Received response %s (headers: %s): %r', + response, response.headers, response.content) + + response_ct = response.headers.get('Content-Type') + try: + # TODO: response.json() is called twice, once here, and + # once in _get and _post clients + jobj = response.json() + except ValueError as error: + jobj = None + + if not response.ok: + if jobj is not None: + if response_ct != cls.JSON_ERROR_CONTENT_TYPE: + logger.debug( + 'Ignoring wrong Content-Type (%r) for JSON Error', + response_ct) + try: + raise messages.Error.from_json(jobj) + except jose.DeserializationError as error: + # Couldn't deserialize JSON object + raise errors.ClientError((response, error)) + else: + # response is not JSON object + raise errors.ClientError(response) + else: + if jobj is not None and response_ct != cls.JSON_CONTENT_TYPE: + logger.debug( + 'Ignoring wrong Content-Type (%r) for JSON decodable ' + 'response', response_ct) + + if content_type == cls.JSON_CONTENT_TYPE and jobj is None: + raise errors.ClientError( + 'Unexpected response Content-Type: {0}'.format(response_ct)) + + return response + + def _send_request(self, method, url, *args, **kwargs): + """Send HTTP request. + + Makes sure that `verify_ssl` is respected. Logs request and + response (with headers). For allowed parameters please see + `requests.request`. + + :param str method: method for the new `requests.Request` object + :param str url: URL for the new `requests.Request` object + + :raises requests.exceptions.RequestException: in case of any problems + + :returns: HTTP Response + :rtype: `requests.Response` + + + """ + logging.debug('Sending %s request to %s', method, url) + kwargs['verify'] = self.verify_ssl + response = requests.request(method, url, *args, **kwargs) + logging.debug('Received %s. Headers: %s. Content: %r', + response, response.headers, response.content) + return response + + def head(self, *args, **kwargs): + """Send HEAD request without checking the response. + + Note, that `_check_response` is not called, as it is expected + that status code other than successfuly 2xx will be returned, or + messages2.Error will be raised by the server. + + """ + return self._send_request('HEAD', *args, **kwargs) + + def get(self, url, content_type=JSON_CONTENT_TYPE, **kwargs): + """Send GET request and check response.""" + return self._check_response( + self._send_request('GET', url, **kwargs), content_type=content_type) + + def _add_nonce(self, response): + if self.REPLAY_NONCE_HEADER in response.headers: + nonce = response.headers[self.REPLAY_NONCE_HEADER] + error = jws.Header.validate_nonce(nonce) + if error is None: + logger.debug('Storing nonce: %r', nonce) + self._nonces.add(nonce) + else: + raise errors.BadNonce(nonce, error) + else: + raise errors.MissingNonce(response) + + def _get_nonce(self, url): + if not self._nonces: + logging.debug('Requesting fresh nonce') + self._add_nonce(self.head(url)) + return self._nonces.pop() + + def post(self, url, obj, content_type=JSON_CONTENT_TYPE, **kwargs): + """POST object wrapped in `.JWS` and check response.""" + data = self._wrap_in_jws(obj, self._get_nonce(url)) + response = self._send_request('POST', url, data=data, **kwargs) + self._add_nonce(response) + return self._check_response(response, content_type=content_type) diff --git a/acme/client_test.py b/acme/client_test.py index d408f0564..b934e1efd 100644 --- a/acme/client_test.py +++ b/acme/client_test.py @@ -23,27 +23,20 @@ KEY2 = jose.JWKRSA.load(pkg_resources.resource_string( class ClientTest(unittest.TestCase): - """Tests for acme.client.Client.""" - + """Tests for acme.client.Client.""" # pylint: disable=too-many-instance-attributes,too-many-public-methods def setUp(self): - self.verify_ssl = mock.MagicMock() - self.wrap_in_jws = mock.MagicMock(return_value=mock.sentinel.wrapped) + self.response = mock.MagicMock( + ok=True, status_code=httplib.OK, headers={}, links={}) + self.net = mock.MagicMock() + self.net.post.return_value = self.response + self.net.get.return_value = self.response from acme.client import Client - self.net = Client( + self.client = Client( new_reg_uri='https://www.letsencrypt-demo.org/acme/new-reg', - key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl) - self.nonce = jose.b64encode('Nonce') - self.net._nonces.add(self.nonce) # pylint: disable=protected-access - - self.response = mock.MagicMock(ok=True, status_code=httplib.OK) - self.response.headers = {} - self.response.links = {} - - self.post = mock.MagicMock(return_value=self.response) - self.get = mock.MagicMock(return_value=self.response) + key=KEY, alg=jose.RS256, net=self.net) self.identifier = messages.Identifier( typ=messages.IDENTIFIER_FQDN, value='example.com') @@ -78,10 +71,295 @@ class ClientTest(unittest.TestCase): uri='https://www.letsencrypt-demo.org/acme/cert/1', cert_chain_uri='https://www.letsencrypt-demo.org/ca') - def _mock_post_get(self): + def test_register(self): + self.response.status_code = httplib.CREATED + self.response.json.return_value = self.regr.body.to_json() + self.response.headers['Location'] = self.regr.uri + self.response.links.update({ + 'next': {'url': self.regr.new_authzr_uri}, + 'terms-of-service': {'url': self.regr.terms_of_service}, + }) + + self.assertEqual(self.regr, self.client.register(self.contact)) + # TODO: test POST call arguments + + # TODO: split here and separate test + reg_wrong_key = self.regr.body.update(key=KEY2.public()) + self.response.json.return_value = reg_wrong_key.to_json() + self.assertRaises( + errors.UnexpectedUpdate, self.client.register, self.contact) + + def test_register_missing_next(self): + self.response.status_code = httplib.CREATED + self.assertRaises( + errors.ClientError, self.client.register, self.regr.body) + + def test_update_registration(self): + self.response.headers['Location'] = self.regr.uri + self.response.json.return_value = self.regr.body.to_json() + self.assertEqual(self.regr, self.client.update_registration(self.regr)) + + # TODO: split here and separate test + self.response.json.return_value = self.regr.body.update( + contact=()).to_json() + self.assertRaises( + errors.UnexpectedUpdate, self.client.update_registration, self.regr) + + def test_agree_to_tos(self): + self.client.update_registration = mock.Mock() + self.client.agree_to_tos(self.regr) + regr = self.client.update_registration.call_args[0][0] + self.assertEqual(self.regr.terms_of_service, regr.body.agreement) + + def test_request_challenges(self): + self.response.status_code = httplib.CREATED + self.response.headers['Location'] = self.authzr.uri + self.response.json.return_value = self.authz.to_json() + self.response.links = { + 'next': {'url': self.authzr.new_cert_uri}, + } + + self.client.request_challenges(self.identifier, self.authzr.uri) + # TODO: test POST call arguments + + # TODO: split here and separate test + self.response.json.return_value = self.authz.update( + identifier=self.identifier.update(value='foo')).to_json() + self.assertRaises( + errors.UnexpectedUpdate, self.client.request_challenges, + self.identifier, self.authzr.uri) + + def test_request_challenges_missing_next(self): + self.response.status_code = httplib.CREATED + self.assertRaises( + errors.ClientError, self.client.request_challenges, + self.identifier, self.regr) + + def test_request_domain_challenges(self): + self.client.request_challenges = mock.MagicMock() + self.assertEqual( + self.client.request_challenges(self.identifier), + self.client.request_domain_challenges('example.com', self.regr)) + + def test_answer_challenge(self): + self.response.links['up'] = {'url': self.challr.authzr_uri} + self.response.json.return_value = self.challr.body.to_json() + + chall_response = challenges.DNSResponse() + + self.client.answer_challenge(self.challr.body, chall_response) + + # TODO: split here and separate test + self.assertRaises(errors.UnexpectedUpdate, self.client.answer_challenge, + self.challr.body.update(uri='foo'), chall_response) + + def test_answer_challenge_missing_next(self): + self.assertRaises(errors.ClientError, self.client.answer_challenge, + self.challr.body, challenges.DNSResponse()) + + def test_retry_after_date(self): + self.response.headers['Retry-After'] = 'Fri, 31 Dec 1999 23:59:59 GMT' + self.assertEqual( + datetime.datetime(1999, 12, 31, 23, 59, 59), + self.client.retry_after(response=self.response, default=10)) + + @mock.patch('acme.client.datetime') + def test_retry_after_invalid(self, dt_mock): + dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) + dt_mock.timedelta = datetime.timedelta + + self.response.headers['Retry-After'] = 'foooo' + self.assertEqual( + datetime.datetime(2015, 3, 27, 0, 0, 10), + self.client.retry_after(response=self.response, default=10)) + + @mock.patch('acme.client.datetime') + def test_retry_after_seconds(self, dt_mock): + dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) + dt_mock.timedelta = datetime.timedelta + + self.response.headers['Retry-After'] = '50' + self.assertEqual( + datetime.datetime(2015, 3, 27, 0, 0, 50), + self.client.retry_after(response=self.response, default=10)) + + @mock.patch('acme.client.datetime') + def test_retry_after_missing(self, dt_mock): + dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) + dt_mock.timedelta = datetime.timedelta + + self.assertEqual( + datetime.datetime(2015, 3, 27, 0, 0, 10), + self.client.retry_after(response=self.response, default=10)) + + def test_poll(self): + self.response.json.return_value = self.authzr.body.to_json() + self.assertEqual((self.authzr, self.response), + self.client.poll(self.authzr)) + + # TODO: split here and separate test + self.response.json.return_value = self.authz.update( + identifier=self.identifier.update(value='foo')).to_json() + self.assertRaises( + errors.UnexpectedUpdate, self.client.poll, self.authzr) + + def test_request_issuance(self): + self.response.content = messages_test.CERT.as_der() + self.response.headers['Location'] = self.certr.uri + self.response.links['up'] = {'url': self.certr.cert_chain_uri} + self.assertEqual(self.certr, self.client.request_issuance( + messages_test.CSR, (self.authzr,))) + # TODO: check POST args + + def test_request_issuance_missing_up(self): + self.response.content = messages_test.CERT.as_der() + self.response.headers['Location'] = self.certr.uri + self.assertEqual( + self.certr.update(cert_chain_uri=None), + self.client.request_issuance(messages_test.CSR, (self.authzr,))) + + def test_request_issuance_missing_location(self): + self.assertRaises( + errors.ClientError, self.client.request_issuance, + messages_test.CSR, (self.authzr,)) + + @mock.patch('acme.client.datetime') + @mock.patch('acme.client.time') + def test_poll_and_request_issuance(self, time_mock, dt_mock): + # clock.dt | pylint: disable=no-member + clock = mock.MagicMock(dt=datetime.datetime(2015, 3, 27)) + + def sleep(seconds): + """increment clock""" + clock.dt += datetime.timedelta(seconds=seconds) + time_mock.sleep.side_effect = sleep + + def now(): + """return current clock value""" + return clock.dt + dt_mock.datetime.now.side_effect = now + dt_mock.timedelta = datetime.timedelta + + def poll(authzr): # pylint: disable=missing-docstring + # record poll start time based on the current clock value + authzr.times.append(clock.dt) + + # suppose it takes 2 seconds for server to produce the + # result, increment clock + clock.dt += datetime.timedelta(seconds=2) + + if not authzr.retries: # no more retries + done = mock.MagicMock(uri=authzr.uri, times=authzr.times) + done.body.status = messages.STATUS_VALID + return done, [] + + # response (2nd result tuple element) is reduced to only + # Retry-After header contents represented as integer + # seconds; authzr.retries is a list of Retry-After + # headers, head(retries) is peeled of as a current + # Retry-After header, and tail(retries) is persisted for + # later poll() calls + return (mock.MagicMock(retries=authzr.retries[1:], + uri=authzr.uri + '.', times=authzr.times), + authzr.retries[0]) + self.client.poll = mock.MagicMock(side_effect=poll) + + mintime = 7 + + def retry_after(response, default): # pylint: disable=missing-docstring + # check that poll_and_request_issuance correctly passes mintime + self.assertEqual(default, mintime) + return clock.dt + datetime.timedelta(seconds=response) + self.client.retry_after = mock.MagicMock(side_effect=retry_after) + + def request_issuance(csr, authzrs): # pylint: disable=missing-docstring + return csr, authzrs + self.client.request_issuance = mock.MagicMock( + side_effect=request_issuance) + + csr = mock.MagicMock() + authzrs = ( + mock.MagicMock(uri='a', times=[], retries=(8, 20, 30)), + mock.MagicMock(uri='b', times=[], retries=(5,)), + ) + + cert, updated_authzrs = self.client.poll_and_request_issuance( + csr, authzrs, mintime=mintime) + self.assertTrue(cert[0] is csr) + self.assertTrue(cert[1] is updated_authzrs) + self.assertEqual(updated_authzrs[0].uri, 'a...') + self.assertEqual(updated_authzrs[1].uri, 'b.') + self.assertEqual(updated_authzrs[0].times, [ + datetime.datetime(2015, 3, 27), + # a is scheduled for 10, but b is polling [9..11), so it + # will be picked up as soon as b is finished, without + # additional sleeping + datetime.datetime(2015, 3, 27, 0, 0, 11), + datetime.datetime(2015, 3, 27, 0, 0, 33), + datetime.datetime(2015, 3, 27, 0, 1, 5), + ]) + self.assertEqual(updated_authzrs[1].times, [ + datetime.datetime(2015, 3, 27, 0, 0, 2), + datetime.datetime(2015, 3, 27, 0, 0, 9), + ]) + self.assertEqual(clock.dt, datetime.datetime(2015, 3, 27, 0, 1, 7)) + + def test_check_cert(self): + self.response.headers['Location'] = self.certr.uri + self.response.content = messages_test.CERT.as_der() + self.assertEqual(self.certr.update(body=messages_test.CERT), + self.client.check_cert(self.certr)) + + # TODO: split here and separate test + self.response.headers['Location'] = 'foo' + self.assertRaises( + errors.UnexpectedUpdate, self.client.check_cert, self.certr) + + def test_check_cert_missing_location(self): + self.response.content = messages_test.CERT.as_der() + self.assertRaises( + errors.ClientError, self.client.check_cert, self.certr) + + def test_refresh(self): + self.client.check_cert = mock.MagicMock() + self.assertEqual( + self.client.check_cert(self.certr), self.client.refresh(self.certr)) + + def test_fetch_chain(self): # pylint: disable=protected-access - self.net._post = self.post - self.net._get = self.get + self.client._get_cert = mock.MagicMock() + self.client._get_cert.return_value = ("response", "certificate") + self.assertEqual(self.client._get_cert(self.certr.cert_chain_uri)[1], + self.client.fetch_chain(self.certr)) + + def test_fetch_chain_no_up_link(self): + self.assertTrue(self.client.fetch_chain(self.certr.update( + cert_chain_uri=None)) is None) + + def test_revoke(self): + self.client.revoke(self.certr.body) + self.net.post.assert_called_once_with(messages.Revocation.url( + self.client.new_reg_uri), mock.ANY) + + def test_revoke_bad_status_raises_error(self): + self.response.status_code = httplib.METHOD_NOT_ALLOWED + self.assertRaises(errors.ClientError, self.client.revoke, self.certr) + + +class ClientNetworkTest(unittest.TestCase): + """Tests for acme.client.ClientNetwork.""" + + def setUp(self): + self.verify_ssl = mock.MagicMock() + self.wrap_in_jws = mock.MagicMock(return_value=mock.sentinel.wrapped) + + from acme.client import ClientNetwork + self.net = ClientNetwork( + key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl) + + self.response = mock.MagicMock(ok=True, status_code=httplib.OK) + self.response.headers = {} + self.response.links = {} def test_init(self): self.assertTrue(self.net.verify_ssl is self.verify_ssl) @@ -139,384 +417,127 @@ class ClientTest(unittest.TestCase): self.response.json.side_effect = ValueError for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: self.response.headers['Content-Type'] = response_ct - # pylint: disable=protected-access - self.net._check_response(self.response) + # pylint: disable=protected-access,no-value-for-parameter + self.assertEqual( + self.response, self.net._check_response(self.response)) def test_check_response_jobj(self): self.response.json.return_value = {} for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: self.response.headers['Content-Type'] = response_ct + # pylint: disable=protected-access,no-value-for-parameter + self.assertEqual( + self.response, self.net._check_response(self.response)) + + @mock.patch('acme.client.requests') + def test_send_request(self, mock_requests): + mock_requests.request.return_value = self.response + # pylint: disable=protected-access + self.assertEqual(self.response, self.net._send_request( + 'HEAD', 'url', 'foo', bar='baz')) + mock_requests.request.assert_called_once_with( + 'HEAD', 'url', 'foo', verify=mock.ANY, bar='baz') + + @mock.patch('acme.client.requests') + def test_send_request_verify_ssl(self, mock_requests): + # pylint: disable=protected-access + for verify in True, False: + mock_requests.request.reset_mock() + mock_requests.request.return_value = self.response + self.net.verify_ssl = verify # pylint: disable=protected-access - self.net._check_response(self.response) + self.assertEqual( + self.response, self.net._send_request('GET', 'url')) + mock_requests.request.assert_called_once_with( + 'GET', 'url', verify=verify) @mock.patch('acme.client.requests') - def test_get_requests_error_passthrough(self, requests_mock): - requests_mock.exceptions = requests.exceptions - requests_mock.get.side_effect = requests.exceptions.RequestException + def test_requests_error_passthrough(self, mock_requests): + mock_requests.exceptions = requests.exceptions + mock_requests.request.side_effect = requests.exceptions.RequestException # pylint: disable=protected-access - self.assertRaises(errors.ClientError, self.net._get, 'uri') + self.assertRaises(requests.exceptions.RequestException, + self.net._send_request, 'GET', 'uri') + + +class ClientNetworkWithMockedResponseTest(unittest.TestCase): + """Tests for acme.client.ClientNetwork which mock out response.""" + # pylint: disable=too-many-instance-attributes + + def setUp(self): + from acme.client import ClientNetwork + self.net = ClientNetwork(key=None, alg=None) + + self.response = mock.MagicMock(ok=True, status_code=httplib.OK) + self.response.headers = {} + self.response.links = {} + self.checked_response = mock.MagicMock() + self.obj = mock.MagicMock() + self.wrapped_obj = mock.MagicMock() + self.content_type = mock.sentinel.content_type + + self.all_nonces = [jose.b64encode('Nonce'), jose.b64encode('Nonce2')] + self.available_nonces = self.all_nonces[:] + def send_request(*args, **kwargs): + # pylint: disable=unused-argument,missing-docstring + if self.available_nonces: + self.response.headers = { + self.net.REPLAY_NONCE_HEADER: self.available_nonces.pop()} + else: + self.response.headers = {} + return self.response - @mock.patch('acme.client.requests') - def test_get(self, requests_mock): # pylint: disable=protected-access - self.net._check_response = mock.MagicMock() - self.net._get('uri', content_type='ct') - self.net._check_response.assert_called_once_with( - requests_mock.get('uri'), content_type='ct') + self.net._send_request = self.send_request = mock.MagicMock( + side_effect=send_request) + self.net._check_response = self.check_response + self.net._wrap_in_jws = mock.MagicMock(return_value=self.wrapped_obj) - def _mock_wrap_in_jws(self): + def check_response(self, response, content_type): + # pylint: disable=missing-docstring + self.assertEqual(self.response, response) + self.assertEqual(self.content_type, content_type) + return self.checked_response + + def test_head(self): + self.assertEqual(self.response, self.net.head('url', 'foo', bar='baz')) + self.send_request.assert_called_once('HEAD', 'url', 'foo', bar='baz') + + def test_get(self): + self.assertEqual(self.checked_response, self.net.get( + 'url', content_type=self.content_type, bar='baz')) + self.send_request.assert_called_once_with('GET', 'url', bar='baz') + + def test_post(self): # pylint: disable=protected-access - self.net._wrap_in_jws = self.wrap_in_jws + self.assertEqual(self.checked_response, self.net.post( + 'uri', self.obj, content_type=self.content_type)) + self.net._wrap_in_jws.assert_called_once_with( + self.obj, self.all_nonces.pop()) - @mock.patch('acme.client.requests') - def test_post_requests_error_passthrough(self, requests_mock): - requests_mock.exceptions = requests.exceptions - requests_mock.post.side_effect = requests.exceptions.RequestException - # pylint: disable=protected-access - self._mock_wrap_in_jws() - self.assertRaises( - errors.ClientError, self.net._post, 'uri', mock.sentinel.obj) + assert not self.available_nonces + self.assertRaises(errors.MissingNonce, self.net.post, + 'uri', self.obj, content_type=self.content_type) + self.net._wrap_in_jws.assert_called_with( + self.obj, self.all_nonces.pop()) - @mock.patch('acme.client.requests') - def test_post(self, requests_mock): - # pylint: disable=protected-access - self.net._check_response = mock.MagicMock() - self._mock_wrap_in_jws() - requests_mock.post().headers = { - self.net.REPLAY_NONCE_HEADER: self.nonce} - self.net._post('uri', mock.sentinel.obj, content_type='ct') - self.net._check_response.assert_called_once_with( - requests_mock.post('uri', mock.sentinel.wrapped), content_type='ct') + def test_post_wrong_initial_nonce(self): # HEAD + self.available_nonces = ['f', jose.b64encode('good')] + self.assertRaises(errors.BadNonce, self.net.post, 'uri', + self.obj, content_type=self.content_type) - @mock.patch('acme.client.requests') - def test_post_replay_nonce_handling(self, requests_mock): - # pylint: disable=protected-access - self.net._check_response = mock.MagicMock() - self._mock_wrap_in_jws() + def test_post_wrong_post_response_nonce(self): + self.available_nonces = [jose.b64encode('good'), 'f'] + self.assertRaises(errors.BadNonce, self.net.post, 'uri', + self.obj, content_type=self.content_type) - self.net._nonces.clear() - self.assertRaises( - errors.ClientError, self.net._post, 'uri', mock.sentinel.obj) - - nonce2 = jose.b64encode('Nonce2') - requests_mock.head('uri').headers = { - self.net.REPLAY_NONCE_HEADER: nonce2} - requests_mock.post('uri').headers = { - self.net.REPLAY_NONCE_HEADER: self.nonce} - - self.net._post('uri', mock.sentinel.obj) - - requests_mock.head.assert_called_with('uri') - self.wrap_in_jws.assert_called_once_with(mock.sentinel.obj, nonce2) - self.assertEqual(self.net._nonces, set([self.nonce])) - - # wrong nonce - requests_mock.post('uri').headers = {self.net.REPLAY_NONCE_HEADER: 'F'} - self.assertRaises( - errors.ClientError, self.net._post, 'uri', mock.sentinel.obj) - - @mock.patch('acme.client.requests') - def test_get_post_verify_ssl(self, requests_mock): - # pylint: disable=protected-access - self._mock_wrap_in_jws() - self.net._check_response = mock.MagicMock() - - for verify_ssl in [True, False]: - self.net.verify_ssl = verify_ssl - self.net._get('uri') - self.net._nonces.add('N') - requests_mock.post().headers = { - self.net.REPLAY_NONCE_HEADER: self.nonce} - self.net._post('uri', mock.sentinel.obj) - requests_mock.get.assert_called_once_with('uri', verify=verify_ssl) - requests_mock.post.assert_called_with( - 'uri', data=mock.sentinel.wrapped, verify=verify_ssl) - requests_mock.reset_mock() - - def test_register(self): - self.response.status_code = httplib.CREATED - self.response.json.return_value = self.regr.body.to_json() - self.response.headers['Location'] = self.regr.uri - self.response.links.update({ - 'next': {'url': self.regr.new_authzr_uri}, - 'terms-of-service': {'url': self.regr.terms_of_service}, - }) - - self._mock_post_get() - self.assertEqual(self.regr, self.net.register(self.contact)) - # TODO: test POST call arguments - - # TODO: split here and separate test - reg_wrong_key = self.regr.body.update(key=KEY2.public()) - self.response.json.return_value = reg_wrong_key.to_json() - self.assertRaises( - errors.UnexpectedUpdate, self.net.register, self.contact) - - def test_register_missing_next(self): - self.response.status_code = httplib.CREATED - self._mock_post_get() - self.assertRaises( - errors.ClientError, self.net.register, self.regr.body) - - def test_update_registration(self): - self.response.headers['Location'] = self.regr.uri - self.response.json.return_value = self.regr.body.to_json() - self._mock_post_get() - self.assertEqual(self.regr, self.net.update_registration(self.regr)) - - # TODO: split here and separate test - self.response.json.return_value = self.regr.body.update( - contact=()).to_json() - self.assertRaises( - errors.UnexpectedUpdate, self.net.update_registration, self.regr) - - def test_agree_to_tos(self): - self.net.update_registration = mock.Mock() - self.net.agree_to_tos(self.regr) - regr = self.net.update_registration.call_args[0][0] - self.assertEqual(self.regr.terms_of_service, regr.body.agreement) - - def test_request_challenges(self): - self.response.status_code = httplib.CREATED - self.response.headers['Location'] = self.authzr.uri - self.response.json.return_value = self.authz.to_json() - self.response.links = { - 'next': {'url': self.authzr.new_cert_uri}, - } - - self._mock_post_get() - self.net.request_challenges(self.identifier, self.authzr.uri) - # TODO: test POST call arguments - - # TODO: split here and separate test - self.response.json.return_value = self.authz.update( - identifier=self.identifier.update(value='foo')).to_json() - self.assertRaises(errors.UnexpectedUpdate, self.net.request_challenges, - self.identifier, self.authzr.uri) - - def test_request_challenges_missing_next(self): - self.response.status_code = httplib.CREATED - self._mock_post_get() - self.assertRaises( - errors.ClientError, self.net.request_challenges, - self.identifier, self.regr) - - def test_request_domain_challenges(self): - self.net.request_challenges = mock.MagicMock() - self.assertEqual( - self.net.request_challenges(self.identifier), - self.net.request_domain_challenges('example.com', self.regr)) - - def test_answer_challenge(self): - self.response.links['up'] = {'url': self.challr.authzr_uri} - self.response.json.return_value = self.challr.body.to_json() - - chall_response = challenges.DNSResponse() - - self._mock_post_get() - self.net.answer_challenge(self.challr.body, chall_response) - - # TODO: split here and separate test - self.assertRaises(errors.UnexpectedUpdate, self.net.answer_challenge, - self.challr.body.update(uri='foo'), chall_response) - - def test_answer_challenge_missing_next(self): - self._mock_post_get() - self.assertRaises(errors.ClientError, self.net.answer_challenge, - self.challr.body, challenges.DNSResponse()) - - def test_retry_after_date(self): - self.response.headers['Retry-After'] = 'Fri, 31 Dec 1999 23:59:59 GMT' - self.assertEqual( - datetime.datetime(1999, 12, 31, 23, 59, 59), - self.net.retry_after(response=self.response, default=10)) - - @mock.patch('acme.client.datetime') - def test_retry_after_invalid(self, dt_mock): - dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) - dt_mock.timedelta = datetime.timedelta - - self.response.headers['Retry-After'] = 'foooo' - self.assertEqual( - datetime.datetime(2015, 3, 27, 0, 0, 10), - self.net.retry_after(response=self.response, default=10)) - - @mock.patch('acme.client.datetime') - def test_retry_after_seconds(self, dt_mock): - dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) - dt_mock.timedelta = datetime.timedelta - - self.response.headers['Retry-After'] = '50' - self.assertEqual( - datetime.datetime(2015, 3, 27, 0, 0, 50), - self.net.retry_after(response=self.response, default=10)) - - @mock.patch('acme.client.datetime') - def test_retry_after_missing(self, dt_mock): - dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) - dt_mock.timedelta = datetime.timedelta - - self.assertEqual( - datetime.datetime(2015, 3, 27, 0, 0, 10), - self.net.retry_after(response=self.response, default=10)) - - def test_poll(self): - self.response.json.return_value = self.authzr.body.to_json() - self._mock_post_get() - self.assertEqual((self.authzr, self.response), - self.net.poll(self.authzr)) - - # TODO: split here and separate test - self.response.json.return_value = self.authz.update( - identifier=self.identifier.update(value='foo')).to_json() - self.assertRaises(errors.UnexpectedUpdate, self.net.poll, self.authzr) - - def test_request_issuance(self): - self.response.content = messages_test.CERT.as_der() - self.response.headers['Location'] = self.certr.uri - self.response.links['up'] = {'url': self.certr.cert_chain_uri} - self._mock_post_get() - self.assertEqual(self.certr, self.net.request_issuance( - messages_test.CSR, (self.authzr,))) - # TODO: check POST args - - def test_request_issuance_missing_up(self): - self.response.content = messages_test.CERT.as_der() - self.response.headers['Location'] = self.certr.uri - self._mock_post_get() - self.assertEqual( - self.certr.update(cert_chain_uri=None), - self.net.request_issuance(messages_test.CSR, (self.authzr,))) - - def test_request_issuance_missing_location(self): - self._mock_post_get() - self.assertRaises( - errors.ClientError, self.net.request_issuance, - messages_test.CSR, (self.authzr,)) - - @mock.patch('acme.client.datetime') - @mock.patch('acme.client.time') - def test_poll_and_request_issuance(self, time_mock, dt_mock): - # clock.dt | pylint: disable=no-member - clock = mock.MagicMock(dt=datetime.datetime(2015, 3, 27)) - - def sleep(seconds): - """increment clock""" - clock.dt += datetime.timedelta(seconds=seconds) - time_mock.sleep.side_effect = sleep - - def now(): - """return current clock value""" - return clock.dt - dt_mock.datetime.now.side_effect = now - dt_mock.timedelta = datetime.timedelta - - def poll(authzr): # pylint: disable=missing-docstring - # record poll start time based on the current clock value - authzr.times.append(clock.dt) - - # suppose it takes 2 seconds for server to produce the - # result, increment clock - clock.dt += datetime.timedelta(seconds=2) - - if not authzr.retries: # no more retries - done = mock.MagicMock(uri=authzr.uri, times=authzr.times) - done.body.status = messages.STATUS_VALID - return done, [] - - # response (2nd result tuple element) is reduced to only - # Retry-After header contents represented as integer - # seconds; authzr.retries is a list of Retry-After - # headers, head(retries) is peeled of as a current - # Retry-After header, and tail(retries) is persisted for - # later poll() calls - return (mock.MagicMock(retries=authzr.retries[1:], - uri=authzr.uri + '.', times=authzr.times), - authzr.retries[0]) - self.net.poll = mock.MagicMock(side_effect=poll) - - mintime = 7 - - def retry_after(response, default): # pylint: disable=missing-docstring - # check that poll_and_request_issuance correctly passes mintime - self.assertEqual(default, mintime) - return clock.dt + datetime.timedelta(seconds=response) - self.net.retry_after = mock.MagicMock(side_effect=retry_after) - - def request_issuance(csr, authzrs): # pylint: disable=missing-docstring - return csr, authzrs - self.net.request_issuance = mock.MagicMock(side_effect=request_issuance) - - csr = mock.MagicMock() - authzrs = ( - mock.MagicMock(uri='a', times=[], retries=(8, 20, 30)), - mock.MagicMock(uri='b', times=[], retries=(5,)), - ) - - cert, updated_authzrs = self.net.poll_and_request_issuance( - csr, authzrs, mintime=mintime) - self.assertTrue(cert[0] is csr) - self.assertTrue(cert[1] is updated_authzrs) - self.assertEqual(updated_authzrs[0].uri, 'a...') - self.assertEqual(updated_authzrs[1].uri, 'b.') - self.assertEqual(updated_authzrs[0].times, [ - datetime.datetime(2015, 3, 27), - # a is scheduled for 10, but b is polling [9..11), so it - # will be picked up as soon as b is finished, without - # additional sleeping - datetime.datetime(2015, 3, 27, 0, 0, 11), - datetime.datetime(2015, 3, 27, 0, 0, 33), - datetime.datetime(2015, 3, 27, 0, 1, 5), - ]) - self.assertEqual(updated_authzrs[1].times, [ - datetime.datetime(2015, 3, 27, 0, 0, 2), - datetime.datetime(2015, 3, 27, 0, 0, 9), - ]) - self.assertEqual(clock.dt, datetime.datetime(2015, 3, 27, 0, 1, 7)) - - def test_check_cert(self): - self.response.headers['Location'] = self.certr.uri - self.response.content = messages_test.CERT.as_der() - self._mock_post_get() - self.assertEqual(self.certr.update(body=messages_test.CERT), - self.net.check_cert(self.certr)) - - # TODO: split here and separate test - self.response.headers['Location'] = 'foo' - self.assertRaises( - errors.UnexpectedUpdate, self.net.check_cert, self.certr) - - def test_check_cert_missing_location(self): - self.response.content = messages_test.CERT.as_der() - self._mock_post_get() - self.assertRaises(errors.ClientError, self.net.check_cert, self.certr) - - def test_refresh(self): - self.net.check_cert = mock.MagicMock() - self.assertEqual( - self.net.check_cert(self.certr), self.net.refresh(self.certr)) - - def test_fetch_chain(self): - # pylint: disable=protected-access - self.net._get_cert = mock.MagicMock() - self.net._get_cert.return_value = ("response", "certificate") - self.assertEqual(self.net._get_cert(self.certr.cert_chain_uri)[1], - self.net.fetch_chain(self.certr)) - - def test_fetch_chain_no_up_link(self): - self.assertTrue(self.net.fetch_chain(self.certr.update( - cert_chain_uri=None)) is None) - - def test_revoke(self): - self._mock_post_get() - self.net.revoke(self.certr.body) - self.post.assert_called_once_with(messages.Revocation.url( - self.net.new_reg_uri), mock.ANY) - - def test_revoke_bad_status_raises_error(self): - self.response.status_code = httplib.METHOD_NOT_ALLOWED - self._mock_post_get() - self.assertRaises(errors.ClientError, self.net.revoke, self.certr) + def test_head_get_post_error_passthrough(self): + self.send_request.side_effect = requests.exceptions.RequestException + for method in self.net.head, self.net.get: + self.assertRaises( + requests.exceptions.RequestException, method, 'GET', 'uri') + self.assertRaises(requests.exceptions.RequestException, + self.net.post, 'uri', obj=self.obj) if __name__ == '__main__': diff --git a/acme/errors.py b/acme/errors.py index 5046d7aee..9a96ec43a 100644 --- a/acme/errors.py +++ b/acme/errors.py @@ -5,11 +5,49 @@ from acme.jose import errors as jose_errors class Error(Exception): """Generic ACME error.""" + class SchemaValidationError(jose_errors.DeserializationError): """JSON schema ACME object validation error.""" + class ClientError(Error): """Network error.""" + class UnexpectedUpdate(ClientError): - """Unexpected update.""" + """Unexpected update error.""" + + +class NonceError(ClientError): + """Server response nonce error.""" + + +class BadNonce(NonceError): + """Bad nonce error.""" + def __init__(self, nonce, error, *args, **kwargs): + super(BadNonce, self).__init__(*args, **kwargs) + self.nonce = nonce + self.error = error + + def __str__(self): + return 'Invalid nonce ({0!r}): {1}'.format(self.nonce, self.error) + + +class MissingNonce(NonceError): + """Missing nonce error. + + According to the specification an "ACME server MUST include an + Replay-Nonce header field in each successful response to a POST it + provides to a client (...)". + + :ivar requests.Response response: HTTP Response + + """ + def __init__(self, response, *args, **kwargs): + super(MissingNonce, self).__init__(*args, **kwargs) + self.response = response + + def __str__(self): + return ('Server {0} response did not include a replay ' + 'nonce, headers: {1}'.format( + self.response.request.method, self.response.headers)) diff --git a/acme/errors_test.py b/acme/errors_test.py new file mode 100644 index 000000000..3790d91ed --- /dev/null +++ b/acme/errors_test.py @@ -0,0 +1,33 @@ +"""Tests for acme.errors.""" +import unittest + +import mock + + +class BadNonceTest(unittest.TestCase): + """Tests for acme.errors.BadNonce.""" + + def setUp(self): + from acme.errors import BadNonce + self.error = BadNonce(nonce="xxx", error="error") + + def test_str(self): + self.assertEqual("Invalid nonce ('xxx'): error", str(self.error)) + + +class MissingNonceTest(unittest.TestCase): + """Tests for acme.errors.MissingNonce.""" + + def setUp(self): + from acme.errors import MissingNonce + self.response = mock.MagicMock(headers={}) + self.response.request.method = 'FOO' + self.error = MissingNonce(self.response) + + def test_str(self): + self.assertTrue("FOO" in str(self.error)) + self.assertTrue("{}" in str(self.error)) + + +if __name__ == "__main__": + unittest.main() # pragma: no cover