diff --git a/acme/acme/crypto_util.py b/acme/acme/crypto_util.py index 030946f82..32533630b 100644 --- a/acme/acme/crypto_util.py +++ b/acme/acme/crypto_util.py @@ -26,47 +26,95 @@ logger = logging.getLogger(__name__) _DEFAULT_DVSNI_SSL_METHOD = OpenSSL.SSL.SSLv23_METHOD -def _serve_sni(certs, sock, reuseaddr=True, method=_DEFAULT_DVSNI_SSL_METHOD, - accept=None): - """Start SNI-enabled server, that drops connection after handshake. +class SSLSocket(object): # pylint: disable=too-few-public-methods + """SSL wrapper for sockets.""" - :param certs: Mapping from SNI name to ``(key, cert)`` `tuple`. - :param sock: Already bound socket. - :param bool reuseaddr: Should `socket.SO_REUSEADDR` be set? - :param method: See `OpenSSL.SSL.Context` for allowed values. - :param accept: Callable that doesn't take any arguments and - returns ``True`` if more connections should be served. + def __init__(self, sock, certs, method=_DEFAULT_DVSNI_SSL_METHOD): + self.sock = sock + self.certs = certs + self.method = method - """ - def _pick_certificate(connection): + def __getattr__(self, name): + return getattr(self.sock, name) + + def _pick_certificate_cb(self, connection): + """SNI certificate callback. + + This method will set a new OpenSSL context object for this + connection when an incoming connection provides an SNI name + (in order to serve the appropriate certificate, if any). + + :param connection: The TLS connection object on which the SNI + extension was received. + :type connection: :class:`OpenSSL.Connection` + + """ + server_name = connection.get_servername() try: - key, cert = certs[connection.get_servername()] + key, cert = self.certs[server_name] except KeyError: + logger.debug("Server name (%s) not recognized, dropping SSL", + server_name) return - new_context = OpenSSL.SSL.Context(method) + new_context = OpenSSL.SSL.Context(self.method) new_context.use_privatekey(key) new_context.use_certificate(cert) connection.set_context(new_context) - if reuseaddr: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.listen(1) # TODO: add func arg? + class FakeConnection(object): + """Fake OpenSSL.SSL.Connection.""" - while accept is None or accept(): - server, addr = sock.accept() - logger.debug('Received connection from %s', addr) + # pylint: disable=missing-docstring - with contextlib.closing(server): - context = OpenSSL.SSL.Context(method) - context.set_tlsext_servername_callback(_pick_certificate) + def __init__(self, connection): + self._wrapped = connection + self._makefile_refs = 0 - server_ssl = OpenSSL.SSL.Connection(context, server) - server_ssl.set_accept_state() - try: - server_ssl.do_handshake() - server_ssl.shutdown() - except OpenSSL.SSL.Error as error: - raise errors.Error(error) + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def shutdown(self, *unused_args): + # OpenSSL.SSL.Connection.shutdown doesn't accept any args + return self._wrapped.shutdown() + + # stuff below ripped off from + # https://hg.python.org/cpython/file/2.7/Lib/ssl.py + # XXX: this uses Python's internal API + + def makefile(self, mode='r', bufsize=-1): + self._makefile_refs += 1 + # SocketServer.StreamRequesthandler.finish will try to + # close the wfile/rfile. close=True causes curl: (56) + # GnuTLS recv error (-110): The TLS connection was + # non-properly terminated. + # TODO: doesn't work in Python3 + # pylint: disable=protected-access + return socket._fileobject(self._wrapped, mode, bufsize, close=False) + + def close(self): + if self._makefile_refs < 1: + self._wrapped.close() + else: + self._makefile_refs -= 1 + + def accept(self): # pylint: disable=missing-docstring + sock, addr = self.sock.accept() + + context = OpenSSL.SSL.Context(self.method) + context.set_tlsext_servername_callback(self._pick_certificate_cb) + + ssl_sock = self.FakeConnection(OpenSSL.SSL.Connection(context, sock)) + ssl_sock.set_accept_state() + + logger.debug("Performing handshake with %s", addr) + try: + ssl_sock.do_handshake() + except OpenSSL.SSL.Error as error: + # _pick_certificate_cb might have returned without + # creating SSL context (wrong server name) + raise socket.error(error) + + return ssl_sock, addr def probe_sni(name, host, port=443, timeout=300, diff --git a/acme/acme/crypto_util_test.py b/acme/acme/crypto_util_test.py index 64c7cb552..bfd16388c 100644 --- a/acme/acme/crypto_util_test.py +++ b/acme/acme/crypto_util_test.py @@ -4,45 +4,43 @@ import threading import time import unittest -import mock -import OpenSSL +from six.moves import socketserver # pylint: disable=import-error from acme import errors from acme import jose from acme import test_util -class ServeProbeSNITest(unittest.TestCase): - """Tests for acme.crypto_util._serve_sni/probe_sni.""" +class SSLSocketAndProbeSNITest(unittest.TestCase): + """Tests for acme.crypto_util.SSLSocket/probe_sni.""" def setUp(self): self.cert = test_util.load_cert('cert.pem') - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - test_util.load_vector('rsa512_key.pem')) + key = test_util.load_pyopenssl_private_key('rsa512_key.pem') # pylint: disable=protected-access certs = {b'foo': (key, self.cert._wrapped)} - sock = socket.socket() - sock.bind(('', 0)) # pick random port - self.port = sock.getsockname()[1] + from acme.crypto_util import SSLSocket - self.server = threading.Thread(target=self._run_server, args=(certs, sock)) - self.server.start() + class _TestServer(socketserver.TCPServer): + + # pylint: disable=too-few-public-methods + # six.moves.* | pylint: disable=attribute-defined-outside-init,no-init + + def server_bind(self): # pylint: disable=missing-docstring + self.socket = SSLSocket(socket.socket(), certs=certs) + socketserver.TCPServer.server_bind(self) + + self.server = _TestServer(('', 0), socketserver.BaseRequestHandler) + self.port = self.server.socket.getsockname()[1] + self.server_thread = threading.Thread( + # pylint: disable=no-member + target=self.server.handle_request) + self.server_thread.start() time.sleep(1) # TODO: avoid race conditions in other way - @classmethod - def _run_server(cls, certs, sock): - from acme.crypto_util import _serve_sni - # TODO: improve testing of server errors and their conditions - try: - return _serve_sni( - certs, sock, accept=mock.Mock(side_effect=[True, False])) - except errors.Error: - pass - def tearDown(self): - self.server.join() + self.server_thread.join() def _probe(self, name): from acme.crypto_util import probe_sni