diff --git a/cores/esp8266/Arduino.h b/cores/esp8266/Arduino.h index a715f5215..61e5711a4 100644 --- a/cores/esp8266/Arduino.h +++ b/cores/esp8266/Arduino.h @@ -233,6 +233,7 @@ void loop(void); #include "HardwareSerial.h" #include "Esp.h" +#include "Updater.h" #include "debug.h" #define min(a,b) ((a)<(b)?(a):(b)) diff --git a/cores/esp8266/Esp.cpp b/cores/esp8266/Esp.cpp index a33cfa2a0..4e0b75106 100644 --- a/cores/esp8266/Esp.cpp +++ b/cores/esp8266/Esp.cpp @@ -30,7 +30,7 @@ extern struct rst_info resetInfo; } -// #define DEBUG_SERIAL Serial +//#define DEBUG_SERIAL Serial /** @@ -358,96 +358,38 @@ uint32_t EspClass::getFreeSketchSpace() { return freeSpaceEnd - freeSpaceStart; } -bool EspClass::updateSketch(Stream& in, uint32_t size, bool restartOnFail) { - - if (size > getFreeSketchSpace()){ - if(restartOnFail) ESP.restart(); - return false; - } - - uint32_t usedSize = getSketchSize(); - uint32_t freeSpaceStart = (usedSize + FLASH_SECTOR_SIZE - 1) & (~(FLASH_SECTOR_SIZE - 1)); - uint32_t roundedSize = (size + FLASH_SECTOR_SIZE - 1) & (~(FLASH_SECTOR_SIZE - 1)); - +bool EspClass::updateSketch(Stream& in, uint32_t size, bool restartOnFail, bool restartOnSuccess) { + if(!Update.begin(size)){ #ifdef DEBUG_SERIAL - DEBUG_SERIAL.printf("erase @0x%x size=0x%x\r\n", freeSpaceStart, roundedSize); + DEBUG_SERIAL.print("Update "); + Update.printError(DEBUG_SERIAL); #endif + if(restartOnFail) ESP.restart(); + return false; + } - noInterrupts(); - int rc = SPIEraseAreaEx(freeSpaceStart, roundedSize); - interrupts(); - if (rc){ - if(restartOnFail) ESP.restart(); - return false; - } - + if(Update.writeStream(in) != size){ #ifdef DEBUG_SERIAL - DEBUG_SERIAL.println("erase done"); + DEBUG_SERIAL.print("Update "); + Update.printError(DEBUG_SERIAL); #endif + if(restartOnFail) ESP.restart(); + return false; + } - uint32_t addr = freeSpaceStart; - uint32_t left = size; - - const uint32_t bufferSize = FLASH_SECTOR_SIZE; - std::unique_ptr buffer(new uint8_t[bufferSize]); - + if(!Update.end()){ #ifdef DEBUG_SERIAL - DEBUG_SERIAL.println("writing"); + DEBUG_SERIAL.print("Update "); + Update.printError(DEBUG_SERIAL); #endif - while (left > 0) { - size_t willRead = (left < bufferSize) ? left : bufferSize; - size_t rd = in.readBytes(buffer.get(), willRead); - if (rd != willRead) { -#ifdef DEBUG_SERIAL - DEBUG_SERIAL.printf("stream read less: %u/%u\n", rd, willRead); -#endif - if(rd == 0){ //we got nothing from the client - //we should actually give it a bit of a chance to send us something - //connection could be slow ;) - if(restartOnFail) ESP.restart(); - return false; - } - //we at least got some data, lets write it to the flash - willRead = rd; - } - - if(addr == freeSpaceStart) { - // check for valid first magic byte - if(*((uint8 *) buffer.get()) != 0xE9) { - if(restartOnFail) ESP.restart(); - return false; - } - } - - noInterrupts(); - rc = SPIWrite(addr, buffer.get(), willRead); - interrupts(); - if (rc) { -#ifdef DEBUG_SERIAL - DEBUG_SERIAL.println("write failed"); -#endif - if(restartOnFail) ESP.restart(); - return false; - } - - addr += willRead; - left -= willRead; -#ifdef DEBUG_SERIAL - DEBUG_SERIAL.print("."); -#endif - } + if(restartOnFail) ESP.restart(); + return false; + } #ifdef DEBUG_SERIAL - DEBUG_SERIAL.println("\r\nrestarting"); -#endif - eboot_command ebcmd; - ebcmd.action = ACTION_COPY_RAW; - ebcmd.args[0] = freeSpaceStart; - ebcmd.args[1] = 0x00000; - ebcmd.args[2] = size; - eboot_command_write(&ebcmd); - - ESP.restart(); - return true; // never happens + DEBUG_SERIAL.println("Update SUCCESS"); +#endif + if(restartOnSuccess) ESP.restart(); + return true; } diff --git a/cores/esp8266/Esp.h b/cores/esp8266/Esp.h index 8021ed198..8e66f8f88 100644 --- a/cores/esp8266/Esp.h +++ b/cores/esp8266/Esp.h @@ -116,7 +116,7 @@ class EspClass { uint32_t getSketchSize(); uint32_t getFreeSketchSpace(); - bool updateSketch(Stream& in, uint32_t size, bool restartOnFail = false); + bool updateSketch(Stream& in, uint32_t size, bool restartOnFail = false, bool restartOnSuccess = true); String getResetInfo(); struct rst_info * getResetInfoPtr(); diff --git a/cores/esp8266/Updater.cpp b/cores/esp8266/Updater.cpp new file mode 100644 index 000000000..0009f2f4a --- /dev/null +++ b/cores/esp8266/Updater.cpp @@ -0,0 +1,190 @@ +#include "Updater.h" +#include "Arduino.h" +#include "eboot_command.h" +extern "C"{ + #include "mem.h" +} +#define DEBUG_UPDATER Serial + +extern "C" uint32_t _SPIFFS_start; + +UpdaterClass::UpdaterClass() : _error(0), _buffer(0), _bufferLen(0), _size(0), _startAddress(0), _currentAddress(0) {} + +bool UpdaterClass::begin(size_t size){ + if(_size > 0){ +#ifdef DEBUG_UPDATER + DEBUG_UPDATER.println("already running"); +#endif + return false; + } + + if(size == 0){ + _error = UPDATE_ERROR_SIZE; +#ifdef DEBUG_UPDATER + printError(DEBUG_UPDATER); +#endif + return false; + } + + if(_buffer) os_free(_buffer); + + _bufferLen = 0; + _startAddress = 0; + _currentAddress = 0; + _size = 0; + _error = 0; + + uint32_t usedSize = ESP.getSketchSize(); + uint32_t freeSpaceStart = (usedSize + FLASH_SECTOR_SIZE - 1) & (~(FLASH_SECTOR_SIZE - 1)); + uint32_t freeSpaceEnd = (uint32_t)&_SPIFFS_start - 0x40200000; + uint32_t roundedSize = (size + FLASH_SECTOR_SIZE - 1) & (~(FLASH_SECTOR_SIZE - 1)); + + if(roundedSize > (freeSpaceEnd - freeSpaceStart)){ + _error = UPDATE_ERROR_SPACE; +#ifdef DEBUG_UPDATER + printError(DEBUG_UPDATER); +#endif + return false; + } + noInterrupts(); + int rc = SPIEraseAreaEx(freeSpaceStart, roundedSize); + interrupts(); + if (rc){ + _error = UPDATE_ERROR_ERASE; +#ifdef DEBUG_UPDATER + printError(DEBUG_UPDATER); +#endif + return false; + } + _startAddress = freeSpaceStart; + _currentAddress = _startAddress; + _size = size; + _buffer = (uint8_t*)os_malloc(FLASH_SECTOR_SIZE); + + return true; +} + +bool UpdaterClass::end(){ + if(_size == 0){ +#ifdef DEBUG_UPDATER + DEBUG_UPDATER.println("no update"); +#endif + return false; + } + + if(_buffer) os_free(_buffer); + _bufferLen = 0; + + if(hasError() || !isFinished()){ +#ifdef DEBUG_UPDATER + DEBUG_UPDATER.printf("premature end: res:%u, pos:%u/%u\n", getError(), progress(), _size); +#endif + _currentAddress = 0; + _startAddress = 0; + _size = 0; + return false; + } + + eboot_command ebcmd; + ebcmd.action = ACTION_COPY_RAW; + ebcmd.args[0] = _startAddress; + ebcmd.args[1] = 0x00000; + ebcmd.args[2] = _size; + eboot_command_write(&ebcmd); + + _currentAddress = 0; + _startAddress = 0; + _size = 0; + _error = UPDATE_ERROR_OK; + return true; +} + +bool UpdaterClass::_writeBuffer(){ + WDT_FEED(); + noInterrupts(); + int rc = SPIWrite(_currentAddress, _buffer, _bufferLen); + interrupts(); + if (rc) { + _error = UPDATE_ERROR_WRITE; +#ifdef DEBUG_UPDATER + printError(DEBUG_UPDATER); +#endif + return false; + } + _currentAddress += _bufferLen; + _bufferLen = 0; + return true; +} + +size_t UpdaterClass::write(uint8_t *data, size_t len){ + size_t left = len; + if(hasError()) + return 0; + + while((_bufferLen + left) > FLASH_SECTOR_SIZE){ + size_t toBuff = FLASH_SECTOR_SIZE - _bufferLen; + memcpy(_buffer + _bufferLen, data + (len - left), toBuff); + _bufferLen += toBuff; + if(!_writeBuffer()){ + return len - left; + } + left -= toBuff; + yield(); + } + //lets see whats left + memcpy(_buffer + _bufferLen, data + (len - left), left); + _bufferLen += left; + if(_bufferLen == remaining()){ + //we are at the end of the update, so should write what's left to flash + if(!_writeBuffer()){ + return len - left; + } + } + return len; +} + +size_t UpdaterClass::writeStream(Stream &data){ + size_t written = 0; + size_t toRead = 0; + if(hasError()) + return 0; + + while(remaining()){ + toRead = FLASH_SECTOR_SIZE - _bufferLen; + toRead = data.readBytes(_buffer + _bufferLen, toRead); + if(toRead == 0){ //Timeout + _error = UPDATE_ERROR_STREAM; +#ifdef DEBUG_UPDATER + printError(DEBUG_UPDATER); +#endif + return written; + } + _bufferLen += toRead; + if((_bufferLen == remaining() || _bufferLen == FLASH_SECTOR_SIZE) && !_writeBuffer()) + return written; + written += toRead; + yield(); + } + return written; +} + +void UpdaterClass::printError(Stream &out){ + out.printf("ERROR[%u]: ", _error); + if(_error == UPDATE_ERROR_OK){ + out.println("No Error"); + } else if(_error == UPDATE_ERROR_WRITE){ + out.println("Flash Write Failed"); + } else if(_error == UPDATE_ERROR_ERASE){ + out.println("Flash Erase Failed"); + } else if(_error == UPDATE_ERROR_SPACE){ + out.println("Not Enough Space"); + } else if(_error == UPDATE_ERROR_SIZE){ + out.println("Bad Size Given"); + } else if(_error == UPDATE_ERROR_STREAM){ + out.println("Stream Read Timeout"); + } else { + out.println("UNKNOWN"); + } +} + +UpdaterClass Update; diff --git a/cores/esp8266/Updater.h b/cores/esp8266/Updater.h new file mode 100644 index 000000000..107706060 --- /dev/null +++ b/cores/esp8266/Updater.h @@ -0,0 +1,119 @@ +#ifndef ESP8266UPDATER_H +#define ESP8266UPDATER_H + +#include "Arduino.h" +#include "flash_utils.h" + +#define UPDATE_ERROR_OK 0 +#define UPDATE_ERROR_WRITE 1 +#define UPDATE_ERROR_ERASE 2 +#define UPDATE_ERROR_SPACE 3 +#define UPDATE_ERROR_SIZE 4 +#define UPDATE_ERROR_STREAM 5 + +class UpdaterClass { + public: + UpdaterClass(); + /* + Call this to check and erase the space needed for the update + Will return false if there is not enough space + Or the erase of the flash failed + */ + bool begin(size_t size); + + /* + Writes a buffer to the flash and increments the address + Returns the amount written + */ + size_t write(uint8_t *data, size_t len); + + /* + Writes the remaining bytes from the Sream to the flash + Uses readBytes() and sets UPDATE_ERROR_STREAM on timeout + Returns the bytes written + Should be equal to the remaining bytes when called + Usable for slow streams like Serial + */ + size_t writeStream(Stream &data); + + /* + If all bytes are written + this call will write the config to eboot + and return true + If there is already an update running but is not finished + or there is an error + this will clear everything and return false + the last error is available through getError() + */ + bool end(); + + /* + Prints the last error to an output stream + */ + void printError(Stream &out); + + //Helpers + uint8_t getError(){ return _error; } + void clearError(){ _error = UPDATE_ERROR_OK; } + bool hasError(){ return _error != UPDATE_ERROR_OK; } + bool isRunning(){ return _size > 0; } + bool isFinished(){ return hasError()?true:(_currentAddress == (_startAddress + _size)); } + size_t size(){ return _size; } + size_t progress(){ return _currentAddress - _startAddress; } + size_t remaining(){ return hasError()?0:(size() - progress()); } + + /* + Template to write from objects that expose + available() and read(uint8_t*, size_t) methods + faster than the readStream method + writes only what is available + */ + template + size_t write(T &data){ + size_t written = 0; + if(hasError()) + return 0; + size_t available = data.available(); + while(available){ + if((_bufferLen + available) > remaining()){ + available = (remaining() - _bufferLen); + } + if((_bufferLen + available) > FLASH_SECTOR_SIZE){ + size_t toBuff = FLASH_SECTOR_SIZE - _bufferLen; + data.read(_buffer + _bufferLen, toBuff); + _bufferLen += toBuff; + if(!_writeBuffer()) + return written; + written += toBuff; + } else { + data.read(_buffer + _bufferLen, available); + _bufferLen += available; + written += available; + if(_bufferLen == remaining()){ + if(!_writeBuffer()){ + return written; + } + } + } + if(remaining() == 0) + return written; + yield(); + available = data.available(); + } + return written; + } + + private: + uint8_t *_buffer; + uint8_t _error; + size_t _bufferLen; + size_t _size; + uint32_t _startAddress; + uint32_t _currentAddress; + + bool _writeBuffer(); +}; + +extern UpdaterClass Update; + +#endif \ No newline at end of file diff --git a/libraries/ESP8266mDNS/examples/DNS_SD_Arduino_OTA/DNS_SD_Arduino_OTA.ino b/libraries/ESP8266mDNS/examples/DNS_SD_Arduino_OTA/DNS_SD_Arduino_OTA.ino index ca611218d..8ceb2be8a 100644 --- a/libraries/ESP8266mDNS/examples/DNS_SD_Arduino_OTA/DNS_SD_Arduino_OTA.ino +++ b/libraries/ESP8266mDNS/examples/DNS_SD_Arduino_OTA/DNS_SD_Arduino_OTA.ino @@ -7,11 +7,12 @@ const char* ssid = "**********"; const char* pass = "**********"; const uint16_t aport = 8266; -WiFiUDP listener; +WiFiServer TelnetServer(aport); +WiFiClient Telnet; +WiFiUDP OTA; void setup() { Serial.begin(115200); - Serial.setDebugOutput(true); Serial.println(""); Serial.println("Arduino OTA Test"); @@ -22,29 +23,74 @@ void setup() { if(WiFi.waitForConnectResult() == WL_CONNECTED){ MDNS.begin(host); MDNS.addService("arduino", "tcp", aport); - listener.begin(aport); + OTA.begin(aport); + TelnetServer.begin(); + TelnetServer.setNoDelay(true); Serial.print("IP address: "); Serial.println(WiFi.localIP()); } } void loop() { - if (listener.parsePacket()) { - IPAddress remote = listener.remoteIP(); - int cmd = listener.parseInt(); - int port = listener.parseInt(); - int sz = listener.parseInt(); - Serial.printf("Starting Update: cmd:%d, port:%d, size:%d\r\n", cmd, port, sz); - WiFiClient cl; - if (!cl.connect(remote, port)) { - Serial.println("Failed to connect"); + //OTA Sketch + if (OTA.parsePacket()) { + IPAddress remote = OTA.remoteIP(); + int cmd = OTA.parseInt(); + int port = OTA.parseInt(); + int size = OTA.parseInt(); + + Serial.print("Update Start: ip:"); + Serial.print(remote); + Serial.printf(", port:%d, size:%d\n", port, size); + uint32_t startTime = millis(); + + if(!Update.begin(size)){ + Serial.println("Update Begin Error"); return; } - listener.stop(); - if (!ESP.updateSketch(cl, sz)) { - Serial.println("Update failed"); - listener.begin(aport); + + WiFiClient client; + if (client.connect(remote, port)) { + + Serial.setDebugOutput(true); + while(!Update.isFinished()) Update.write(client); + Serial.setDebugOutput(false); + + if(Update.end()){ + client.println("OK"); + Serial.printf("Update Success: %u\nRebooting...\n", millis() - startTime); + ESP.restart(); + } else { + Update.printError(client); + Update.printError(Serial); + } + } else { + Serial.printf("Connect Failed: %u\n", millis() - startTime); } } + //IDE Monitor (connected to Serial) + if (TelnetServer.hasClient()){ + if (!Telnet || !Telnet.connected()){ + if(Telnet) Telnet.stop(); + Telnet = TelnetServer.available(); + } else { + WiFiClient toKill = TelnetServer.available(); + toKill.stop(); + } + } + if (Telnet && Telnet.connected() && Telnet.available()){ + while(Telnet.available()) + Serial.write(Telnet.read()); + } + if(Serial.available()){ + size_t len = Serial.available(); + uint8_t * sbuf = (uint8_t *)malloc(len); + Serial.readBytes(sbuf, len); + if (Telnet && Telnet.connected()){ + Telnet.write((uint8_t *)sbuf, len); + yield(); + } + free(sbuf); + } delay(100); } diff --git a/tools/espota.py b/tools/espota.py index b98d610aa..58305c6df 100755 --- a/tools/espota.py +++ b/tools/espota.py @@ -17,48 +17,66 @@ def serve(remoteAddr, remotePort, filename): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) serverPort = 48266 server_address = ('0.0.0.0', serverPort) - print('starting up on %s port %s' % server_address, file=sys.stderr) - sock.bind(server_address) - sock.listen(1) + print('Starting on %s:%s' % server_address, file=sys.stderr) + try: + sock.bind(server_address) + sock.listen(1) + except: + print('Socket Failed', file=sys.stderr) + return 1 - sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - remote_address = (remoteAddr, int(remotePort)) content_size = os.path.getsize(filename) - print('upload size: %d' % content_size, file=sys.stderr) + print('Upload size: %d' % content_size, file=sys.stderr) message = '%d %d %d\n' % (0, serverPort, content_size) - while True: - # Wait for a connection - print('sending invitation', file=sys.stderr) - sent = sock2.sendto(message, remote_address) + # Wait for a connection + print('Sending invitation to:', remoteAddr, file=sys.stderr) + sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + remote_address = (remoteAddr, int(remotePort)) + sent = sock2.sendto(message, remote_address) + sock2.close() + + print('Waiting for device...\n', file=sys.stderr) + try: sock.settimeout(10) - print('waiting...', file=sys.stderr) connection, client_address = sock.accept() sock.settimeout(None) connection.settimeout(None) + except: + print('No response from device', file=sys.stderr) + sock.close() + return 1 + + try: + f = open(filename, "rb") + sys.stderr.write('Uploading') + sys.stderr.flush() + while True: + chunk = f.read(4096) + if not chunk: break + sys.stderr.write('.') + sys.stderr.flush() + connection.sendall(chunk) + + print('\nWaiting for result...\n', file=sys.stderr) try: - print('connection from', client_address, file=sys.stderr) - - print('sending file %s\n' % filename, file=sys.stderr) - f = open(filename, "rb") - - while True: - chunk = f.read(4096) - if not chunk: - break - - sys.stderr.write('.') - sys.stderr.flush() - #print('sending %d' % len(chunk), file=sys.stderr) - connection.sendall(chunk) - - print('\ndone!', file=sys.stderr) - return 0 - - finally: + connection.settimeout(60) + data = connection.recv(32) + print('Result: %s' % data, file=sys.stderr) connection.close() f.close() - return 1 + return 0 + except: + print('Result: No Answer!', file=sys.stderr) + connection.close() + f.close() + return 1 + + finally: + connection.close() + f.close() + sock.close() + return 1 def main(args): return serve(args[1], args[2], args[3])