1
0
mirror of https://github.com/esp8266/Arduino.git synced 2025-04-19 23:22:16 +03:00

Add sanity check so we do not trigger an update from wrong data

This commit is contained in:
Me No Dev 2015-11-09 01:47:51 +02:00
parent 14bb946896
commit 13b8cc0a27
2 changed files with 20 additions and 2 deletions

View File

@ -6,6 +6,8 @@
//#define OTA_DEBUG 1 //#define OTA_DEBUG 1
#define U_AUTH 200
ArduinoOTAClass::ArduinoOTAClass() ArduinoOTAClass::ArduinoOTAClass()
{ {
_udp_ota = new WiFiUDP(); _udp_ota = new WiFiUDP();
@ -169,12 +171,17 @@ void ArduinoOTAClass::handle() {
if (!_udp_ota->parsePacket()) return; if (!_udp_ota->parsePacket()) return;
if(_state == OTA_IDLE){ if(_state == OTA_IDLE){
int cmd = _udp_ota->parseInt();
if(cmd != U_FLASH && cmd != U_SPIFFS)
return;
_ota_ip = _udp_ota->remoteIP(); _ota_ip = _udp_ota->remoteIP();
_cmd = _udp_ota->parseInt(); _cmd = cmd;
_ota_port = _udp_ota->parseInt(); _ota_port = _udp_ota->parseInt();
_size = _udp_ota->parseInt(); _size = _udp_ota->parseInt();
_udp_ota->read(); _udp_ota->read();
sprintf(_md5, "%s", _udp_ota->readStringUntil('\n').c_str()); sprintf(_md5, "%s", _udp_ota->readStringUntil('\n').c_str());
if(strlen(_md5) != 32)
return;
#if OTA_DEBUG #if OTA_DEBUG
Serial.print("Update Start: ip:"); Serial.print("Update Start: ip:");
@ -199,8 +206,18 @@ void ArduinoOTAClass::handle() {
_state = OTA_RUNUPDATE; _state = OTA_RUNUPDATE;
} }
} else if(_state == OTA_WAITAUTH){ } else if(_state == OTA_WAITAUTH){
int cmd = _udp_ota->parseInt();
if(cmd != U_AUTH){
_state = OTA_IDLE;
return;
}
_udp_ota->read();
String cnonce = _udp_ota->readStringUntil(' '); String cnonce = _udp_ota->readStringUntil(' ');
String response = _udp_ota->readStringUntil('\n'); String response = _udp_ota->readStringUntil('\n');
if(cnonce.length() != 32 || response.length() != 32){
_state = OTA_IDLE;
return;
}
MD5Builder _passmd5; MD5Builder _passmd5;
_passmd5.begin(); _passmd5.begin();

View File

@ -34,6 +34,7 @@ import hashlib
# Commands # Commands
FLASH = 0 FLASH = 0
SPIFFS = 100 SPIFFS = 100
AUTH = 200
def serve(remoteAddr, remotePort, password, filename, command = FLASH): def serve(remoteAddr, remotePort, password, filename, command = FLASH):
@ -78,7 +79,7 @@ def serve(remoteAddr, remotePort, password, filename, command = FLASH):
result = hashlib.md5(result_text).hexdigest() result = hashlib.md5(result_text).hexdigest()
sys.stderr.write('Authenticating...') sys.stderr.write('Authenticating...')
sys.stderr.flush() sys.stderr.flush()
message = '%s %s\n' % (cnonce, result) message = '%d %s %s\n' % (AUTH, cnonce, result)
sock2.sendto(message, remote_address) sock2.sendto(message, remote_address)
sock2.settimeout(10) sock2.settimeout(10)
try: try: