Prepare for automatic OTA updates

This commit is contained in:
Djuri Baars 2024-09-11 17:40:44 +02:00
parent 1f2110fc5a
commit 5d5b09f56c
7 changed files with 409 additions and 227 deletions

View file

@ -56,7 +56,10 @@ jobs:
run: mkdir -p ${{ matrix.chip.name }}_${{ matrix.epd_variant }} && esptool.py --chip ${{ matrix.chip.version }} merge_bin -o ${{ matrix.chip.name }}_${{ matrix.epd_variant }}/${{ matrix.chip.name }}_${{ matrix.epd_variant }}.bin --flash_mode dio 0x0000 .pio/build/${{ matrix.chip.name }}_${{ matrix.epd_variant }}/bootloader.bin 0x8000 .pio/build/${{ matrix.chip.name }}_${{ matrix.epd_variant }}/partitions.bin 0xe000 .pio/boot_app0.bin 0x10000 .pio/build/${{ matrix.chip.name }}_${{ matrix.epd_variant }}/firmware.bin 0x369000 .pio/build/${{ matrix.chip.name }}_${{ matrix.epd_variant }}/littlefs.bin
- name: Create checksum for merged binary
run: shasum -a 256 ${{ matrix.chip.name }}_${{ matrix.epd_variant }}/${{ matrix.chip.name }}_${{ matrix.epd_variant }}.bin | awk '{print $1}' > ${{ matrix.chip.name }}_${{ matrix.epd_variant }}/${{ matrix.chip.name }}_${{ matrix.epd_variant }}.sha256
run: shasum -a 256 ${{ matrix.chip.name }}_${{ matrix.epd_variant }}/${{ matrix.chip.name }}_${{ matrix.epd_variant }}.bin | awk '{print $1}' > ${{ matrix.chip.name }}_${{ matrix.epd_variant }}/${{ matrix.chip.name }}_${{ matrix.epd_variant }}.bin.sha256
- name: Create checksum for littlefs partition
run: shasum -a 256 ${{ matrix.chip.name }}_${{ matrix.epd_variant }}/littlefs.bin | awk '{print $1}' > ${{ matrix.chip.name }}_${{ matrix.epd_variant }}/littlefs.bin.sha256
- name: Copy all artifacts to output folder
run: cp .pio/build/${{ matrix.chip.name }}_${{ matrix.epd_variant }}/*.bin .pio/boot_app0.bin ${{ matrix.chip.name }}_${{ matrix.epd_variant }}

View file

@ -2,9 +2,12 @@
TaskHandle_t taskOtaHandle = NULL;
bool isOtaUpdating = false;
QueueHandle_t otaQueue;
void setupOTA() {
if (preferences.getBool("otaEnabled", DEFAULT_OTA_ENABLED)) {
void setupOTA()
{
if (preferences.getBool("otaEnabled", DEFAULT_OTA_ENABLED))
{
ArduinoOTA.onStart(onOTAStart);
ArduinoOTA.onProgress(onOTAProgress);
@ -16,31 +19,38 @@ void setupOTA() {
ArduinoOTA.setRebootOnSuccess(false);
ArduinoOTA.begin();
// downloadUpdate();
otaQueue = xQueueCreate(1, sizeof(UpdateMessage));
xTaskCreate(handleOTATask, "handleOTA", 4096, NULL, tskIDLE_PRIORITY,
xTaskCreate(handleOTATask, "handleOTA", 8192, NULL, 20,
&taskOtaHandle);
}
}
void onOTAProgress(unsigned int progress, unsigned int total) {
void onOTAProgress(unsigned int progress, unsigned int total)
{
uint percentage = progress / (total / 100);
pixels.fill(pixels.Color(0, 255, 0));
if (percentage < 100) {
if (percentage < 100)
{
pixels.setPixelColor(0, pixels.Color(0, 0, 0));
}
if (percentage < 75) {
if (percentage < 75)
{
pixels.setPixelColor(1, pixels.Color(0, 0, 0));
}
if (percentage < 50) {
if (percentage < 50)
{
pixels.setPixelColor(2, pixels.Color(0, 0, 0));
}
if (percentage < 25) {
if (percentage < 25)
{
pixels.setPixelColor(3, pixels.Color(0, 0, 0));
}
pixels.show();
}
void onOTAStart() {
void onOTAStart()
{
forceFullRefresh();
std::array<String, NUM_SCREENS> epdContent = {"U", "P", "D", "A",
"T", "E", "!"};
@ -58,76 +68,296 @@ void onOTAStart() {
vTaskSuspend(ledTaskHandle);
vTaskSuspend(buttonTaskHandle);
stopWebServer();
// stopWebServer();
stopBlockNotify();
stopPriceNotify();
}
void handleOTATask(void *parameter) {
for (;;) {
ArduinoOTA.handle(); // Allow OTA updates to occur
void handleOTATask(void *parameter)
{
UpdateMessage msg;
for (;;)
{
if (xQueueReceive(otaQueue, &msg, 0) == pdTRUE)
{
int result = downloadUpdateHandler(msg.updateType);
}
ArduinoOTA.handle(); // Allow OTA updates to occur
vTaskDelay(pdMS_TO_TICKS(2000));
}
}
// void downloadUpdate() {
// WiFiClientSecure client;
// client.setInsecure();
// HTTPClient http;
// http.setUserAgent(USER_AGENT);
String getLatestRelease(const String &fileToDownload)
{
String releaseUrl = "https://api.github.com/repos/btclock/btclock_v3/releases/latest";
WiFiClientSecure client;
client.setCACert(github_root_ca);
HTTPClient http;
http.begin(client, releaseUrl);
http.setUserAgent(USER_AGENT);
// // Send HTTP request to CoinGecko API
// http.useHTTP10(true);
int httpCode = http.GET();
// http.begin(client,
// "https://api.github.com/repos/btclock/btclock_v3/releases/latest");
// int httpCode = http.GET();
String downloadUrl = "";
// if (httpCode == 200) {
// // WiFiClient * stream = http->getStreamPtr();
if (httpCode > 0)
{
String payload = http.getString();
// JsonDocument filter;
JsonDocument doc;
deserializeJson(doc, payload);
// JsonObject filter_assets_0 = filter["assets"].add<JsonObject>();
// filter_assets_0["name"] = true;
// filter_assets_0["browser_download_url"] = true;
JsonArray assets = doc["assets"];
// JsonDocument doc;
for (JsonObject asset : assets)
{
if (asset["name"] == fileToDownload)
{
downloadUrl = asset["browser_download_url"].as<String>();
break;
}
}
Serial.printf("Latest release URL: %s\r\n", downloadUrl.c_str());
}
return downloadUrl;
}
// DeserializationError error = deserializeJson(
// doc, http.getStream(), DeserializationOption::Filter(filter));
int downloadUpdateHandler(char updateType)
{
WiFiClientSecure client;
client.setCACert(github_root_ca);
HTTPClient http;
http.setFollowRedirects(HTTPC_STRICT_FOLLOW_REDIRECTS);
// if (error) {
// Serial.print("deserializeJson() failed: ");
// Serial.println(error.c_str());
// return;
// }
String latestRelease = "";
// String downloadUrl;
// for (JsonObject asset : doc["assets"].as<JsonArray>()) {
// if (asset["name"].as<String>().compareTo("firmware.bin") == 0) {
// downloadUrl = asset["browser_download_url"].as<String>();
// break;
// }
// }
switch (updateType)
{
case UPDATE_FIRMWARE:
{
latestRelease = getLatestRelease(getFirmwareFilename());
}
break;
case UPDATE_WEBUI:
{
latestRelease = getLatestRelease("littlefs.bin");
updateWebUi(latestRelease, U_SPIFFS);
return 0;
}
break;
}
// Serial.printf("Download update from %s", downloadUrl);
if (latestRelease.isEmpty())
{
return 503;
}
// First, download the expected SHA256
String expectedSHA256 = downloadSHA256(getFirmwareFilename());
if (expectedSHA256.isEmpty())
{
Serial.println("Failed to get SHA256 checksum. Aborting update.");
return false;
}
// // esp_http_client_config_t config = {
// // .url = CONFIG_FIRMWARE_UPGRADE_URL,
// // };
// // esp_https_ota_config_t ota_config = {
// // .http_config = &config,
// // };
// // esp_err_t ret = esp_https_ota(&ota_config);
// // if (ret == ESP_OK)
// // {
// // esp_restart();
// // }
// }
// }
http.begin(client, latestRelease);
http.setUserAgent(USER_AGENT);
void onOTAError(ota_error_t error) {
int httpCode = http.GET();
if (httpCode == HTTP_CODE_OK)
{
int contentLength = http.getSize();
if (contentLength > 0)
{
// Allocate memory to store the firmware
uint8_t *firmware = (uint8_t *)malloc(contentLength);
if (!firmware)
{
Serial.println(F("Not enough memory to store firmware"));
return false;
}
WiFiClient *stream = http.getStreamPtr();
size_t bytesRead = 0;
while (bytesRead < contentLength)
{
size_t available = stream->available();
if (available)
{
size_t readBytes = stream->readBytes(firmware + bytesRead, available);
bytesRead += readBytes;
}
yield(); // Allow background tasks to run
}
if (bytesRead != contentLength)
{
Serial.println("Failed to read entire firmware");
free(firmware);
return false;
}
// Calculate SHA256
String calculated_sha256 = calculateSHA256(firmware, contentLength);
Serial.print("Calculated checksum: ");
Serial.println(calculated_sha256);
Serial.print("Expected checksum: ");
Serial.println(expectedSHA256);
if (calculated_sha256 != expectedSHA256)
{
Serial.println("Checksum mismatch. Aborting update.");
free(firmware);
return false;
}
Update.onProgress(onOTAProgress);
int updateType = (updateType == UPDATE_WEBUI) ? U_SPIFFS : U_FLASH;
if (Update.begin(contentLength, updateType))
{
size_t written = Update.writeStream(*stream);
if (written == contentLength)
{
Serial.println("Written : " + String(written) + " successfully");
}
else
{
Serial.println("Written only : " + String(written) + "/" + String(contentLength) + ". Retry?");
}
if (Update.end())
{
Serial.println("OTA done!");
if (Update.isFinished())
{
Serial.println("Update successfully completed. Rebooting.");
ESP.restart();
}
else
{
Serial.println("Update not finished? Something went wrong!");
}
}
else
{
Serial.println("Error Occurred. Error #: " + String(Update.getError()));
}
}
else
{
Serial.println("Not enough space to begin OTA");
}
}
else
{
Serial.println("Invalid content length");
}
}
else
{
Serial.printf("HTTP error: %d\n", httpCode);
return 503;
}
http.end();
return 200;
}
void updateWebUi(String latestRelease, int command)
{
WiFiClientSecure client;
client.setCACert(github_root_ca);
HTTPClient http;
http.setFollowRedirects(HTTPC_STRICT_FOLLOW_REDIRECTS);
http.begin(client, latestRelease);
http.setUserAgent(USER_AGENT);
int httpCode = http.GET();
if (httpCode == HTTP_CODE_OK)
{
int contentLength = http.getSize();
if (contentLength > 0)
{
uint8_t *buffer = (uint8_t *)malloc(contentLength);
if (buffer)
{
WiFiClient *stream = http.getStreamPtr();
size_t written = stream->readBytes(buffer, contentLength);
if (written == contentLength)
{
String expectedSHA256 = "";
if (command == U_FLASH)
{
expectedSHA256 = downloadSHA256(getFirmwareFilename());
Serial.print("Expected checksum: ");
Serial.println(expectedSHA256);
}
String calculated_sha256 = calculateSHA256(buffer, contentLength);
Serial.print("Checksum is ");
Serial.println(calculated_sha256);
if ((command == U_FLASH && expectedSHA256.equals(calculated_sha256)) || command == U_SPIFFS)
{
Serial.println("Checksum verified. Proceeding with update.");
Update.onProgress(onOTAProgress);
if (Update.begin(contentLength, command))
{
onOTAStart();
Update.write(buffer, contentLength);
if (Update.end())
{
Serial.println("Update complete. Rebooting.");
ESP.restart();
}
else
{
Serial.println("Error in update process.");
}
}
else
{
Serial.println("Not enough space to begin OTA");
}
}
else
{
Serial.println("Checksum mismatch. Aborting update.");
}
}
else
{
Serial.println("Error downloading firmware");
}
free(buffer);
}
else
{
Serial.println("Not enough memory to allocate buffer");
}
}
else
{
Serial.println("Invalid content length");
}
}
else
{
Serial.print(httpCode);
Serial.println("Error on HTTP request");
}
}
void onOTAError(ota_error_t error)
{
Serial.println(F("\nOTA update error, restarting"));
Wire.end();
SPI.end();
@ -136,7 +366,8 @@ void onOTAError(ota_error_t error) {
ESP.restart();
}
void onOTAComplete() {
void onOTAComplete()
{
Serial.println(F("\nOTA update finished"));
Wire.end();
SPI.end();
@ -144,6 +375,37 @@ void onOTAComplete() {
ESP.restart();
}
bool getIsOTAUpdating() {
bool getIsOTAUpdating()
{
return isOtaUpdating;
}
String downloadSHA256(const String &filename)
{
String sha256Url = getLatestRelease(filename + ".sha256");
if (sha256Url.isEmpty())
{
Serial.println("Failed to get SHA256 file URL");
return "";
}
WiFiClientSecure client;
client.setCACert(github_root_ca);
HTTPClient http;
http.setFollowRedirects(HTTPC_STRICT_FOLLOW_REDIRECTS);
http.begin(client, sha256Url);
http.setUserAgent(USER_AGENT);
int httpCode = http.GET();
if (httpCode == HTTP_CODE_OK)
{
String sha256 = http.getString();
sha256.trim(); // Remove any whitespace or newline characters
return sha256;
}
else
{
Serial.printf("Failed to download SHA256 file. HTTP error: %d\n", httpCode);
return "";
}
}

View file

@ -1,9 +1,20 @@
#pragma once
#include <Arduino.h>
#include <ArduinoOTA.h>
#include "lib/config.hpp"
#include "lib/shared.hpp"
#ifndef UPDATE_MESSAGE_HPP
#define UPDATE_MESSAGE_HPP
typedef struct {
char updateType;
} UpdateMessage;
#endif
extern QueueHandle_t otaQueue;
void setupOTA();
void onOTAStart();
void handleOTATask(void *parameter);
@ -11,5 +22,10 @@ void onOTAProgress(unsigned int progress, unsigned int total);
// void downloadUpdate();
void onOTAError(ota_error_t error);
void onOTAComplete();
int downloadUpdateHandler(char updateType);
String getLatestRelease(const String& fileToDownload);
bool getIsOTAUpdating();
bool getIsOTAUpdating();
void updateWebUi(String latestRelease, int command);
String downloadSHA256(const String& filename);

View file

@ -72,3 +72,40 @@ String calculateSHA256(uint8_t *data, size_t len)
return String(sha256_str);
}
String calculateSHA256(WiFiClient *stream, size_t contentLength) {
mbedtls_md_context_t ctx;
mbedtls_md_type_t md_type = MBEDTLS_MD_SHA256;
mbedtls_md_init(&ctx);
mbedtls_md_setup(&ctx, mbedtls_md_info_from_type(md_type), 0);
mbedtls_md_starts(&ctx);
uint8_t buff[1024];
size_t bytesRead = 0;
while (bytesRead < contentLength) {
size_t toRead = min((size_t)(contentLength - bytesRead), sizeof(buff));
size_t readBytes = stream->readBytes(buff, toRead);
if (readBytes == 0) {
break;
}
mbedtls_md_update(&ctx, buff, readBytes);
bytesRead += readBytes;
}
byte shaResult[32];
mbedtls_md_finish(&ctx, shaResult);
mbedtls_md_free(&ctx);
String result = "";
for (int i = 0; i < sizeof(shaResult); i++) {
char str[3];
sprintf(str, "%02x", (int)shaResult[i]);
result += str;
}
return result;
}

View file

@ -2,6 +2,7 @@
#include <Adafruit_MCP23X17.h>
#include <ArduinoJson.h>
#include <WiFiClientSecure.h>
#include <Preferences.h>
#include <freertos/FreeRTOS.h>
#include <freertos/task.h>
@ -78,4 +79,4 @@ struct ScreenMapping {
};
String calculateSHA256(uint8_t* data, size_t len);
String calculateSHA256(WiFiClient *stream, size_t contentLength);

View file

@ -142,6 +142,32 @@ void onFirmwareUpdate(AsyncWebServerRequest *request)
request->send(response);
}
void onUpdateWebUi(AsyncWebServerRequest *request)
{
UpdateMessage msg = {UPDATE_WEBUI};
if (xQueueSend(otaQueue, &msg, 0) == pdTRUE)
{
request->send(200, "text/plain", "WebUI update triggered");
}
else
{
request->send(503, "text/plain", "Update already in progress");
}
}
void onUpdateFirmware(AsyncWebServerRequest *request)
{
UpdateMessage msg = {UPDATE_FIRMWARE};
if (xQueueSend(otaQueue, &msg, 0) == pdTRUE)
{
request->send(200, "text/plain", "Firmware update triggered");
}
else
{
request->send(503, "text/plain", "Update already in progress");
}
}
void asyncWebuiUpdateHandler(AsyncWebServerRequest *request, String filename, size_t index, uint8_t *data, size_t len, bool final)
{
asyncFileUpdateHandler(request, filename, index, data, len, final, U_SPIFFS);
@ -1045,167 +1071,6 @@ void onApiShowCurrency(AsyncWebServerRequest *request)
request->send(404);
}
String getLatestRelease(const String &fileToDownload)
{
// const char *fileToDownload = "littlefs.bin";
String releaseUrl = "https://api.github.com/repos/btclock/btclock_v3/releases/latest";
WiFiClientSecure client;
client.setCACert(github_root_ca);
HTTPClient http;
http.begin(client, releaseUrl);
http.setUserAgent(USER_AGENT);
int httpCode = http.GET();
String downloadUrl = "";
if (httpCode > 0)
{
String payload = http.getString();
JsonDocument doc;
deserializeJson(doc, payload);
JsonArray assets = doc["assets"];
for (JsonObject asset : assets)
{
if (asset["name"] == fileToDownload)
{
downloadUrl = asset["browser_download_url"].as<String>();
break;
}
}
Serial.printf("Latest release URL: %s\r\n", downloadUrl.c_str());
}
return downloadUrl;
}
void onUpdateWebUi(AsyncWebServerRequest *request)
{
request->send(downloadUpdateHandler(UPDATE_WEBUI));
}
void onUpdateFirmware(AsyncWebServerRequest *request)
{
request->send(downloadUpdateHandler(UPDATE_FIRMWARE));
}
int downloadUpdateHandler(char updateType)
{
WiFiClientSecure client;
client.setCACert(github_root_ca);
HTTPClient http;
http.setFollowRedirects(HTTPC_STRICT_FOLLOW_REDIRECTS);
String latestRelease = "";
switch (updateType)
{
case UPDATE_FIRMWARE:
latestRelease = getLatestRelease(getFirmwareFilename());
break;
case UPDATE_WEBUI:
latestRelease = getLatestRelease("littlefs.bin");
break;
}
if (latestRelease.equals(""))
{
return 503;
}
http.begin(client, latestRelease);
http.setUserAgent(USER_AGENT);
int httpCode = http.GET();
if (httpCode == HTTP_CODE_OK)
{
int contentLength = http.getSize();
if (contentLength > 0)
{
uint8_t *buffer = (uint8_t *)malloc(contentLength);
if (buffer)
{
WiFiClient *stream = http.getStreamPtr();
size_t written = stream->readBytes(buffer, contentLength);
if (written == contentLength)
{
String calculated_sha256 = calculateSHA256(buffer, contentLength);
Serial.print("Checksum is ");
Serial.println(calculated_sha256);
if (true)
{
Serial.println("Checksum verified. Proceeding with update.");
Update.onProgress(onOTAProgress);
int updateType = U_FLASH;
switch (updateType)
{
case UPDATE_WEBUI:
updateType = U_SPIFFS;
break;
default:
{
updateType = U_FLASH;
}
}
if (Update.begin(contentLength, updateType))
{
Update.write(buffer, contentLength);
if (Update.end())
{
Serial.println("Update complete. Rebooting.");
ESP.restart();
}
else
{
Serial.println("Error in update process.");
}
}
else
{
Serial.println("Not enough space to begin OTA");
}
}
else
{
Serial.println("Checksum mismatch. Aborting update.");
}
}
else
{
Serial.println("Error downloading firmware");
}
free(buffer);
}
else
{
Serial.println("Not enough memory to allocate buffer");
}
}
else
{
Serial.println("Invalid content length");
}
}
else
{
Serial.print(httpCode);
Serial.println("Error on HTTP request");
return 503;
}
http.end();
return 200;
}
#ifdef HAS_FRONTLIGHT
void onApiFrontlightOn(AsyncWebServerRequest *request)
{

View file

@ -29,9 +29,7 @@ void onApiSetWifiTxPower(AsyncWebServerRequest *request);
void onUpdateWebUi(AsyncWebServerRequest *request);
void onUpdateFirmware(AsyncWebServerRequest *request);
int downloadUpdateHandler(char updateType);
String getLatestRelease(const String& fileToDownload);
void onApiScreenNext(AsyncWebServerRequest *request);
void onApiScreenPrevious(AsyncWebServerRequest *request);