websocket_protocol.cc 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. #include "websocket_protocol.h"
  2. #include "board.h"
  3. #include "system_info.h"
  4. #include "application.h"
  5. #include "settings.h"
  6. #include <cstring>
  7. #include <cJSON.h>
  8. #include <esp_log.h>
  9. #include <arpa/inet.h>
  10. #include "assets/lang_config.h"
  11. #define TAG "WS"
  12. WebsocketProtocol::WebsocketProtocol() {
  13. event_group_handle_ = xEventGroupCreate();
  14. }
  15. WebsocketProtocol::~WebsocketProtocol() {
  16. vEventGroupDelete(event_group_handle_);
  17. }
  18. bool WebsocketProtocol::Start() {
  19. // Only connect to server when audio channel is needed
  20. return true;
  21. }
  22. bool WebsocketProtocol::SendAudio(std::unique_ptr<AudioStreamPacket> packet) {
  23. if (websocket_ == nullptr || !websocket_->IsConnected()) {
  24. return false;
  25. }
  26. if (version_ == 2) {
  27. std::string serialized;
  28. serialized.resize(sizeof(BinaryProtocol2) + packet->payload.size());
  29. auto bp2 = (BinaryProtocol2*)serialized.data();
  30. bp2->version = htons(version_);
  31. bp2->type = 0;
  32. bp2->reserved = 0;
  33. bp2->timestamp = htonl(packet->timestamp);
  34. bp2->payload_size = htonl(packet->payload.size());
  35. memcpy(bp2->payload, packet->payload.data(), packet->payload.size());
  36. return websocket_->Send(serialized.data(), serialized.size(), true);
  37. } else if (version_ == 3) {
  38. std::string serialized;
  39. serialized.resize(sizeof(BinaryProtocol3) + packet->payload.size());
  40. auto bp3 = (BinaryProtocol3*)serialized.data();
  41. bp3->type = 0;
  42. bp3->reserved = 0;
  43. bp3->payload_size = htons(packet->payload.size());
  44. memcpy(bp3->payload, packet->payload.data(), packet->payload.size());
  45. return websocket_->Send(serialized.data(), serialized.size(), true);
  46. } else {
  47. return websocket_->Send(packet->payload.data(), packet->payload.size(), true);
  48. }
  49. }
  50. bool WebsocketProtocol::SendText(const std::string& text) {
  51. if (websocket_ == nullptr || !websocket_->IsConnected()) {
  52. return false;
  53. }
  54. if (!websocket_->Send(text)) {
  55. ESP_LOGE(TAG, "Failed to send text: %s", text.c_str());
  56. SetError(Lang::Strings::SERVER_ERROR);
  57. return false;
  58. }
  59. return true;
  60. }
  61. bool WebsocketProtocol::IsAudioChannelOpened() const {
  62. return websocket_ != nullptr && websocket_->IsConnected() && !error_occurred_ && !IsTimeout();
  63. }
  64. void WebsocketProtocol::CloseAudioChannel() {
  65. websocket_.reset();
  66. }
  67. bool WebsocketProtocol::OpenAudioChannel() {
  68. Settings settings("websocket", false);
  69. std::string url = settings.GetString("url");
  70. std::string token = settings.GetString("token");
  71. int version = settings.GetInt("version");
  72. if (version != 0) {
  73. version_ = version;
  74. }
  75. error_occurred_ = false;
  76. auto network = Board::GetInstance().GetNetwork();
  77. websocket_ = network->CreateWebSocket(1);
  78. if (websocket_ == nullptr) {
  79. ESP_LOGE(TAG, "Failed to create websocket");
  80. return false;
  81. }
  82. if (!token.empty()) {
  83. // If token not has a space, add "Bearer " prefix
  84. if (token.find(" ") == std::string::npos) {
  85. token = "Bearer " + token;
  86. }
  87. websocket_->SetHeader("Authorization", token.c_str());
  88. }
  89. websocket_->SetHeader("Protocol-Version", std::to_string(version_).c_str());
  90. websocket_->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str());
  91. websocket_->SetHeader("Client-Id", Board::GetInstance().GetUuid().c_str());
  92. websocket_->OnData([this](const char* data, size_t len, bool binary) {
  93. if (binary) {
  94. if (on_incoming_audio_ != nullptr) {
  95. if (version_ == 2) {
  96. BinaryProtocol2* bp2 = (BinaryProtocol2*)data;
  97. bp2->version = ntohs(bp2->version);
  98. bp2->type = ntohs(bp2->type);
  99. bp2->timestamp = ntohl(bp2->timestamp);
  100. bp2->payload_size = ntohl(bp2->payload_size);
  101. auto payload = (uint8_t*)bp2->payload;
  102. on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
  103. .sample_rate = server_sample_rate_,
  104. .frame_duration = server_frame_duration_,
  105. .timestamp = bp2->timestamp,
  106. .payload = std::vector<uint8_t>(payload, payload + bp2->payload_size)
  107. }));
  108. } else if (version_ == 3) {
  109. BinaryProtocol3* bp3 = (BinaryProtocol3*)data;
  110. bp3->type = bp3->type;
  111. bp3->payload_size = ntohs(bp3->payload_size);
  112. auto payload = (uint8_t*)bp3->payload;
  113. on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
  114. .sample_rate = server_sample_rate_,
  115. .frame_duration = server_frame_duration_,
  116. .timestamp = 0,
  117. .payload = std::vector<uint8_t>(payload, payload + bp3->payload_size)
  118. }));
  119. } else {
  120. on_incoming_audio_(std::make_unique<AudioStreamPacket>(AudioStreamPacket{
  121. .sample_rate = server_sample_rate_,
  122. .frame_duration = server_frame_duration_,
  123. .timestamp = 0,
  124. .payload = std::vector<uint8_t>((uint8_t*)data, (uint8_t*)data + len)
  125. }));
  126. }
  127. }
  128. } else {
  129. // Parse JSON data
  130. auto root = cJSON_Parse(data);
  131. auto type = cJSON_GetObjectItem(root, "type");
  132. if (cJSON_IsString(type)) {
  133. if (strcmp(type->valuestring, "hello") == 0) {
  134. ParseServerHello(root);
  135. } else {
  136. if (on_incoming_json_ != nullptr) {
  137. on_incoming_json_(root);
  138. }
  139. }
  140. } else {
  141. ESP_LOGE(TAG, "Missing message type, data: %s", data);
  142. }
  143. cJSON_Delete(root);
  144. }
  145. last_incoming_time_ = std::chrono::steady_clock::now();
  146. });
  147. websocket_->OnDisconnected([this]() {
  148. ESP_LOGI(TAG, "Websocket disconnected");
  149. if (on_audio_channel_closed_ != nullptr) {
  150. on_audio_channel_closed_();
  151. }
  152. });
  153. ESP_LOGI(TAG, "Connecting to websocket server: %s with version: %d", url.c_str(), version_);
  154. if (!websocket_->Connect(url.c_str())) {
  155. ESP_LOGE(TAG, "Failed to connect to websocket server");
  156. SetError(Lang::Strings::SERVER_NOT_CONNECTED);
  157. return false;
  158. }
  159. // Send hello message to describe the client
  160. auto message = GetHelloMessage();
  161. if (!SendText(message)) {
  162. return false;
  163. }
  164. // Wait for server hello
  165. EventBits_t bits = xEventGroupWaitBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
  166. if (!(bits & WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT)) {
  167. ESP_LOGE(TAG, "Failed to receive server hello");
  168. SetError(Lang::Strings::SERVER_TIMEOUT);
  169. return false;
  170. }
  171. if (on_audio_channel_opened_ != nullptr) {
  172. on_audio_channel_opened_();
  173. }
  174. return true;
  175. }
  176. std::string WebsocketProtocol::GetHelloMessage() {
  177. // keys: message type, version, audio_params (format, sample_rate, channels)
  178. cJSON* root = cJSON_CreateObject();
  179. cJSON_AddStringToObject(root, "type", "hello");
  180. cJSON_AddNumberToObject(root, "version", version_);
  181. cJSON* features = cJSON_CreateObject();
  182. #if CONFIG_USE_SERVER_AEC
  183. cJSON_AddBoolToObject(features, "aec", true);
  184. #endif
  185. cJSON_AddBoolToObject(features, "mcp", true);
  186. cJSON_AddItemToObject(root, "features", features);
  187. cJSON_AddStringToObject(root, "transport", "websocket");
  188. cJSON* audio_params = cJSON_CreateObject();
  189. cJSON_AddStringToObject(audio_params, "format", "opus");
  190. cJSON_AddNumberToObject(audio_params, "sample_rate", 16000);
  191. cJSON_AddNumberToObject(audio_params, "channels", 1);
  192. cJSON_AddNumberToObject(audio_params, "frame_duration", OPUS_FRAME_DURATION_MS);
  193. cJSON_AddItemToObject(root, "audio_params", audio_params);
  194. auto json_str = cJSON_PrintUnformatted(root);
  195. std::string message(json_str);
  196. cJSON_free(json_str);
  197. cJSON_Delete(root);
  198. return message;
  199. }
  200. void WebsocketProtocol::ParseServerHello(const cJSON* root) {
  201. auto transport = cJSON_GetObjectItem(root, "transport");
  202. if (transport == nullptr || strcmp(transport->valuestring, "websocket") != 0) {
  203. ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
  204. return;
  205. }
  206. auto session_id = cJSON_GetObjectItem(root, "session_id");
  207. if (cJSON_IsString(session_id)) {
  208. session_id_ = session_id->valuestring;
  209. ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
  210. }
  211. auto audio_params = cJSON_GetObjectItem(root, "audio_params");
  212. if (cJSON_IsObject(audio_params)) {
  213. auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
  214. if (cJSON_IsNumber(sample_rate)) {
  215. server_sample_rate_ = sample_rate->valueint;
  216. }
  217. auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
  218. if (cJSON_IsNumber(frame_duration)) {
  219. server_frame_duration_ = frame_duration->valueint;
  220. }
  221. }
  222. xEventGroupSetBits(event_group_handle_, WEBSOCKET_PROTOCOL_SERVER_HELLO_EVENT);
  223. }