diff --git a/WFConsulClient.cc b/WFConsulClient.cc new file mode 100644 index 0000000..8e1bef2 --- /dev/null +++ b/WFConsulClient.cc @@ -0,0 +1,1477 @@ +/* + Copyright (c) 2022 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include +#include "json_parser.h" // JSON解析工具 +#include "StringUtil.h" // 字符串工具 +#include "URIParser.h" // URI解析工具 +#include "HttpUtil.h" // HTTP相关工具 +#include "WFConsulClient.h" // Consul客户端相关定义 + +using namespace protocol; // 使用协议相关命名空间 + +// WFConsulTask 构造函数 +WFConsulTask::WFConsulTask(const std::string& proxy_url, + const std::string& service_namespace, + const std::string& service_name, + const std::string& service_id, + int retry_max, + consul_callback_t&& cb) : + proxy_url(proxy_url), // 初始化代理地址 + callback(std::move(cb)) // 设置回调函数 +{ + this->service.service_name = service_name; // 初始化服务名称 + this->service.service_namespace = service_namespace; // 初始化服务命名空间 + this->service.service_id = service_id; // 初始化服务 ID + this->api_type = CONSUL_API_TYPE_UNKNOWN; // 设置 API 类型为未知 + this->retry_max = retry_max; // 设置最大重试次数 + this->finish = false; // 标记任务未完成 + this->consul_index = 0; // 初始化 Consul 索引为 0 +} + +// 设置服务相关信息 +void WFConsulTask::set_service(const struct protocol::ConsulService *service) +{ + this->service.tags = service->tags; // 设置服务标签 + this->service.meta = service->meta; // 设置服务元信息 + this->service.tag_override = service->tag_override; // 设置标签覆盖标志 + this->service.service_address = service->service_address; // 设置服务地址 + this->service.lan = service->lan; // 设置本地局域网地址 + this->service.lan_ipv4 = service->lan_ipv4; // 设置 IPv4 地址 + this->service.lan_ipv6 = service->lan_ipv6; // 设置 IPv6 地址 + this->service.virtual_address = service->virtual_address; // 设置虚拟地址 + this->service.wan = service->wan; // 设置广域网地址 + this->service.wan_ipv4 = service->wan_ipv4; // 设置 IPv4 地址 + this->service.wan_ipv6 = service->wan_ipv6; // 设置 IPv6 地址 +} + + +// 解析服务发现结果的静态函数 +static bool parse_discover_result(const json_value_t *root, + std::vector& result); + +// 解析服务列表结果的静态函数 +static bool parse_list_service_result(const json_value_t *root, + std::vector& result); + + +// 获取服务发现结果 +bool WFConsulTask::get_discover_result( + std::vector& result) +{ + json_value_t *root; // JSON 根节点 + int errno_bak; // 备份当前 errno 值 + bool ret; // 返回结果标志 + + // 检查当前 API 类型是否为服务发现 + if (this->api_type != CONSUL_API_TYPE_DISCOVER) + { + errno = EPERM; // 设置权限错误 + return false; // 返回失败 + } + + errno_bak = errno; // 备份 errno + errno = EBADMSG; // 设置消息错误 + std::string body = HttpUtil::decode_chunked_body(&this->http_resp); // 解码 HTTP 响应的分块数据 + root = json_value_parse(body.c_str()); // 解析 JSON 数据 + if (!root) // 如果解析失败 + return false; // 返回失败 + + ret = parse_discover_result(root, result); // 解析服务发现结果 + json_value_destroy(root); // 销毁 JSON 对象 + if (ret) // 如果解析成功 + errno = errno_bak; // 恢复原始 errno + + return ret; // 返回解析结果 +} + + +// 获取服务列表结果 +bool WFConsulTask::get_list_service_result( + std::vector& result) +{ + json_value_t *root; // JSON 根节点 + int errno_bak; // 备份当前 errno 值 + bool ret; // 返回结果标志 + + // 检查当前 API 类型是否为服务列表 + if (this->api_type != CONSUL_API_TYPE_LIST_SERVICE) + { + errno = EPERM; // 设置权限错误 + return false; // 返回失败 + } + + errno_bak = errno; // 备份 errno + errno = EBADMSG; // 设置消息错误 + std::string body = HttpUtil::decode_chunked_body(&this->http_resp); // 解码 HTTP 响应的分块数据 + root = json_value_parse(body.c_str()); // 解析 JSON 数据 + if (!root) // 如果解析失败 + return false; // 返回失败 + + ret = parse_list_service_result(root, result); // 解析服务列表结果 + json_value_destroy(root); // 销毁 JSON 对象 + if (ret) // 如果解析成功 + errno = errno_bak; // 恢复原始 errno + + return ret; // 返回解析结果 +} + + +// 分发任务,根据 API 类型选择相应的任务处理逻辑 +void WFConsulTask::dispatch() +{ + WFHttpTask *task; // 定义 HTTP 任务指针 + + if (this->finish) // 如果任务已完成 + { + this->subtask_done(); // 标记子任务完成 + return; // 退出 + } + + // 根据 API 类型选择处理任务 + switch(this->api_type) + { + case CONSUL_API_TYPE_DISCOVER: // 如果是服务发现任务 + task = create_discover_task(); // 创建服务发现任务 + break; + + case CONSUL_API_TYPE_LIST_SERVICE: // 如果是服务列表任务 + task = create_list_service_task(); // 创建服务列表任务 + break; + + case CONSUL_API_TYPE_DEREGISTER: // 如果是注销服务任务 + task = create_deregister_task(); // 创建注销服务任务 + break; + + case CONSUL_API_TYPE_REGISTER: // 如果是注册服务任务 + task = create_register_task(); // 创建注册服务任务 + if (task) + break; + + // 如果任务创建失败,设置错误状态 + if (1) + { + this->state = WFT_STATE_SYS_ERROR; // 设置系统错误状态 + this->error = errno; // 设置错误码 + } + else + { + default: + this->state = WFT_STATE_TASK_ERROR; // 设置任务错误状态 + this->error = WFT_ERR_CONSUL_API_UNKNOWN; // 设置未知 API 错误 + } + + this->finish = true; // 标记任务完成 + this->subtask_done(); // 标记子任务完成 + return; + } + + // 将任务插入到工作队列 + series_of(this)->push_front(this); // 将当前任务加入队列 + series_of(this)->push_front(task); // 将新任务加入队列 + this->subtask_done(); // 标记子任务完成 +} + + +// 标记任务完成,并执行回调或清理逻辑 +SubTask *WFConsulTask::done() +{ + SeriesWork *series = series_of(this); // 获取当前任务的工作队列 + + if (finish) // 如果任务已完成 + { + if (this->callback) // 如果有回调函数 + this->callback(this); // 执行回调函数 + + delete this; // 删除任务对象 + } + + return series->pop(); // 从队列中弹出任务 +} + + +// 将毫秒时间转换为字符串形式(如 "5s" 或 "3m") +static std::string convert_time_to_str(int milliseconds) +{ + std::string str_time; // 保存时间字符串 + int seconds = milliseconds / 1000; // 将毫秒转换为秒 + + if (seconds >= 180) // 如果时间超过 3 分钟 + str_time = std::to_string(seconds / 60) + "m"; // 转换为分钟形式 + else + str_time = std::to_string(seconds) + "s"; // 转换为秒形式 + + return str_time; // 返回时间字符串 +} + + +// 生成服务发现请求的 URL +std::string WFConsulTask::generate_discover_request() +{ + std::string url = this->proxy_url; // 初始化 URL 为代理地址 + + // 添加服务健康检查路径 + url += "/v1/health/service/" + this->service.service_name; + + // 添加数据中心参数 + url += "?dc=" + this->config.get_datacenter(); + + // 添加服务命名空间参数 + url += "&ns=" + this->service.service_namespace; + + // 获取 passing 参数值,表示是否仅返回通过健康检查的服务实例 + std::string passing = this->config.get_passing() ? "true" : "false"; + url += "&passing=" + passing; // 添加 passing 参数 + + // 添加 token 参数用于鉴权 + url += "&token=" + this->config.get_token(); + + // 添加过滤表达式参数 + url += "&filter=" + this->config.get_filter_expr(); + + // 检查是否启用 Consul 阻塞查询 + if (this->config.blocking_query()) + { + // 添加索引参数,用于阻塞查询的初始状态 + url += "&index=" + std::to_string(this->get_consul_index()); + + // 添加等待时间参数,表示阻塞查询的超时时间 + url += "&wait=" + convert_time_to_str(this->config.get_wait_ttl()); + } + + return url; // 返回生成的 URL +} + + +// 创建服务发现任务 +WFHttpTask *WFConsulTask::create_discover_task() +{ + std::string url = generate_discover_request(); // 调用生成服务发现 URL 的函数 + WFHttpTask *task = WFTaskFactory::create_http_task( // 创建 HTTP 任务 + url, 0, this->retry_max, discover_callback); + + HttpRequest *req = task->get_req(); // 获取 HTTP 请求对象 + + // 添加 HTTP 请求头,指定内容类型为 JSON + req->add_header_pair("Content-Type", "application/json"); + + task->user_data = this; // 设置用户数据为当前任务对象 + return task; // 返回创建的任务 +} + + +// 创建服务列表任务 +WFHttpTask *WFConsulTask::create_list_service_task() +{ + std::string url = this->proxy_url; // 初始化 URL 为代理地址 + + // 添加服务目录路径和鉴权 token + url += "/v1/catalog/services?token=" + this->config.get_token(); + + // 添加数据中心参数 + url += "&dc=" + this->config.get_datacenter(); + + // 添加服务命名空间参数 + url += "&ns=" + this->service.service_namespace; + + WFHttpTask *task = WFTaskFactory::create_http_task( // 创建 HTTP 任务 + url, 0, this->retry_max, list_service_callback); + + HttpRequest *req = task->get_req(); // 获取 HTTP 请求对象 + + // 添加 HTTP 请求头,指定内容类型为 JSON + req->add_header_pair("Content-Type", "application/json"); + + task->user_data = this; // 设置用户数据为当前任务对象 + return task; // 返回创建的任务 +} + + +static void print_json_value(const json_value_t *val, int depth, + std::string& json_str); + +static bool create_register_request(const json_value_t *root, + const struct ConsulService *service, + const ConsulConfig& config); + +// 创建服务注册任务 +WFHttpTask *WFConsulTask::create_register_task() +{ + std::string payload; // 定义用于存储请求体的字符串 + + std::string url = this->proxy_url; // 初始化 URL 为代理地址 + // 添加服务注册路径和替换检查选项 + url += "/v1/agent/service/register?replace-existing-checks="; + url += this->config.get_replace_checks() ? "true" : "false"; + + WFHttpTask *task = WFTaskFactory::create_http_task( // 创建 HTTP 任务 + url, 0, this->retry_max, register_callback); + + HttpRequest *req = task->get_req(); // 获取 HTTP 请求对象 + + req->set_method(HttpMethodPut); // 设置 HTTP 方法为 PUT + req->add_header_pair("Content-Type", "application/json"); // 添加内容类型头 + + // 如果存在鉴权 token,则添加到请求头 + if (!this->config.get_token().empty()) + req->add_header_pair("X-Consul-Token", this->config.get_token()); + + json_value_t *root = json_value_create(JSON_VALUE_OBJECT); // 创建 JSON 对象根节点 + if (root) // 如果创建成功 + { + // 构建服务注册请求 JSON 数据 + if (create_register_request(root, &this->service, this->config)) + print_json_value(root, 0, payload); // 将 JSON 数据转为字符串 + + json_value_destroy(root); // 销毁 JSON 对象 + // 如果请求体不为空且追加到请求成功 + if (!payload.empty() && req->append_output_body(payload)) + { + task->user_data = this; // 设置用户数据为当前任务对象 + return task; // 返回创建的任务 + } + } + + task->dismiss(); // 如果失败,销毁任务 + return NULL; // 返回空指针 +} + + +// 创建服务注销任务 +WFHttpTask *WFConsulTask::create_deregister_task() +{ + std::string url = this->proxy_url; // 初始化 URL 为代理地址 + + // 添加服务注销路径,包含服务 ID 和命名空间参数 + url += "/v1/agent/service/deregister/" + this->service.service_id; + url += "?ns=" + this->service.service_namespace; + + WFHttpTask *task = WFTaskFactory::create_http_task( // 创建 HTTP 任务 + url, 0, this->retry_max, register_callback); + + HttpRequest *req = task->get_req(); // 获取 HTTP 请求对象 + + req->set_method(HttpMethodPut); // 设置 HTTP 方法为 PUT + req->add_header_pair("Content-Type", "application/json"); // 添加内容类型头 + + // 如果存在鉴权 token,则添加到请求头 + std::string token = this->config.get_token(); + if (!token.empty()) + req->add_header_pair("X-Consul-Token", token); + + task->user_data = this; // 设置用户数据为当前任务对象 + return task; // 返回创建的任务 +} + + +// 检查任务执行结果 +bool WFConsulTask::check_task_result(WFHttpTask *task, WFConsulTask *consul_task) +{ + // 如果任务状态不是成功状态 + if (task->get_state() != WFT_STATE_SUCCESS) + { + consul_task->state = task->get_state(); // 设置任务状态 + consul_task->error = task->get_error(); // 设置错误信息 + return false; // 返回失败 + } + + // 获取任务的 HTTP 响应 + protocol::HttpResponse *resp = task->get_resp(); + + // 检查 HTTP 响应状态码是否为 "200" + if (strcmp(resp->get_status_code(), "200") != 0) + { + consul_task->state = WFT_STATE_TASK_ERROR; // 设置任务错误状态 + consul_task->error = WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED; // 设置响应失败错误 + return false; // 返回失败 + } + + return true; // 返回成功 +} + + +// 从 HTTP 响应头中获取 Consul 索引 +long long WFConsulTask::get_consul_index(HttpResponse *resp) +{ + long long consul_index = 0; // 初始化 Consul 索引为 0 + + // 创建 HTTP 头游标以遍历响应头 + protocol::HttpHeaderCursor cursor(resp); + std::string consul_index_str; // 用于存储 Consul 索引字符串 + + // 查找 "X-Consul-Index" 响应头 + if (cursor.find("X-Consul-Index", consul_index_str)) + { + // 将字符串转换为长整型 + consul_index = strtoll(consul_index_str.c_str(), NULL, 10); + if (consul_index < 0) // 如果索引小于 0,则重置为 0 + consul_index = 0; + } + + return consul_index; // 返回 Consul 索引 +} + + +// 服务发现任务的回调函数 +void WFConsulTask::discover_callback(WFHttpTask *task) +{ + WFConsulTask *t = (WFConsulTask*)task->user_data; // 获取任务的用户数据 + + // 检查任务执行结果 + if (WFConsulTask::check_task_result(task, t)) + { + protocol::HttpResponse *resp = task->get_resp(); // 获取 HTTP 响应 + long long consul_index = t->get_consul_index(resp); // 获取当前 Consul 索引 + long long last_consul_index = t->get_consul_index(); // 获取上一次的 Consul 索引 + + // 更新 Consul 索引,如果当前索引小于上次索引,则重置为 0 + t->set_consul_index(consul_index < last_consul_index ? 0 : consul_index); + t->state = task->get_state(); // 设置任务状态 + } + + // 将 HTTP 响应转移到任务对象 + t->http_resp = std::move(*task->get_resp()); + t->finish = true; // 标记任务完成 +} + + +// 服务列表任务的回调函数 +void WFConsulTask::list_service_callback(WFHttpTask *task) +{ + WFConsulTask *t = (WFConsulTask*)task->user_data; // 获取任务的用户数据 + + // 检查任务执行结果 + if (WFConsulTask::check_task_result(task, t)) + t->state = task->get_state(); // 设置任务状态 + + // 将 HTTP 响应转移到任务对象 + t->http_resp = std::move(*task->get_resp()); + t->finish = true; // 标记任务完成 +} + + +// 服务注册任务的回调函数 +void WFConsulTask::register_callback(WFHttpTask *task) +{ + WFConsulTask *t = (WFConsulTask *)task->user_data; // 获取任务的用户数据 + + // 检查任务执行结果 + if (WFConsulTask::check_task_result(task, t)) + t->state = task->get_state(); // 设置任务状态 + + // 将 HTTP 响应转移到任务对象 + t->http_resp = std::move(*task->get_resp()); + t->finish = true; // 标记任务完成 +} + + +// 初始化 Consul 客户端 +int WFConsulClient::init(const std::string& proxy_url, ConsulConfig config) +{ + ParsedURI uri; // 定义用于存储解析后的 URI 的对象 + + // 解析代理 URL + if (URIParser::parse(proxy_url, uri) >= 0) + { + // 构建完整的代理 URL + this->proxy_url = uri.scheme; + this->proxy_url += "://"; + this->proxy_url += uri.host; + if (uri.port) // 如果指定了端口,则追加端口信息 + { + this->proxy_url += ":"; + this->proxy_url += uri.port; + } + + this->config = std::move(config); // 将配置赋值给客户端 + return 0; // 返回成功 + } + else if (uri.state == URI_STATE_INVALID) // 如果 URI 无效 + errno = EINVAL; // 设置错误码为无效参数 + + return -1; // 返回失败 +} + + +// 初始化 Consul 客户端(使用默认配置) +int WFConsulClient::init(const std::string& proxy_url) +{ + return this->init(proxy_url, ConsulConfig()); // 调用带配置参数的 init 函数 +} + + +// 创建服务发现任务 +WFConsulTask *WFConsulClient::create_discover_task( + const std::string& service_namespace, + const std::string& service_name, + int retry_max, + consul_callback_t cb) +{ + // 创建一个新的 Consul 任务 + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + service_name, "", retry_max, + std::move(cb)); + + task->set_api_type(CONSUL_API_TYPE_DISCOVER); // 设置任务类型为服务发现 + task->set_config(this->config); // 设置任务的配置信息 + return task; // 返回创建的任务 +} + + +// 创建服务列表任务 +WFConsulTask *WFConsulClient::create_list_service_task( + const std::string& service_namespace, + int retry_max, + consul_callback_t cb) +{ + // 创建一个新的 Consul 任务 + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + "", "", retry_max, + std::move(cb)); + + task->set_api_type(CONSUL_API_TYPE_LIST_SERVICE); // 设置任务类型为服务列表 + task->set_config(this->config); // 设置任务的配置信息 + return task; // 返回创建的任务 +} + + +// 创建服务注册任务 +WFConsulTask *WFConsulClient::create_register_task( + const std::string& service_namespace, + const std::string& service_name, + const std::string& service_id, + int retry_max, + consul_callback_t cb) +{ + // 创建一个新的 Consul 任务 + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + service_name, service_id, retry_max, + std::move(cb)); + + task->set_api_type(CONSUL_API_TYPE_REGISTER); // 设置任务类型为服务注册 + task->set_config(this->config); // 设置任务的配置信息 + return task; // 返回创建的任务 +} + + +// 创建服务注销任务 +WFConsulTask *WFConsulClient::create_deregister_task( + const std::string& service_namespace, + const std::string& service_id, + int retry_max, + consul_callback_t cb) +{ + // 创建一个新的 Consul 任务 + WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, + "", service_id, retry_max, + std::move(cb)); + + task->set_api_type(CONSUL_API_TYPE_DEREGISTER); // 设置任务类型为服务注销 + task->set_config(this->config); // 设置任务的配置信息 + return task; // 返回创建的任务 +} + + +// 创建带标签的地址对象 +static bool create_tagged_address(const ConsulAddress& consul_address, + const std::string& name, + json_object_t *tagged_obj) +{ + // 如果地址为空,直接返回成功 + if (consul_address.first.empty()) + return true; + + // 向 JSON 对象中添加名为 name 的嵌套对象 + const json_value_t *val = json_object_append(tagged_obj, name.c_str(), + JSON_VALUE_OBJECT); + if (!val) // 如果添加失败,返回 false + return false; + + // 获取新创建的嵌套对象 + json_object_t *obj = json_value_object(val); + if (!obj) // 如果对象获取失败,返回 false + return false; + + // 向嵌套对象中添加 "Address" 字段 + if (!json_object_append(obj, "Address", JSON_VALUE_STRING, + consul_address.first.c_str())) + return false; + + // 向嵌套对象中添加 "Port" 字段 + if (!json_object_append(obj, "Port", JSON_VALUE_NUMBER, + (double)consul_address.second)) + return false; + + return true; // 地址对象创建成功 +} + + +// 创建健康检查配置 +static bool create_health_check(const ConsulConfig& config, json_object_t *obj) +{ + const json_value_t *val; // 存储添加字段的返回值 + std::string str; // 临时字符串存储变量 + + // 如果未启用健康检查,直接返回成功 + if (!config.get_health_check()) + return true; + + // 在 JSON 对象中添加 "Check" 字段 + val = json_object_append(obj, "Check", JSON_VALUE_OBJECT); + if (!val) // 如果添加失败,返回 false + return false; + + obj = json_value_object(val); // 获取 "Check" 对象 + if (!obj) // 如果对象获取失败,返回 false + return false; + + // 添加健康检查的名称 + str = config.get_check_name(); + if (!json_object_append(obj, "Name", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加健康检查的备注 + str = config.get_check_notes(); + if (!json_object_append(obj, "Notes", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加健康检查的 HTTP URL + str = config.get_check_http_url(); + if (!str.empty()) // 如果 URL 不为空 + { + if (!json_object_append(obj, "HTTP", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加 HTTP 方法 + str = config.get_check_http_method(); + if (!json_object_append(obj, "Method", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加 HTTP 请求体 + str = config.get_http_body(); + if (!json_object_append(obj, "Body", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加 HTTP 头部信息 + val = json_object_append(obj, "Header", JSON_VALUE_OBJECT); + if (!val) + return false; + + json_object_t *header_obj = json_value_object(val); + if (!header_obj) + return false; + + // 遍历并添加所有 HTTP 头部字段 + for (const auto& header : *config.get_http_headers()) + { + val = json_object_append(header_obj, header.first.c_str(), + JSON_VALUE_ARRAY); + if (!val) + return false; + + json_array_t *arr = json_value_array(val); + if (!arr) + return false; + + // 添加头部字段对应的多个值 + for (const auto& value : header.second) + { + if (!json_array_append(arr, JSON_VALUE_STRING, value.c_str())) + return false; + } + } + } + + // 添加 TCP 检查地址 + str = config.get_check_tcp(); + if (!str.empty()) + { + if (!json_object_append(obj, "TCP", JSON_VALUE_STRING, str.c_str())) + return false; + } + + // 添加初始状态 + str = config.get_initial_status(); + if (!json_object_append(obj, "Status", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加自动注销时间 + str = convert_time_to_str(config.get_auto_deregister_time()); + if (!json_object_append(obj, "DeregisterCriticalServiceAfter", + JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加检查间隔时间 + str = convert_time_to_str(config.get_check_interval()); + if (!json_object_append(obj, "Interval", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加检查超时时间 + str = convert_time_to_str(config.get_check_timeout()); + if (!json_object_append(obj, "Timeout", JSON_VALUE_STRING, str.c_str())) + return false; + + // 添加连续成功次数限制 + if (!json_object_append(obj, "SuccessBeforePassing", JSON_VALUE_NUMBER, + (double)config.get_success_times())) + return false; + + // 添加连续失败次数限制 + if (!json_object_append(obj, "FailuresBeforeCritical", JSON_VALUE_NUMBER, + (double)config.get_failure_times())) + return false; + + return true; // 健康检查配置创建成功 +} + + +// 创建服务注册请求 +static bool create_register_request(const json_value_t *root, + const struct ConsulService *service, + const ConsulConfig& config) +{ + const json_value_t *val; // 临时存储 JSON 值 + json_object_t *obj; // 临时存储 JSON 对象 + + obj = json_value_object(root); // 获取根对象 + if (!obj) + return false; + + // 添加服务 ID + if (!json_object_append(obj, "ID", JSON_VALUE_STRING, + service->service_id.c_str())) + return false; + + // 添加服务名称 + if (!json_object_append(obj, "Name", JSON_VALUE_STRING, + service->service_name.c_str())) + return false; + + // 如果命名空间非空,则添加命名空间 + if (!service->service_namespace.empty()) + { + if (!json_object_append(obj, "ns", JSON_VALUE_STRING, + service->service_namespace.c_str())) + return false; + } + + // 添加服务标签数组 + val = json_object_append(obj, "Tags", JSON_VALUE_ARRAY); + if (!val) + return false; + + json_array_t *arr = json_value_array(val); + if (!arr) + return false; + + // 遍历并添加所有标签 + for (const auto& tag : service->tags) + { + if (!json_array_append(arr, JSON_VALUE_STRING, tag.c_str())) + return false; + } + + // 添加服务地址 + if (!json_object_append(obj, "Address", JSON_VALUE_STRING, + service->service_address.first.c_str())) + return false; + + // 添加服务端口 + if (!json_object_append(obj, "Port", JSON_VALUE_NUMBER, + (double)service->service_address.second)) + return false; + + // 添加元数据对象 + val = json_object_append(obj, "Meta", JSON_VALUE_OBJECT); + if (!val) + return false; + + json_object_t *meta_obj = json_value_object(val); + if (!meta_obj) + return false; + + // 遍历并添加所有元数据 + for (const auto& meta_kv : service->meta) + { + if (!json_object_append(meta_obj, meta_kv.first.c_str(), + JSON_VALUE_STRING, meta_kv.second.c_str())) + return false; + } + + // 添加标签覆盖选项 + int type = service->tag_override ? JSON_VALUE_TRUE : JSON_VALUE_FALSE; + if (!json_object_append(obj, "EnableTagOverride", type)) + return false; + + // 添加带标签的地址 + val = json_object_append(obj, "TaggedAddresses", JSON_VALUE_OBJECT); + if (!val) + return false; + + json_object_t *tagged_obj = json_value_object(val); + if (!tagged_obj) + return false; + + // 创建所有类型的带标签地址 + if (!create_tagged_address(service->lan, "lan", tagged_obj)) + return false; + + if (!create_tagged_address(service->lan_ipv4, "lan_ipv4", tagged_obj)) + return false; + + if (!create_tagged_address(service->lan_ipv6, "lan_ipv6", tagged_obj)) + return false; + + if (!create_tagged_address(service->virtual_address, "virtual", tagged_obj)) + return false; + + if (!create_tagged_address(service->wan, "wan", tagged_obj)) + return false; + + if (!create_tagged_address(service->wan_ipv4, "wan_ipv4", tagged_obj)) + return false; + + if (!create_tagged_address(service->wan_ipv6, "wan_ipv6", tagged_obj)) + return false; + + // 创建健康检查配置 + if (!create_health_check(config, obj)) + return false; + + return true; // 服务注册请求创建成功 +} + + +// 解析服务列表结果 +static bool parse_list_service_result(const json_value_t *root, + std::vector& result) +{ + const json_object_t *obj; // 存储 JSON 对象 + const json_value_t *val; // 存储 JSON 值 + const json_array_t *arr; // 存储 JSON 数组 + const char *key; // 存储键 + const char *str; // 存储字符串值 + + // 将根值转换为 JSON 对象 + obj = json_value_object(root); + if (!obj) // 如果转换失败,返回 false + return false; + + // 遍历 JSON 对象中的每个键值对 + json_object_for_each(key, val, obj) + { + struct ConsulServiceTags instance; // 定义一个服务标签实例 + + instance.service_name = key; // 将键作为服务名称 + arr = json_value_array(val); // 获取值作为数组 + if (!arr) // 如果值不是数组,返回 false + return false; + + const json_value_t *tag_val; // 遍历数组的每个值 + json_array_for_each(tag_val, arr) + { + str = json_value_string(tag_val); // 获取值的字符串形式 + if (!str) // 如果值不是字符串,返回 false + return false; + + instance.tags.emplace_back(str); // 将标签添加到实例 + } + + result.emplace_back(std::move(instance)); // 将实例添加到结果列表 + } + + return true; // 返回成功 +} + + +// 解析发现节点信息 +static bool parse_discover_node(const json_object_t *obj, + struct ConsulServiceInstance *instance) +{ + const json_value_t *val; // 存储 JSON 值 + const char *str; // 存储字符串值 + + // 查找并解析 "Node" 信息 + val = json_object_find("Node", obj); + if (!val) + return false; + + obj = json_value_object(val); // 获取 "Node" 对象 + if (!obj) + return false; + + // 获取 "ID" 字段 + val = json_object_find("ID", obj); + if (!val) + return false; + + str = json_value_string(val); // 转换为字符串 + if (!str) + return false; + + instance->node_id = str; // 设置节点 ID + + // 获取 "Node" 字段 + val = json_object_find("Node", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->node_name = str; // 设置节点名称 + + // 获取 "Address" 字段 + val = json_object_find("Address", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->node_address = str; // 设置节点地址 + + // 获取 "Datacenter" 字段 + val = json_object_find("Datacenter", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->dc = str; // 设置数据中心 + + // 解析 "Meta" 信息 + val = json_object_find("Meta", obj); + if (!val) + return false; + + const json_object_t *meta_obj = json_value_object(val); // 获取 "Meta" 对象 + if (!meta_obj) + return false; + + const char *meta_k; // 存储元数据键 + const json_value_t *meta_v; // 存储元数据值 + + // 遍历 "Meta" 对象 + json_object_for_each(meta_k, meta_v, meta_obj) + { + str = json_value_string(meta_v); // 获取元数据值 + if (!str) + return false; + + instance->node_meta[meta_k] = str; // 添加元数据键值对 + } + + // 获取 "CreateIndex" 字段 + val = json_object_find("CreateIndex", obj); + if (val) + instance->create_index = json_value_number(val); // 设置创建索引 + + // 获取 "ModifyIndex" 字段 + val = json_object_find("ModifyIndex", obj); + if (val) + instance->modify_index = json_value_number(val); // 设置修改索引 + + return true; // 返回成功 +} + + +// 解析服务发现结果 +static bool parse_discover_result(const json_value_t *root, + std::vector& result) +{ + const json_array_t *arr = json_value_array(root); // 获取 JSON 数组 + const json_value_t *val; // 存储数组值 + const json_object_t *obj; // 存储 JSON 对象 + + if (!arr) // 如果不是数组,返回 false + return false; + + // 遍历数组 + json_array_for_each(val, arr) + { + struct ConsulServiceInstance instance; // 创建服务实例 + + obj = json_value_object(val); // 获取数组中的对象 + if (!obj) + return false; + + // 解析节点信息 + if (!parse_discover_node(obj, &instance)) + return false; + + // 解析服务信息 + if (!parse_service(obj, &instance.service)) + return false; + + // 解析健康检查 + parse_health_check(obj, &instance); + + result.emplace_back(std::move(instance)); // 添加实例到结果列表 + } + + return true; // 返回成功 +} + + +// 解析服务信息 +static bool parse_service(const json_object_t *obj, + struct ConsulService *service) +{ + const json_value_t *val; // 用于存储 JSON 值 + const char *str; // 用于存储字符串值 + + // 获取 "Service" 对象 + val = json_object_find("Service", obj); + if (!val) + return false; + + obj = json_value_object(val); // 将值转换为对象 + if (!obj) + return false; + + // 解析服务 ID + val = json_object_find("ID", obj); + if (!val) + return false; + + str = json_value_string(val); // 将值转换为字符串 + if (!str) + return false; + + service->service_id = str; // 设置服务 ID + + // 解析服务名称 + val = json_object_find("Service", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + service->service_name = str; // 设置服务名称 + + // 解析命名空间 + val = json_object_find("Namespace", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + service->service_namespace = str; // 设置命名空间 + } + + // 解析服务地址 + val = json_object_find("Address", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + service->service_address.first = str; // 设置服务地址 + + // 解析服务端口 + val = json_object_find("Port", obj); + if (!val) + return false; + + service->service_address.second = json_value_number(val); // 设置服务端口 + + // 解析带标签地址 + val = json_object_find("TaggedAddresses", obj); + if (!val) + return false; + + parse_tagged_address("lan", val, service->lan); + parse_tagged_address("lan_ipv4", val, service->lan_ipv4); + parse_tagged_address("lan_ipv6", val, service->lan_ipv6); + parse_tagged_address("virtual", val, service->virtual_address); + parse_tagged_address("wan", val, service->wan); + parse_tagged_address("wan_ipv4", val, service->wan_ipv4); + parse_tagged_address("wan_ipv6", val, service->wan_ipv6); + + // 解析标签列表 + val = json_object_find("Tags", obj); + if (!val) + return false; + + const json_array_t *tags_arr = json_value_array(val); // 获取数组 + if (tags_arr) + { + const json_value_t *tags_value; + json_array_for_each(tags_value, tags_arr) // 遍历数组 + { + str = json_value_string(tags_value); // 获取标签字符串 + if (!str) + return false; + + service->tags.emplace_back(str); // 添加标签到服务 + } + } + + // 解析元数据 + val = json_object_find("Meta", obj); + if (!val) + return false; + + const json_object_t *meta_obj = json_value_object(val); // 获取元数据对象 + if (!meta_obj) + return false; + + const char *meta_k; + const json_value_t *meta_v; + json_object_for_each(meta_k, meta_v, meta_obj) // 遍历元数据 + { + str = json_value_string(meta_v); + if (!str) + return false; + + service->meta[meta_k] = str; // 添加元数据键值对 + } + + // 检查标签覆盖选项 + val = json_object_find("EnableTagOverride", obj); + if (val) + service->tag_override = (json_value_type(val) == JSON_VALUE_TRUE); + + return true; // 返回成功 +} + + +// 解析健康检查信息 +static bool parse_health_check(const json_object_t *obj, + struct ConsulServiceInstance *instance) +{ + const json_value_t *val; // 存储 JSON 值 + const char *str; // 存储字符串值 + + // 获取 "Checks" 数组 + val = json_object_find("Checks", obj); + if (!val) + return false; + + const json_array_t *check_arr = json_value_array(val); // 转换为数组 + if (!check_arr) + return false; + + const json_value_t *arr_val; + json_array_for_each(arr_val, check_arr) // 遍历数组 + { + obj = json_value_object(arr_val); // 获取数组中的对象 + if (!obj) + return false; + + // 解析服务名称 + val = json_object_find("ServiceName", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + std::string check_service_name = str; + + // 解析服务 ID + val = json_object_find("ServiceID", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + std::string check_service_id = str; + if (check_service_id.empty() || check_service_name.empty()) + continue; + + // 解析检查 ID + val = json_object_find("CheckID", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->check_id = str; + + // 解析检查名称 + val = json_object_find("Name", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->check_name = str; + + // 解析检查状态 + val = json_object_find("Status", obj); + if (!val) + return false; + + str = json_value_string(val); + if (!str) + return false; + + instance->check_status = str; + + // 解析备注 + val = json_object_find("Notes", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + instance->check_notes = str; + } + + // 解析输出 + val = json_object_find("Output", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + instance->check_output = str; + } + + // 解析检查类型 + val = json_object_find("Type", obj); + if (val) + { + str = json_value_string(val); + if (!str) + return false; + + instance->check_type = str; + } + + break; // 只解析一个有效的健康检查 + } + + return true; // 返回成功 +} + +// 解析服务发现结果 +static bool parse_discover_result(const json_value_t *root, + std::vector& result) +{ + const json_array_t *arr = json_value_array(root); // 转换为数组 + const json_value_t *val; // 存储数组中的值 + const json_object_t *obj; // 存储 JSON 对象 + + if (!arr) + return false; + + // 遍历数组 + json_array_for_each(val, arr) + { + struct ConsulServiceInstance instance; // 定义服务实例 + + obj = json_value_object(val); // 获取数组中的对象 + if (!obj) + return false; + + // 解析节点信息 + if (!parse_discover_node(obj, &instance)) + return false; + + // 解析服务信息 + if (!parse_service(obj, &instance.service)) + return false; + + // 解析健康检查 + parse_health_check(obj, &instance); + + result.emplace_back(std::move(instance)); // 添加实例到结果 + } + + return true; // 返回成功 +} + + +// 打印 JSON 对象,递归地将其转换为字符串格式 +static void print_json_object(const json_object_t *obj, int depth, + std::string& json_str) +{ + const char *name; // JSON 对象的键 + const json_value_t *val; // JSON 对象的值 + int n = 0; // 键值对的计数器 + int i; + + json_str += "{\n"; // 开始对象 + + // 遍历 JSON 对象中的每个键值对 + json_object_for_each(name, val, obj) + { + if (n != 0) // 如果不是第一个键值对,添加逗号换行 + json_str += ",\n"; + + n++; // 键值对计数加 1 + for (i = 0; i < depth + 1; i++) // 根据深度添加缩进 + json_str += " "; + + // 添加键 + json_str += "\""; + json_str += name; + json_str += "\": "; + + // 递归打印值 + print_json_value(val, depth + 1, json_str); + } + + json_str += "\n"; // 对象内容结束 + for (i = 0; i < depth; i++) // 根据深度添加缩进 + json_str += " "; + + json_str += "}"; // 结束对象 +} + + +// 打印 JSON 数组,递归地将其转换为字符串格式 +static void print_json_array(const json_array_t *arr, int depth, + std::string& json_str) +{ + const json_value_t *val; // JSON 数组的值 + int n = 0; // 数组值的计数器 + int i; + + json_str += "[\n"; // 开始数组 + + // 遍历 JSON 数组中的每个值 + json_array_for_each(val, arr) + { + if (n != 0) // 如果不是第一个值,添加逗号换行 + json_str += ",\n"; + + n++; // 值计数加 1 + for (i = 0; i < depth + 1; i++) // 根据深度添加缩进 + json_str += " "; + + // 递归打印值 + print_json_value(val, depth + 1, json_str); + } + + json_str += "\n"; // 数组内容结束 + for (i = 0; i < depth; i++) // 根据深度添加缩进 + json_str += " "; + + json_str += "]"; // 结束数组 +} + + +// 打印 JSON 字符串,处理特殊字符转义 +static void print_json_string(const char *str, std::string& json_str) +{ + json_str += "\""; // 开始字符串 + + // 遍历字符串中的每个字符 + while (*str) + { + switch (*str) + { + case '\r': // 回车符转义 + json_str += "\\r"; + break; + case '\n': // 换行符转义 + json_str += "\\n"; + break; + case '\f': // 换页符转义 + json_str += "\\f"; + break; + case '\b': // 退格符转义 + json_str += "\\b"; + break; + case '\"': // 双引号转义 + json_str += "\\\""; + break; + case '\t': // 制表符转义 + json_str += "\\t"; + break; + case '\\': // 反斜杠转义 + json_str += "\\\\"; + break; + default: // 其他字符直接添加 + json_str += *str; + break; + } + str++; // 指向下一个字符 + } + json_str += "\""; // 结束字符串 +} + + +// 打印 JSON 数字 +static void print_json_number(double number, std::string& json_str) +{ + long long integer = number; // 检查数字是否是整数 + + if (integer == number) // 如果是整数 + json_str += std::to_string(integer); // 转换为整数字符串 + else // 如果是浮点数 + json_str += std::to_string(number); // 转换为浮点数字符串 +} + + +// 根据值的类型打印 JSON 值 +static void print_json_value(const json_value_t *val, int depth, + std::string& json_str) +{ + switch (json_value_type(val)) // 获取值的类型 + { + case JSON_VALUE_STRING: // 如果是字符串 + print_json_string(json_value_string(val), json_str); + break; + case JSON_VALUE_NUMBER: // 如果是数字 + print_json_number(json_value_number(val), json_str); + break; + case JSON_VALUE_OBJECT: // 如果是对象 + print_json_object(json_value_object(val), depth, json_str); + break; + case JSON_VALUE_ARRAY: // 如果是数组 + print_json_array(json_value_array(val), depth, json_str); + break; + case JSON_VALUE_TRUE: // 如果是布尔值 true + json_str += "true"; + break; + case JSON_VALUE_FALSE: // 如果是布尔值 false + json_str += "false"; + break; + case JSON_VALUE_NULL: // 如果是 null + json_str += "null"; + break; + } +} + diff --git a/WFDnsClient.cc b/WFDnsClient.cc new file mode 100644 index 0000000..b74afca --- /dev/null +++ b/WFDnsClient.cc @@ -0,0 +1,308 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include "URIParser.h" +#include "StringUtil.h" +#include "WFDnsClient.h" + +using namespace protocol; + +using DnsCtx = std::function; +using ComplexTask = WFComplexClientTask; + +// 定义一个DnsParams类,用于存储DNS查询参数 +class DnsParams +{ +public: + // 定义一个内部结构体dns_params,用于存储具体的DNS参数 + struct dns_params + { + std::vector uris; // 存储解析后的URI + std::vector search_list; // 搜索列表 + int ndots; // 域名中允许的最小点数 + int attempts; // 最大尝试次数 + bool rotate; // 是否轮询 + }; + +public: + // DnsParams类的构造函数 + DnsParams() + { + this->ref = new std::atomic(1); // 初始化引用计数为1 + this->params = new dns_params(); // 初始化参数结构体 + } + + // DnsParams类的拷贝构造函数 + DnsParams(const DnsParams& p) + { + this->ref = p.ref; // 拷贝引用计数指针 + this->params = p.params; // 拷贝参数结构体指针 + this->incref(); // 增加引用计数 + } + + // DnsParams类的赋值运算符 + DnsParams& operator=(const DnsParams& p) + { + if (this != &p) // 如果不是自赋值 + { + this->decref(); // 减少当前对象的引用计数 + this->ref = p.ref; // 拷贝引用计数指针 + this->params = p.params; // 拷贝参数结构体指针 + this->incref(); // 增加引用计数 + } + return *this; // 返回当前对象的引用 + } + + // DnsParams类的析构函数 + ~DnsParams() { this->decref(); } // 减少引用计数 + + // 获取const类型的参数指针 + const dns_params *get_params() const { return this->params; } + // 获取非const类型的参数指针 + dns_params *get_params() { return this->params; } + +private: + // 增加引用计数 + void incref() { (*this->ref)++; } + // 减少引用计数,并在计数为0时释放资源 + void decref() + { + if (--*this->ref == 0) + { + delete this->params; // 删除参数结构体 + delete this->ref; // 删除引用计数 + } + } + +private: + dns_params *params; // 指向参数结构体的指针 + std::atomic *ref; // 指向引用计数的指针 +}; + +// 定义DNS状态枚举 +enum +{ + DNS_STATUS_TRY_ORIGIN_DONE = 0, + DNS_STATUS_TRY_ORIGIN_FIRST = 1, + DNS_STATUS_TRY_ORIGIN_LAST = 2 +}; + +// 定义DnsStatus结构体,用于存储DNS查询状态 +struct DnsStatus +{ + std::string origin_name; // 原始域名 + std::string current_name; // 当前域名 + size_t next_server; // 下一个要尝试的服务器 + size_t last_server; // 上一个尝试的服务器 + size_t next_domain; // 下一个要尝试的搜索域 + int attempts_left; // 剩余尝试次数 + int try_origin_state; // 尝试原始域名的状态 +}; + +// 计算字符串中点的数量 +static int __get_ndots(const std::string& s) +{ + int ndots = 0; + for (size_t i = 0; i < s.size(); i++) + ndots += s[i] == '.'; // 统计点的数量 + return ndots; +} + +// 检查是否有下一个域名要尝试 +static bool __has_next_name(const DnsParams::dns_params *p, + struct DnsStatus *s) +{ + if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_FIRST) + { + s->current_name = s->origin_name; // 设置当前域名为原始域名 + s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; // 更新状态 + return true; + } + + if (s->next_domain < p->search_list.size()) // 如果还有搜索域要尝试 + { + s->current_name = s->origin_name; // 设置当前域名为原始域名 + s->current_name.push_back('.'); // 添加点 + s->current_name.append(p->search_list[s->next_domain]); // 添加搜索域 + + s->next_domain++; // 移动到下一个搜索域 + return true; + } + + if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_LAST) + { + s->current_name = s->origin_name; // 设置当前域名为原始域名 + s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; // 更新状态 + return true; + } + + return false; // 没有下一个域名要尝试 +} + +// DNS查询回调内部函数 +static void __callback_internal(WFDnsTask *task, const DnsParams& params, + struct DnsStatus& s) +{ + ComplexTask *ctask = static_cast(task); // 转换任务类型 + int state = task->get_state(); // 获取任务状态 + DnsRequest *req = task->get_req(); // 获取DNS请求 + DnsResponse *resp = task->get_resp(); // 获取DNS响应 + const auto *p = params.get_params(); // 获取DNS参数 + int rcode = resp->get_rcode(); // 获取响应码 + + bool try_next_server = state != WFT_STATE_SUCCESS || // 如果状态不是成功 + rcode == DNS_RCODE_SERVER_FAILURE || // 或者响应码是服务器失败 + rcode == DNS_RCODE_NOT_IMPLEMENTED || // 或者响应码是未实现 + rcode == DNS_RCODE_REFUSED; // 或者响应码是拒绝 + bool try_next_name = rcode == DNS_RCODE_FORMAT_ERROR || // 如果响应码是格式错误 + rcode == DNS_RCODE_NAME_ERROR || // 或者响应码是名字错误 + resp->get_ancount() == 0; // 或者响应中没有答案 + + if (try_next_server) + { + if (s.last_server == s.next_server) // 如果已经是最后一个服务器 + s.attempts_left--; // 减少尝试次数 + if (s.attempts_left <= 0) // 如果尝试次数用完 + return; // 返回 + + s.next_server = (s.next_server + 1) % p->uris.size(); // 计算下一个服务器 + ctask->set_redirect(p->uris[s.next_server]); // 设置重定向 + return; // 返回 + } + + if (try_next_name && __has_next_name(p, &s)) // 如果需要尝试下一个名字 + { + req->set_question_name(s.current_name.c_str()); // 设置查询名字 + ctask->set_redirect(p->uris[s.next_server]); // 设置重定向 + return; // 返回 + } +} + +// WFDnsClient类的初始化函数,带一个URL参数 +int WFDnsClient::init(const std::string& url) +{ + return this->init(url, "", 1, 2, false); // 调用另一个初始化函数 +} + +// WFDnsClient类的初始化函数,用于配置DNS客户端 +int WFDnsClient::init(const std::string& url, const std::string& search_list, + int ndots, int attempts, bool rotate) +{ + std::vector hosts; // 用于存储分割后的主机名 + std::vector uris; // 用于存储解析后的URI + std::string host; // 单个主机名字符串 + ParsedURI uri; // 用于存储解析后的单个URI + + this->id = 0; // 初始化客户端ID为0 + hosts = StringUtil::split_filter_empty(url, ','); // 根据逗号分割URL字符串,获取主机名列表 + + // 遍历主机名列表,对每个主机名进行处理 + for (size_t i = 0; i < hosts.size(); i++) + { + host = hosts[i]; // 获取当前主机名 + // 检查主机名是否以"dns://"或"dnss://"开头,如果不是,则添加"dns://"前缀 + if (strncasecmp(host.c_str(), "dns://", 6) != 0 && + strncasecmp(host.c_str(), "dnss://", 7) != 0) + { + host = "dns://" + host; + } + + // 使用URIParser解析当前主机名,如果解析失败则返回错误码-1 + if (URIParser::parse(host, uri) != 0) + return -1; + + // 将解析后的URI添加到uris列表中 + uris.emplace_back(std::move(uri)); + } + + // 如果uris列表为空,或者ndots小于0,或者attempts小于1,则设置errno为EINVAL并返回错误码-1 + if (uris.empty() || ndots < 0 || attempts < 1) + { + errno = EINVAL; + return -1; + } + + // 创建一个新的DnsParams对象来存储DNS参数 + this->params = new DnsParams; + DnsParams::dns_params *q = ((DnsParams *)this->params)->get_params(); // 获取DNS参数的指针 + q->uris = std::move(uris); // 将解析后的URI列表移动到DNS参数中 + q->search_list = StringUtil::split_filter_empty(search_list, ','); // 根据逗号分割search_list字符串,获取搜索域列表 + // 设置ndots的值,如果大于15则设置为15 + q->ndots = ndots > 15 ? 15 : ndots; + // 设置attempts的值,如果大于5则设置为5 + q->attempts = attempts > 5 ? 5 : attempts; + q->rotate = rotate; // 设置是否轮询 + + return 0; // 初始化成功,返回0 +} + +// WFDnsClient类的析构函数,用于释放资源 +void WFDnsClient::deinit() +{ + delete (DnsParams *)this->params; // 删除分配的DnsParams对象 + this->params = NULL; // 将params指针设置为NULL +} + +// WFDnsClient类的方法,用于创建一个新的DNS任务 +WFDnsTask *WFDnsClient::create_dns_task(const std::string& name, + dns_callback_t callback) +{ + DnsParams::dns_params *p = ((DnsParams *)this->params)->get_params(); // 获取DNS参数 + struct DnsStatus status; // 创建DNS状态结构体 + size_t next_server; // 下一个要尝试的服务器索引 + WFDnsTask *task; // 指向新创建的DNS任务的指针 + DnsRequest *req; // 指向DNS请求的指针 + + // 如果启用轮询,则计算下一个服务器索引,否则使用第一个服务器 + next_server = p->rotate ? this->id++ % p->uris.size() : 0; + + status.origin_name = name; // 设置原始域名 + status.next_domain = 0; // 设置下一个要尝试的搜索域索引 + status.attempts_left = p->attempts; // 设置剩余尝试次数 + status.try_origin_state = DNS_STATUS_TRY_ORIGIN_FIRST; // 设置尝试原始域名的状态 + + // 如果域名以点结尾,则跳过搜索域 + if (!name.empty() && name.back() == '.') + status.next_domain = p->search_list.size(); + // 如果域名中的点数少于ndots,则设置尝试原始域名的状态为最后尝试 + else if (__get_ndots(name) < p->ndots) + status.try_origin_state = DNS_STATUS_TRY_ORIGIN_LAST; + + // 检查是否有下一个域名要尝试,并更新状态 + __has_next_name(p, &status); + + // 创建一个新的DNS任务,使用下一个服务器的URI和提供的回调函数 + task = WFTaskFactory::create_dns_task(p->uris[next_server], 0, std::move(callback)); + status.next_server = next_server; // 设置当前服务器索引 + status.last_server = (next_server + p->uris.size() - 1) % p->uris.size(); // 设置最后一个服务器索引 + + req = task->get_req(); // 获取DNS请求对象 + req->set_question(status.current_name.c_str(), DNS_TYPE_A, DNS_CLASS_IN); // 设置DNS请求的问题部分 + req->set_rd(1); // 设置DNS请求的递归标志 + + ComplexTask *ctask = static_cast(task); // 将任务转换为ComplexTask类型 + // 设置任务的上下文回调,当DNS任务完成时,将调用__callback_internal函数 + *ctask->get_mutable_ctx() = std::bind(__callback_internal, + std::placeholders::_1, + *(DnsParams *)params, status); + + return task; // 返回新创建的DNS任务 +} diff --git a/dns_parser.c b/dns_parser.c new file mode 100644 index 0000000..78022ed --- /dev/null +++ b/dns_parser.c @@ -0,0 +1,1206 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#include +#include +#include +#include "dns_parser.h" + +#define DNS_LABELS_MAX 63 +#define DNS_NAMES_MAX 256 +#define DNS_MSGBASE_INIT_SIZE 514 // 512 + 2(leading length) +#define MAX(x, y) ((x) <= (y) ? (y) : (x)) + +struct __dns_record_entry +{ + struct list_head entry_list; + struct dns_record record; +}; + + +static inline uint8_t __dns_parser_uint8(const char *ptr) +{ + return (unsigned char)ptr[0]; +} + +static inline uint16_t __dns_parser_uint16(const char *ptr) +{ + const unsigned char *p = (const unsigned char *)ptr; + return ((uint16_t)p[0] << 8) + + ((uint16_t)p[1]); +} + +static inline uint32_t __dns_parser_uint32(const char *ptr) +{ + const unsigned char *p = (const unsigned char *)ptr; + return ((uint32_t)p[0] << 24) + + ((uint32_t)p[1] << 16) + + ((uint32_t)p[2] << 8) + + ((uint32_t)p[3]); +} + +/* + * Parse a single . + * is a domain name represented as a series of labels, and + * terminated by a label with zero length. + * + * phost must point to an char array with at least DNS_NAMES_MAX+1 size + */ +static int __dns_parser_parse_host(char *phost, dns_parser_t *parser) +{ + uint8_t len; + uint16_t pointer; + size_t hcur; + const char *msgend; + const char **cur; + const char *curbackup; // backup cur when host label is pointer + + msgend = (const char *)parser->msgbuf + parser->msgsize; + cur = &(parser->cur); + curbackup = NULL; + hcur = 0; + + if (*cur >= msgend) + return -2; + + while (*cur < msgend) + { + len = __dns_parser_uint8(*cur); + + if ((len & 0xC0) == 0) + { + (*cur)++; + if (len == 0) + break; + if (len > DNS_LABELS_MAX || *cur + len > msgend || + hcur + len + 1 > DNS_NAMES_MAX) + return -2; + + memcpy(phost + hcur, *cur, len); + *cur += len; + hcur += len; + phost[hcur++] = '.'; + } + // RFC 1035, 4.1.4 Message compression + else if ((len & 0xC0) == 0xC0) + { + pointer = __dns_parser_uint16(*cur) & 0x3FFF; + + if (pointer >= parser->msgsize) + return -2; + + // pointer must point to a prior position + if ((const char *)parser->msgbase + pointer >= *cur) + return -2; + + *cur += 2; + + // backup cur only when the first pointer occurs + if (curbackup == NULL) + curbackup = *cur; + *cur = (const char *)parser->msgbase + pointer; + } + else + return -2; + } + if (curbackup != NULL) + *cur = curbackup; + + if (hcur > 1 && phost[hcur - 1] == '.') + hcur--; + + if (hcur == 0) + phost[hcur++] = '.'; + phost[hcur++] = '\0'; + + return 0; +} + +static void __dns_parser_free_record(struct __dns_record_entry *r) +{ + switch (r->record.type) + { + case DNS_TYPE_SOA: + { + struct dns_record_soa *soa; + soa = (struct dns_record_soa *)(r->record.rdata); + free(soa->mname); + free(soa->rname); + break; + } + case DNS_TYPE_SRV: + { + struct dns_record_srv *srv; + srv = (struct dns_record_srv *)(r->record.rdata); + free(srv->target); + break; + } + case DNS_TYPE_MX: + { + struct dns_record_mx *mx; + mx = (struct dns_record_mx *)(r->record.rdata); + free(mx->exchange); + break; + } + } + free(r->record.name); + free(r); +} + +static void __dns_parser_free_record_list(struct list_head *head) +{ + struct list_head *pos, *tmp; + struct __dns_record_entry *entry; + + list_for_each_safe(pos, tmp, head) + { + entry = list_entry(pos, struct __dns_record_entry, entry_list); + list_del(pos); + __dns_parser_free_record(entry); + } +} + +/* + * A RDATA format, from RFC 1035: + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | ADDRESS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * ADDRESS: A 32 bit Internet address. + * Hosts that have multiple Internet addresses will have multiple A records. + */ +static int __dns_parser_parse_a(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + + if (sizeof (struct in_addr) != rdlength) + return -2; + + cur = &(parser->cur); + entry_size = sizeof (struct __dns_record_entry) + sizeof (struct in_addr); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, *cur, rdlength); + *cur += rdlength; + *r = entry; + + return 0; +} + +/* + * AAAA RDATA format, from RFC 3596: + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | ADDRESS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * ADDRESS: A 128 bit Internet address. + * Hosts that have multiple addresses will have multiple AAAA records. + */ +static int __dns_parser_parse_aaaa(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + + if (sizeof (struct in6_addr) != rdlength) + return -2; + + cur = &(parser->cur); + entry_size = sizeof (struct __dns_record_entry) + sizeof (struct in6_addr); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, *cur, rdlength); + *cur += rdlength; + *r = entry; + + return 0; +} + +/* + * Parse any record. + */ +static int __dns_parser_parse_names(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + size_t name_len; + char name[DNS_NAMES_MAX + 2]; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + ret = __dns_parser_parse_host(name, parser); + if (ret < 0) + return ret; + + if (*cur != rcdend) + return -2; + + name_len = strlen(name); + entry_size = sizeof (struct __dns_record_entry) + name_len + 1; + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, name, name_len + 1); + *r = entry; + + return 0; +} + +/* + * SOA RDATA format, from RFC 1035: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / MNAME / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / RNAME / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | SERIAL | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | REFRESH | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | RETRY | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | EXPIRE | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | MINIMUM | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * MNAME: + * RNAME: + * SERIAL: The unsigned 32 bit version number. + * REFRESH: A 32 bit time interval. + * RETRY: A 32 bit time interval. + * EXPIRE: A 32 bit time value. + * MINIMUM: The unsigned 32 bit integer. + */ +static int __dns_parser_parse_soa(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + struct dns_record_soa *soa; + size_t entry_size; + char mname[DNS_NAMES_MAX + 2]; + char rname[DNS_NAMES_MAX + 2]; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + ret = __dns_parser_parse_host(mname, parser); + if (ret < 0) + return ret; + ret = __dns_parser_parse_host(rname, parser); + if (ret < 0) + return ret; + + if (*cur + 20 != rcdend) + return -2; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_soa); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + soa = (struct dns_record_soa *)(entry->record.rdata); + + soa->mname = strdup(mname); + soa->rname = strdup(rname); + soa->serial = __dns_parser_uint32(*cur); + soa->refresh = __dns_parser_uint32(*cur + 4); + soa->retry = __dns_parser_uint32(*cur + 8); + soa->expire = __dns_parser_uint32(*cur + 12); + soa->minimum = __dns_parser_uint32(*cur + 16); + + if (!soa->mname || !soa->rname) + { + free(soa->mname); + free(soa->rname); + free(entry); + return -1; + } + + *cur += 20; + *r = entry; + + return 0; +} + +/* + * SRV RDATA format, from RFC 2782: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | PRIORITY | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | WEIGHT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | PORT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / TARGET / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * PRIORITY: A 16 bit unsigned integer in network byte order. + * WEIGHT: A 16 bit unsigned integer in network byte order. + * PORT: A 16 bit unsigned integer in network byte order. + * TARGET: + */ +static int __dns_parser_parse_srv(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + struct dns_record_srv *srv; + size_t entry_size; + char target[DNS_NAMES_MAX + 2]; + uint16_t priority; + uint16_t weight; + uint16_t port; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + + if (*cur + 6 > rcdend) + return -2; + + priority = __dns_parser_uint16(*cur); + weight = __dns_parser_uint16(*cur + 2); + port = __dns_parser_uint16(*cur + 4); + *cur += 6; + + ret = __dns_parser_parse_host(target, parser); + if (ret < 0) + return ret; + if (*cur != rcdend) + return -2; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_srv); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + srv = (struct dns_record_srv *)(entry->record.rdata); + + srv->priority = priority; + srv->weight = weight; + srv->port = port; + srv->target = strdup(target); + + if (!srv->target) + { + free(entry); + return -1; + } + + *r = entry; + + return 0; +} + +static int __dns_parser_parse_mx(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char *rcdend; + const char **cur; + struct __dns_record_entry *entry; + struct dns_record_mx *mx; + size_t entry_size; + char exchange[DNS_NAMES_MAX + 2]; + int16_t preference; + int ret; + + cur = &(parser->cur); + rcdend = *cur + rdlength; + + if (*cur + 2 > rcdend) + return -2; + preference = __dns_parser_uint16(*cur); + *cur += 2; + + ret = __dns_parser_parse_host(exchange, parser); + if (ret < 0) + return ret; + if (*cur != rcdend) + return -2; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_mx); + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + mx = (struct dns_record_mx *)(entry->record.rdata); + mx->exchange = strdup(exchange); + mx->preference = preference; + + if (!mx->exchange) + { + free(entry); + return -1; + } + + *r = entry; + + return 0; +} + +static int __dns_parser_parse_others(struct __dns_record_entry **r, + uint16_t rdlength, + dns_parser_t *parser) +{ + const char **cur; + struct __dns_record_entry *entry; + size_t entry_size; + + cur = &(parser->cur); + entry_size = sizeof (struct __dns_record_entry) + rdlength; + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, *cur, rdlength); + *cur += rdlength; + *r = entry; + + return 0; +} + +/* + * RR format, from RFC 1035: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | | + * / NAME / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | TYPE | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | CLASS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | TTL | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | RDLENGTH | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * / RDATA / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + */ +static int __dns_parser_parse_record(int idx, dns_parser_t *parser) +{ + uint16_t i; + uint16_t type; + uint16_t rclass; + uint32_t ttl; + uint16_t rdlength; + uint16_t count; + const char *msgend; + const char **cur; + int ret; + struct __dns_record_entry *entry; + char host[DNS_NAMES_MAX + 2]; + struct list_head *list; + + switch (idx) + { + case 0: + count = parser->header.ancount; + list = &parser->answer_list; + break; + case 1: + count = parser->header.nscount; + list = &parser->authority_list; + break; + case 2: + count = parser->header.arcount; + list = &parser->additional_list; + break; + default: + return -2; + } + + msgend = (const char *)parser->msgbuf + parser->msgsize; + cur = &(parser->cur); + + for (i = 0; i < count; i++) + { + ret = __dns_parser_parse_host(host, parser); + if (ret < 0) + return ret; + + // TYPE(2) + CLASS(2) + TTL(4) + RDLENGTH(2) = 10 + if (*cur + 10 > msgend) + return -2; + type = __dns_parser_uint16(*cur); + rclass = __dns_parser_uint16(*cur + 2); + ttl = __dns_parser_uint32(*cur + 4); + rdlength = __dns_parser_uint16(*cur + 8); + *cur += 10; + if (*cur + rdlength > msgend) + return -2; + + entry = NULL; + switch (type) + { + case DNS_TYPE_A: + ret = __dns_parser_parse_a(&entry, rdlength, parser); + break; + case DNS_TYPE_AAAA: + ret = __dns_parser_parse_aaaa(&entry, rdlength, parser); + break; + case DNS_TYPE_NS: + case DNS_TYPE_CNAME: + case DNS_TYPE_PTR: + ret = __dns_parser_parse_names(&entry, rdlength, parser); + break; + case DNS_TYPE_SOA: + ret = __dns_parser_parse_soa(&entry, rdlength, parser); + break; + case DNS_TYPE_SRV: + ret = __dns_parser_parse_srv(&entry, rdlength, parser); + break; + case DNS_TYPE_MX: + ret = __dns_parser_parse_mx(&entry, rdlength, parser); + break; + default: + ret = __dns_parser_parse_others(&entry, rdlength, parser); + } + + if (ret < 0) + return ret; + + entry->record.name = strdup(host); + if (!entry->record.name) + { + __dns_parser_free_record(entry); + return -1; + } + + entry->record.type = type; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + entry->record.rdlength = rdlength; + list_add_tail(&entry->entry_list, list); + } + + return 0; +} + +/* + * Question format, from RFC 1035: + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | | + * / QNAME / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | QTYPE | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | QCLASS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * The query name is encoded as a series of labels, each represented + * as a one-byte length (maximum 63) followed by the text of the + * label. The list is terminated by a label of length zero (which can + * be thought of as the root domain). + */ +static int __dns_parser_parse_question(dns_parser_t *parser) +{ + uint16_t qtype; + uint16_t qclass; + const char *msgend; + const char **cur; + int ret; + char host[DNS_NAMES_MAX + 2]; + + msgend = (const char *)parser->msgbuf + parser->msgsize; + cur = &(parser->cur); + + // question count != 1 is an error + if (parser->header.qdcount != 1) + return -2; + + // parse qname + ret = __dns_parser_parse_host(host, parser); + if (ret < 0) + return ret; + + // parse qtype and qclass + if (*cur + 4 > msgend) + return -2; + + qtype = __dns_parser_uint16(*cur); + qclass = __dns_parser_uint16(*cur + 2); + *cur += 4; + + if (parser->question.qname) + free(parser->question.qname); + + parser->question.qname = strdup(host); + if (parser->question.qname == NULL) + return -1; + + parser->question.qtype = qtype; + parser->question.qclass = qclass; + + return 0; +} + +void dns_parser_init(dns_parser_t *parser) +{ + // 鍒濆鍖栬В鏋愬櫒缁撴瀯浣擄紝鍖呮嫭鍒嗛厤鍐呭瓨銆佽缃寚閽堛佸垵濮嬪寲璁℃暟鍣ㄥ拰鍒楄〃澶 + parser->msgbuf = NULL; + parser->msgbase = NULL; + parser->cur = NULL; + parser->msgsize = 0; + parser->bufsize = 0; + parser->complete = 0; + parser->single_packet = 0; + memset(&parser->header, 0, sizeof (struct dns_header)); + memset(&parser->question, 0, sizeof (struct dns_question)); + INIT_LIST_HEAD(&parser->answer_list); + INIT_LIST_HEAD(&parser->authority_list); + INIT_LIST_HEAD(&parser->additional_list); +} + +int dns_parser_set_question(const char *name, + uint16_t qtype, + uint16_t qclass, + dns_parser_t *parser) +{ + int ret; + + // 璁剧疆DNS鏌ヨ鐨勯棶棰橀儴鍒嗭紝鍖呮嫭鍩熷悕銆佹煡璇㈢被鍨嬪拰鏌ヨ绫 + ret = dns_parser_set_question_name(name, parser); + if (ret < 0) + return ret; + + parser->question.qtype = qtype; + parser->question.qclass = qclass; + parser->header.qdcount = 1; + + return 0; +} + +int dns_parser_set_question_name(const char *name, dns_parser_t *parser) +{ + char *newname; + size_t len; + + len = strlen(name); + newname = (char *)malloc(len + 1); + + if (!newname) + return -1; + + memcpy(newname, name, len + 1); + // Remove trailing dot, except name is "." + if (len > 1 && newname[len - 1] == '.') + newname[len - 1] = '\0'; + + if (parser->question.qname) + free(parser->question.qname); + parser->question.qname = newname; + + return 0; +} + +void dns_parser_set_id(uint16_t id, dns_parser_t *parser) +{ + parser->header.id = id; +} + +int dns_parser_parse_all(dns_parser_t *parser) +{ + struct dns_header *h; + int ret; + int i; + + parser->complete = 1; + parser->cur = (const char *)parser->msgbase; + h = &parser->header; + + if (parser->msgsize < sizeof (struct dns_header)) + return -2; + + memcpy(h, parser->msgbase, sizeof (struct dns_header)); + h->id = ntohs(h->id); + h->qdcount = ntohs(h->qdcount); + h->ancount = ntohs(h->ancount); + h->nscount = ntohs(h->nscount); + h->arcount = ntohs(h->arcount); + parser->cur += sizeof (struct dns_header); + + ret = __dns_parser_parse_question(parser); + if (ret < 0) + return ret; + + for (i = 0; i < 3; i++) + { + ret = __dns_parser_parse_record(i, parser); + if (ret < 0) + return ret; + } + + return 0; +} + +int dns_parser_append_message(const void *buf, + size_t *n, + dns_parser_t *parser) +{ + int ret; + size_t total; + size_t new_size; + size_t msgsize_bak; + void *new_buf; + + if (parser->complete) + { + *n = 0; + return 1; + } + + if (!parser->single_packet) + { + msgsize_bak = parser->msgsize; + if (parser->msgsize + *n > parser->bufsize) + { + new_size = MAX(DNS_MSGBASE_INIT_SIZE, 2 * parser->bufsize); + + while (new_size < parser->msgsize + *n) + new_size *= 2; + + new_buf = realloc(parser->msgbuf, new_size); + if (!new_buf) + return -1; + + parser->msgbuf = new_buf; + parser->bufsize = new_size; + } + + memcpy((char*)parser->msgbuf + parser->msgsize, buf, *n); + parser->msgsize += *n; + + if (parser->msgsize < 2) + return 0; + + total = __dns_parser_uint16((char*)parser->msgbuf); + if (parser->msgsize < total + 2) + return 0; + + *n = total + 2 - msgsize_bak; + parser->msgsize = total + 2; + parser->msgbase = (char*)parser->msgbuf + 2; + } + else + { + parser->msgbuf = malloc(*n); + memcpy(parser->msgbuf, buf, *n); + parser->msgbase = parser->msgbuf; + parser->msgsize = *n; + parser->bufsize = *n; + } + + ret = dns_parser_parse_all(parser); + if (ret < 0) + return ret; + + return 1; +} + +void dns_parser_deinit(dns_parser_t *parser) +{ + free(parser->msgbuf); + free(parser->question.qname); + + __dns_parser_free_record_list(&parser->answer_list); + __dns_parser_free_record_list(&parser->authority_list); + __dns_parser_free_record_list(&parser->additional_list); +} + +int dns_record_cursor_next(struct dns_record **record, + dns_record_cursor_t *cursor) +{ + struct __dns_record_entry *e; + + if (cursor->next->next != cursor->head) + { + cursor->next = cursor->next->next; + e = list_entry(cursor->next, struct __dns_record_entry, entry_list); + *record = &e->record; + return 0; + } + + return 1; +} + +int dns_record_cursor_find_cname(const char *name, + const char **cname, + dns_record_cursor_t *cursor) +{ + struct __dns_record_entry *e; + + if (!name || !cname) + return 1; + + cursor->next = cursor->head; + while (cursor->next->next != cursor->head) + { + cursor->next = cursor->next->next; + e = list_entry(cursor->next, struct __dns_record_entry, entry_list); + + if (e->record.type == DNS_TYPE_CNAME && + strcasecmp(name, e->record.name) == 0) + { + *cname = (const char *)e->record.rdata; + return 0; + } + } + + return 1; +} + +int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, uint16_t rlen, const void *rdata, + struct list_head *list) +{ + struct __dns_record_entry *entry; + size_t entry_size = sizeof (struct __dns_record_entry) + rlen; + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.name = strdup(name); + if (!entry->record.name) + { + free(entry); + return -1; + } + + entry->record.type = type; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + entry->record.rdlength = rlen; + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, rdata, rlen); + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, const char *rdata, + struct list_head *list) +{ + size_t rlen = strlen(rdata); + // record.rdlength has no meaning for parsed record types, ignore its + // correctness, same for soa/srv/mx record + return dns_add_raw_record(name, type, rclass, ttl, rlen+1, rdata, list); +} + +int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_soa *soa; + size_t entry_size; + char *pname, *pmname, *prname; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_soa); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + soa = (struct dns_record_soa *)(entry->record.rdata); + + pname = strdup(name); + pmname = strdup(mname); + prname = strdup(rname); + + if (!pname || !pmname || !prname) + { + free(pname); + free(pmname); + free(prname); + free(entry); + return -1; + } + + soa->mname = pmname; + soa->rname = prname; + soa->serial = serial; + soa->refresh = refresh; + soa->retry = retry; + soa->expire = expire; + soa->minimum = minimum; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SOA; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, + uint16_t port, const char *target, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_srv *srv; + size_t entry_size; + char *pname, *ptarget; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_srv); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + srv = (struct dns_record_srv *)(entry->record.rdata); + + pname = strdup(name); + ptarget = strdup(target); + + if (!pname || !ptarget) + { + free(pname); + free(ptarget); + free(entry); + return -1; + } + + srv->priority = priority; + srv->weight = weight; + srv->port = port; + srv->target = ptarget; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SRV; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_mx *mx; + size_t entry_size; + char *pname, *pexchange; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_mx); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + mx = (struct dns_record_mx *)(entry->record.rdata); + + pname = strdup(name); + pexchange = strdup(exchange); + + if (!pname || !pexchange) + { + free(pname); + free(pexchange); + free(entry); + return -1; + } + + mx->preference = preference; + mx->exchange = pexchange; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_MX; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +const char *dns_type2str(int type) +{ + switch (type) + { + case DNS_TYPE_A: + return "A"; + case DNS_TYPE_NS: + return "NS"; + case DNS_TYPE_MD: + return "MD"; + case DNS_TYPE_MF: + return "MF"; + case DNS_TYPE_CNAME: + return "CNAME"; + case DNS_TYPE_SOA: + return "SOA"; + case DNS_TYPE_MB: + return "MB"; + case DNS_TYPE_MG: + return "MG"; + case DNS_TYPE_MR: + return "MR"; + case DNS_TYPE_NULL: + return "NULL"; + case DNS_TYPE_WKS: + return "WKS"; + case DNS_TYPE_PTR: + return "PTR"; + case DNS_TYPE_HINFO: + return "HINFO"; + case DNS_TYPE_MINFO: + return "MINFO"; + case DNS_TYPE_MX: + return "MX"; + case DNS_TYPE_AAAA: + return "AAAA"; + case DNS_TYPE_SRV: + return "SRV"; + case DNS_TYPE_TXT: + return "TXT"; + case DNS_TYPE_AXFR: + return "AXFR"; + case DNS_TYPE_MAILB: + return "MAILB"; + case DNS_TYPE_MAILA: + return "MAILA"; + case DNS_TYPE_ALL: + return "ALL"; + default: + return "Unknown"; + } +} + +const char *dns_class2str(int dnsclass) +{ + switch (dnsclass) + { + case DNS_CLASS_IN: + return "IN"; + case DNS_CLASS_CS: + return "CS"; + case DNS_CLASS_CH: + return "CH"; + case DNS_CLASS_HS: + return "HS"; + case DNS_CLASS_ALL: + return "ALL"; + default: + return "Unknown"; + } +} + +const char *dns_opcode2str(int opcode) +{ + switch (opcode) + { + case DNS_OPCODE_QUERY: + return "QUERY"; + case DNS_OPCODE_IQUERY: + return "IQUERY"; + case DNS_OPCODE_STATUS: + return "STATUS"; + default: + return "Unknown"; + } +} + +const char *dns_rcode2str(int rcode) +{ + switch (rcode) + { + case DNS_RCODE_NO_ERROR: + return "NO_ERROR"; + case DNS_RCODE_FORMAT_ERROR: + return "FORMAT_ERROR"; + case DNS_RCODE_SERVER_FAILURE: + return "SERVER_FAILURE"; + case DNS_RCODE_NAME_ERROR: + return "NAME_ERROR"; + case DNS_RCODE_NOT_IMPLEMENTED: + return "NOT_IMPLEMENTED"; + case DNS_RCODE_REFUSED: + return "REFUSED"; + default: + return "Unknown"; + } +} + diff --git a/src%2Fserver/CMakeLists.txt b/src%2Fserver/CMakeLists.txt new file mode 100644 index 0000000..ed11ded --- /dev/null +++ b/src%2Fserver/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.6) # 璁剧疆 CMake 鐨勬渶浣庣増鏈姹備负 3.6 +project(server) # 瀹氫箟椤圭洰鍚嶇О涓 "server" + +# 璁剧疆婧愭枃浠跺垪琛 +set(SRC + WFServer.cc # 娣诲姞 WFServer.cc 鏂囦欢鍒版簮鏂囦欢鍒楄〃 +) + +# 濡傛灉 MySQL 閰嶇疆鏈璁剧疆涓 "n" +if (NOT MYSQL STREQUAL "n") + set(SRC + ${SRC} # 淇濈暀鍘熸湁鐨勬簮鏂囦欢 + WFMySQLServer.cc # 娣诲姞 WFMySQLServer.cc 鏂囦欢鍒版簮鏂囦欢鍒楄〃 + ) +endif () + +# 灏嗘簮鏂囦欢鍒楄〃鏋勫缓涓轰竴涓璞″簱 +add_library(${PROJECT_NAME} OBJECT ${SRC}) # 鐢熸垚鍚嶄负 "server" 鐨勫璞″簱 diff --git a/src%2Fserver/WFDnsServer.h b/src%2Fserver/WFDnsServer.h new file mode 100644 index 0000000..25c2ae8 --- /dev/null +++ b/src%2Fserver/WFDnsServer.h @@ -0,0 +1,70 @@ +/* + Copyright (c) 2021 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Liu Kai (liukaidx@sogou-inc.com) +*/ + +#ifndef _WFDNSSERVER_H_ +#define _WFDNSSERVER_H_ + +#include "DnsMessage.h" // 寮曞叆 DNS 娑堟伅鐩稿叧鐨勫ご鏂囦欢 +#include "WFServer.h" // 寮曞叆鏈嶅姟鍣ㄥ姛鑳界殑澶存枃浠 +#include "WFTaskFactory.h" // 寮曞叆浠诲姟宸ュ巶浠ュ垱寤 DNS 浠诲姟 + +// 瀹氫箟涓涓 dns_process_t 绫诲瀷锛屽畠鏄竴涓帴鍙 WFDnsTask 鎸囬拡骞惰繑鍥 void 鐨勫嚱鏁板璞 +using dns_process_t = std::function; + +// 瀹氫箟 WFDnsServer 绫诲瀷锛屽熀浜 WFServer 妯℃澘绫伙紝浣跨敤 DnsRequest 鍜 DnsResponse 浣滀负鍗忚 +using WFDnsServer = WFServer; + +// 璁剧疆 DNS 鏈嶅姟鍣ㄧ殑榛樿鍙傛暟 +static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT = +{ + .transport_type = TT_UDP, // 浼犺緭绫诲瀷涓 UDP + .max_connections = 2000, // 鏈澶ц繛鎺ユ暟 + .peer_response_timeout = 10 * 1000, // 瀵圭鍝嶅簲瓒呮椂鏃堕棿 + .receive_timeout = -1, // 鎺ユ敹瓒呮椂鏃堕棿锛-1 琛ㄧず涓嶈秴鏃 + .keep_alive_timeout = 300 * 1000, // 淇濇寔杩炴帴鐨勮秴鏃舵椂闂 + .request_size_limit = (size_t)-1, // 璇锋眰澶у皬闄愬埗锛-1 琛ㄧず鏃犻檺鍒 + .ssl_accept_timeout = 5000, // SSL 鎺ュ彈瓒呮椂鏃堕棿 +}; + +// WFDnsServer 鏋勯犲嚱鏁扮殑鐗瑰寲妯℃澘 +template<> inline +WFDnsServer::WFServer(dns_process_t proc) : + WFServerBase(&DNS_SERVER_PARAMS_DEFAULT), // 璋冪敤鍩虹被鏋勯犲嚱鏁帮紝浼犲叆榛樿鍙傛暟 + process(std::move(proc)) // 绉诲姩鏋勯犲嚱鏁颁互鎻愰珮鎬ц兘 +{ +} + +// 鍒涘缓鏂扮殑浼氳瘽锛岄噸鍐欏熀绫绘柟娉 +template<> inline +CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn) +{ + WFDnsTask *task; + + // 浣跨敤浠诲姟宸ュ巶鍒涘缓涓涓柊鐨 DNS 浠诲姟 + task = WFServerTaskFactory::create_dns_task(this, this->process); + // 璁剧疆淇濇寔杩炴帴鐨勮秴鏃舵椂闂 + task->set_keep_alive(this->params.keep_alive_timeout); + // 璁剧疆鎺ユ敹瓒呮椂鏃堕棿 + task->set_receive_timeout(this->params.receive_timeout); + // 璁剧疆璇锋眰澶у皬闄愬埗 + task->get_req()->set_size_limit(this->params.request_size_limit); + + return task; // 杩斿洖鍒涘缓鐨勪换鍔 +} + +#endif // _WFDNSSERVER_H_ diff --git a/src%2Fserver/xmake.lua b/src%2Fserver/xmake.lua new file mode 100644 index 0000000..b76ff63 --- /dev/null +++ b/src%2Fserver/xmake.lua @@ -0,0 +1,10 @@ +-- 瀹氫箟涓涓 xmake 鐩爣锛屽悕涓 server +target("server") + set_kind("object") -- 璁剧疆鐩爣绫诲瀷涓哄璞℃枃浠 + + add_files("*.cc") -- 娣诲姞褰撳墠鐩綍涓嬫墍鏈夌殑 .cc 鏂囦欢鍒扮洰鏍囦腑 + + -- 妫鏌ユ槸鍚︽湭瀹氫箟 MySQL 閰嶇疆 + if not has_config("mysql") then + remove_files("WFMySQLServer.cc") -- 濡傛灉鏈畾涔 MySQL 閰嶇疆锛屽垯绉婚櫎 MySQL 鏈嶅姟鍣ㄧ浉鍏崇殑婧愭枃浠 + end diff --git a/src/doc b/src/doc new file mode 100644 index 0000000..0d06029 Binary files /dev/null and b/src/doc differ diff --git a/src_252Fserver/WFHttpServer.h b/src_252Fserver/WFHttpServer.h new file mode 100644 index 0000000..940a708 --- /dev/null +++ b/src_252Fserver/WFHttpServer.h @@ -0,0 +1,71 @@ +/* + Copyright (c) 2019 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFHTTPSERVER_H_ +#define _WFHTTPSERVER_H_ + +#include // 寮曞叆鏍囧噯搴撲腑鐨勫伐鍏峰嚱鏁 +#include "HttpMessage.h" // 寮曞叆 HTTP 娑堟伅鐩稿叧鐨勫ご鏂囦欢 +#include "WFServer.h" // 寮曞叆鏈嶅姟鍣ㄥ姛鑳界殑澶存枃浠 +#include "WFTaskFactory.h" // 寮曞叆浠诲姟宸ュ巶浠ュ垱寤 HTTP 浠诲姟 + +// 瀹氫箟涓涓 http_process_t 绫诲瀷锛屽畠鏄竴涓帴鍙 WFHttpTask 鎸囬拡骞惰繑鍥 void 鐨勫嚱鏁板璞 +using http_process_t = std::function; + +// 瀹氫箟 WFHttpServer 绫诲瀷锛屽熀浜 WFServer 妯℃澘绫伙紝浣跨敤 HttpRequest 鍜 HttpResponse 浣滀负鍗忚 +using WFHttpServer = WFServer; + +// 璁剧疆 HTTP 鏈嶅姟鍣ㄧ殑榛樿鍙傛暟 +static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = +{ + .transport_type = TT_TCP, // 浼犺緭绫诲瀷涓 TCP + .max_connections = 2000, // 鏈澶ц繛鎺ユ暟 + .peer_response_timeout = 10 * 1000, // 瀵圭鍝嶅簲瓒呮椂鏃堕棿 + .receive_timeout = -1, // 鎺ユ敹瓒呮椂鏃堕棿锛-1 琛ㄧず涓嶈秴鏃 + .keep_alive_timeout = 60 * 1000, // 淇濇寔杩炴帴鐨勮秴鏃舵椂闂 + .request_size_limit = (size_t)-1, // 璇锋眰澶у皬闄愬埗锛-1 琛ㄧず鏃犻檺鍒 + .ssl_accept_timeout = 10 * 1000, // SSL 鎺ュ彈瓒呮椂鏃堕棿 +}; + +// WFHttpServer 鏋勯犲嚱鏁扮殑鐗瑰寲妯℃澘 +template<> inline +WFHttpServer::WFServer(http_process_t proc) : + WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), // 璋冪敤鍩虹被鏋勯犲嚱鏁帮紝浼犲叆榛樿鍙傛暟 + process(std::move(proc)) // 绉诲姩鏋勯犲嚱鏁颁互鎻愰珮鎬ц兘 +{ +} + +// 鍒涘缓鏂扮殑浼氳瘽锛岄噸鍐欏熀绫荤殑鏂规硶 +template<> inline +CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn) +{ + WFHttpTask *task; // 澹版槑涓涓 HTTP 浠诲姟鎸囬拡 + + // 浣跨敤浠诲姟宸ュ巶鍒涘缓涓涓柊鐨 HTTP 浠诲姟 + task = WFServerTaskFactory::create_http_task(this, this->process); + // 璁剧疆淇濇寔杩炴帴鐨勮秴鏃舵椂闂 + task->set_keep_alive(this->params.keep_alive_timeout); + // 璁剧疆鎺ユ敹瓒呮椂鏃堕棿 + task->set_receive_timeout(this->params.receive_timeout); + // 璁剧疆璇锋眰澶у皬闄愬埗 + task->get_req()->set_size_limit(this->params.request_size_limit); + + return task; // 杩斿洖鍒涘缓鐨勪换鍔 +} + +#endif // _WFHTTPSERVER_H_