1
0
mirror of https://github.com/esp8266/Arduino.git synced 2025-04-21 10:26:06 +03:00

Improve receive handling in TLS support (#43)

This commit is contained in:
Ivan Grokhotkov 2015-09-13 22:49:30 +03:00
parent 2a297ca171
commit 098c71ca02
2 changed files with 75 additions and 13 deletions

View File

@ -29,6 +29,7 @@ extern "C"
} }
#include <errno.h> #include <errno.h>
#include "debug.h" #include "debug.h"
#include "cbuf.h"
#include "ESP8266WiFi.h" #include "ESP8266WiFi.h"
#include "WiFiClientSecure.h" #include "WiFiClientSecure.h"
#include "WiFiClient.h" #include "WiFiClient.h"
@ -41,15 +42,23 @@ extern "C"
#include "include/ClientContext.h" #include "include/ClientContext.h"
#include "c_types.h" #include "c_types.h"
//#define DEBUG_SSL
#ifdef DEBUG_SSL
#define SSL_DEBUG_OPTS SSL_DISPLAY_STATES
#else
#define SSL_DEBUG_OPTS 0
#endif
class SSLContext { class SSLContext {
public: public:
SSLContext() { SSLContext() {
if (_ssl_ctx_refcnt == 0) { if (_ssl_ctx_refcnt == 0) {
_ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DISPLAY_STATES, 0); _ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS, 0);
} }
++_ssl_ctx_refcnt; ++_ssl_ctx_refcnt;
_rxbuf = new cbuf(1536);
} }
~SSLContext() { ~SSLContext() {
@ -62,6 +71,8 @@ public:
if (_ssl_ctx_refcnt == 0) { if (_ssl_ctx_refcnt == 0) {
ssl_ctx_free(_ssl_ctx); ssl_ctx_free(_ssl_ctx);
} }
delete _rxbuf;
} }
void ref() { void ref() {
@ -78,27 +89,71 @@ public:
_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast<int>(ctx), nullptr, 0); _ssl = ssl_client_new(_ssl_ctx, reinterpret_cast<int>(ctx), nullptr, 0);
} }
int read(uint8_t* dst, size_t size) {
if (size > _rxbuf->getSize()) {
_readAll();
}
return _rxbuf->read(reinterpret_cast<char*>(dst), size);
}
int read() {
optimistic_yield(100);
if (!_rxbuf->getSize()) {
_readAll();
}
return _rxbuf->read();
}
int peek() {
if (!_rxbuf->getSize()) {
_readAll();
}
return _rxbuf->peek();
}
int available() {
optimistic_yield(100);
return _rxbuf->getSize();
}
operator SSL*() { operator SSL*() {
return _ssl; return _ssl;
} }
protected: protected:
int _readAll() {
uint8_t* data;
int rc = ssl_read(_ssl, &data);
if (rc <= 0)
return 0;
if (rc > _rxbuf->room()) {
DEBUGV("WiFiClientSecure rx overflow");
rc = _rxbuf->room();
}
int result = 0;
size_t sizeBefore = _rxbuf->getSize();
if (rc)
result = _rxbuf->write(reinterpret_cast<const char*>(data), rc);
DEBUGV("*** rb: %d + %d = %d\r\n", sizeBefore, rc, _rxbuf->getSize());
return result;
}
static SSL_CTX* _ssl_ctx; static SSL_CTX* _ssl_ctx;
static int _ssl_ctx_refcnt; static int _ssl_ctx_refcnt;
SSL* _ssl = nullptr; SSL* _ssl = nullptr;
int _refcnt = 0; int _refcnt = 0;
cbuf* _rxbuf;
}; };
SSL_CTX* SSLContext::_ssl_ctx = nullptr; SSL_CTX* SSLContext::_ssl_ctx = nullptr;
int SSLContext::_ssl_ctx_refcnt = 0; int SSLContext::_ssl_ctx_refcnt = 0;
WiFiClientSecure::WiFiClientSecure() WiFiClientSecure::WiFiClientSecure() {
{
} }
WiFiClientSecure::~WiFiClientSecure() WiFiClientSecure::~WiFiClientSecure() {
{
if (_ssl) { if (_ssl) {
_ssl->unref(); _ssl->unref();
} }
@ -164,14 +219,19 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) {
} }
int WiFiClientSecure::read(uint8_t *buf, size_t size) { int WiFiClientSecure::read(uint8_t *buf, size_t size) {
return _ssl->read(buf, size);
}
uint8_t* data; int WiFiClientSecure::read() {
int rc = ssl_read(*_ssl, &data); return _ssl->read();
if (rc <= 0) }
return 0;
memcpy(buf, data, rc); int WiFiClientSecure::peek() {
return rc; return _ssl->peek();
}
int WiFiClientSecure::available() {
return _ssl->available();
} }
void WiFiClientSecure::stop() { void WiFiClientSecure::stop() {
@ -217,13 +277,13 @@ extern "C" int ax_get_file(const char *filename, uint8_t **buf) {
return 0; return 0;
} }
#ifdef DEBUG_TLS_MEM #ifdef DEBUG_TLS_MEM
#define DEBUG_TLS_MEM_PRINT(...) DEBUGV(__VA_ARGS__) #define DEBUG_TLS_MEM_PRINT(...) DEBUGV(__VA_ARGS__)
#else #else
#define DEBUG_TLS_MEM_PRINT(...) #define DEBUG_TLS_MEM_PRINT(...)
#endif #endif
extern "C" void* ax_port_malloc(size_t size, const char* file, int line) { extern "C" void* ax_port_malloc(size_t size, const char* file, int line) {
void* result = malloc(size); void* result = malloc(size);
@ -254,7 +314,6 @@ extern "C" void* ax_port_realloc(void* ptr, size_t size, const char* file, int l
return result; return result;
} }
extern "C" void ax_port_free(void* ptr) { extern "C" void ax_port_free(void* ptr) {
free(ptr); free(ptr);
uint32_t *p = (uint32_t*) ptr; uint32_t *p = (uint32_t*) ptr;

View File

@ -40,6 +40,9 @@ public:
size_t write(const uint8_t *buf, size_t size) override; size_t write(const uint8_t *buf, size_t size) override;
int read(uint8_t *buf, size_t size) override; int read(uint8_t *buf, size_t size) override;
int available() override;
int read() override;
int peek() override;
void stop() override; void stop() override;
protected: protected: