1
0
mirror of https://github.com/certbot/certbot.git synced 2026-01-23 07:20:55 +03:00

Merge pull request #586 from kuba/acme-client

acme.client bug fixes and refactor
This commit is contained in:
James Kasten
2015-07-06 16:42:10 -07:00
4 changed files with 635 additions and 528 deletions

View File

@@ -32,150 +32,18 @@ class Client(object): # pylint: disable=too-many-instance-attributes
:ivar key: `.JWK` (private)
:ivar alg: `.JWASignature`
:ivar bool verify_ssl: Verify SSL certificates?
:ivar .ClientNetwork net: Client network. Useful for testing. If not
supplied, it will be initialized using `key`, `alg` and
`verify_ssl`.
"""
DER_CONTENT_TYPE = 'application/pkix-cert'
JSON_CONTENT_TYPE = 'application/json'
JSON_ERROR_CONTENT_TYPE = 'application/problem+json'
REPLAY_NONCE_HEADER = 'Replay-Nonce'
def __init__(self, new_reg_uri, key, alg=jose.RS256, verify_ssl=True):
def __init__(self, new_reg_uri, key, alg=jose.RS256,
verify_ssl=True, net=None):
self.new_reg_uri = new_reg_uri
self.key = key
self.alg = alg
self.verify_ssl = verify_ssl
self._nonces = set()
def _wrap_in_jws(self, obj, nonce):
"""Wrap `JSONDeSerializable` object in JWS.
.. todo:: Implement ``acmePath``.
:param JSONDeSerializable obj:
:rtype: `.JWS`
"""
dumps = obj.json_dumps()
logger.debug('Serialized JSON: %s', dumps)
return jws.JWS.sign(
payload=dumps, key=self.key, alg=self.alg, nonce=nonce).json_dumps()
@classmethod
def _check_response(cls, response, content_type=None):
"""Check response content and its type.
.. note::
Checking is not strict: wrong server response ``Content-Type``
HTTP header is ignored if response is an expected JSON object
(c.f. Boulder #56).
:param str content_type: Expected Content-Type response header.
If JSON is expected and not present in server response, this
function will raise an error. Otherwise, wrong Content-Type
is ignored, but logged.
:raises .messages.Error: If server response body
carries HTTP Problem (draft-ietf-appsawg-http-problem-00).
:raises .ClientError: In case of other networking errors.
"""
logger.debug('Received response %s (headers: %s): %r',
response, response.headers, response.content)
response_ct = response.headers.get('Content-Type')
try:
# TODO: response.json() is called twice, once here, and
# once in _get and _post clients
jobj = response.json()
except ValueError as error:
jobj = None
if not response.ok:
if jobj is not None:
if response_ct != cls.JSON_ERROR_CONTENT_TYPE:
logger.debug(
'Ignoring wrong Content-Type (%r) for JSON Error',
response_ct)
try:
raise messages.Error.from_json(jobj)
except jose.DeserializationError as error:
# Couldn't deserialize JSON object
raise errors.ClientError((response, error))
else:
# response is not JSON object
raise errors.ClientError(response)
else:
if jobj is not None and response_ct != cls.JSON_CONTENT_TYPE:
logger.debug(
'Ignoring wrong Content-Type (%r) for JSON decodable '
'response', response_ct)
if content_type == cls.JSON_CONTENT_TYPE and jobj is None:
raise errors.ClientError(
'Unexpected response Content-Type: {0}'.format(response_ct))
def _get(self, uri, content_type=JSON_CONTENT_TYPE, **kwargs):
"""Send GET request.
:raises .ClientError:
:returns: HTTP Response
:rtype: `requests.Response`
"""
logger.debug('Sending GET request to %s', uri)
kwargs.setdefault('verify', self.verify_ssl)
try:
response = requests.get(uri, **kwargs)
except requests.exceptions.RequestException as error:
raise errors.ClientError(error)
self._check_response(response, content_type=content_type)
return response
def _add_nonce(self, response):
if self.REPLAY_NONCE_HEADER in response.headers:
nonce = response.headers[self.REPLAY_NONCE_HEADER]
error = jws.Header.validate_nonce(nonce)
if error is None:
logger.debug('Storing nonce: %r', nonce)
self._nonces.add(nonce)
else:
raise errors.ClientError('Invalid nonce ({0}): {1}'.format(
nonce, error))
else:
raise errors.ClientError(
'Server {0} response did not include a replay nonce'.format(
response.request.method))
def _get_nonce(self, uri):
if not self._nonces:
logger.debug('Requesting fresh nonce by sending HEAD to %s', uri)
self._add_nonce(requests.head(uri))
return self._nonces.pop()
def _post(self, uri, obj, content_type=JSON_CONTENT_TYPE, **kwargs):
"""Send POST data.
:param JSONDeSerializable obj: Will be wrapped in JWS.
:param str content_type: Expected ``Content-Type``, fails if not set.
:raises acme.messages.ClientError:
:returns: HTTP Response
:rtype: `requests.Response`
"""
data = self._wrap_in_jws(obj, self._get_nonce(uri))
logger.debug('Sending POST data to %s: %s', uri, data)
kwargs.setdefault('verify', self.verify_ssl)
try:
response = requests.post(uri, data=data, **kwargs)
except requests.exceptions.RequestException as error:
raise errors.ClientError(error)
self._add_nonce(response)
self._check_response(response, content_type=content_type)
return response
self.net = ClientNetwork(key, alg, verify_ssl) if net is None else net
@classmethod
def _regr_from_response(cls, response, uri=None, new_authzr_uri=None,
@@ -211,7 +79,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes
"""
new_reg = messages.Registration(contact=contact)
response = self._post(self.new_reg_uri, new_reg)
response = self.net.post(self.new_reg_uri, new_reg)
assert response.status_code == httplib.CREATED # TODO: handle errors
regr = self._regr_from_response(response)
@@ -230,7 +98,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes
:rtype: `.RegistrationResource`
"""
response = self._post(regr.uri, regr.body)
response = self.net.post(regr.uri, regr.body)
# TODO: Boulder returns httplib.ACCEPTED
#assert response.status_code == httplib.OK
@@ -290,7 +158,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes
"""
new_authz = messages.Authorization(identifier=identifier)
response = self._post(new_authzr_uri, new_authz)
response = self.net.post(new_authzr_uri, new_authz)
assert response.status_code == httplib.CREATED # TODO: handle errors
return self._authzr_from_response(response, identifier)
@@ -326,7 +194,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes
:raises .UnexpectedUpdate:
"""
response = self._post(challb.uri, response)
response = self.net.post(challb.uri, response)
try:
authzr_uri = response.links['up']['url']
except KeyError:
@@ -377,7 +245,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes
:rtype: (`.AuthorizationResource`, `requests.Response`)
"""
response = self._get(authzr.uri)
response = self.net.get(authzr.uri)
updated_authzr = self._authzr_from_response(
response, authzr.body.identifier, authzr.uri, authzr.new_cert_uri)
# TODO: check and raise UnexpectedUpdate
@@ -403,7 +271,7 @@ class Client(object): # pylint: disable=too-many-instance-attributes
csr=csr, authorizations=tuple(authzr.uri for authzr in authzrs))
content_type = self.DER_CONTENT_TYPE # TODO: add 'cert_type 'argument
response = self._post(
response = self.net.post(
authzrs[0].new_cert_uri, # TODO: acme-spec #90
req,
content_type=content_type,
@@ -488,8 +356,8 @@ class Client(object): # pylint: disable=too-many-instance-attributes
"""
content_type = self.DER_CONTENT_TYPE # TODO: make it a param
response = self._get(uri, headers={'Accept': content_type},
content_type=content_type)
response = self.net.get(uri, headers={'Accept': content_type},
content_type=content_type)
return response, jose.ComparableX509(
M2Crypto.X509.load_cert_der_string(response.content))
@@ -551,8 +419,155 @@ class Client(object): # pylint: disable=too-many-instance-attributes
:raises .ClientError: If revocation is unsuccessful.
"""
response = self._post(messages.Revocation.url(self.new_reg_uri),
messages.Revocation(certificate=cert))
response = self.net.post(messages.Revocation.url(self.new_reg_uri),
messages.Revocation(certificate=cert))
if response.status_code != httplib.OK:
raise errors.ClientError(
'Successful revocation must return HTTP OK status')
class ClientNetwork(object):
"""Client network."""
JSON_CONTENT_TYPE = 'application/json'
JSON_ERROR_CONTENT_TYPE = 'application/problem+json'
REPLAY_NONCE_HEADER = 'Replay-Nonce'
def __init__(self, key, alg=jose.RS256, verify_ssl=True):
self.key = key
self.alg = alg
self.verify_ssl = verify_ssl
self._nonces = set()
def _wrap_in_jws(self, obj, nonce):
"""Wrap `JSONDeSerializable` object in JWS.
.. todo:: Implement ``acmePath``.
:param JSONDeSerializable obj:
:rtype: `.JWS`
"""
dumps = obj.json_dumps()
logger.debug('Serialized JSON: %s', dumps)
return jws.JWS.sign(
payload=dumps, key=self.key, alg=self.alg, nonce=nonce).json_dumps()
@classmethod
def _check_response(cls, response, content_type=None):
"""Check response content and its type.
.. note::
Checking is not strict: wrong server response ``Content-Type``
HTTP header is ignored if response is an expected JSON object
(c.f. Boulder #56).
:param str content_type: Expected Content-Type response header.
If JSON is expected and not present in server response, this
function will raise an error. Otherwise, wrong Content-Type
is ignored, but logged.
:raises .messages.Error: If server response body
carries HTTP Problem (draft-ietf-appsawg-http-problem-00).
:raises .ClientError: In case of other networking errors.
"""
logger.debug('Received response %s (headers: %s): %r',
response, response.headers, response.content)
response_ct = response.headers.get('Content-Type')
try:
# TODO: response.json() is called twice, once here, and
# once in _get and _post clients
jobj = response.json()
except ValueError as error:
jobj = None
if not response.ok:
if jobj is not None:
if response_ct != cls.JSON_ERROR_CONTENT_TYPE:
logger.debug(
'Ignoring wrong Content-Type (%r) for JSON Error',
response_ct)
try:
raise messages.Error.from_json(jobj)
except jose.DeserializationError as error:
# Couldn't deserialize JSON object
raise errors.ClientError((response, error))
else:
# response is not JSON object
raise errors.ClientError(response)
else:
if jobj is not None and response_ct != cls.JSON_CONTENT_TYPE:
logger.debug(
'Ignoring wrong Content-Type (%r) for JSON decodable '
'response', response_ct)
if content_type == cls.JSON_CONTENT_TYPE and jobj is None:
raise errors.ClientError(
'Unexpected response Content-Type: {0}'.format(response_ct))
return response
def _send_request(self, method, url, *args, **kwargs):
"""Send HTTP request.
Makes sure that `verify_ssl` is respected. Logs request and
response (with headers). For allowed parameters please see
`requests.request`.
:param str method: method for the new `requests.Request` object
:param str url: URL for the new `requests.Request` object
:raises requests.exceptions.RequestException: in case of any problems
:returns: HTTP Response
:rtype: `requests.Response`
"""
logging.debug('Sending %s request to %s', method, url)
kwargs['verify'] = self.verify_ssl
response = requests.request(method, url, *args, **kwargs)
logging.debug('Received %s. Headers: %s. Content: %r',
response, response.headers, response.content)
return response
def head(self, *args, **kwargs):
"""Send HEAD request without checking the response.
Note, that `_check_response` is not called, as it is expected
that status code other than successfuly 2xx will be returned, or
messages2.Error will be raised by the server.
"""
return self._send_request('HEAD', *args, **kwargs)
def get(self, url, content_type=JSON_CONTENT_TYPE, **kwargs):
"""Send GET request and check response."""
return self._check_response(
self._send_request('GET', url, **kwargs), content_type=content_type)
def _add_nonce(self, response):
if self.REPLAY_NONCE_HEADER in response.headers:
nonce = response.headers[self.REPLAY_NONCE_HEADER]
error = jws.Header.validate_nonce(nonce)
if error is None:
logger.debug('Storing nonce: %r', nonce)
self._nonces.add(nonce)
else:
raise errors.BadNonce(nonce, error)
else:
raise errors.MissingNonce(response)
def _get_nonce(self, url):
if not self._nonces:
logging.debug('Requesting fresh nonce')
self._add_nonce(self.head(url))
return self._nonces.pop()
def post(self, url, obj, content_type=JSON_CONTENT_TYPE, **kwargs):
"""POST object wrapped in `.JWS` and check response."""
data = self._wrap_in_jws(obj, self._get_nonce(url))
response = self._send_request('POST', url, data=data, **kwargs)
self._add_nonce(response)
return self._check_response(response, content_type=content_type)

View File

@@ -23,27 +23,20 @@ KEY2 = jose.JWKRSA.load(pkg_resources.resource_string(
class ClientTest(unittest.TestCase):
"""Tests for acme.client.Client."""
"""Tests for acme.client.Client."""
# pylint: disable=too-many-instance-attributes,too-many-public-methods
def setUp(self):
self.verify_ssl = mock.MagicMock()
self.wrap_in_jws = mock.MagicMock(return_value=mock.sentinel.wrapped)
self.response = mock.MagicMock(
ok=True, status_code=httplib.OK, headers={}, links={})
self.net = mock.MagicMock()
self.net.post.return_value = self.response
self.net.get.return_value = self.response
from acme.client import Client
self.net = Client(
self.client = Client(
new_reg_uri='https://www.letsencrypt-demo.org/acme/new-reg',
key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl)
self.nonce = jose.b64encode('Nonce')
self.net._nonces.add(self.nonce) # pylint: disable=protected-access
self.response = mock.MagicMock(ok=True, status_code=httplib.OK)
self.response.headers = {}
self.response.links = {}
self.post = mock.MagicMock(return_value=self.response)
self.get = mock.MagicMock(return_value=self.response)
key=KEY, alg=jose.RS256, net=self.net)
self.identifier = messages.Identifier(
typ=messages.IDENTIFIER_FQDN, value='example.com')
@@ -78,10 +71,295 @@ class ClientTest(unittest.TestCase):
uri='https://www.letsencrypt-demo.org/acme/cert/1',
cert_chain_uri='https://www.letsencrypt-demo.org/ca')
def _mock_post_get(self):
def test_register(self):
self.response.status_code = httplib.CREATED
self.response.json.return_value = self.regr.body.to_json()
self.response.headers['Location'] = self.regr.uri
self.response.links.update({
'next': {'url': self.regr.new_authzr_uri},
'terms-of-service': {'url': self.regr.terms_of_service},
})
self.assertEqual(self.regr, self.client.register(self.contact))
# TODO: test POST call arguments
# TODO: split here and separate test
reg_wrong_key = self.regr.body.update(key=KEY2.public())
self.response.json.return_value = reg_wrong_key.to_json()
self.assertRaises(
errors.UnexpectedUpdate, self.client.register, self.contact)
def test_register_missing_next(self):
self.response.status_code = httplib.CREATED
self.assertRaises(
errors.ClientError, self.client.register, self.regr.body)
def test_update_registration(self):
self.response.headers['Location'] = self.regr.uri
self.response.json.return_value = self.regr.body.to_json()
self.assertEqual(self.regr, self.client.update_registration(self.regr))
# TODO: split here and separate test
self.response.json.return_value = self.regr.body.update(
contact=()).to_json()
self.assertRaises(
errors.UnexpectedUpdate, self.client.update_registration, self.regr)
def test_agree_to_tos(self):
self.client.update_registration = mock.Mock()
self.client.agree_to_tos(self.regr)
regr = self.client.update_registration.call_args[0][0]
self.assertEqual(self.regr.terms_of_service, regr.body.agreement)
def test_request_challenges(self):
self.response.status_code = httplib.CREATED
self.response.headers['Location'] = self.authzr.uri
self.response.json.return_value = self.authz.to_json()
self.response.links = {
'next': {'url': self.authzr.new_cert_uri},
}
self.client.request_challenges(self.identifier, self.authzr.uri)
# TODO: test POST call arguments
# TODO: split here and separate test
self.response.json.return_value = self.authz.update(
identifier=self.identifier.update(value='foo')).to_json()
self.assertRaises(
errors.UnexpectedUpdate, self.client.request_challenges,
self.identifier, self.authzr.uri)
def test_request_challenges_missing_next(self):
self.response.status_code = httplib.CREATED
self.assertRaises(
errors.ClientError, self.client.request_challenges,
self.identifier, self.regr)
def test_request_domain_challenges(self):
self.client.request_challenges = mock.MagicMock()
self.assertEqual(
self.client.request_challenges(self.identifier),
self.client.request_domain_challenges('example.com', self.regr))
def test_answer_challenge(self):
self.response.links['up'] = {'url': self.challr.authzr_uri}
self.response.json.return_value = self.challr.body.to_json()
chall_response = challenges.DNSResponse()
self.client.answer_challenge(self.challr.body, chall_response)
# TODO: split here and separate test
self.assertRaises(errors.UnexpectedUpdate, self.client.answer_challenge,
self.challr.body.update(uri='foo'), chall_response)
def test_answer_challenge_missing_next(self):
self.assertRaises(errors.ClientError, self.client.answer_challenge,
self.challr.body, challenges.DNSResponse())
def test_retry_after_date(self):
self.response.headers['Retry-After'] = 'Fri, 31 Dec 1999 23:59:59 GMT'
self.assertEqual(
datetime.datetime(1999, 12, 31, 23, 59, 59),
self.client.retry_after(response=self.response, default=10))
@mock.patch('acme.client.datetime')
def test_retry_after_invalid(self, dt_mock):
dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27)
dt_mock.timedelta = datetime.timedelta
self.response.headers['Retry-After'] = 'foooo'
self.assertEqual(
datetime.datetime(2015, 3, 27, 0, 0, 10),
self.client.retry_after(response=self.response, default=10))
@mock.patch('acme.client.datetime')
def test_retry_after_seconds(self, dt_mock):
dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27)
dt_mock.timedelta = datetime.timedelta
self.response.headers['Retry-After'] = '50'
self.assertEqual(
datetime.datetime(2015, 3, 27, 0, 0, 50),
self.client.retry_after(response=self.response, default=10))
@mock.patch('acme.client.datetime')
def test_retry_after_missing(self, dt_mock):
dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27)
dt_mock.timedelta = datetime.timedelta
self.assertEqual(
datetime.datetime(2015, 3, 27, 0, 0, 10),
self.client.retry_after(response=self.response, default=10))
def test_poll(self):
self.response.json.return_value = self.authzr.body.to_json()
self.assertEqual((self.authzr, self.response),
self.client.poll(self.authzr))
# TODO: split here and separate test
self.response.json.return_value = self.authz.update(
identifier=self.identifier.update(value='foo')).to_json()
self.assertRaises(
errors.UnexpectedUpdate, self.client.poll, self.authzr)
def test_request_issuance(self):
self.response.content = messages_test.CERT.as_der()
self.response.headers['Location'] = self.certr.uri
self.response.links['up'] = {'url': self.certr.cert_chain_uri}
self.assertEqual(self.certr, self.client.request_issuance(
messages_test.CSR, (self.authzr,)))
# TODO: check POST args
def test_request_issuance_missing_up(self):
self.response.content = messages_test.CERT.as_der()
self.response.headers['Location'] = self.certr.uri
self.assertEqual(
self.certr.update(cert_chain_uri=None),
self.client.request_issuance(messages_test.CSR, (self.authzr,)))
def test_request_issuance_missing_location(self):
self.assertRaises(
errors.ClientError, self.client.request_issuance,
messages_test.CSR, (self.authzr,))
@mock.patch('acme.client.datetime')
@mock.patch('acme.client.time')
def test_poll_and_request_issuance(self, time_mock, dt_mock):
# clock.dt | pylint: disable=no-member
clock = mock.MagicMock(dt=datetime.datetime(2015, 3, 27))
def sleep(seconds):
"""increment clock"""
clock.dt += datetime.timedelta(seconds=seconds)
time_mock.sleep.side_effect = sleep
def now():
"""return current clock value"""
return clock.dt
dt_mock.datetime.now.side_effect = now
dt_mock.timedelta = datetime.timedelta
def poll(authzr): # pylint: disable=missing-docstring
# record poll start time based on the current clock value
authzr.times.append(clock.dt)
# suppose it takes 2 seconds for server to produce the
# result, increment clock
clock.dt += datetime.timedelta(seconds=2)
if not authzr.retries: # no more retries
done = mock.MagicMock(uri=authzr.uri, times=authzr.times)
done.body.status = messages.STATUS_VALID
return done, []
# response (2nd result tuple element) is reduced to only
# Retry-After header contents represented as integer
# seconds; authzr.retries is a list of Retry-After
# headers, head(retries) is peeled of as a current
# Retry-After header, and tail(retries) is persisted for
# later poll() calls
return (mock.MagicMock(retries=authzr.retries[1:],
uri=authzr.uri + '.', times=authzr.times),
authzr.retries[0])
self.client.poll = mock.MagicMock(side_effect=poll)
mintime = 7
def retry_after(response, default): # pylint: disable=missing-docstring
# check that poll_and_request_issuance correctly passes mintime
self.assertEqual(default, mintime)
return clock.dt + datetime.timedelta(seconds=response)
self.client.retry_after = mock.MagicMock(side_effect=retry_after)
def request_issuance(csr, authzrs): # pylint: disable=missing-docstring
return csr, authzrs
self.client.request_issuance = mock.MagicMock(
side_effect=request_issuance)
csr = mock.MagicMock()
authzrs = (
mock.MagicMock(uri='a', times=[], retries=(8, 20, 30)),
mock.MagicMock(uri='b', times=[], retries=(5,)),
)
cert, updated_authzrs = self.client.poll_and_request_issuance(
csr, authzrs, mintime=mintime)
self.assertTrue(cert[0] is csr)
self.assertTrue(cert[1] is updated_authzrs)
self.assertEqual(updated_authzrs[0].uri, 'a...')
self.assertEqual(updated_authzrs[1].uri, 'b.')
self.assertEqual(updated_authzrs[0].times, [
datetime.datetime(2015, 3, 27),
# a is scheduled for 10, but b is polling [9..11), so it
# will be picked up as soon as b is finished, without
# additional sleeping
datetime.datetime(2015, 3, 27, 0, 0, 11),
datetime.datetime(2015, 3, 27, 0, 0, 33),
datetime.datetime(2015, 3, 27, 0, 1, 5),
])
self.assertEqual(updated_authzrs[1].times, [
datetime.datetime(2015, 3, 27, 0, 0, 2),
datetime.datetime(2015, 3, 27, 0, 0, 9),
])
self.assertEqual(clock.dt, datetime.datetime(2015, 3, 27, 0, 1, 7))
def test_check_cert(self):
self.response.headers['Location'] = self.certr.uri
self.response.content = messages_test.CERT.as_der()
self.assertEqual(self.certr.update(body=messages_test.CERT),
self.client.check_cert(self.certr))
# TODO: split here and separate test
self.response.headers['Location'] = 'foo'
self.assertRaises(
errors.UnexpectedUpdate, self.client.check_cert, self.certr)
def test_check_cert_missing_location(self):
self.response.content = messages_test.CERT.as_der()
self.assertRaises(
errors.ClientError, self.client.check_cert, self.certr)
def test_refresh(self):
self.client.check_cert = mock.MagicMock()
self.assertEqual(
self.client.check_cert(self.certr), self.client.refresh(self.certr))
def test_fetch_chain(self):
# pylint: disable=protected-access
self.net._post = self.post
self.net._get = self.get
self.client._get_cert = mock.MagicMock()
self.client._get_cert.return_value = ("response", "certificate")
self.assertEqual(self.client._get_cert(self.certr.cert_chain_uri)[1],
self.client.fetch_chain(self.certr))
def test_fetch_chain_no_up_link(self):
self.assertTrue(self.client.fetch_chain(self.certr.update(
cert_chain_uri=None)) is None)
def test_revoke(self):
self.client.revoke(self.certr.body)
self.net.post.assert_called_once_with(messages.Revocation.url(
self.client.new_reg_uri), mock.ANY)
def test_revoke_bad_status_raises_error(self):
self.response.status_code = httplib.METHOD_NOT_ALLOWED
self.assertRaises(errors.ClientError, self.client.revoke, self.certr)
class ClientNetworkTest(unittest.TestCase):
"""Tests for acme.client.ClientNetwork."""
def setUp(self):
self.verify_ssl = mock.MagicMock()
self.wrap_in_jws = mock.MagicMock(return_value=mock.sentinel.wrapped)
from acme.client import ClientNetwork
self.net = ClientNetwork(
key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl)
self.response = mock.MagicMock(ok=True, status_code=httplib.OK)
self.response.headers = {}
self.response.links = {}
def test_init(self):
self.assertTrue(self.net.verify_ssl is self.verify_ssl)
@@ -139,384 +417,127 @@ class ClientTest(unittest.TestCase):
self.response.json.side_effect = ValueError
for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']:
self.response.headers['Content-Type'] = response_ct
# pylint: disable=protected-access
self.net._check_response(self.response)
# pylint: disable=protected-access,no-value-for-parameter
self.assertEqual(
self.response, self.net._check_response(self.response))
def test_check_response_jobj(self):
self.response.json.return_value = {}
for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']:
self.response.headers['Content-Type'] = response_ct
# pylint: disable=protected-access,no-value-for-parameter
self.assertEqual(
self.response, self.net._check_response(self.response))
@mock.patch('acme.client.requests')
def test_send_request(self, mock_requests):
mock_requests.request.return_value = self.response
# pylint: disable=protected-access
self.assertEqual(self.response, self.net._send_request(
'HEAD', 'url', 'foo', bar='baz'))
mock_requests.request.assert_called_once_with(
'HEAD', 'url', 'foo', verify=mock.ANY, bar='baz')
@mock.patch('acme.client.requests')
def test_send_request_verify_ssl(self, mock_requests):
# pylint: disable=protected-access
for verify in True, False:
mock_requests.request.reset_mock()
mock_requests.request.return_value = self.response
self.net.verify_ssl = verify
# pylint: disable=protected-access
self.net._check_response(self.response)
self.assertEqual(
self.response, self.net._send_request('GET', 'url'))
mock_requests.request.assert_called_once_with(
'GET', 'url', verify=verify)
@mock.patch('acme.client.requests')
def test_get_requests_error_passthrough(self, requests_mock):
requests_mock.exceptions = requests.exceptions
requests_mock.get.side_effect = requests.exceptions.RequestException
def test_requests_error_passthrough(self, mock_requests):
mock_requests.exceptions = requests.exceptions
mock_requests.request.side_effect = requests.exceptions.RequestException
# pylint: disable=protected-access
self.assertRaises(errors.ClientError, self.net._get, 'uri')
self.assertRaises(requests.exceptions.RequestException,
self.net._send_request, 'GET', 'uri')
class ClientNetworkWithMockedResponseTest(unittest.TestCase):
"""Tests for acme.client.ClientNetwork which mock out response."""
# pylint: disable=too-many-instance-attributes
def setUp(self):
from acme.client import ClientNetwork
self.net = ClientNetwork(key=None, alg=None)
self.response = mock.MagicMock(ok=True, status_code=httplib.OK)
self.response.headers = {}
self.response.links = {}
self.checked_response = mock.MagicMock()
self.obj = mock.MagicMock()
self.wrapped_obj = mock.MagicMock()
self.content_type = mock.sentinel.content_type
self.all_nonces = [jose.b64encode('Nonce'), jose.b64encode('Nonce2')]
self.available_nonces = self.all_nonces[:]
def send_request(*args, **kwargs):
# pylint: disable=unused-argument,missing-docstring
if self.available_nonces:
self.response.headers = {
self.net.REPLAY_NONCE_HEADER: self.available_nonces.pop()}
else:
self.response.headers = {}
return self.response
@mock.patch('acme.client.requests')
def test_get(self, requests_mock):
# pylint: disable=protected-access
self.net._check_response = mock.MagicMock()
self.net._get('uri', content_type='ct')
self.net._check_response.assert_called_once_with(
requests_mock.get('uri'), content_type='ct')
self.net._send_request = self.send_request = mock.MagicMock(
side_effect=send_request)
self.net._check_response = self.check_response
self.net._wrap_in_jws = mock.MagicMock(return_value=self.wrapped_obj)
def _mock_wrap_in_jws(self):
def check_response(self, response, content_type):
# pylint: disable=missing-docstring
self.assertEqual(self.response, response)
self.assertEqual(self.content_type, content_type)
return self.checked_response
def test_head(self):
self.assertEqual(self.response, self.net.head('url', 'foo', bar='baz'))
self.send_request.assert_called_once('HEAD', 'url', 'foo', bar='baz')
def test_get(self):
self.assertEqual(self.checked_response, self.net.get(
'url', content_type=self.content_type, bar='baz'))
self.send_request.assert_called_once_with('GET', 'url', bar='baz')
def test_post(self):
# pylint: disable=protected-access
self.net._wrap_in_jws = self.wrap_in_jws
self.assertEqual(self.checked_response, self.net.post(
'uri', self.obj, content_type=self.content_type))
self.net._wrap_in_jws.assert_called_once_with(
self.obj, self.all_nonces.pop())
@mock.patch('acme.client.requests')
def test_post_requests_error_passthrough(self, requests_mock):
requests_mock.exceptions = requests.exceptions
requests_mock.post.side_effect = requests.exceptions.RequestException
# pylint: disable=protected-access
self._mock_wrap_in_jws()
self.assertRaises(
errors.ClientError, self.net._post, 'uri', mock.sentinel.obj)
assert not self.available_nonces
self.assertRaises(errors.MissingNonce, self.net.post,
'uri', self.obj, content_type=self.content_type)
self.net._wrap_in_jws.assert_called_with(
self.obj, self.all_nonces.pop())
@mock.patch('acme.client.requests')
def test_post(self, requests_mock):
# pylint: disable=protected-access
self.net._check_response = mock.MagicMock()
self._mock_wrap_in_jws()
requests_mock.post().headers = {
self.net.REPLAY_NONCE_HEADER: self.nonce}
self.net._post('uri', mock.sentinel.obj, content_type='ct')
self.net._check_response.assert_called_once_with(
requests_mock.post('uri', mock.sentinel.wrapped), content_type='ct')
def test_post_wrong_initial_nonce(self): # HEAD
self.available_nonces = ['f', jose.b64encode('good')]
self.assertRaises(errors.BadNonce, self.net.post, 'uri',
self.obj, content_type=self.content_type)
@mock.patch('acme.client.requests')
def test_post_replay_nonce_handling(self, requests_mock):
# pylint: disable=protected-access
self.net._check_response = mock.MagicMock()
self._mock_wrap_in_jws()
def test_post_wrong_post_response_nonce(self):
self.available_nonces = [jose.b64encode('good'), 'f']
self.assertRaises(errors.BadNonce, self.net.post, 'uri',
self.obj, content_type=self.content_type)
self.net._nonces.clear()
self.assertRaises(
errors.ClientError, self.net._post, 'uri', mock.sentinel.obj)
nonce2 = jose.b64encode('Nonce2')
requests_mock.head('uri').headers = {
self.net.REPLAY_NONCE_HEADER: nonce2}
requests_mock.post('uri').headers = {
self.net.REPLAY_NONCE_HEADER: self.nonce}
self.net._post('uri', mock.sentinel.obj)
requests_mock.head.assert_called_with('uri')
self.wrap_in_jws.assert_called_once_with(mock.sentinel.obj, nonce2)
self.assertEqual(self.net._nonces, set([self.nonce]))
# wrong nonce
requests_mock.post('uri').headers = {self.net.REPLAY_NONCE_HEADER: 'F'}
self.assertRaises(
errors.ClientError, self.net._post, 'uri', mock.sentinel.obj)
@mock.patch('acme.client.requests')
def test_get_post_verify_ssl(self, requests_mock):
# pylint: disable=protected-access
self._mock_wrap_in_jws()
self.net._check_response = mock.MagicMock()
for verify_ssl in [True, False]:
self.net.verify_ssl = verify_ssl
self.net._get('uri')
self.net._nonces.add('N')
requests_mock.post().headers = {
self.net.REPLAY_NONCE_HEADER: self.nonce}
self.net._post('uri', mock.sentinel.obj)
requests_mock.get.assert_called_once_with('uri', verify=verify_ssl)
requests_mock.post.assert_called_with(
'uri', data=mock.sentinel.wrapped, verify=verify_ssl)
requests_mock.reset_mock()
def test_register(self):
self.response.status_code = httplib.CREATED
self.response.json.return_value = self.regr.body.to_json()
self.response.headers['Location'] = self.regr.uri
self.response.links.update({
'next': {'url': self.regr.new_authzr_uri},
'terms-of-service': {'url': self.regr.terms_of_service},
})
self._mock_post_get()
self.assertEqual(self.regr, self.net.register(self.contact))
# TODO: test POST call arguments
# TODO: split here and separate test
reg_wrong_key = self.regr.body.update(key=KEY2.public())
self.response.json.return_value = reg_wrong_key.to_json()
self.assertRaises(
errors.UnexpectedUpdate, self.net.register, self.contact)
def test_register_missing_next(self):
self.response.status_code = httplib.CREATED
self._mock_post_get()
self.assertRaises(
errors.ClientError, self.net.register, self.regr.body)
def test_update_registration(self):
self.response.headers['Location'] = self.regr.uri
self.response.json.return_value = self.regr.body.to_json()
self._mock_post_get()
self.assertEqual(self.regr, self.net.update_registration(self.regr))
# TODO: split here and separate test
self.response.json.return_value = self.regr.body.update(
contact=()).to_json()
self.assertRaises(
errors.UnexpectedUpdate, self.net.update_registration, self.regr)
def test_agree_to_tos(self):
self.net.update_registration = mock.Mock()
self.net.agree_to_tos(self.regr)
regr = self.net.update_registration.call_args[0][0]
self.assertEqual(self.regr.terms_of_service, regr.body.agreement)
def test_request_challenges(self):
self.response.status_code = httplib.CREATED
self.response.headers['Location'] = self.authzr.uri
self.response.json.return_value = self.authz.to_json()
self.response.links = {
'next': {'url': self.authzr.new_cert_uri},
}
self._mock_post_get()
self.net.request_challenges(self.identifier, self.authzr.uri)
# TODO: test POST call arguments
# TODO: split here and separate test
self.response.json.return_value = self.authz.update(
identifier=self.identifier.update(value='foo')).to_json()
self.assertRaises(errors.UnexpectedUpdate, self.net.request_challenges,
self.identifier, self.authzr.uri)
def test_request_challenges_missing_next(self):
self.response.status_code = httplib.CREATED
self._mock_post_get()
self.assertRaises(
errors.ClientError, self.net.request_challenges,
self.identifier, self.regr)
def test_request_domain_challenges(self):
self.net.request_challenges = mock.MagicMock()
self.assertEqual(
self.net.request_challenges(self.identifier),
self.net.request_domain_challenges('example.com', self.regr))
def test_answer_challenge(self):
self.response.links['up'] = {'url': self.challr.authzr_uri}
self.response.json.return_value = self.challr.body.to_json()
chall_response = challenges.DNSResponse()
self._mock_post_get()
self.net.answer_challenge(self.challr.body, chall_response)
# TODO: split here and separate test
self.assertRaises(errors.UnexpectedUpdate, self.net.answer_challenge,
self.challr.body.update(uri='foo'), chall_response)
def test_answer_challenge_missing_next(self):
self._mock_post_get()
self.assertRaises(errors.ClientError, self.net.answer_challenge,
self.challr.body, challenges.DNSResponse())
def test_retry_after_date(self):
self.response.headers['Retry-After'] = 'Fri, 31 Dec 1999 23:59:59 GMT'
self.assertEqual(
datetime.datetime(1999, 12, 31, 23, 59, 59),
self.net.retry_after(response=self.response, default=10))
@mock.patch('acme.client.datetime')
def test_retry_after_invalid(self, dt_mock):
dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27)
dt_mock.timedelta = datetime.timedelta
self.response.headers['Retry-After'] = 'foooo'
self.assertEqual(
datetime.datetime(2015, 3, 27, 0, 0, 10),
self.net.retry_after(response=self.response, default=10))
@mock.patch('acme.client.datetime')
def test_retry_after_seconds(self, dt_mock):
dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27)
dt_mock.timedelta = datetime.timedelta
self.response.headers['Retry-After'] = '50'
self.assertEqual(
datetime.datetime(2015, 3, 27, 0, 0, 50),
self.net.retry_after(response=self.response, default=10))
@mock.patch('acme.client.datetime')
def test_retry_after_missing(self, dt_mock):
dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27)
dt_mock.timedelta = datetime.timedelta
self.assertEqual(
datetime.datetime(2015, 3, 27, 0, 0, 10),
self.net.retry_after(response=self.response, default=10))
def test_poll(self):
self.response.json.return_value = self.authzr.body.to_json()
self._mock_post_get()
self.assertEqual((self.authzr, self.response),
self.net.poll(self.authzr))
# TODO: split here and separate test
self.response.json.return_value = self.authz.update(
identifier=self.identifier.update(value='foo')).to_json()
self.assertRaises(errors.UnexpectedUpdate, self.net.poll, self.authzr)
def test_request_issuance(self):
self.response.content = messages_test.CERT.as_der()
self.response.headers['Location'] = self.certr.uri
self.response.links['up'] = {'url': self.certr.cert_chain_uri}
self._mock_post_get()
self.assertEqual(self.certr, self.net.request_issuance(
messages_test.CSR, (self.authzr,)))
# TODO: check POST args
def test_request_issuance_missing_up(self):
self.response.content = messages_test.CERT.as_der()
self.response.headers['Location'] = self.certr.uri
self._mock_post_get()
self.assertEqual(
self.certr.update(cert_chain_uri=None),
self.net.request_issuance(messages_test.CSR, (self.authzr,)))
def test_request_issuance_missing_location(self):
self._mock_post_get()
self.assertRaises(
errors.ClientError, self.net.request_issuance,
messages_test.CSR, (self.authzr,))
@mock.patch('acme.client.datetime')
@mock.patch('acme.client.time')
def test_poll_and_request_issuance(self, time_mock, dt_mock):
# clock.dt | pylint: disable=no-member
clock = mock.MagicMock(dt=datetime.datetime(2015, 3, 27))
def sleep(seconds):
"""increment clock"""
clock.dt += datetime.timedelta(seconds=seconds)
time_mock.sleep.side_effect = sleep
def now():
"""return current clock value"""
return clock.dt
dt_mock.datetime.now.side_effect = now
dt_mock.timedelta = datetime.timedelta
def poll(authzr): # pylint: disable=missing-docstring
# record poll start time based on the current clock value
authzr.times.append(clock.dt)
# suppose it takes 2 seconds for server to produce the
# result, increment clock
clock.dt += datetime.timedelta(seconds=2)
if not authzr.retries: # no more retries
done = mock.MagicMock(uri=authzr.uri, times=authzr.times)
done.body.status = messages.STATUS_VALID
return done, []
# response (2nd result tuple element) is reduced to only
# Retry-After header contents represented as integer
# seconds; authzr.retries is a list of Retry-After
# headers, head(retries) is peeled of as a current
# Retry-After header, and tail(retries) is persisted for
# later poll() calls
return (mock.MagicMock(retries=authzr.retries[1:],
uri=authzr.uri + '.', times=authzr.times),
authzr.retries[0])
self.net.poll = mock.MagicMock(side_effect=poll)
mintime = 7
def retry_after(response, default): # pylint: disable=missing-docstring
# check that poll_and_request_issuance correctly passes mintime
self.assertEqual(default, mintime)
return clock.dt + datetime.timedelta(seconds=response)
self.net.retry_after = mock.MagicMock(side_effect=retry_after)
def request_issuance(csr, authzrs): # pylint: disable=missing-docstring
return csr, authzrs
self.net.request_issuance = mock.MagicMock(side_effect=request_issuance)
csr = mock.MagicMock()
authzrs = (
mock.MagicMock(uri='a', times=[], retries=(8, 20, 30)),
mock.MagicMock(uri='b', times=[], retries=(5,)),
)
cert, updated_authzrs = self.net.poll_and_request_issuance(
csr, authzrs, mintime=mintime)
self.assertTrue(cert[0] is csr)
self.assertTrue(cert[1] is updated_authzrs)
self.assertEqual(updated_authzrs[0].uri, 'a...')
self.assertEqual(updated_authzrs[1].uri, 'b.')
self.assertEqual(updated_authzrs[0].times, [
datetime.datetime(2015, 3, 27),
# a is scheduled for 10, but b is polling [9..11), so it
# will be picked up as soon as b is finished, without
# additional sleeping
datetime.datetime(2015, 3, 27, 0, 0, 11),
datetime.datetime(2015, 3, 27, 0, 0, 33),
datetime.datetime(2015, 3, 27, 0, 1, 5),
])
self.assertEqual(updated_authzrs[1].times, [
datetime.datetime(2015, 3, 27, 0, 0, 2),
datetime.datetime(2015, 3, 27, 0, 0, 9),
])
self.assertEqual(clock.dt, datetime.datetime(2015, 3, 27, 0, 1, 7))
def test_check_cert(self):
self.response.headers['Location'] = self.certr.uri
self.response.content = messages_test.CERT.as_der()
self._mock_post_get()
self.assertEqual(self.certr.update(body=messages_test.CERT),
self.net.check_cert(self.certr))
# TODO: split here and separate test
self.response.headers['Location'] = 'foo'
self.assertRaises(
errors.UnexpectedUpdate, self.net.check_cert, self.certr)
def test_check_cert_missing_location(self):
self.response.content = messages_test.CERT.as_der()
self._mock_post_get()
self.assertRaises(errors.ClientError, self.net.check_cert, self.certr)
def test_refresh(self):
self.net.check_cert = mock.MagicMock()
self.assertEqual(
self.net.check_cert(self.certr), self.net.refresh(self.certr))
def test_fetch_chain(self):
# pylint: disable=protected-access
self.net._get_cert = mock.MagicMock()
self.net._get_cert.return_value = ("response", "certificate")
self.assertEqual(self.net._get_cert(self.certr.cert_chain_uri)[1],
self.net.fetch_chain(self.certr))
def test_fetch_chain_no_up_link(self):
self.assertTrue(self.net.fetch_chain(self.certr.update(
cert_chain_uri=None)) is None)
def test_revoke(self):
self._mock_post_get()
self.net.revoke(self.certr.body)
self.post.assert_called_once_with(messages.Revocation.url(
self.net.new_reg_uri), mock.ANY)
def test_revoke_bad_status_raises_error(self):
self.response.status_code = httplib.METHOD_NOT_ALLOWED
self._mock_post_get()
self.assertRaises(errors.ClientError, self.net.revoke, self.certr)
def test_head_get_post_error_passthrough(self):
self.send_request.side_effect = requests.exceptions.RequestException
for method in self.net.head, self.net.get:
self.assertRaises(
requests.exceptions.RequestException, method, 'GET', 'uri')
self.assertRaises(requests.exceptions.RequestException,
self.net.post, 'uri', obj=self.obj)
if __name__ == '__main__':

View File

@@ -5,11 +5,49 @@ from acme.jose import errors as jose_errors
class Error(Exception):
"""Generic ACME error."""
class SchemaValidationError(jose_errors.DeserializationError):
"""JSON schema ACME object validation error."""
class ClientError(Error):
"""Network error."""
class UnexpectedUpdate(ClientError):
"""Unexpected update."""
"""Unexpected update error."""
class NonceError(ClientError):
"""Server response nonce error."""
class BadNonce(NonceError):
"""Bad nonce error."""
def __init__(self, nonce, error, *args, **kwargs):
super(BadNonce, self).__init__(*args, **kwargs)
self.nonce = nonce
self.error = error
def __str__(self):
return 'Invalid nonce ({0!r}): {1}'.format(self.nonce, self.error)
class MissingNonce(NonceError):
"""Missing nonce error.
According to the specification an "ACME server MUST include an
Replay-Nonce header field in each successful response to a POST it
provides to a client (...)".
:ivar requests.Response response: HTTP Response
"""
def __init__(self, response, *args, **kwargs):
super(MissingNonce, self).__init__(*args, **kwargs)
self.response = response
def __str__(self):
return ('Server {0} response did not include a replay '
'nonce, headers: {1}'.format(
self.response.request.method, self.response.headers))

33
acme/errors_test.py Normal file
View File

@@ -0,0 +1,33 @@
"""Tests for acme.errors."""
import unittest
import mock
class BadNonceTest(unittest.TestCase):
"""Tests for acme.errors.BadNonce."""
def setUp(self):
from acme.errors import BadNonce
self.error = BadNonce(nonce="xxx", error="error")
def test_str(self):
self.assertEqual("Invalid nonce ('xxx'): error", str(self.error))
class MissingNonceTest(unittest.TestCase):
"""Tests for acme.errors.MissingNonce."""
def setUp(self):
from acme.errors import MissingNonce
self.response = mock.MagicMock(headers={})
self.response.request.method = 'FOO'
self.error = MissingNonce(self.response)
def test_str(self):
self.assertTrue("FOO" in str(self.error))
self.assertTrue("{}" in str(self.error))
if __name__ == "__main__":
unittest.main() # pragma: no cover