mqtt_protocol.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. #include "mqtt_protocol.h"
  2. #include "board.h"
  3. #include "application.h"
  4. #include "settings.h"
  5. #include <esp_log.h>
  6. #include <cstring>
  7. #include <arpa/inet.h>
  8. #include "assets/lang_config.h"
  9. #define TAG "MQTT"
  10. MqttProtocol::MqttProtocol() {
  11. event_group_handle_ = xEventGroupCreate();
  12. }
  13. MqttProtocol::~MqttProtocol() {
  14. ESP_LOGI(TAG, "MqttProtocol deinit");
  15. vEventGroupDelete(event_group_handle_);
  16. }
  17. bool MqttProtocol::Start() {
  18. return StartMqttClient(false);
  19. }
  20. bool MqttProtocol::StartMqttClient(bool report_error) {
  21. if (mqtt_ != nullptr) {
  22. ESP_LOGW(TAG, "Mqtt client already started");
  23. mqtt_.reset();
  24. }
  25. Settings settings("mqtt", false);
  26. auto endpoint = settings.GetString("endpoint");
  27. auto client_id = settings.GetString("client_id");
  28. auto username = settings.GetString("username");
  29. auto password = settings.GetString("password");
  30. int keepalive_interval = settings.GetInt("keepalive", 240);
  31. publish_topic_ = settings.GetString("publish_topic");
  32. if (endpoint.empty()) {
  33. ESP_LOGW(TAG, "MQTT endpoint is not specified");
  34. if (report_error) {
  35. SetError(Lang::Strings::SERVER_NOT_FOUND);
  36. }
  37. return false;
  38. }
  39. auto network = Board::GetInstance().GetNetwork();
  40. mqtt_ = network->CreateMqtt(0);
  41. mqtt_->SetKeepAlive(keepalive_interval);
  42. mqtt_->OnDisconnected([this]() {
  43. ESP_LOGI(TAG, "Disconnected from endpoint");
  44. });
  45. mqtt_->OnMessage([this](const std::string& topic, const std::string& payload) {
  46. cJSON* root = cJSON_Parse(payload.c_str());
  47. if (root == nullptr) {
  48. ESP_LOGE(TAG, "Failed to parse json message %s", payload.c_str());
  49. return;
  50. }
  51. cJSON* type = cJSON_GetObjectItem(root, "type");
  52. if (!cJSON_IsString(type)) {
  53. ESP_LOGE(TAG, "Message type is invalid");
  54. cJSON_Delete(root);
  55. return;
  56. }
  57. if (strcmp(type->valuestring, "hello") == 0) {
  58. ParseServerHello(root);
  59. } else if (strcmp(type->valuestring, "goodbye") == 0) {
  60. auto session_id = cJSON_GetObjectItem(root, "session_id");
  61. ESP_LOGI(TAG, "Received goodbye message, session_id: %s", session_id ? session_id->valuestring : "null");
  62. if (session_id == nullptr || session_id_ == session_id->valuestring) {
  63. Application::GetInstance().Schedule([this]() {
  64. CloseAudioChannel();
  65. });
  66. }
  67. } else if (on_incoming_json_ != nullptr) {
  68. on_incoming_json_(root);
  69. }
  70. cJSON_Delete(root);
  71. last_incoming_time_ = std::chrono::steady_clock::now();
  72. });
  73. ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint.c_str());
  74. std::string broker_address;
  75. int broker_port = 8883;
  76. size_t pos = endpoint.find(':');
  77. if (pos != std::string::npos) {
  78. broker_address = endpoint.substr(0, pos);
  79. broker_port = std::stoi(endpoint.substr(pos + 1));
  80. } else {
  81. broker_address = endpoint;
  82. }
  83. if (!mqtt_->Connect(broker_address, broker_port, client_id, username, password)) {
  84. ESP_LOGE(TAG, "Failed to connect to endpoint");
  85. SetError(Lang::Strings::SERVER_NOT_CONNECTED);
  86. return false;
  87. }
  88. ESP_LOGI(TAG, "Connected to endpoint");
  89. return true;
  90. }
  91. bool MqttProtocol::SendText(const std::string& text) {
  92. if (publish_topic_.empty()) {
  93. return false;
  94. }
  95. if (!mqtt_->Publish(publish_topic_, text)) {
  96. ESP_LOGE(TAG, "Failed to publish message: %s", text.c_str());
  97. SetError(Lang::Strings::SERVER_ERROR);
  98. return false;
  99. }
  100. return true;
  101. }
  102. bool MqttProtocol::SendAudio(std::unique_ptr<AudioStreamPacket> packet) {
  103. std::lock_guard<std::mutex> lock(channel_mutex_);
  104. if (udp_ == nullptr) {
  105. return false;
  106. }
  107. std::string nonce(aes_nonce_);
  108. *(uint16_t*)&nonce[2] = htons(packet->payload.size());
  109. *(uint32_t*)&nonce[8] = htonl(packet->timestamp);
  110. *(uint32_t*)&nonce[12] = htonl(++local_sequence_);
  111. std::string encrypted;
  112. encrypted.resize(aes_nonce_.size() + packet->payload.size());
  113. memcpy(encrypted.data(), nonce.data(), nonce.size());
  114. size_t nc_off = 0;
  115. uint8_t stream_block[16] = {0};
  116. if (mbedtls_aes_crypt_ctr(&aes_ctx_, packet->payload.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block,
  117. (uint8_t*)packet->payload.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) {
  118. ESP_LOGE(TAG, "Failed to encrypt audio data");
  119. return false;
  120. }
  121. return udp_->Send(encrypted) > 0;
  122. }
  123. void MqttProtocol::CloseAudioChannel() {
  124. {
  125. std::lock_guard<std::mutex> lock(channel_mutex_);
  126. udp_.reset();
  127. }
  128. std::string message = "{";
  129. message += "\"session_id\":\"" + session_id_ + "\",";
  130. message += "\"type\":\"goodbye\"";
  131. message += "}";
  132. SendText(message);
  133. if (on_audio_channel_closed_ != nullptr) {
  134. on_audio_channel_closed_();
  135. }
  136. }
  137. bool MqttProtocol::OpenAudioChannel() {
  138. if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
  139. ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
  140. if (!StartMqttClient(true)) {
  141. return false;
  142. }
  143. }
  144. error_occurred_ = false;
  145. session_id_ = "";
  146. xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
  147. auto message = GetHelloMessage();
  148. if (!SendText(message)) {
  149. return false;
  150. }
  151. // 等待服务器响应
  152. EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
  153. if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) {
  154. ESP_LOGE(TAG, "Failed to receive server hello");
  155. SetError(Lang::Strings::SERVER_TIMEOUT);
  156. return false;
  157. }
  158. std::lock_guard<std::mutex> lock(channel_mutex_);
  159. auto network = Board::GetInstance().GetNetwork();
  160. udp_ = network->CreateUdp(2);
  161. udp_->OnMessage([this](const std::string& data) {
  162. /*
  163. * UDP Encrypted OPUS Packet Format:
  164. * |type 1u|flags 1u|payload_len 2u|ssrc 4u|timestamp 4u|sequence 4u|
  165. * |payload payload_len|
  166. */
  167. if (data.size() < sizeof(aes_nonce_)) {
  168. ESP_LOGE(TAG, "Invalid audio packet size: %u", data.size());
  169. return;
  170. }
  171. if (data[0] != 0x01) {
  172. ESP_LOGE(TAG, "Invalid audio packet type: %x", data[0]);
  173. return;
  174. }
  175. uint32_t timestamp = ntohl(*(uint32_t*)&data[8]);
  176. uint32_t sequence = ntohl(*(uint32_t*)&data[12]);
  177. if (sequence < remote_sequence_) {
  178. ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
  179. return;
  180. }
  181. if (sequence != remote_sequence_ + 1) {
  182. ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
  183. }
  184. size_t decrypted_size = data.size() - aes_nonce_.size();
  185. size_t nc_off = 0;
  186. uint8_t stream_block[16] = {0};
  187. auto nonce = (uint8_t*)data.data();
  188. auto encrypted = (uint8_t*)data.data() + aes_nonce_.size();
  189. auto packet = std::make_unique<AudioStreamPacket>();
  190. packet->sample_rate = server_sample_rate_;
  191. packet->frame_duration = server_frame_duration_;
  192. packet->timestamp = timestamp;
  193. packet->payload.resize(decrypted_size);
  194. int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)packet->payload.data());
  195. if (ret != 0) {
  196. ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret);
  197. return;
  198. }
  199. if (on_incoming_audio_ != nullptr) {
  200. on_incoming_audio_(std::move(packet));
  201. }
  202. remote_sequence_ = sequence;
  203. last_incoming_time_ = std::chrono::steady_clock::now();
  204. });
  205. udp_->Connect(udp_server_, udp_port_);
  206. if (on_audio_channel_opened_ != nullptr) {
  207. on_audio_channel_opened_();
  208. }
  209. return true;
  210. }
  211. std::string MqttProtocol::GetHelloMessage() {
  212. // 发送 hello 消息申请 UDP 通道
  213. cJSON* root = cJSON_CreateObject();
  214. cJSON_AddStringToObject(root, "type", "hello");
  215. cJSON_AddNumberToObject(root, "version", 3);
  216. cJSON_AddStringToObject(root, "transport", "udp");
  217. cJSON* features = cJSON_CreateObject();
  218. #if CONFIG_USE_SERVER_AEC
  219. cJSON_AddBoolToObject(features, "aec", true);
  220. #endif
  221. cJSON_AddBoolToObject(features, "mcp", true);
  222. cJSON_AddItemToObject(root, "features", features);
  223. cJSON* audio_params = cJSON_CreateObject();
  224. cJSON_AddStringToObject(audio_params, "format", "opus");
  225. cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
  226. cJSON_AddNumberToObject(audio_params, "channels", 1);
  227. cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
  228. cJSON_AddItemToObject(root, "audio_params", audio_params);
  229. auto json_str = cJSON_PrintUnformatted(root);
  230. std::string message(json_str);
  231. cJSON_free(json_str);
  232. cJSON_Delete(root);
  233. return message;
  234. }
  235. void MqttProtocol::ParseServerHello(const cJSON* root) {
  236. auto transport = cJSON_GetObjectItem(root, "transport");
  237. if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
  238. ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
  239. return;
  240. }
  241. auto session_id = cJSON_GetObjectItem(root, "session_id");
  242. if (cJSON_IsString(session_id)) {
  243. session_id_ = session_id->valuestring;
  244. ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
  245. }
  246. // Get sample rate from hello message
  247. auto audio_params = cJSON_GetObjectItem(root, "audio_params");
  248. if (cJSON_IsObject(audio_params)) {
  249. auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
  250. if (cJSON_IsNumber(sample_rate)) {
  251. server_sample_rate_ = sample_rate->valueint;
  252. }
  253. auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
  254. if (cJSON_IsNumber(frame_duration)) {
  255. server_frame_duration_ = frame_duration->valueint;
  256. }
  257. }
  258. auto udp = cJSON_GetObjectItem(root, "udp");
  259. if (!cJSON_IsObject(udp)) {
  260. ESP_LOGE(TAG, "UDP is not specified");
  261. return;
  262. }
  263. udp_server_ = cJSON_GetObjectItem(udp, "server")->valuestring;
  264. udp_port_ = cJSON_GetObjectItem(udp, "port")->valueint;
  265. auto key = cJSON_GetObjectItem(udp, "key")->valuestring;
  266. auto nonce = cJSON_GetObjectItem(udp, "nonce")->valuestring;
  267. // auto encryption = cJSON_GetObjectItem(udp, "encryption")->valuestring;
  268. // ESP_LOGI(TAG, "UDP server: %s, port: %d, encryption: %s", udp_server_.c_str(), udp_port_, encryption);
  269. aes_nonce_ = DecodeHexString(nonce);
  270. mbedtls_aes_init(&aes_ctx_);
  271. mbedtls_aes_setkey_enc(&aes_ctx_, (const unsigned char*)DecodeHexString(key).c_str(), 128);
  272. local_sequence_ = 0;
  273. remote_sequence_ = 0;
  274. xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
  275. }
  276. static const char hex_chars[] = "0123456789ABCDEF";
  277. // 辅助函数,将单个十六进制字符转换为对应的数值
  278. static inline uint8_t CharToHex(char c) {
  279. if (c >= '0' && c <= '9') return c - '0';
  280. if (c >= 'A' && c <= 'F') return c - 'A' + 10;
  281. if (c >= 'a' && c <= 'f') return c - 'a' + 10;
  282. return 0; // 对于无效输入,返回0
  283. }
  284. std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
  285. std::string decoded;
  286. decoded.reserve(hex_string.size() / 2);
  287. for (size_t i = 0; i < hex_string.size(); i += 2) {
  288. char byte = (CharToHex(hex_string[i]) << 4) | CharToHex(hex_string[i + 1]);
  289. decoded.push_back(byte);
  290. }
  291. return decoded;
  292. }
  293. bool MqttProtocol::IsAudioChannelOpened() const {
  294. return udp_ != nullptr && !error_occurred_ && !IsTimeout();
  295. }