ota.cc 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. #include "ota.h"
  2. #include "system_info.h"
  3. #include "settings.h"
  4. #include "assets/lang_config.h"
  5. #include <cJSON.h>
  6. #include <esp_log.h>
  7. #include <esp_partition.h>
  8. #include <esp_ota_ops.h>
  9. #include <esp_app_format.h>
  10. #include <esp_efuse.h>
  11. #include <esp_efuse_table.h>
  12. #ifdef SOC_HMAC_SUPPORTED
  13. #include <esp_hmac.h>
  14. #endif
  15. #include <cstring>
  16. #include <vector>
  17. #include <sstream>
  18. #include <algorithm>
  19. #define TAG "Ota"
  20. Ota::Ota() {
  21. #ifdef ESP_EFUSE_BLOCK_USR_DATA
  22. // Read Serial Number from efuse user_data
  23. uint8_t serial_number[33] = {0};
  24. if (esp_efuse_read_field_blob(ESP_EFUSE_USER_DATA, serial_number, 32 * 8) == ESP_OK) {
  25. if (serial_number[0] == 0) {
  26. has_serial_number_ = false;
  27. } else {
  28. serial_number_ = std::string(reinterpret_cast<char*>(serial_number), 32);
  29. has_serial_number_ = true;
  30. }
  31. }
  32. #endif
  33. }
  34. Ota::~Ota() {
  35. }
  36. std::string Ota::GetCheckVersionUrl() {
  37. Settings settings("wifi", false);
  38. std::string url = settings.GetString("ota_url");
  39. if (url.empty()) {
  40. url = CONFIG_OTA_URL;
  41. }
  42. return url;
  43. }
  44. std::unique_ptr<Http> Ota::SetupHttp() {
  45. auto& board = Board::GetInstance();
  46. auto app_desc = esp_app_get_description();
  47. auto network = board.GetNetwork();
  48. auto http = network->CreateHttp(0);
  49. http->SetHeader("Activation-Version", has_serial_number_ ? "2" : "1");
  50. http->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str());
  51. http->SetHeader("Client-Id", board.GetUuid());
  52. if (has_serial_number_) {
  53. http->SetHeader("Serial-Number", serial_number_.c_str());
  54. }
  55. http->SetHeader("User-Agent", std::string(BOARD_NAME "/") + app_desc->version);
  56. http->SetHeader("Accept-Language", Lang::CODE);
  57. http->SetHeader("Content-Type", "application/json");
  58. return http;
  59. }
  60. /*
  61. * Specification: https://ccnphfhqs21z.feishu.cn/wiki/FjW6wZmisimNBBkov6OcmfvknVd
  62. */
  63. bool Ota::CheckVersion() {
  64. auto& board = Board::GetInstance();
  65. auto app_desc = esp_app_get_description();
  66. // Check if there is a new firmware version available
  67. current_version_ = app_desc->version;
  68. ESP_LOGI(TAG, "Current version: %s", current_version_.c_str());
  69. std::string url = GetCheckVersionUrl();
  70. if (url.length() < 10) {
  71. ESP_LOGE(TAG, "Check version URL is not properly set");
  72. return false;
  73. }
  74. auto http = SetupHttp();
  75. std::string data = board.GetJson();
  76. std::string method = data.length() > 0 ? "POST" : "GET";
  77. http->SetContent(std::move(data));
  78. if (!http->Open(method, url)) {
  79. ESP_LOGE(TAG, "Failed to open HTTP connection");
  80. return false;
  81. }
  82. auto status_code = http->GetStatusCode();
  83. if (status_code != 200) {
  84. ESP_LOGE(TAG, "Failed to check version, status code: %d", status_code);
  85. return false;
  86. }
  87. data = http->ReadAll();
  88. http->Close();
  89. // Response: { "firmware": { "version": "1.0.0", "url": "http://" } }
  90. // Parse the JSON response and check if the version is newer
  91. // If it is, set has_new_version_ to true and store the new version and URL
  92. cJSON *root = cJSON_Parse(data.c_str());
  93. if (root == NULL) {
  94. ESP_LOGE(TAG, "Failed to parse JSON response");
  95. return false;
  96. }
  97. has_activation_code_ = false;
  98. has_activation_challenge_ = false;
  99. cJSON *activation = cJSON_GetObjectItem(root, "activation");
  100. if (cJSON_IsObject(activation)) {
  101. cJSON* message = cJSON_GetObjectItem(activation, "message");
  102. if (cJSON_IsString(message)) {
  103. activation_message_ = message->valuestring;
  104. }
  105. cJSON* code = cJSON_GetObjectItem(activation, "code");
  106. if (cJSON_IsString(code)) {
  107. activation_code_ = code->valuestring;
  108. has_activation_code_ = true;
  109. }
  110. cJSON* challenge = cJSON_GetObjectItem(activation, "challenge");
  111. if (cJSON_IsString(challenge)) {
  112. activation_challenge_ = challenge->valuestring;
  113. has_activation_challenge_ = true;
  114. }
  115. cJSON* timeout_ms = cJSON_GetObjectItem(activation, "timeout_ms");
  116. if (cJSON_IsNumber(timeout_ms)) {
  117. activation_timeout_ms_ = timeout_ms->valueint;
  118. }
  119. }
  120. has_mqtt_config_ = false;
  121. cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
  122. if (cJSON_IsObject(mqtt)) {
  123. Settings settings("mqtt", true);
  124. cJSON *item = NULL;
  125. cJSON_ArrayForEach(item, mqtt) {
  126. if (cJSON_IsString(item)) {
  127. if (settings.GetString(item->string) != item->valuestring) {
  128. settings.SetString(item->string, item->valuestring);
  129. }
  130. } else if (cJSON_IsNumber(item)) {
  131. if (settings.GetInt(item->string) != item->valueint) {
  132. settings.SetInt(item->string, item->valueint);
  133. }
  134. }
  135. }
  136. has_mqtt_config_ = true;
  137. } else {
  138. ESP_LOGI(TAG, "No mqtt section found !");
  139. }
  140. has_websocket_config_ = false;
  141. cJSON *websocket = cJSON_GetObjectItem(root, "websocket");
  142. if (cJSON_IsObject(websocket)) {
  143. Settings settings("websocket", true);
  144. cJSON *item = NULL;
  145. cJSON_ArrayForEach(item, websocket) {
  146. if (cJSON_IsString(item)) {
  147. if (settings.GetString(item->string) != item->valuestring) {
  148. settings.SetString(item->string, item->valuestring);
  149. }
  150. } else if (cJSON_IsNumber(item)) {
  151. if (settings.GetInt(item->string) != item->valueint) {
  152. settings.SetInt(item->string, item->valueint);
  153. }
  154. }
  155. }
  156. has_websocket_config_ = true;
  157. } else {
  158. ESP_LOGI(TAG, "No websocket section found!");
  159. }
  160. has_server_time_ = false;
  161. cJSON *server_time = cJSON_GetObjectItem(root, "server_time");
  162. if (cJSON_IsObject(server_time)) {
  163. cJSON *timestamp = cJSON_GetObjectItem(server_time, "timestamp");
  164. cJSON *timezone_offset = cJSON_GetObjectItem(server_time, "timezone_offset");
  165. if (cJSON_IsNumber(timestamp)) {
  166. // 设置系统时间
  167. struct timeval tv;
  168. double ts = timestamp->valuedouble;
  169. // 如果有时区偏移,计算本地时间
  170. if (cJSON_IsNumber(timezone_offset)) {
  171. ts += (timezone_offset->valueint * 60 * 1000); // 转换分钟为毫秒
  172. }
  173. tv.tv_sec = (time_t)(ts / 1000); // 转换毫秒为秒
  174. tv.tv_usec = (suseconds_t)((long long)ts % 1000) * 1000; // 剩余的毫秒转换为微秒
  175. settimeofday(&tv, NULL);
  176. has_server_time_ = true;
  177. }
  178. } else {
  179. ESP_LOGW(TAG, "No server_time section found!");
  180. }
  181. has_new_version_ = false;
  182. cJSON *firmware = cJSON_GetObjectItem(root, "firmware");
  183. if (cJSON_IsObject(firmware)) {
  184. cJSON *version = cJSON_GetObjectItem(firmware, "version");
  185. if (cJSON_IsString(version)) {
  186. firmware_version_ = version->valuestring;
  187. }
  188. cJSON *url = cJSON_GetObjectItem(firmware, "url");
  189. if (cJSON_IsString(url)) {
  190. firmware_url_ = url->valuestring;
  191. }
  192. if (cJSON_IsString(version) && cJSON_IsString(url)) {
  193. // Check if the version is newer, for example, 0.1.0 is newer than 0.0.1
  194. has_new_version_ = IsNewVersionAvailable(current_version_, firmware_version_);
  195. if (has_new_version_) {
  196. ESP_LOGI(TAG, "New version available: %s", firmware_version_.c_str());
  197. } else {
  198. ESP_LOGI(TAG, "Current is the latest version");
  199. }
  200. // If the force flag is set to 1, the given version is forced to be installed
  201. cJSON *force = cJSON_GetObjectItem(firmware, "force");
  202. if (cJSON_IsNumber(force) && force->valueint == 1) {
  203. has_new_version_ = true;
  204. }
  205. }
  206. } else {
  207. ESP_LOGW(TAG, "No firmware section found!");
  208. }
  209. cJSON_Delete(root);
  210. return true;
  211. }
  212. void Ota::MarkCurrentVersionValid() {
  213. auto partition = esp_ota_get_running_partition();
  214. if (strcmp(partition->label, "factory") == 0) {
  215. ESP_LOGI(TAG, "Running from factory partition, skipping");
  216. return;
  217. }
  218. ESP_LOGI(TAG, "Running partition: %s", partition->label);
  219. esp_ota_img_states_t state;
  220. if (esp_ota_get_state_partition(partition, &state) != ESP_OK) {
  221. ESP_LOGE(TAG, "Failed to get state of partition");
  222. return;
  223. }
  224. if (state == ESP_OTA_IMG_PENDING_VERIFY) {
  225. ESP_LOGI(TAG, "Marking firmware as valid");
  226. esp_ota_mark_app_valid_cancel_rollback();
  227. }
  228. }
  229. bool Ota::Upgrade(const std::string& firmware_url) {
  230. ESP_LOGI(TAG, "Upgrading firmware from %s", firmware_url.c_str());
  231. esp_ota_handle_t update_handle = 0;
  232. auto update_partition = esp_ota_get_next_update_partition(NULL);
  233. if (update_partition == NULL) {
  234. ESP_LOGE(TAG, "Failed to get update partition");
  235. return false;
  236. }
  237. ESP_LOGI(TAG, "Writing to partition %s at offset 0x%lx", update_partition->label, update_partition->address);
  238. bool image_header_checked = false;
  239. std::string image_header;
  240. auto network = Board::GetInstance().GetNetwork();
  241. auto http = network->CreateHttp(0);
  242. if (!http->Open("GET", firmware_url)) {
  243. ESP_LOGE(TAG, "Failed to open HTTP connection");
  244. return false;
  245. }
  246. if (http->GetStatusCode() != 200) {
  247. ESP_LOGE(TAG, "Failed to get firmware, status code: %d", http->GetStatusCode());
  248. return false;
  249. }
  250. size_t content_length = http->GetBodyLength();
  251. if (content_length == 0) {
  252. ESP_LOGE(TAG, "Failed to get content length");
  253. return false;
  254. }
  255. char buffer[512];
  256. size_t total_read = 0, recent_read = 0;
  257. auto last_calc_time = esp_timer_get_time();
  258. while (true) {
  259. int ret = http->Read(buffer, sizeof(buffer));
  260. if (ret < 0) {
  261. ESP_LOGE(TAG, "Failed to read HTTP data: %s", esp_err_to_name(ret));
  262. return false;
  263. }
  264. // Calculate speed and progress every second
  265. recent_read += ret;
  266. total_read += ret;
  267. if (esp_timer_get_time() - last_calc_time >= 1000000 || ret == 0) {
  268. size_t progress = total_read * 100 / content_length;
  269. ESP_LOGI(TAG, "Progress: %u%% (%u/%u), Speed: %uB/s", progress, total_read, content_length, recent_read);
  270. if (upgrade_callback_) {
  271. upgrade_callback_(progress, recent_read);
  272. }
  273. last_calc_time = esp_timer_get_time();
  274. recent_read = 0;
  275. }
  276. if (ret == 0) {
  277. break;
  278. }
  279. if (!image_header_checked) {
  280. image_header.append(buffer, ret);
  281. if (image_header.size() >= sizeof(esp_image_header_t) + sizeof(esp_image_segment_header_t) + sizeof(esp_app_desc_t)) {
  282. esp_app_desc_t new_app_info;
  283. memcpy(&new_app_info, image_header.data() + sizeof(esp_image_header_t) + sizeof(esp_image_segment_header_t), sizeof(esp_app_desc_t));
  284. ESP_LOGI(TAG, "New firmware version: %s", new_app_info.version);
  285. auto current_version = esp_app_get_description()->version;
  286. if (memcmp(new_app_info.version, current_version, sizeof(new_app_info.version)) == 0) {
  287. ESP_LOGE(TAG, "Firmware version is the same, skipping upgrade");
  288. return false;
  289. }
  290. if (esp_ota_begin(update_partition, OTA_WITH_SEQUENTIAL_WRITES, &update_handle)) {
  291. esp_ota_abort(update_handle);
  292. ESP_LOGE(TAG, "Failed to begin OTA");
  293. return false;
  294. }
  295. image_header_checked = true;
  296. std::string().swap(image_header);
  297. }
  298. }
  299. auto err = esp_ota_write(update_handle, buffer, ret);
  300. if (err != ESP_OK) {
  301. ESP_LOGE(TAG, "Failed to write OTA data: %s", esp_err_to_name(err));
  302. esp_ota_abort(update_handle);
  303. return false;
  304. }
  305. }
  306. http->Close();
  307. esp_err_t err = esp_ota_end(update_handle);
  308. if (err != ESP_OK) {
  309. if (err == ESP_ERR_OTA_VALIDATE_FAILED) {
  310. ESP_LOGE(TAG, "Image validation failed, image is corrupted");
  311. } else {
  312. ESP_LOGE(TAG, "Failed to end OTA: %s", esp_err_to_name(err));
  313. }
  314. return false;
  315. }
  316. err = esp_ota_set_boot_partition(update_partition);
  317. if (err != ESP_OK) {
  318. ESP_LOGE(TAG, "Failed to set boot partition: %s", esp_err_to_name(err));
  319. return false;
  320. }
  321. ESP_LOGI(TAG, "Firmware upgrade successful");
  322. return true;
  323. }
  324. bool Ota::StartUpgrade(std::function<void(int progress, size_t speed)> callback) {
  325. upgrade_callback_ = callback;
  326. return Upgrade(firmware_url_);
  327. }
  328. std::vector<int> Ota::ParseVersion(const std::string& version) {
  329. std::vector<int> versionNumbers;
  330. std::stringstream ss(version);
  331. std::string segment;
  332. while (std::getline(ss, segment, '.')) {
  333. versionNumbers.push_back(std::stoi(segment));
  334. }
  335. return versionNumbers;
  336. }
  337. bool Ota::IsNewVersionAvailable(const std::string& currentVersion, const std::string& newVersion) {
  338. std::vector<int> current = ParseVersion(currentVersion);
  339. std::vector<int> newer = ParseVersion(newVersion);
  340. for (size_t i = 0; i < std::min(current.size(), newer.size()); ++i) {
  341. if (newer[i] > current[i]) {
  342. return true;
  343. } else if (newer[i] < current[i]) {
  344. return false;
  345. }
  346. }
  347. return newer.size() > current.size();
  348. }
  349. std::string Ota::GetActivationPayload() {
  350. if (!has_serial_number_) {
  351. return "{}";
  352. }
  353. std::string hmac_hex;
  354. #ifdef SOC_HMAC_SUPPORTED
  355. uint8_t hmac_result[32]; // SHA-256 输出为32字节
  356. // 使用Key0计算HMAC
  357. esp_err_t ret = esp_hmac_calculate(HMAC_KEY0, (uint8_t*)activation_challenge_.data(), activation_challenge_.size(), hmac_result);
  358. if (ret != ESP_OK) {
  359. ESP_LOGE(TAG, "HMAC calculation failed: %s", esp_err_to_name(ret));
  360. return "{}";
  361. }
  362. for (size_t i = 0; i < sizeof(hmac_result); i++) {
  363. char buffer[3];
  364. sprintf(buffer, "%02x", hmac_result[i]);
  365. hmac_hex += buffer;
  366. }
  367. #endif
  368. cJSON *payload = cJSON_CreateObject();
  369. cJSON_AddStringToObject(payload, "algorithm", "hmac-sha256");
  370. cJSON_AddStringToObject(payload, "serial_number", serial_number_.c_str());
  371. cJSON_AddStringToObject(payload, "challenge", activation_challenge_.c_str());
  372. cJSON_AddStringToObject(payload, "hmac", hmac_hex.c_str());
  373. auto json_str = cJSON_PrintUnformatted(payload);
  374. std::string json(json_str);
  375. cJSON_free(json_str);
  376. cJSON_Delete(payload);
  377. ESP_LOGI(TAG, "Activation payload: %s", json.c_str());
  378. return json;
  379. }
  380. esp_err_t Ota::Activate() {
  381. if (!has_activation_challenge_) {
  382. ESP_LOGW(TAG, "No activation challenge found");
  383. return ESP_FAIL;
  384. }
  385. std::string url = GetCheckVersionUrl();
  386. if (url.back() != '/') {
  387. url += "/activate";
  388. } else {
  389. url += "activate";
  390. }
  391. auto http = SetupHttp();
  392. std::string data = GetActivationPayload();
  393. http->SetContent(std::move(data));
  394. if (!http->Open("POST", url)) {
  395. ESP_LOGE(TAG, "Failed to open HTTP connection");
  396. return ESP_FAIL;
  397. }
  398. auto status_code = http->GetStatusCode();
  399. if (status_code == 202) {
  400. return ESP_ERR_TIMEOUT;
  401. }
  402. if (status_code != 200) {
  403. ESP_LOGE(TAG, "Failed to activate, code: %d, body: %s", status_code, http->ReadAll().c_str());
  404. return ESP_FAIL;
  405. }
  406. ESP_LOGI(TAG, "Activation successful");
  407. return ESP_OK;
  408. }