mirror of
https://github.com/esp8266/Arduino.git
synced 2025-04-19 23:22:16 +03:00
Avoid a null pointer exception when ArduinoOTA.end() is called more than once and thus the UDP socket is already freed. Also avoid unnecessary teardown if the class is not initialized yet (for example, begin() wasn't called yet, or end() is called multiple times).
409 lines
9.4 KiB
C++
409 lines
9.4 KiB
C++
#ifndef LWIP_OPEN_SRC
|
|
#define LWIP_OPEN_SRC
|
|
#endif
|
|
#include <functional>
|
|
#include <WiFiUdp.h>
|
|
#include <eboot_command.h>
|
|
#include "ArduinoOTA.h"
|
|
#include "MD5Builder.h"
|
|
#include "StreamString.h"
|
|
|
|
extern "C" {
|
|
#include "osapi.h"
|
|
#include "ets_sys.h"
|
|
#include "user_interface.h"
|
|
}
|
|
|
|
#include "lwip/opt.h"
|
|
#include "lwip/udp.h"
|
|
#include "lwip/inet.h"
|
|
#include "lwip/igmp.h"
|
|
#include "lwip/mem.h"
|
|
#include "include/UdpContext.h"
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
#include <ESP8266mDNS.h>
|
|
#endif
|
|
|
|
#if defined(DEBUG_ESP_OTA) && defined(DEBUG_ESP_PORT)
|
|
#define OTA_DEBUG DEBUG_ESP_PORT
|
|
#define OTA_DEBUG_PRINTF(fmt, ...) OTA_DEBUG.printf_P(PSTR(fmt), ##__VA_ARGS__)
|
|
#else
|
|
#define OTA_DEBUG_PRINTF(...)
|
|
#endif
|
|
|
|
ArduinoOTAClass::ArduinoOTAClass()
|
|
{
|
|
}
|
|
|
|
ArduinoOTAClass::~ArduinoOTAClass(){
|
|
if(_udp_ota){
|
|
_udp_ota->unref();
|
|
_udp_ota = 0;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::onStart(THandlerFunction fn) {
|
|
_start_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::onEnd(THandlerFunction fn) {
|
|
_end_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::onProgress(THandlerFunction_Progress fn) {
|
|
_progress_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::onError(THandlerFunction_Error fn) {
|
|
_error_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::setPort(uint16_t port) {
|
|
if (!_initialized && !_port && port) {
|
|
_port = port;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::setHostname(const char * hostname) {
|
|
if (!_initialized && !_hostname.length() && hostname) {
|
|
_hostname = hostname;
|
|
}
|
|
}
|
|
|
|
String ArduinoOTAClass::getHostname() {
|
|
return _hostname;
|
|
}
|
|
|
|
void ArduinoOTAClass::setPassword(const char * password) {
|
|
if (!_initialized && !_password.length() && password) {
|
|
MD5Builder passmd5;
|
|
passmd5.begin();
|
|
passmd5.add(password);
|
|
passmd5.calculate();
|
|
_password = passmd5.toString();
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::setPasswordHash(const char * password) {
|
|
if (!_initialized && !_password.length() && password) {
|
|
_password = password;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::setRebootOnSuccess(bool reboot){
|
|
_rebootOnSuccess = reboot;
|
|
}
|
|
|
|
void ArduinoOTAClass::setEraseConfig(ota_erase_cfg_t eraseConfig){
|
|
_eraseConfig = eraseConfig;
|
|
}
|
|
|
|
void ArduinoOTAClass::begin(bool useMDNS) {
|
|
if (_initialized)
|
|
return;
|
|
|
|
_useMDNS = useMDNS;
|
|
|
|
if (!_hostname.length()) {
|
|
char tmp[15];
|
|
sprintf(tmp, "esp8266-%06x", ESP.getChipId());
|
|
_hostname = tmp;
|
|
}
|
|
if (!_port) {
|
|
_port = 8266;
|
|
}
|
|
|
|
if(_udp_ota){
|
|
_udp_ota->unref();
|
|
_udp_ota = 0;
|
|
}
|
|
|
|
_udp_ota = new UdpContext;
|
|
_udp_ota->ref();
|
|
|
|
if(!_udp_ota->listen(IP_ADDR_ANY, _port))
|
|
return;
|
|
_udp_ota->onRx(std::bind(&ArduinoOTAClass::_onRx, this));
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
if(_useMDNS) {
|
|
MDNS.begin(_hostname.c_str());
|
|
|
|
if (_password.length()) {
|
|
MDNS.enableArduino(_port, true);
|
|
} else {
|
|
MDNS.enableArduino(_port);
|
|
}
|
|
}
|
|
#endif
|
|
_initialized = true;
|
|
_state = OTA_IDLE;
|
|
OTA_DEBUG_PRINTF("OTA server at: %s.local:%u\n", _hostname.c_str(), _port);
|
|
}
|
|
|
|
int ArduinoOTAClass::parseInt(){
|
|
char data[16];
|
|
uint8_t index;
|
|
char value;
|
|
while(_udp_ota->peek() == ' ') _udp_ota->read();
|
|
for(index = 0; index < sizeof(data); ++index){
|
|
value = _udp_ota->peek();
|
|
if(value < '0' || value > '9'){
|
|
data[index] = '\0';
|
|
return atoi(data);
|
|
}
|
|
data[index] = _udp_ota->read();
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
String ArduinoOTAClass::readStringUntil(char end){
|
|
String res;
|
|
int value;
|
|
while(true){
|
|
value = _udp_ota->read();
|
|
if(value < 0 || value == '\0' || value == end){
|
|
return res;
|
|
}
|
|
res += static_cast<char>(value);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
void ArduinoOTAClass::_onRx(){
|
|
if(!_udp_ota->next()) return;
|
|
IPAddress ota_ip;
|
|
|
|
if (_state == OTA_IDLE) {
|
|
int cmd = parseInt();
|
|
if (cmd != U_FLASH && cmd != U_FS)
|
|
return;
|
|
_ota_ip = _udp_ota->getRemoteAddress();
|
|
_cmd = cmd;
|
|
_ota_port = parseInt();
|
|
_ota_udp_port = _udp_ota->getRemotePort();
|
|
_size = parseInt();
|
|
_udp_ota->read();
|
|
_md5 = readStringUntil('\n');
|
|
_md5.trim();
|
|
if(_md5.length() != 32)
|
|
return;
|
|
|
|
ota_ip = _ota_ip;
|
|
|
|
if (_password.length()){
|
|
MD5Builder nonce_md5;
|
|
nonce_md5.begin();
|
|
nonce_md5.add(String(micros()));
|
|
nonce_md5.calculate();
|
|
_nonce = nonce_md5.toString();
|
|
|
|
char auth_req[38];
|
|
sprintf(auth_req, "AUTH %s", _nonce.c_str());
|
|
_udp_ota->append((const char *)auth_req, strlen(auth_req));
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
_state = OTA_WAITAUTH;
|
|
return;
|
|
} else {
|
|
_state = OTA_RUNUPDATE;
|
|
}
|
|
} else if (_state == OTA_WAITAUTH) {
|
|
int cmd = parseInt();
|
|
if (cmd != U_AUTH) {
|
|
_state = OTA_IDLE;
|
|
return;
|
|
}
|
|
_udp_ota->read();
|
|
String cnonce = readStringUntil(' ');
|
|
String response = readStringUntil('\n');
|
|
if (cnonce.length() != 32 || response.length() != 32) {
|
|
_state = OTA_IDLE;
|
|
return;
|
|
}
|
|
|
|
String challenge = _password + ':' + String(_nonce) + ':' + cnonce;
|
|
MD5Builder _challengemd5;
|
|
_challengemd5.begin();
|
|
_challengemd5.add(challenge);
|
|
_challengemd5.calculate();
|
|
String result = _challengemd5.toString();
|
|
|
|
ota_ip = _ota_ip;
|
|
if(result.equalsConstantTime(response)) {
|
|
_state = OTA_RUNUPDATE;
|
|
} else {
|
|
_udp_ota->append("Authentication Failed", 21);
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
if (_error_callback) _error_callback(OTA_AUTH_ERROR);
|
|
_state = OTA_IDLE;
|
|
}
|
|
}
|
|
|
|
while(_udp_ota->next()) _udp_ota->flush();
|
|
}
|
|
|
|
void ArduinoOTAClass::_runUpdate() {
|
|
IPAddress ota_ip = _ota_ip;
|
|
|
|
if (!Update.begin(_size, _cmd)) {
|
|
OTA_DEBUG_PRINTF("Update Begin Error\n");
|
|
if (_error_callback) {
|
|
_error_callback(OTA_BEGIN_ERROR);
|
|
}
|
|
|
|
StreamString ss;
|
|
Update.printError(ss);
|
|
_udp_ota->append("ERR: ", 5);
|
|
_udp_ota->append(ss.c_str(), ss.length());
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
delay(100);
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
_state = OTA_IDLE;
|
|
return;
|
|
}
|
|
_udp_ota->append("OK", 2);
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
delay(100);
|
|
|
|
Update.setMD5(_md5.c_str());
|
|
|
|
if (_start_callback) {
|
|
_start_callback();
|
|
}
|
|
if (_progress_callback) {
|
|
_progress_callback(0, _size);
|
|
}
|
|
|
|
WiFiClient client;
|
|
if (!client.connect(_ota_ip, _ota_port)) {
|
|
OTA_DEBUG_PRINTF("Connect Failed\n");
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
if (_error_callback) {
|
|
_error_callback(OTA_CONNECT_ERROR);
|
|
}
|
|
_state = OTA_IDLE;
|
|
}
|
|
// OTA sends little packets
|
|
client.setNoDelay(true);
|
|
|
|
uint32_t written, total = 0;
|
|
while (!Update.isFinished() && (client.connected() || client.available())) {
|
|
int waited = 1000;
|
|
while (!client.available() && waited--)
|
|
delay(1);
|
|
if (!waited){
|
|
OTA_DEBUG_PRINTF("Receive Failed\n");
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
if (_error_callback) {
|
|
_error_callback(OTA_RECEIVE_ERROR);
|
|
}
|
|
_state = OTA_IDLE;
|
|
}
|
|
written = Update.write(client);
|
|
if (written > 0) {
|
|
client.print(written, DEC);
|
|
total += written;
|
|
if(_progress_callback) {
|
|
_progress_callback(total, _size);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (Update.end()) {
|
|
// Ensure last count packet has been sent out and not combined with the final OK
|
|
client.flush();
|
|
delay(1000);
|
|
client.print("OK");
|
|
client.flush();
|
|
delay(1000);
|
|
client.stop();
|
|
OTA_DEBUG_PRINTF("Update Success\n");
|
|
if (_end_callback) {
|
|
_end_callback();
|
|
}
|
|
if(_rebootOnSuccess){
|
|
OTA_DEBUG_PRINTF("Rebooting...\n");
|
|
//let serial/network finish tasks that might be given in _end_callback
|
|
delay(100);
|
|
if (OTA_ERASE_CFG_NO != _eraseConfig) {
|
|
eraseConfigAndReset(); // returns on failure
|
|
if (_error_callback) {
|
|
_error_callback(OTA_ERASE_SETTINGS_ERROR);
|
|
}
|
|
if (OTA_ERASE_CFG_ABORT_ON_ERROR == _eraseConfig) {
|
|
eboot_command_clear();
|
|
return;
|
|
}
|
|
#ifdef OTA_DEBUG
|
|
else if (OTA_ERASE_CFG_IGNORE_ERROR == _eraseConfig) {
|
|
// Fallthrough and restart
|
|
} else {
|
|
panic();
|
|
}
|
|
#endif
|
|
}
|
|
ESP.restart();
|
|
}
|
|
} else {
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
if (_error_callback) {
|
|
_error_callback(OTA_END_ERROR);
|
|
}
|
|
Update.printError(client);
|
|
#ifdef OTA_DEBUG
|
|
Update.printError(OTA_DEBUG);
|
|
#endif
|
|
_state = OTA_IDLE;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::end() {
|
|
if (!_initialized)
|
|
return;
|
|
|
|
_initialized = false;
|
|
if(_udp_ota){
|
|
_udp_ota->unref();
|
|
_udp_ota = 0;
|
|
}
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
if(_useMDNS){
|
|
MDNS.end();
|
|
}
|
|
#endif
|
|
_state = OTA_IDLE;
|
|
OTA_DEBUG_PRINTF("OTA server stopped.\n");
|
|
}
|
|
|
|
void ArduinoOTAClass::eraseConfigAndReset() {
|
|
OTA_DEBUG_PRINTF("Erase Config and Hard Reset ...\n");
|
|
if (WiFi.mode(WIFI_OFF)) {
|
|
ESP.eraseConfigAndReset(); // No return testing - Only returns on failure
|
|
OTA_DEBUG_PRINTF(" ESP.eraseConfigAndReset() failed!\n");
|
|
} else {
|
|
OTA_DEBUG_PRINTF(" WiFi.mode(WIFI_OFF) Timeout!\n");
|
|
}
|
|
}
|
|
|
|
//this needs to be called in the loop()
|
|
void ArduinoOTAClass::handle() {
|
|
if (_state == OTA_RUNUPDATE) {
|
|
_runUpdate();
|
|
_state = OTA_IDLE;
|
|
}
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
if(_useMDNS)
|
|
MDNS.update(); //handle MDNS update as well, given that ArduinoOTA relies on it anyways
|
|
#endif
|
|
}
|
|
|
|
int ArduinoOTAClass::getCommand() {
|
|
return _cmd;
|
|
}
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_ARDUINOOTA)
|
|
ArduinoOTAClass ArduinoOTA;
|
|
#endif
|