1
0
mirror of https://github.com/certbot/certbot.git synced 2026-01-21 19:01:07 +03:00

acme: _serve_sni -> SSLSocket

This commit is contained in:
Jakub Warmuz
2015-09-26 14:55:27 +00:00
parent c74bc409d8
commit d73b600eeb
2 changed files with 98 additions and 52 deletions

View File

@@ -26,47 +26,95 @@ logger = logging.getLogger(__name__)
_DEFAULT_DVSNI_SSL_METHOD = OpenSSL.SSL.SSLv23_METHOD
def _serve_sni(certs, sock, reuseaddr=True, method=_DEFAULT_DVSNI_SSL_METHOD,
accept=None):
"""Start SNI-enabled server, that drops connection after handshake.
class SSLSocket(object): # pylint: disable=too-few-public-methods
"""SSL wrapper for sockets."""
:param certs: Mapping from SNI name to ``(key, cert)`` `tuple`.
:param sock: Already bound socket.
:param bool reuseaddr: Should `socket.SO_REUSEADDR` be set?
:param method: See `OpenSSL.SSL.Context` for allowed values.
:param accept: Callable that doesn't take any arguments and
returns ``True`` if more connections should be served.
def __init__(self, sock, certs, method=_DEFAULT_DVSNI_SSL_METHOD):
self.sock = sock
self.certs = certs
self.method = method
"""
def _pick_certificate(connection):
def __getattr__(self, name):
return getattr(self.sock, name)
def _pick_certificate_cb(self, connection):
"""SNI certificate callback.
This method will set a new OpenSSL context object for this
connection when an incoming connection provides an SNI name
(in order to serve the appropriate certificate, if any).
:param connection: The TLS connection object on which the SNI
extension was received.
:type connection: :class:`OpenSSL.Connection`
"""
server_name = connection.get_servername()
try:
key, cert = certs[connection.get_servername()]
key, cert = self.certs[server_name]
except KeyError:
logger.debug("Server name (%s) not recognized, dropping SSL",
server_name)
return
new_context = OpenSSL.SSL.Context(method)
new_context = OpenSSL.SSL.Context(self.method)
new_context.use_privatekey(key)
new_context.use_certificate(cert)
connection.set_context(new_context)
if reuseaddr:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.listen(1) # TODO: add func arg?
class FakeConnection(object):
"""Fake OpenSSL.SSL.Connection."""
while accept is None or accept():
server, addr = sock.accept()
logger.debug('Received connection from %s', addr)
# pylint: disable=missing-docstring
with contextlib.closing(server):
context = OpenSSL.SSL.Context(method)
context.set_tlsext_servername_callback(_pick_certificate)
def __init__(self, connection):
self._wrapped = connection
self._makefile_refs = 0
server_ssl = OpenSSL.SSL.Connection(context, server)
server_ssl.set_accept_state()
try:
server_ssl.do_handshake()
server_ssl.shutdown()
except OpenSSL.SSL.Error as error:
raise errors.Error(error)
def __getattr__(self, name):
return getattr(self._wrapped, name)
def shutdown(self, *unused_args):
# OpenSSL.SSL.Connection.shutdown doesn't accept any args
return self._wrapped.shutdown()
# stuff below ripped off from
# https://hg.python.org/cpython/file/2.7/Lib/ssl.py
# XXX: this uses Python's internal API
def makefile(self, mode='r', bufsize=-1):
self._makefile_refs += 1
# SocketServer.StreamRequesthandler.finish will try to
# close the wfile/rfile. close=True causes curl: (56)
# GnuTLS recv error (-110): The TLS connection was
# non-properly terminated.
# TODO: doesn't work in Python3
# pylint: disable=protected-access
return socket._fileobject(self._wrapped, mode, bufsize, close=False)
def close(self):
if self._makefile_refs < 1:
self._wrapped.close()
else:
self._makefile_refs -= 1
def accept(self): # pylint: disable=missing-docstring
sock, addr = self.sock.accept()
context = OpenSSL.SSL.Context(self.method)
context.set_tlsext_servername_callback(self._pick_certificate_cb)
ssl_sock = self.FakeConnection(OpenSSL.SSL.Connection(context, sock))
ssl_sock.set_accept_state()
logger.debug("Performing handshake with %s", addr)
try:
ssl_sock.do_handshake()
except OpenSSL.SSL.Error as error:
# _pick_certificate_cb might have returned without
# creating SSL context (wrong server name)
raise socket.error(error)
return ssl_sock, addr
def probe_sni(name, host, port=443, timeout=300,

View File

@@ -4,45 +4,43 @@ import threading
import time
import unittest
import mock
import OpenSSL
from six.moves import socketserver # pylint: disable=import-error
from acme import errors
from acme import jose
from acme import test_util
class ServeProbeSNITest(unittest.TestCase):
"""Tests for acme.crypto_util._serve_sni/probe_sni."""
class SSLSocketAndProbeSNITest(unittest.TestCase):
"""Tests for acme.crypto_util.SSLSocket/probe_sni."""
def setUp(self):
self.cert = test_util.load_cert('cert.pem')
key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
test_util.load_vector('rsa512_key.pem'))
key = test_util.load_pyopenssl_private_key('rsa512_key.pem')
# pylint: disable=protected-access
certs = {b'foo': (key, self.cert._wrapped)}
sock = socket.socket()
sock.bind(('', 0)) # pick random port
self.port = sock.getsockname()[1]
from acme.crypto_util import SSLSocket
self.server = threading.Thread(target=self._run_server, args=(certs, sock))
self.server.start()
class _TestServer(socketserver.TCPServer):
# pylint: disable=too-few-public-methods
# six.moves.* | pylint: disable=attribute-defined-outside-init,no-init
def server_bind(self): # pylint: disable=missing-docstring
self.socket = SSLSocket(socket.socket(), certs=certs)
socketserver.TCPServer.server_bind(self)
self.server = _TestServer(('', 0), socketserver.BaseRequestHandler)
self.port = self.server.socket.getsockname()[1]
self.server_thread = threading.Thread(
# pylint: disable=no-member
target=self.server.handle_request)
self.server_thread.start()
time.sleep(1) # TODO: avoid race conditions in other way
@classmethod
def _run_server(cls, certs, sock):
from acme.crypto_util import _serve_sni
# TODO: improve testing of server errors and their conditions
try:
return _serve_sni(
certs, sock, accept=mock.Mock(side_effect=[True, False]))
except errors.Error:
pass
def tearDown(self):
self.server.join()
self.server_thread.join()
def _probe(self, name):
from acme.crypto_util import probe_sni