mirror of
https://github.com/esp8266/Arduino.git
synced 2025-07-18 23:03:34 +03:00
Fix mem leak in SSL server, allow for concurrent client and server connections w/o interference (#4305)
* Fix leak on multiple SSL server connections Fixes #4302 The refcnt setup for the WiFiClientSecure's SSLContext and ClientContext had issues in certain conditions, causing a massive memory leak on each SSL server connection. Depending on the state of the machine, after two or three connections it would OOM and crash. This patch replaces most of the refcnt operations with C++11 shared_ptr operations, cleaning up the code substantially and removing the leakage. Also fixes a race condition where ClientContext was free'd before the SSLContext was stopped/shutdown. When the SSLContext tried to do ssl_free, axtls would attempt to send out the real SSL disconnect bits over the wire, however by this time the ClientContext is invalid and it would fault. * Separate client and server SSL_CTX, support both Refactor to use a separate client SSL_CTX and server SSL_CTX. This allows for separate certificates to be installed on each, and means that you can now have both a *single* client and a *single* server running in parallel at the same time, as they'll have separate memory areas. Tested using mqtt_esp8266 SSL client with a client certificate and a WebServerSecure with its own custom certificate and key in parallel. * Add brackets around a couple if-else clauses
This commit is contained in:
committed by
Develo
parent
cda72a07e0
commit
bf5a0f24dc
@ -74,37 +74,47 @@ typedef std::list<BufferItem> BufferList;
|
||||
class SSLContext
|
||||
{
|
||||
public:
|
||||
SSLContext()
|
||||
SSLContext(bool isServer = false)
|
||||
{
|
||||
if (_ssl_ctx_refcnt == 0) {
|
||||
_ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
|
||||
_isServer = isServer;
|
||||
if (!_isServer) {
|
||||
if (_ssl_client_ctx_refcnt == 0) {
|
||||
_ssl_client_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
|
||||
}
|
||||
++_ssl_client_ctx_refcnt;
|
||||
} else {
|
||||
if (_ssl_svr_ctx_refcnt == 0) {
|
||||
_ssl_svr_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
|
||||
}
|
||||
++_ssl_svr_ctx_refcnt;
|
||||
}
|
||||
++_ssl_ctx_refcnt;
|
||||
}
|
||||
|
||||
~SSLContext()
|
||||
{
|
||||
if (_ssl) {
|
||||
ssl_free(_ssl);
|
||||
_ssl = nullptr;
|
||||
if (io_ctx) {
|
||||
io_ctx->unref();
|
||||
io_ctx = nullptr;
|
||||
}
|
||||
|
||||
--_ssl_ctx_refcnt;
|
||||
if (_ssl_ctx_refcnt == 0) {
|
||||
ssl_ctx_free(_ssl_ctx);
|
||||
_ssl = nullptr;
|
||||
if (!_isServer) {
|
||||
--_ssl_client_ctx_refcnt;
|
||||
if (_ssl_client_ctx_refcnt == 0) {
|
||||
ssl_ctx_free(_ssl_client_ctx);
|
||||
_ssl_client_ctx = nullptr;
|
||||
}
|
||||
} else {
|
||||
--_ssl_svr_ctx_refcnt;
|
||||
if (_ssl_svr_ctx_refcnt == 0) {
|
||||
ssl_ctx_free(_ssl_svr_ctx);
|
||||
_ssl_svr_ctx = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ref()
|
||||
static void _delete_shared_SSL(SSL *_to_del)
|
||||
{
|
||||
++_refcnt;
|
||||
}
|
||||
|
||||
void unref()
|
||||
{
|
||||
if (--_refcnt == 0) {
|
||||
delete this;
|
||||
}
|
||||
ssl_free(_to_del);
|
||||
}
|
||||
|
||||
void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms)
|
||||
@ -116,17 +126,23 @@ public:
|
||||
ssl_free will want to send a close notify alert, but the old TCP connection
|
||||
is already gone at this point, so reset io_ctx. */
|
||||
io_ctx = nullptr;
|
||||
ssl_free(_ssl);
|
||||
_ssl = nullptr;
|
||||
_available = 0;
|
||||
_read_ptr = nullptr;
|
||||
}
|
||||
io_ctx = ctx;
|
||||
_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
|
||||
ctx->ref();
|
||||
|
||||
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
|
||||
SSL *_new_ssl = ssl_client_new(_ssl_client_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
|
||||
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
|
||||
_ssl = _new_ssl_shared;
|
||||
|
||||
uint32_t t = millis();
|
||||
|
||||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
|
||||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
|
||||
uint8_t* data;
|
||||
int rc = ssl_read(_ssl, &data);
|
||||
int rc = ssl_read(_ssl.get(), &data);
|
||||
if (rc < SSL_OK) {
|
||||
ssl_display_error(rc);
|
||||
break;
|
||||
@ -134,18 +150,23 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void connectServer(ClientContext *ctx) {
|
||||
void connectServer(ClientContext *ctx, uint32_t timeout_ms)
|
||||
{
|
||||
io_ctx = ctx;
|
||||
_ssl = ssl_server_new(_ssl_ctx, reinterpret_cast<int>(this));
|
||||
_isServer = true;
|
||||
ctx->ref();
|
||||
|
||||
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
|
||||
SSL *_new_ssl = ssl_server_new(_ssl_svr_ctx, reinterpret_cast<int>(this));
|
||||
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
|
||||
_ssl = _new_ssl_shared;
|
||||
|
||||
uint32_t timeout_ms = 5000;
|
||||
uint32_t t = millis();
|
||||
|
||||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
|
||||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
|
||||
uint8_t* data;
|
||||
int rc = ssl_read(_ssl, &data);
|
||||
int rc = ssl_read(_ssl.get(), &data);
|
||||
if (rc < SSL_OK) {
|
||||
ssl_display_error(rc);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -153,13 +174,19 @@ public:
|
||||
|
||||
void stop()
|
||||
{
|
||||
if (io_ctx) {
|
||||
io_ctx->unref();
|
||||
}
|
||||
io_ctx = nullptr;
|
||||
}
|
||||
|
||||
bool connected()
|
||||
{
|
||||
if (_isServer) return _ssl != nullptr;
|
||||
else return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK;
|
||||
if (_isServer) {
|
||||
return _ssl != nullptr;
|
||||
} else {
|
||||
return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK;
|
||||
}
|
||||
}
|
||||
|
||||
int read(uint8_t* dst, size_t size)
|
||||
@ -289,10 +316,9 @@ public:
|
||||
return loadObject(type, buf.get(), size);
|
||||
}
|
||||
|
||||
|
||||
bool loadObject(int type, const uint8_t* data, size_t size)
|
||||
{
|
||||
int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast<int>(size), nullptr);
|
||||
int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast<int>(size), nullptr);
|
||||
if (rc != SSL_OK) {
|
||||
DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc);
|
||||
return false;
|
||||
@ -302,7 +328,7 @@ public:
|
||||
|
||||
bool verifyCert()
|
||||
{
|
||||
int rc = ssl_verify_cert(_ssl);
|
||||
int rc = ssl_verify_cert(_ssl.get());
|
||||
if (_allowSelfSignedCerts && rc == SSL_X509_ERROR(X509_VFY_ERROR_SELF_SIGNED)) {
|
||||
DEBUGV("Allowing self-signed certificate\n");
|
||||
return true;
|
||||
@ -321,12 +347,16 @@ public:
|
||||
|
||||
operator SSL*()
|
||||
{
|
||||
return _ssl;
|
||||
return _ssl.get();
|
||||
}
|
||||
|
||||
static ClientContext* getIOContext(int fd)
|
||||
{
|
||||
return reinterpret_cast<SSLContext*>(fd)->io_ctx;
|
||||
if (fd) {
|
||||
SSLContext *thisSSL = reinterpret_cast<SSLContext*>(fd);
|
||||
return thisSSL->io_ctx;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -339,10 +369,9 @@ protected:
|
||||
optimistic_yield(100);
|
||||
|
||||
uint8_t* data;
|
||||
int rc = ssl_read(_ssl, &data);
|
||||
int rc = ssl_read(_ssl.get(), &data);
|
||||
if (rc <= 0) {
|
||||
if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
|
||||
ssl_free(_ssl);
|
||||
_ssl = nullptr;
|
||||
}
|
||||
return 0;
|
||||
@ -359,7 +388,7 @@ protected:
|
||||
return 0;
|
||||
}
|
||||
|
||||
int rc = ssl_write(_ssl, src, size);
|
||||
int rc = ssl_write(_ssl.get(), src, size);
|
||||
if (rc >= 0) {
|
||||
return rc;
|
||||
}
|
||||
@ -404,10 +433,11 @@ protected:
|
||||
}
|
||||
|
||||
bool _isServer = false;
|
||||
static SSL_CTX* _ssl_ctx;
|
||||
static int _ssl_ctx_refcnt;
|
||||
SSL* _ssl = nullptr;
|
||||
int _refcnt = 0;
|
||||
static SSL_CTX* _ssl_client_ctx;
|
||||
static int _ssl_client_ctx_refcnt;
|
||||
static SSL_CTX* _ssl_svr_ctx;
|
||||
static int _ssl_svr_ctx_refcnt;
|
||||
std::shared_ptr<SSL> _ssl = nullptr;
|
||||
const uint8_t* _read_ptr = nullptr;
|
||||
size_t _available = 0;
|
||||
BufferList _writeBuffers;
|
||||
@ -415,8 +445,10 @@ protected:
|
||||
ClientContext* io_ctx = nullptr;
|
||||
};
|
||||
|
||||
SSL_CTX* SSLContext::_ssl_ctx = nullptr;
|
||||
int SSLContext::_ssl_ctx_refcnt = 0;
|
||||
SSL_CTX* SSLContext::_ssl_client_ctx = nullptr;
|
||||
int SSLContext::_ssl_client_ctx_refcnt = 0;
|
||||
SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr;
|
||||
int SSLContext::_ssl_svr_ctx_refcnt = 0;
|
||||
|
||||
WiFiClientSecure::WiFiClientSecure()
|
||||
{
|
||||
@ -426,41 +458,25 @@ WiFiClientSecure::WiFiClientSecure()
|
||||
|
||||
WiFiClientSecure::~WiFiClientSecure()
|
||||
{
|
||||
if (_ssl) {
|
||||
_ssl->unref();
|
||||
}
|
||||
}
|
||||
|
||||
WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other)
|
||||
: WiFiClient(static_cast<const WiFiClient&>(other))
|
||||
{
|
||||
_ssl = other._ssl;
|
||||
if (_ssl) {
|
||||
_ssl->ref();
|
||||
}
|
||||
}
|
||||
|
||||
WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs)
|
||||
{
|
||||
(WiFiClient&) *this = rhs;
|
||||
_ssl = rhs._ssl;
|
||||
if (_ssl) {
|
||||
_ssl->ref();
|
||||
}
|
||||
return *this;
|
||||
_ssl = nullptr;
|
||||
}
|
||||
|
||||
// Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning
|
||||
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen)
|
||||
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM,
|
||||
const uint8_t *rsakey, int rsakeyLen,
|
||||
const uint8_t *cert, int certLen)
|
||||
{
|
||||
_client = client;
|
||||
if (_ssl) {
|
||||
_ssl->unref();
|
||||
_ssl = nullptr;
|
||||
}
|
||||
// TLS handshake may take more than the 5 second default timeout
|
||||
_timeout = 15000;
|
||||
|
||||
_ssl = new SSLContext;
|
||||
_ssl->ref();
|
||||
// We've been given the client context from the available() call
|
||||
_client = client;
|
||||
_client->ref();
|
||||
|
||||
// Make the "_ssl" SSLContext, in the constructor there should be none yet
|
||||
SSLContext *_new_ssl = new SSLContext(true);
|
||||
std::shared_ptr<SSLContext> _new_ssl_shared(_new_ssl);
|
||||
_ssl = _new_ssl_shared;
|
||||
|
||||
if (usePMEM) {
|
||||
if (rsakey && rsakeyLen) {
|
||||
@ -477,8 +493,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui
|
||||
_ssl->loadObject(SSL_OBJ_X509_CERT, cert, certLen);
|
||||
}
|
||||
}
|
||||
_client->ref();
|
||||
_ssl->connectServer(client);
|
||||
_ssl->connectServer(client, _timeout);
|
||||
}
|
||||
|
||||
int WiFiClientSecure::connect(IPAddress ip, uint16_t port)
|
||||
@ -510,14 +525,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port)
|
||||
int WiFiClientSecure::_connectSSL(const char* hostName)
|
||||
{
|
||||
if (!_ssl) {
|
||||
_ssl = new SSLContext;
|
||||
_ssl->ref();
|
||||
_ssl = std::make_shared<SSLContext>();
|
||||
}
|
||||
_ssl->connect(_client, hostName, _timeout);
|
||||
|
||||
auto status = ssl_handshake_status(*_ssl);
|
||||
if (status != SSL_OK) {
|
||||
_ssl->unref();
|
||||
_ssl = nullptr;
|
||||
return 0;
|
||||
}
|
||||
@ -537,7 +550,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
|
||||
}
|
||||
|
||||
if (rc != SSL_CLOSE_NOTIFY) {
|
||||
_ssl->unref();
|
||||
_ssl = nullptr;
|
||||
}
|
||||
|
||||
@ -640,8 +652,6 @@ void WiFiClientSecure::stop()
|
||||
{
|
||||
if (_ssl) {
|
||||
_ssl->stop();
|
||||
_ssl->unref();
|
||||
_ssl = nullptr;
|
||||
}
|
||||
WiFiClient::stop();
|
||||
}
|
||||
@ -723,9 +733,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name)
|
||||
String domain_name_str(domain_name);
|
||||
domain_name_str.toLowerCase();
|
||||
|
||||
const char* san = NULL;
|
||||
const char* san = nullptr;
|
||||
int i = 0;
|
||||
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) {
|
||||
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != nullptr) {
|
||||
String san_str(san);
|
||||
san_str.toLowerCase();
|
||||
if (matchName(san_str, domain_name_str)) {
|
||||
@ -759,8 +769,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name)
|
||||
void WiFiClientSecure::_initSSLContext()
|
||||
{
|
||||
if (!_ssl) {
|
||||
_ssl = new SSLContext;
|
||||
_ssl->ref();
|
||||
_ssl = std::make_shared<SSLContext>();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,8 +32,6 @@ class WiFiClientSecure : public WiFiClient {
|
||||
public:
|
||||
WiFiClientSecure();
|
||||
~WiFiClientSecure() override;
|
||||
WiFiClientSecure(const WiFiClientSecure&);
|
||||
WiFiClientSecure& operator=(const WiFiClientSecure&);
|
||||
|
||||
int connect(IPAddress ip, uint16_t port) override;
|
||||
int connect(const String host, uint16_t port) override;
|
||||
@ -91,7 +89,7 @@ protected:
|
||||
int _connectSSL(const char* hostName);
|
||||
bool _verifyDN(const char* name);
|
||||
|
||||
SSLContext* _ssl = nullptr;
|
||||
std::shared_ptr<SSLContext> _ssl = nullptr;
|
||||
};
|
||||
|
||||
#endif //wificlientsecure_h
|
||||
|
Reference in New Issue
Block a user