diff --git a/libraries/ArduinoOTA/src/ArduinoOTA.cpp b/libraries/ArduinoOTA/src/ArduinoOTA.cpp index a5b0d09de58..e6e17cb9781 100644 --- a/libraries/ArduinoOTA/src/ArduinoOTA.cpp +++ b/libraries/ArduinoOTA/src/ArduinoOTA.cpp @@ -26,9 +26,9 @@ // #define OTA_DEBUG Serial -ArduinoOTAClass::ArduinoOTAClass() - : _port(0), _initialized(false), _rebootOnSuccess(true), _mdnsEnabled(true), _state(OTA_IDLE), _size(0), _cmd(0), _ota_port(0), _ota_timeout(1000), - _start_callback(NULL), _end_callback(NULL), _error_callback(NULL), _progress_callback(NULL) {} +ArduinoOTAClass::ArduinoOTAClass(UpdateClass *updater) + : _updater(updater), _port(0), _initialized(false), _rebootOnSuccess(true), _mdnsEnabled(true), _state(OTA_IDLE), _size(0), _cmd(0), _ota_port(0), + _ota_timeout(1000), _start_callback(NULL), _end_callback(NULL), _error_callback(NULL), _progress_callback(NULL) {} ArduinoOTAClass::~ArduinoOTAClass() { end(); @@ -297,10 +297,14 @@ void ArduinoOTAClass::_onRx() { } void ArduinoOTAClass::_runUpdate() { + if (!_updater) { + log_e("UpdateClass is NULL!"); + return; + } const char *partition_label = _partition_label.length() ? _partition_label.c_str() : NULL; - if (!Update.begin(_size, _cmd, -1, LOW, partition_label)) { + if (!_updater->begin(_size, _cmd, -1, LOW, partition_label)) { - log_e("Begin ERROR: %s", Update.errorString()); + log_e("Begin ERROR: %s", _updater->errorString()); if (_error_callback) { _error_callback(OTA_BEGIN_ERROR); @@ -309,7 +313,7 @@ void ArduinoOTAClass::_runUpdate() { return; } - Update.setMD5(_md5.c_str()); // Note: Update library still uses MD5 for firmware integrity, this is separate from authentication + _updater->setMD5(_md5.c_str()); // Note: Update library still uses MD5 for firmware integrity, this is separate from authentication if (_start_callback) { _start_callback(); @@ -328,7 +332,7 @@ void ArduinoOTAClass::_runUpdate() { uint32_t written = 0, total = 0, tried = 0; - while (!Update.isFinished() && client.connected()) { + while (!_updater->isFinished() && client.connected()) { size_t waited = _ota_timeout; size_t available = client.available(); while (!available && waited) { @@ -351,7 +355,7 @@ void ArduinoOTAClass::_runUpdate() { _error_callback(OTA_RECEIVE_ERROR); } _state = OTA_IDLE; - Update.abort(); + _updater->abort(); return; } if (!available) { @@ -373,7 +377,7 @@ void ArduinoOTAClass::_runUpdate() { } } - written = Update.write(buf, r); + written = _updater->write(buf, r); if (written > 0) { if (written != r) { log_w("didn't write enough! %u != %u", written, r); @@ -386,11 +390,11 @@ void ArduinoOTAClass::_runUpdate() { _progress_callback(total, _size); } } else { - log_e("Write ERROR: %s", Update.errorString()); + log_e("Write ERROR: %s", _updater->errorString()); } } - if (Update.end()) { + if (_updater->end()) { client.print("OK"); client.stop(); delay(10); @@ -406,10 +410,10 @@ void ArduinoOTAClass::_runUpdate() { if (_error_callback) { _error_callback(OTA_END_ERROR); } - Update.printError(client); + _updater->printError(client); client.stop(); delay(10); - log_e("Update ERROR: %s", Update.errorString()); + log_e("Update ERROR: %s", _updater->errorString()); _state = OTA_IDLE; } } @@ -448,6 +452,11 @@ void ArduinoOTAClass::setTimeout(int timeoutInMillis) { _ota_timeout = timeoutInMillis; } +ArduinoOTAClass &ArduinoOTAClass::setUpdaterInstance(UpdateClass *updater) { + _updater = updater; + return *this; +} + #if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_ARDUINOOTA) ArduinoOTAClass ArduinoOTA; #endif diff --git a/libraries/ArduinoOTA/src/ArduinoOTA.h b/libraries/ArduinoOTA/src/ArduinoOTA.h index a946388c4aa..d95c9d798f8 100644 --- a/libraries/ArduinoOTA/src/ArduinoOTA.h +++ b/libraries/ArduinoOTA/src/ArduinoOTA.h @@ -41,7 +41,11 @@ class ArduinoOTAClass { typedef std::function THandlerFunction_Error; typedef std::function THandlerFunction_Progress; - ArduinoOTAClass(); +#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_UPDATE) + ArduinoOTAClass(UpdateClass *updater = &Update); +#else + ArduinoOTAClass(UpdateClass *updater = nullptr); +#endif ~ArduinoOTAClass(); //Sets the service port. Default 3232 @@ -61,6 +65,9 @@ class ArduinoOTAClass { ArduinoOTAClass &setPartitionLabel(const char *partition_label); String getPartitionLabel(); + //Sets instance of UpdateClass to perform updating operations + ArduinoOTAClass &setUpdaterInstance(UpdateClass *updater); + //Sets if the device should be rebooted after successful update. Default true ArduinoOTAClass &setRebootOnSuccess(bool reboot); @@ -94,6 +101,7 @@ class ArduinoOTAClass { void setTimeout(int timeoutInMillis); private: + UpdateClass *_updater; int _port; String _password; String _hostname; diff --git a/libraries/HTTPUpdate/src/HTTPUpdate.cpp b/libraries/HTTPUpdate/src/HTTPUpdate.cpp index 5183afac017..3a54572a925 100644 --- a/libraries/HTTPUpdate/src/HTTPUpdate.cpp +++ b/libraries/HTTPUpdate/src/HTTPUpdate.cpp @@ -35,15 +35,8 @@ // To do extern "C" uint32_t _SPIFFS_start; // To do extern "C" uint32_t _SPIFFS_end; -HTTPUpdate::HTTPUpdate(void) : HTTPUpdate(8000) {} - -HTTPUpdate::HTTPUpdate(int httpClientTimeout) : _httpClientTimeout(httpClientTimeout), _ledPin(-1) { - _followRedirects = HTTPC_DISABLE_FOLLOW_REDIRECTS; - _md5Sum = String(); - _user = String(); - _password = String(); - _auth = String(); -} +HTTPUpdate::HTTPUpdate(int httpClientTimeout, UpdateClass *updater) + : _httpClientTimeout(httpClientTimeout), _updater(updater), _followRedirects(HTTPC_DISABLE_FOLLOW_REDIRECTS) {} HTTPUpdate::~HTTPUpdate(void) {} @@ -129,6 +122,9 @@ int HTTPUpdate::getLastError(void) { * @return String error */ String HTTPUpdate::getLastErrorString(void) { + if (!_updater) { + return {}; + } if (_lastError == 0) { return String(); // no error @@ -137,7 +133,7 @@ String HTTPUpdate::getLastErrorString(void) { // error from Update class if (_lastError > 0) { StreamString error; - Update.printError(error); + _updater->printError(error); error.trim(); // remove line ending return String("Update error: ") + error; } @@ -444,16 +440,19 @@ HTTPUpdateResult HTTPUpdate::handleUpdate(HTTPClient &http, const String ¤ * @return true if Update ok */ bool HTTPUpdate::runUpdate(Stream &in, uint32_t size, String md5, int command) { + if (!_updater) { + return false; + } StreamString error; if (_cbProgress) { - Update.onProgress(_cbProgress); + _updater->onProgress(_cbProgress); } - if (!Update.begin(size, command, _ledPin, _ledOn)) { - _lastError = Update.getError(); - Update.printError(error); + if (!_updater->begin(size, command, _ledPin, _ledOn)) { + _lastError = _updater->getError(); + _updater->printError(error); error.trim(); // remove line ending log_e("Update.begin failed! (%s)\n", error.c_str()); return false; @@ -464,7 +463,7 @@ bool HTTPUpdate::runUpdate(Stream &in, uint32_t size, String md5, int command) { } if (md5.length()) { - if (!Update.setMD5(md5.c_str())) { + if (!_updater->setMD5(md5.c_str())) { _lastError = HTTP_UE_SERVER_FAULTY_MD5; log_e("Update.setMD5 failed! (%s)\n", md5.c_str()); return false; @@ -473,9 +472,9 @@ bool HTTPUpdate::runUpdate(Stream &in, uint32_t size, String md5, int command) { // To do: the SHA256 could be checked if the server sends it - if (Update.writeStream(in) != size) { - _lastError = Update.getError(); - Update.printError(error); + if (_updater->writeStream(in) != size) { + _lastError = _updater->getError(); + _updater->printError(error); error.trim(); // remove line ending log_e("Update.writeStream failed! (%s)\n", error.c_str()); return false; @@ -485,9 +484,9 @@ bool HTTPUpdate::runUpdate(Stream &in, uint32_t size, String md5, int command) { _cbProgress(size, size); } - if (!Update.end()) { - _lastError = Update.getError(); - Update.printError(error); + if (!_updater->end()) { + _lastError = _updater->getError(); + _updater->printError(error); error.trim(); // remove line ending log_e("Update.end failed! (%s)\n", error.c_str()); return false; diff --git a/libraries/HTTPUpdate/src/HTTPUpdate.h b/libraries/HTTPUpdate/src/HTTPUpdate.h index ad38701b948..9517c0b45c4 100644 --- a/libraries/HTTPUpdate/src/HTTPUpdate.h +++ b/libraries/HTTPUpdate/src/HTTPUpdate.h @@ -58,8 +58,13 @@ using HTTPUpdateProgressCB = std::function; class HTTPUpdate { public: - HTTPUpdate(void); - HTTPUpdate(int httpClientTimeout); +#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_UPDATE) + HTTPUpdate(UpdateClass *updater = &Update) : HTTPUpdate(8000, updater){}; + HTTPUpdate(int httpClientTimeout, UpdateClass *updater = &Update); +#else + HTTPUpdate(UpdateClass *updater = nullptr) : HTTPUpdate(8000, updater){}; + HTTPUpdate(int httpClientTimeout, UpdateClass *updater = nullptr); +#endif ~HTTPUpdate(void); void rebootOnUpdate(bool reboot) { @@ -92,6 +97,11 @@ class HTTPUpdate { _auth = auth; } + //Sets instance of UpdateClass to perform updating operations + void setUpdaterInstance(UpdateClass *updater) { + _updater = updater; + }; + t_httpUpdate_return update(NetworkClient &client, const String &url, const String ¤tVersion = "", HTTPUpdateRequestCB requestCB = NULL); t_httpUpdate_return update( @@ -143,6 +153,7 @@ class HTTPUpdate { private: int _httpClientTimeout; + UpdateClass *_updater; followRedirects_t _followRedirects; String _user; String _password; @@ -155,7 +166,7 @@ class HTTPUpdate { HTTPUpdateErrorCB _cbError; HTTPUpdateProgressCB _cbProgress; - int _ledPin; + int _ledPin{-1}; uint8_t _ledOn; };