1
0
mirror of https://github.com/esp8266/Arduino.git synced 2025-07-18 23:03:34 +03:00

Fix mem leak in SSL server, allow for concurrent client and server connections w/o interference (#4305)

* Fix leak on multiple SSL server connections

Fixes #4302

The refcnt setup for the WiFiClientSecure's SSLContext and ClientContext
had issues in certain conditions, causing a massive memory leak on each
SSL server connection.  Depending on the state of the machine, after two or
three connections it would OOM and crash.

This patch replaces most of the refcnt operations with C++11 shared_ptr
operations, cleaning up the code substantially and removing the leakage.

Also fixes a race condition where ClientContext was free'd before the SSLContext
was stopped/shutdown.  When the SSLContext tried to do ssl_free, axtls would
attempt to send out the real SSL disconnect bits over the wire, however by
this time the ClientContext is invalid and it would fault.

* Separate client and server SSL_CTX, support both

Refactor to use a separate client SSL_CTX and server SSL_CTX.  This
allows for separate certificates to be installed on each, and means
that you can now have both a *single* client and a *single* server
running in parallel at the same time, as they'll have separate memory
areas.

Tested using mqtt_esp8266 SSL client with a client certificate and a
WebServerSecure with its own custom certificate and key in parallel.

* Add brackets around a couple if-else clauses
This commit is contained in:
Earle F. Philhower, III
2018-02-08 10:25:24 -08:00
committed by Develo
parent cda72a07e0
commit bf5a0f24dc
2 changed files with 98 additions and 91 deletions

View File

@ -74,37 +74,47 @@ typedef std::list<BufferItem> BufferList;
class SSLContext
{
public:
SSLContext()
SSLContext(bool isServer = false)
{
if (_ssl_ctx_refcnt == 0) {
_ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
_isServer = isServer;
if (!_isServer) {
if (_ssl_client_ctx_refcnt == 0) {
_ssl_client_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
}
++_ssl_client_ctx_refcnt;
} else {
if (_ssl_svr_ctx_refcnt == 0) {
_ssl_svr_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
}
++_ssl_svr_ctx_refcnt;
}
++_ssl_ctx_refcnt;
}
~SSLContext()
{
if (_ssl) {
ssl_free(_ssl);
_ssl = nullptr;
if (io_ctx) {
io_ctx->unref();
io_ctx = nullptr;
}
--_ssl_ctx_refcnt;
if (_ssl_ctx_refcnt == 0) {
ssl_ctx_free(_ssl_ctx);
_ssl = nullptr;
if (!_isServer) {
--_ssl_client_ctx_refcnt;
if (_ssl_client_ctx_refcnt == 0) {
ssl_ctx_free(_ssl_client_ctx);
_ssl_client_ctx = nullptr;
}
} else {
--_ssl_svr_ctx_refcnt;
if (_ssl_svr_ctx_refcnt == 0) {
ssl_ctx_free(_ssl_svr_ctx);
_ssl_svr_ctx = nullptr;
}
}
}
void ref()
static void _delete_shared_SSL(SSL *_to_del)
{
++_refcnt;
}
void unref()
{
if (--_refcnt == 0) {
delete this;
}
ssl_free(_to_del);
}
void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms)
@ -116,17 +126,23 @@ public:
ssl_free will want to send a close notify alert, but the old TCP connection
is already gone at this point, so reset io_ctx. */
io_ctx = nullptr;
ssl_free(_ssl);
_ssl = nullptr;
_available = 0;
_read_ptr = nullptr;
}
io_ctx = ctx;
_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
ctx->ref();
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
SSL *_new_ssl = ssl_client_new(_ssl_client_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
_ssl = _new_ssl_shared;
uint32_t t = millis();
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
uint8_t* data;
int rc = ssl_read(_ssl, &data);
int rc = ssl_read(_ssl.get(), &data);
if (rc < SSL_OK) {
ssl_display_error(rc);
break;
@ -134,18 +150,23 @@ public:
}
}
void connectServer(ClientContext *ctx) {
void connectServer(ClientContext *ctx, uint32_t timeout_ms)
{
io_ctx = ctx;
_ssl = ssl_server_new(_ssl_ctx, reinterpret_cast<int>(this));
_isServer = true;
ctx->ref();
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
SSL *_new_ssl = ssl_server_new(_ssl_svr_ctx, reinterpret_cast<int>(this));
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
_ssl = _new_ssl_shared;
uint32_t timeout_ms = 5000;
uint32_t t = millis();
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
uint8_t* data;
int rc = ssl_read(_ssl, &data);
int rc = ssl_read(_ssl.get(), &data);
if (rc < SSL_OK) {
ssl_display_error(rc);
break;
}
}
@ -153,13 +174,19 @@ public:
void stop()
{
if (io_ctx) {
io_ctx->unref();
}
io_ctx = nullptr;
}
bool connected()
{
if (_isServer) return _ssl != nullptr;
else return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK;
if (_isServer) {
return _ssl != nullptr;
} else {
return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK;
}
}
int read(uint8_t* dst, size_t size)
@ -289,10 +316,9 @@ public:
return loadObject(type, buf.get(), size);
}
bool loadObject(int type, const uint8_t* data, size_t size)
{
int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast<int>(size), nullptr);
int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast<int>(size), nullptr);
if (rc != SSL_OK) {
DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc);
return false;
@ -302,7 +328,7 @@ public:
bool verifyCert()
{
int rc = ssl_verify_cert(_ssl);
int rc = ssl_verify_cert(_ssl.get());
if (_allowSelfSignedCerts && rc == SSL_X509_ERROR(X509_VFY_ERROR_SELF_SIGNED)) {
DEBUGV("Allowing self-signed certificate\n");
return true;
@ -321,12 +347,16 @@ public:
operator SSL*()
{
return _ssl;
return _ssl.get();
}
static ClientContext* getIOContext(int fd)
{
return reinterpret_cast<SSLContext*>(fd)->io_ctx;
if (fd) {
SSLContext *thisSSL = reinterpret_cast<SSLContext*>(fd);
return thisSSL->io_ctx;
}
return nullptr;
}
protected:
@ -339,10 +369,9 @@ protected:
optimistic_yield(100);
uint8_t* data;
int rc = ssl_read(_ssl, &data);
int rc = ssl_read(_ssl.get(), &data);
if (rc <= 0) {
if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
ssl_free(_ssl);
_ssl = nullptr;
}
return 0;
@ -359,7 +388,7 @@ protected:
return 0;
}
int rc = ssl_write(_ssl, src, size);
int rc = ssl_write(_ssl.get(), src, size);
if (rc >= 0) {
return rc;
}
@ -404,10 +433,11 @@ protected:
}
bool _isServer = false;
static SSL_CTX* _ssl_ctx;
static int _ssl_ctx_refcnt;
SSL* _ssl = nullptr;
int _refcnt = 0;
static SSL_CTX* _ssl_client_ctx;
static int _ssl_client_ctx_refcnt;
static SSL_CTX* _ssl_svr_ctx;
static int _ssl_svr_ctx_refcnt;
std::shared_ptr<SSL> _ssl = nullptr;
const uint8_t* _read_ptr = nullptr;
size_t _available = 0;
BufferList _writeBuffers;
@ -415,8 +445,10 @@ protected:
ClientContext* io_ctx = nullptr;
};
SSL_CTX* SSLContext::_ssl_ctx = nullptr;
int SSLContext::_ssl_ctx_refcnt = 0;
SSL_CTX* SSLContext::_ssl_client_ctx = nullptr;
int SSLContext::_ssl_client_ctx_refcnt = 0;
SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr;
int SSLContext::_ssl_svr_ctx_refcnt = 0;
WiFiClientSecure::WiFiClientSecure()
{
@ -426,41 +458,25 @@ WiFiClientSecure::WiFiClientSecure()
WiFiClientSecure::~WiFiClientSecure()
{
if (_ssl) {
_ssl->unref();
}
}
WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other)
: WiFiClient(static_cast<const WiFiClient&>(other))
{
_ssl = other._ssl;
if (_ssl) {
_ssl->ref();
}
}
WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs)
{
(WiFiClient&) *this = rhs;
_ssl = rhs._ssl;
if (_ssl) {
_ssl->ref();
}
return *this;
_ssl = nullptr;
}
// Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen)
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM,
const uint8_t *rsakey, int rsakeyLen,
const uint8_t *cert, int certLen)
{
_client = client;
if (_ssl) {
_ssl->unref();
_ssl = nullptr;
}
// TLS handshake may take more than the 5 second default timeout
_timeout = 15000;
_ssl = new SSLContext;
_ssl->ref();
// We've been given the client context from the available() call
_client = client;
_client->ref();
// Make the "_ssl" SSLContext, in the constructor there should be none yet
SSLContext *_new_ssl = new SSLContext(true);
std::shared_ptr<SSLContext> _new_ssl_shared(_new_ssl);
_ssl = _new_ssl_shared;
if (usePMEM) {
if (rsakey && rsakeyLen) {
@ -477,8 +493,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui
_ssl->loadObject(SSL_OBJ_X509_CERT, cert, certLen);
}
}
_client->ref();
_ssl->connectServer(client);
_ssl->connectServer(client, _timeout);
}
int WiFiClientSecure::connect(IPAddress ip, uint16_t port)
@ -510,14 +525,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port)
int WiFiClientSecure::_connectSSL(const char* hostName)
{
if (!_ssl) {
_ssl = new SSLContext;
_ssl->ref();
_ssl = std::make_shared<SSLContext>();
}
_ssl->connect(_client, hostName, _timeout);
auto status = ssl_handshake_status(*_ssl);
if (status != SSL_OK) {
_ssl->unref();
_ssl = nullptr;
return 0;
}
@ -537,7 +550,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
}
if (rc != SSL_CLOSE_NOTIFY) {
_ssl->unref();
_ssl = nullptr;
}
@ -640,8 +652,6 @@ void WiFiClientSecure::stop()
{
if (_ssl) {
_ssl->stop();
_ssl->unref();
_ssl = nullptr;
}
WiFiClient::stop();
}
@ -723,9 +733,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name)
String domain_name_str(domain_name);
domain_name_str.toLowerCase();
const char* san = NULL;
const char* san = nullptr;
int i = 0;
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) {
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != nullptr) {
String san_str(san);
san_str.toLowerCase();
if (matchName(san_str, domain_name_str)) {
@ -759,8 +769,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name)
void WiFiClientSecure::_initSSLContext()
{
if (!_ssl) {
_ssl = new SSLContext;
_ssl->ref();
_ssl = std::make_shared<SSLContext>();
}
}

View File

@ -32,8 +32,6 @@ class WiFiClientSecure : public WiFiClient {
public:
WiFiClientSecure();
~WiFiClientSecure() override;
WiFiClientSecure(const WiFiClientSecure&);
WiFiClientSecure& operator=(const WiFiClientSecure&);
int connect(IPAddress ip, uint16_t port) override;
int connect(const String host, uint16_t port) override;
@ -91,7 +89,7 @@ protected:
int _connectSSL(const char* hostName);
bool _verifyDN(const char* name);
SSLContext* _ssl = nullptr;
std::shared_ptr<SSLContext> _ssl = nullptr;
};
#endif //wificlientsecure_h