1
0
mirror of https://github.com/esp8266/Arduino.git synced 2025-04-21 10:26:06 +03:00

WiFiClientSecure: certificate loading refactoring, support for CA root cert verification

This commit is contained in:
Ivan Grokhotkov 2016-08-25 13:01:10 +08:00
parent 7f6e0c98f6
commit b41266097f
2 changed files with 195 additions and 132 deletions

View File

@ -24,8 +24,8 @@
extern "C" extern "C"
{ {
#include "osapi.h" #include "osapi.h"
#include "ets_sys.h" #include "ets_sys.h"
} }
#include <errno.h> #include <errno.h>
#include "debug.h" #include "debug.h"
@ -50,28 +50,19 @@ extern "C"
#define SSL_DEBUG_OPTS 0 #define SSL_DEBUG_OPTS 0
#endif #endif
uint8_t* default_private_key = 0; class SSLContext
uint32_t default_private_key_len = 0; {
static bool default_private_key_dynamic = false;
static int s_pk_refcnt = 0;
uint8_t* default_certificate = 0;
uint32_t default_certificate_len = 0;
static bool default_certificate_dynamic = false;
static void clear_private_key();
static void clear_certificate();
class SSLContext {
public: public:
SSLContext() { SSLContext()
{
if (_ssl_ctx_refcnt == 0) { if (_ssl_ctx_refcnt == 0) {
_ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING, 0); _ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
} }
++_ssl_ctx_refcnt; ++_ssl_ctx_refcnt;
} }
~SSLContext() { ~SSLContext()
{
if (_ssl) { if (_ssl) {
ssl_free(_ssl); ssl_free(_ssl);
_ssl = nullptr; _ssl = nullptr;
@ -85,17 +76,20 @@ public:
s_io_ctx = nullptr; s_io_ctx = nullptr;
} }
void ref() { void ref()
{
++_refcnt; ++_refcnt;
} }
void unref() { void unref()
{
if (--_refcnt == 0) { if (--_refcnt == 0) {
delete this; delete this;
} }
} }
void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms) { void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms)
{
s_io_ctx = ctx; s_io_ctx = ctx;
_ssl = ssl_client_new(_ssl_ctx, 0, nullptr, 0, hostName); _ssl = ssl_client_new(_ssl_ctx, 0, nullptr, 0, hostName);
uint32_t t = millis(); uint32_t t = millis();
@ -109,18 +103,22 @@ public:
} }
} }
void stop() { void stop()
{
s_io_ctx = nullptr; s_io_ctx = nullptr;
} }
bool connected() { bool connected()
{
return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK; return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK;
} }
int read(uint8_t* dst, size_t size) { int read(uint8_t* dst, size_t size)
{
if (!_available) { if (!_available) {
if (!_readAll()) if (!_readAll()) {
return 0; return 0;
}
} }
size_t will_copy = (_available < size) ? _available : size; size_t will_copy = (_available < size) ? _available : size;
memcpy(dst, _read_ptr, will_copy); memcpy(dst, _read_ptr, will_copy);
@ -132,10 +130,12 @@ public:
return will_copy; return will_copy;
} }
int read() { int read()
{
if (!_available) { if (!_available) {
if (!_readAll()) if (!_readAll()) {
return -1; return -1;
}
} }
int result = _read_ptr[0]; int result = _read_ptr[0];
++_read_ptr; ++_read_ptr;
@ -146,18 +146,22 @@ public:
return result; return result;
} }
int peek() { int peek()
{
if (!_available) { if (!_available) {
if (!_readAll()) if (!_readAll()) {
return -1; return -1;
}
} }
return _read_ptr[0]; return _read_ptr[0];
} }
size_t peekBytes(char *dst, size_t size) { size_t peekBytes(char *dst, size_t size)
if(!_available) { {
if(!_readAll()) if (!_available) {
if (!_readAll()) {
return -1; return -1;
}
} }
size_t will_copy = (_available < size) ? _available : size; size_t will_copy = (_available < size) ? _available : size;
@ -165,7 +169,8 @@ public:
return will_copy; return will_copy;
} }
int available() { int available()
{
auto cb = _available; auto cb = _available;
if (cb == 0) { if (cb == 0) {
cb = _readAll(); cb = _readAll();
@ -175,18 +180,49 @@ public:
return cb; return cb;
} }
operator SSL*() { bool loadObject(int type, Stream& stream, size_t size)
{
std::unique_ptr<uint8_t[]> buf(new uint8_t[size]);
if (!buf.get()) {
DEBUGV("loadObject: failed to allocate memory\n");
return false;
}
size_t cb = stream.readBytes(buf.get(), size);
if (cb != size) {
DEBUGV("loadObject: reading %u bytes, got %u\n", size, cb);
return false;
}
return loadObject(type, buf.get(), size);
}
bool loadObject(int type, const uint8_t* data, size_t size)
{
int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast<int>(size), nullptr);
if (rc != SSL_OK) {
DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc);
return false;
}
return true;
}
operator SSL*()
{
return _ssl; return _ssl;
} }
static ClientContext* getIOContext(int fd) { static ClientContext* getIOContext(int fd)
{
return s_io_ctx; return s_io_ctx;
} }
protected: protected:
int _readAll() { int _readAll()
if (!_ssl) {
if (!_ssl) {
return 0; return 0;
}
optimistic_yield(100); optimistic_yield(100);
@ -218,22 +254,19 @@ SSL_CTX* SSLContext::_ssl_ctx = nullptr;
int SSLContext::_ssl_ctx_refcnt = 0; int SSLContext::_ssl_ctx_refcnt = 0;
ClientContext* SSLContext::s_io_ctx = nullptr; ClientContext* SSLContext::s_io_ctx = nullptr;
WiFiClientSecure::WiFiClientSecure() { WiFiClientSecure::WiFiClientSecure()
++s_pk_refcnt; {
} }
WiFiClientSecure::~WiFiClientSecure() { WiFiClientSecure::~WiFiClientSecure()
{
if (_ssl) { if (_ssl) {
_ssl->unref(); _ssl->unref();
} }
if (--s_pk_refcnt == 0) {
clear_private_key();
clear_certificate();
}
} }
WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other) WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other)
: WiFiClient(static_cast<const WiFiClient&>(other)) : WiFiClient(static_cast<const WiFiClient&>(other))
{ {
_ssl = other._ssl; _ssl = other._ssl;
if (_ssl) { if (_ssl) {
@ -241,7 +274,8 @@ WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other)
} }
} }
WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs) { WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs)
{
(WiFiClient&) *this = rhs; (WiFiClient&) *this = rhs;
_ssl = rhs._ssl; _ssl = rhs._ssl;
if (_ssl) { if (_ssl) {
@ -250,14 +284,17 @@ WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs) {
return *this; return *this;
} }
int WiFiClientSecure::connect(IPAddress ip, uint16_t port) { int WiFiClientSecure::connect(IPAddress ip, uint16_t port)
if (!WiFiClient::connect(ip, port)) {
if (!WiFiClient::connect(ip, port)) {
return 0; return 0;
}
return _connectSSL(nullptr); return _connectSSL(nullptr);
} }
int WiFiClientSecure::connect(const char* name, uint16_t port) { int WiFiClientSecure::connect(const char* name, uint16_t port)
{
IPAddress remote_addr; IPAddress remote_addr;
if (!WiFi.hostByName(name, remote_addr)) { if (!WiFi.hostByName(name, remote_addr)) {
return 0; return 0;
@ -268,7 +305,8 @@ int WiFiClientSecure::connect(const char* name, uint16_t port) {
return _connectSSL(name); return _connectSSL(name);
} }
int WiFiClientSecure::_connectSSL(const char* hostName) { int WiFiClientSecure::_connectSSL(const char* hostName)
{
if (_ssl) { if (_ssl) {
_ssl->unref(); _ssl->unref();
_ssl = nullptr; _ssl = nullptr;
@ -288,13 +326,16 @@ int WiFiClientSecure::_connectSSL(const char* hostName) {
return 1; return 1;
} }
size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) { size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
if (!_ssl) {
if (!_ssl) {
return 0; return 0;
}
int rc = ssl_write(*_ssl, buf, size); int rc = ssl_write(*_ssl, buf, size);
if (rc >= 0) if (rc >= 0) {
return rc; return rc;
}
if (rc != SSL_CLOSE_NOTIFY) { if (rc != SSL_CLOSE_NOTIFY) {
_ssl->unref(); _ssl->unref();
@ -304,44 +345,51 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) {
return 0; return 0;
} }
int WiFiClientSecure::read(uint8_t *buf, size_t size) { int WiFiClientSecure::read(uint8_t *buf, size_t size)
if (!_ssl) {
if (!_ssl) {
return 0; return 0;
}
return _ssl->read(buf, size); return _ssl->read(buf, size);
} }
int WiFiClientSecure::read() { int WiFiClientSecure::read()
if (!_ssl) {
if (!_ssl) {
return -1; return -1;
}
return _ssl->read(); return _ssl->read();
} }
int WiFiClientSecure::peek() { int WiFiClientSecure::peek()
if (!_ssl) {
if (!_ssl) {
return -1; return -1;
}
return _ssl->peek(); return _ssl->peek();
} }
size_t WiFiClientSecure::peekBytes(uint8_t *buffer, size_t length) { size_t WiFiClientSecure::peekBytes(uint8_t *buffer, size_t length)
{
size_t count = 0; size_t count = 0;
if(!_ssl) { if (!_ssl) {
return 0; return 0;
} }
_startMillis = millis(); _startMillis = millis();
while((available() < (int) length) && ((millis() - _startMillis) < _timeout)) { while ((available() < (int) length) && ((millis() - _startMillis) < _timeout)) {
yield(); yield();
} }
if(!_ssl) { if (!_ssl) {
return 0; return 0;
} }
if(available() < (int) length) { if (available() < (int) length) {
count = available(); count = available();
} else { } else {
count = length; count = length;
@ -350,9 +398,11 @@ size_t WiFiClientSecure::peekBytes(uint8_t *buffer, size_t length) {
return _ssl->peekBytes((char *)buffer, count); return _ssl->peekBytes((char *)buffer, count);
} }
int WiFiClientSecure::available() { int WiFiClientSecure::available()
if (!_ssl) {
if (!_ssl) {
return 0; return 0;
}
return _ssl->available(); return _ssl->available();
} }
@ -366,7 +416,8 @@ Y Y x Y
x N N N x N N N
err x N N err x N N
*/ */
uint8_t WiFiClientSecure::connected() { uint8_t WiFiClientSecure::connected()
{
if (_ssl) { if (_ssl) {
if (_ssl->available()) { if (_ssl->available()) {
return true; return true;
@ -378,21 +429,21 @@ uint8_t WiFiClientSecure::connected() {
return false; return false;
} }
void WiFiClientSecure::stop() { void WiFiClientSecure::stop()
{
if (_ssl) { if (_ssl) {
_ssl->stop(); _ssl->stop();
} }
WiFiClient::stop(); WiFiClient::stop();
} }
static bool parseHexNibble(char pb, uint8_t* res) { static bool parseHexNibble(char pb, uint8_t* res)
{
if (pb >= '0' && pb <= '9') { if (pb >= '0' && pb <= '9') {
*res = (uint8_t) (pb - '0'); return true; *res = (uint8_t) (pb - '0'); return true;
} } else if (pb >= 'a' && pb <= 'f') {
else if (pb >= 'a' && pb <= 'f') {
*res = (uint8_t) (pb - 'a' + 10); return true; *res = (uint8_t) (pb - 'a' + 10); return true;
} } else if (pb >= 'A' && pb <= 'F') {
else if (pb >= 'A' && pb <= 'F') {
*res = (uint8_t) (pb - 'A' + 10); return true; *res = (uint8_t) (pb - 'A' + 10); return true;
} }
return false; return false;
@ -424,9 +475,11 @@ static bool matchName(const String& name, const String& domainName)
return domainName.substring(domainNameFirstDotPos) == name.substring(firstDotPos); return domainName.substring(domainNameFirstDotPos) == name.substring(firstDotPos);
} }
bool WiFiClientSecure::verify(const char* fp, const char* domain_name) { bool WiFiClientSecure::verify(const char* fp, const char* domain_name)
if (!_ssl) {
if (!_ssl) {
return false; return false;
}
uint8_t sha1[20]; uint8_t sha1[20];
int len = strlen(fp); int len = strlen(fp);
@ -452,13 +505,18 @@ bool WiFiClientSecure::verify(const char* fp, const char* domain_name) {
return false; return false;
} }
return _verifyDN(domain_name);
}
bool WiFiClientSecure::_verifyDN(const char* domain_name)
{
DEBUGV("domain name: '%s'\r\n", (domain_name)?domain_name:"(null)"); DEBUGV("domain name: '%s'\r\n", (domain_name)?domain_name:"(null)");
String domain_name_str(domain_name); String domain_name_str(domain_name);
domain_name_str.toLowerCase(); domain_name_str.toLowerCase();
const char* san = NULL; const char* san = NULL;
int i = 0; int i = 0;
while((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) { while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) {
if (matchName(String(san), domain_name_str)) { if (matchName(String(san), domain_name_str)) {
return true; return true;
} }
@ -474,67 +532,62 @@ bool WiFiClientSecure::verify(const char* fp, const char* domain_name) {
return false; return false;
} }
void WiFiClientSecure::setCertificate(const uint8_t* cert_data, size_t size) { bool WiFiClientSecure::verifyCertChain(const char* domain_name)
clear_certificate(); {
default_certificate = (uint8_t*) cert_data; if (!_ssl) {
default_certificate_len = size; return false;
}
int rc = ssl_verify_cert(*_ssl);
if (rc != SSL_OK) {
DEBUGV("ssl_verify_cert returned %d\n", rc);
return false;
}
return _verifyDN(domain_name);
} }
void WiFiClientSecure::setPrivateKey(const uint8_t* pk, size_t size) { void WiFiClientSecure::setCertificate(const uint8_t* cert_data, size_t size)
clear_private_key(); {
default_private_key = (uint8_t*) pk; if (!_ssl) {
default_private_key_len = size; return;
}
_ssl->loadObject(SSL_OBJ_X509_CERT, cert_data, size);
} }
bool WiFiClientSecure::loadCertificate(Stream& stream, size_t size) { void WiFiClientSecure::setPrivateKey(const uint8_t* pk, size_t size)
clear_certificate(); {
default_certificate = new uint8_t[size]; if (!_ssl) {
if (!default_certificate) { return;
return false; }
} _ssl->loadObject(SSL_OBJ_RSA_KEY, pk, size);
if (stream.readBytes(default_certificate, size) != size) {
delete[] default_certificate;
return false;
}
default_certificate_dynamic = true;
default_certificate_len = size;
return true;
} }
bool WiFiClientSecure::loadPrivateKey(Stream& stream, size_t size) { bool WiFiClientSecure::loadCACert(Stream& stream, size_t size)
clear_private_key(); {
default_private_key = new uint8_t[size]; if (!_ssl) {
if (!default_private_key) { return false;
return false; }
} return _ssl->loadObject(SSL_OBJ_X509_CACERT, stream, size);
if (stream.readBytes(default_private_key, size) != size) {
delete[] default_private_key;
return false;
}
default_private_key_dynamic = true;
default_private_key_len = size;
return true;
} }
static void clear_private_key() { bool WiFiClientSecure::loadCertificate(Stream& stream, size_t size)
if (default_private_key && default_private_key_dynamic) { {
delete[] default_private_key; if (!_ssl) {
default_private_key_dynamic = false; return false;
} }
default_private_key = 0; return _ssl->loadObject(SSL_OBJ_X509_CERT, stream, size);
default_private_key_len = 0;
} }
static void clear_certificate() { bool WiFiClientSecure::loadPrivateKey(Stream& stream, size_t size)
if (default_certificate && default_certificate_dynamic) { {
delete[] default_certificate; if (!_ssl) {
default_certificate_dynamic = false; return false;
} }
default_certificate = 0; return _ssl->loadObject(SSL_OBJ_RSA_KEY, stream, size);
default_certificate_len = 0;
} }
extern "C" int __ax_port_read(int fd, uint8_t* buffer, size_t count) { extern "C" int __ax_port_read(int fd, uint8_t* buffer, size_t count)
{
ClientContext* _client = SSLContext::getIOContext(fd); ClientContext* _client = SSLContext::getIOContext(fd);
if (!_client || _client->state() != ESTABLISHED && !_client->getSize()) { if (!_client || _client->state() != ESTABLISHED && !_client->getSize()) {
errno = EIO; errno = EIO;
@ -552,7 +605,8 @@ extern "C" int __ax_port_read(int fd, uint8_t* buffer, size_t count) {
} }
extern "C" void ax_port_read() __attribute__ ((weak, alias("__ax_port_read"))); extern "C" void ax_port_read() __attribute__ ((weak, alias("__ax_port_read")));
extern "C" int __ax_port_write(int fd, uint8_t* buffer, size_t count) { extern "C" int __ax_port_write(int fd, uint8_t* buffer, size_t count)
{
ClientContext* _client = SSLContext::getIOContext(fd); ClientContext* _client = SSLContext::getIOContext(fd);
if (!_client || _client->state() != ESTABLISHED) { if (!_client || _client->state() != ESTABLISHED) {
errno = EIO; errno = EIO;
@ -567,7 +621,8 @@ extern "C" int __ax_port_write(int fd, uint8_t* buffer, size_t count) {
} }
extern "C" void ax_port_write() __attribute__ ((weak, alias("__ax_port_write"))); extern "C" void ax_port_write() __attribute__ ((weak, alias("__ax_port_write")));
extern "C" int __ax_get_file(const char *filename, uint8_t **buf) { extern "C" int __ax_get_file(const char *filename, uint8_t **buf)
{
*buf = 0; *buf = 0;
return 0; return 0;
} }
@ -580,7 +635,8 @@ extern "C" void ax_get_file() __attribute__ ((weak, alias("__ax_get_file")));
#define DEBUG_TLS_MEM_PRINT(...) #define DEBUG_TLS_MEM_PRINT(...)
#endif #endif
extern "C" void* ax_port_malloc(size_t size, const char* file, int line) { extern "C" void* ax_port_malloc(size_t size, const char* file, int line)
{
void* result = malloc(size); void* result = malloc(size);
if (result == nullptr) { if (result == nullptr) {
DEBUG_TLS_MEM_PRINT("%s:%d malloc %d failed, left %d\r\n", file, line, size, ESP.getFreeHeap()); DEBUG_TLS_MEM_PRINT("%s:%d malloc %d failed, left %d\r\n", file, line, size, ESP.getFreeHeap());
@ -591,13 +647,15 @@ extern "C" void* ax_port_malloc(size_t size, const char* file, int line) {
return result; return result;
} }
extern "C" void* ax_port_calloc(size_t size, size_t count, const char* file, int line) { extern "C" void* ax_port_calloc(size_t size, size_t count, const char* file, int line)
{
void* result = ax_port_malloc(size * count, file, line); void* result = ax_port_malloc(size * count, file, line);
memset(result, 0, size * count); memset(result, 0, size * count);
return result; return result;
} }
extern "C" void* ax_port_realloc(void* ptr, size_t size, const char* file, int line) { extern "C" void* ax_port_realloc(void* ptr, size_t size, const char* file, int line)
{
void* result = realloc(ptr, size); void* result = realloc(ptr, size);
if (result == nullptr) { if (result == nullptr) {
DEBUG_TLS_MEM_PRINT("%s:%d realloc %d failed, left %d\r\n", file, line, size, ESP.getFreeHeap()); DEBUG_TLS_MEM_PRINT("%s:%d realloc %d failed, left %d\r\n", file, line, size, ESP.getFreeHeap());
@ -608,11 +666,13 @@ extern "C" void* ax_port_realloc(void* ptr, size_t size, const char* file, int l
return result; return result;
} }
extern "C" void ax_port_free(void* ptr) { extern "C" void ax_port_free(void* ptr)
{
free(ptr); free(ptr);
} }
extern "C" void __ax_wdt_feed() { extern "C" void __ax_wdt_feed()
{
optimistic_yield(10000); optimistic_yield(10000);
} }
extern "C" void ax_wdt_feed() __attribute__ ((weak, alias("__ax_wdt_feed"))); extern "C" void ax_wdt_feed() __attribute__ ((weak, alias("__ax_wdt_feed")));

View File

@ -39,6 +39,7 @@ public:
int connect(const char* name, uint16_t port) override; int connect(const char* name, uint16_t port) override;
bool verify(const char* fingerprint, const char* domain_name); bool verify(const char* fingerprint, const char* domain_name);
bool verifyCertChain(const char* domain_name);
uint8_t connected() override; uint8_t connected() override;
size_t write(const uint8_t *buf, size_t size) override; size_t write(const uint8_t *buf, size_t size) override;
@ -54,6 +55,7 @@ public:
bool loadCertificate(Stream& stream, size_t size); bool loadCertificate(Stream& stream, size_t size);
bool loadPrivateKey(Stream& stream, size_t size); bool loadPrivateKey(Stream& stream, size_t size);
bool loadCACert(Stream& stream, size_t size);
template<typename TFile> template<typename TFile>
bool loadCertificate(TFile& file) { bool loadCertificate(TFile& file) {
@ -67,6 +69,7 @@ public:
protected: protected:
int _connectSSL(const char* hostName); int _connectSSL(const char* hostName);
bool _verifyDN(const char* name);
SSLContext* _ssl = nullptr; SSLContext* _ssl = nullptr;
}; };