mcp_server.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #ifndef MCP_SERVER_H
  2. #define MCP_SERVER_H
  3. #include <string>
  4. #include <vector>
  5. #include <map>
  6. #include <functional>
  7. #include <variant>
  8. #include <optional>
  9. #include <stdexcept>
  10. #include <thread>
  11. #include <cJSON.h>
  12. // 添加类型别名
  13. using ReturnValue = std::variant<bool, int, std::string>;
  14. enum PropertyType {
  15. kPropertyTypeBoolean,
  16. kPropertyTypeInteger,
  17. kPropertyTypeString
  18. };
  19. class Property {
  20. private:
  21. std::string name_;
  22. PropertyType type_;
  23. std::variant<bool, int, std::string> value_;
  24. bool has_default_value_;
  25. std::optional<int> min_value_; // 新增:整数最小值
  26. std::optional<int> max_value_; // 新增:整数最大值
  27. public:
  28. // Required field constructor
  29. Property(const std::string& name, PropertyType type)
  30. : name_(name), type_(type), has_default_value_(false) {}
  31. // Optional field constructor with default value
  32. template<typename T>
  33. Property(const std::string& name, PropertyType type, const T& default_value)
  34. : name_(name), type_(type), has_default_value_(true) {
  35. value_ = default_value;
  36. }
  37. Property(const std::string& name, PropertyType type, int min_value, int max_value)
  38. : name_(name), type_(type), has_default_value_(false), min_value_(min_value), max_value_(max_value) {
  39. if (type != kPropertyTypeInteger) {
  40. throw std::invalid_argument("Range limits only apply to integer properties");
  41. }
  42. }
  43. Property(const std::string& name, PropertyType type, int default_value, int min_value, int max_value)
  44. : name_(name), type_(type), has_default_value_(true), min_value_(min_value), max_value_(max_value) {
  45. if (type != kPropertyTypeInteger) {
  46. throw std::invalid_argument("Range limits only apply to integer properties");
  47. }
  48. if (default_value < min_value || default_value > max_value) {
  49. throw std::invalid_argument("Default value must be within the specified range");
  50. }
  51. value_ = default_value;
  52. }
  53. inline const std::string& name() const { return name_; }
  54. inline PropertyType type() const { return type_; }
  55. inline bool has_default_value() const { return has_default_value_; }
  56. inline bool has_range() const { return min_value_.has_value() && max_value_.has_value(); }
  57. inline int min_value() const { return min_value_.value_or(0); }
  58. inline int max_value() const { return max_value_.value_or(0); }
  59. template<typename T>
  60. inline T value() const {
  61. return std::get<T>(value_);
  62. }
  63. template<typename T>
  64. inline void set_value(const T& value) {
  65. // 添加对设置的整数值进行范围检查
  66. if constexpr (std::is_same_v<T, int>) {
  67. if (min_value_.has_value() && value < min_value_.value()) {
  68. throw std::invalid_argument("Value is below minimum allowed: " + std::to_string(min_value_.value()));
  69. }
  70. if (max_value_.has_value() && value > max_value_.value()) {
  71. throw std::invalid_argument("Value exceeds maximum allowed: " + std::to_string(max_value_.value()));
  72. }
  73. }
  74. value_ = value;
  75. }
  76. std::string to_json() const {
  77. cJSON *json = cJSON_CreateObject();
  78. if (type_ == kPropertyTypeBoolean) {
  79. cJSON_AddStringToObject(json, "type", "boolean");
  80. if (has_default_value_) {
  81. cJSON_AddBoolToObject(json, "default", value<bool>());
  82. }
  83. } else if (type_ == kPropertyTypeInteger) {
  84. cJSON_AddStringToObject(json, "type", "integer");
  85. if (has_default_value_) {
  86. cJSON_AddNumberToObject(json, "default", value<int>());
  87. }
  88. if (min_value_.has_value()) {
  89. cJSON_AddNumberToObject(json, "minimum", min_value_.value());
  90. }
  91. if (max_value_.has_value()) {
  92. cJSON_AddNumberToObject(json, "maximum", max_value_.value());
  93. }
  94. } else if (type_ == kPropertyTypeString) {
  95. cJSON_AddStringToObject(json, "type", "string");
  96. if (has_default_value_) {
  97. cJSON_AddStringToObject(json, "default", value<std::string>().c_str());
  98. }
  99. }
  100. char *json_str = cJSON_PrintUnformatted(json);
  101. std::string result(json_str);
  102. cJSON_free(json_str);
  103. cJSON_Delete(json);
  104. return result;
  105. }
  106. };
  107. class PropertyList {
  108. private:
  109. std::vector<Property> properties_;
  110. public:
  111. PropertyList() = default;
  112. PropertyList(const std::vector<Property>& properties) : properties_(properties) {}
  113. void AddProperty(const Property& property) {
  114. properties_.push_back(property);
  115. }
  116. const Property& operator[](const std::string& name) const {
  117. for (const auto& property : properties_) {
  118. if (property.name() == name) {
  119. return property;
  120. }
  121. }
  122. throw std::runtime_error("Property not found: " + name);
  123. }
  124. auto begin() { return properties_.begin(); }
  125. auto end() { return properties_.end(); }
  126. std::vector<std::string> GetRequired() const {
  127. std::vector<std::string> required;
  128. for (auto& property : properties_) {
  129. if (!property.has_default_value()) {
  130. required.push_back(property.name());
  131. }
  132. }
  133. return required;
  134. }
  135. std::string to_json() const {
  136. cJSON *json = cJSON_CreateObject();
  137. for (const auto& property : properties_) {
  138. cJSON *prop_json = cJSON_Parse(property.to_json().c_str());
  139. cJSON_AddItemToObject(json, property.name().c_str(), prop_json);
  140. }
  141. char *json_str = cJSON_PrintUnformatted(json);
  142. std::string result(json_str);
  143. cJSON_free(json_str);
  144. cJSON_Delete(json);
  145. return result;
  146. }
  147. };
  148. class McpTool {
  149. private:
  150. std::string name_;
  151. std::string description_;
  152. PropertyList properties_;
  153. std::function<ReturnValue(const PropertyList&)> callback_;
  154. public:
  155. McpTool(const std::string& name,
  156. const std::string& description,
  157. const PropertyList& properties,
  158. std::function<ReturnValue(const PropertyList&)> callback)
  159. : name_(name),
  160. description_(description),
  161. properties_(properties),
  162. callback_(callback) {}
  163. inline const std::string& name() const { return name_; }
  164. inline const std::string& description() const { return description_; }
  165. inline const PropertyList& properties() const { return properties_; }
  166. std::string to_json() const {
  167. std::vector<std::string> required = properties_.GetRequired();
  168. cJSON *json = cJSON_CreateObject();
  169. cJSON_AddStringToObject(json, "name", name_.c_str());
  170. cJSON_AddStringToObject(json, "description", description_.c_str());
  171. cJSON *input_schema = cJSON_CreateObject();
  172. cJSON_AddStringToObject(input_schema, "type", "object");
  173. cJSON *properties = cJSON_Parse(properties_.to_json().c_str());
  174. cJSON_AddItemToObject(input_schema, "properties", properties);
  175. if (!required.empty()) {
  176. cJSON *required_array = cJSON_CreateArray();
  177. for (const auto& property : required) {
  178. cJSON_AddItemToArray(required_array, cJSON_CreateString(property.c_str()));
  179. }
  180. cJSON_AddItemToObject(input_schema, "required", required_array);
  181. }
  182. cJSON_AddItemToObject(json, "inputSchema", input_schema);
  183. char *json_str = cJSON_PrintUnformatted(json);
  184. std::string result(json_str);
  185. cJSON_free(json_str);
  186. cJSON_Delete(json);
  187. return result;
  188. }
  189. std::string Call(const PropertyList& properties) {
  190. ReturnValue return_value = callback_(properties);
  191. // 返回结果
  192. cJSON* result = cJSON_CreateObject();
  193. cJSON* content = cJSON_CreateArray();
  194. cJSON* text = cJSON_CreateObject();
  195. cJSON_AddStringToObject(text, "type", "text");
  196. if (std::holds_alternative<std::string>(return_value)) {
  197. cJSON_AddStringToObject(text, "text", std::get<std::string>(return_value).c_str());
  198. } else if (std::holds_alternative<bool>(return_value)) {
  199. cJSON_AddStringToObject(text, "text", std::get<bool>(return_value) ? "true" : "false");
  200. } else if (std::holds_alternative<int>(return_value)) {
  201. cJSON_AddStringToObject(text, "text", std::to_string(std::get<int>(return_value)).c_str());
  202. }
  203. cJSON_AddItemToArray(content, text);
  204. cJSON_AddItemToObject(result, "content", content);
  205. cJSON_AddBoolToObject(result, "isError", false);
  206. auto json_str = cJSON_PrintUnformatted(result);
  207. std::string result_str(json_str);
  208. cJSON_free(json_str);
  209. cJSON_Delete(result);
  210. return result_str;
  211. }
  212. };
  213. class McpServer {
  214. public:
  215. static McpServer& GetInstance() {
  216. static McpServer instance;
  217. return instance;
  218. }
  219. void AddCommonTools();
  220. void AddTool(McpTool* tool);
  221. void AddTool(const std::string& name, const std::string& description, const PropertyList& properties, std::function<ReturnValue(const PropertyList&)> callback);
  222. void ParseMessage(const cJSON* json);
  223. void ParseMessage(const std::string& message);
  224. private:
  225. McpServer();
  226. ~McpServer();
  227. void ParseCapabilities(const cJSON* capabilities);
  228. void ReplyResult(int id, const std::string& result);
  229. void ReplyError(int id, const std::string& message);
  230. void GetToolsList(int id, const std::string& cursor);
  231. void DoToolCall(int id, const std::string& tool_name, const cJSON* tool_arguments, int stack_size);
  232. std::vector<McpTool*> tools_;
  233. std::thread tool_call_thread_;
  234. };
  235. #endif // MCP_SERVER_H